- 新增 FactorManager 组件:统一管理多种来源因子 - 新增 DataPipeline 组件:完整数据处理流程(注册、过滤、划分、预处理) - 新增 Task 策略组件:BaseTask 抽象基类、RegressionTask、RankTask - 新增 ResultAnalyzer 组件:特征重要性分析和结果组装 - 新增 TrainerV2:作为纯调度引擎协调各组件 - 支持回归和排序学习两种训练模式 - 采用组合模式解耦训练流程,消除代码重复
91 lines
2.2 KiB
Python
91 lines
2.2 KiB
Python
"""训练模块 - ProStock 量化投资框架
|
||
|
||
提供模型训练、数据处理和预测的完整流程。
|
||
"""
|
||
|
||
# 基础抽象类
|
||
from src.training.components.base import BaseModel, BaseProcessor
|
||
|
||
# 注册中心
|
||
from src.training.registry import (
|
||
ModelRegistry,
|
||
ProcessorRegistry,
|
||
register_model,
|
||
register_processor,
|
||
)
|
||
|
||
# 数据划分器
|
||
from src.training.components.splitters import DateSplitter
|
||
|
||
# 股票池选择器配置(已迁移到 StockPoolManager,保留文件占位)
|
||
# from src.training.components.selectors import ...
|
||
|
||
# 数据处理器
|
||
from src.training.components.processors import (
|
||
CrossSectionalStandardScaler,
|
||
NullFiller,
|
||
StandardScaler,
|
||
Winsorizer,
|
||
)
|
||
|
||
# 模型
|
||
from src.training.components.models import LightGBMModel
|
||
|
||
# 数据过滤器
|
||
from src.training.components.filters import BaseFilter, STFilter
|
||
|
||
# 训练核心
|
||
from src.training.core import StockPoolManager, Trainer
|
||
|
||
# 工具函数
|
||
from src.training.utils import check_data_quality
|
||
|
||
# 配置
|
||
from src.training.config import TrainingConfig
|
||
|
||
# 新增:模块化 Trainer 组件
|
||
from src.training.factor_manager import FactorManager
|
||
from src.training.pipeline import DataPipeline
|
||
from src.training.result_analyzer import ResultAnalyzer
|
||
from src.training.tasks import BaseTask, RegressionTask, RankTask
|
||
|
||
__all__ = [
|
||
# 基础抽象类
|
||
"BaseModel",
|
||
"BaseProcessor",
|
||
# 注册中心
|
||
"ModelRegistry",
|
||
"ProcessorRegistry",
|
||
"register_model",
|
||
"register_processor",
|
||
# 数据划分器
|
||
"DateSplitter",
|
||
# 股票池选择器配置(已迁移,保留注释占位)
|
||
# "StockFilterConfig", # 已删除,使用 StockPoolManager + filter_func 替代
|
||
# "MarketCapSelectorConfig", # 已删除,使用 StockPoolManager + required_factors 替代
|
||
# 数据处理器
|
||
"NullFiller",
|
||
"StandardScaler",
|
||
"CrossSectionalStandardScaler",
|
||
"Winsorizer",
|
||
# 数据过滤器
|
||
"BaseFilter",
|
||
"STFilter",
|
||
# 模型
|
||
"LightGBMModel",
|
||
# 训练核心
|
||
"StockPoolManager",
|
||
"Trainer",
|
||
# 工具函数
|
||
"check_data_quality",
|
||
# 配置
|
||
"TrainingConfig",
|
||
# 新增:模块化 Trainer 组件
|
||
"FactorManager",
|
||
"DataPipeline",
|
||
"ResultAnalyzer",
|
||
"BaseTask",
|
||
"RegressionTask",
|
||
"RankTask",
|
||
]
|