feat(training): 实现 Trainer 模块化重构 (Trainer V2)

- 新增 FactorManager 组件:统一管理多种来源因子
- 新增 DataPipeline 组件:完整数据处理流程(注册、过滤、划分、预处理)
- 新增 Task 策略组件:BaseTask 抽象基类、RegressionTask、RankTask
- 新增 ResultAnalyzer 组件:特征重要性分析和结果组装
- 新增 TrainerV2:作为纯调度引擎协调各组件
- 支持回归和排序学习两种训练模式
- 采用组合模式解耦训练流程,消除代码重复
This commit is contained in:
2026-03-24 23:35:31 +08:00
parent bace4cc5f4
commit e41a128ca3
13 changed files with 4045 additions and 1509 deletions

309
src/training/pipeline.py Normal file
View File

@@ -0,0 +1,309 @@
"""数据流水线
完整的数据处理流程:
1. 因子注册和数据准备
2. 应用过滤器STFilter 等)
3. 股票池筛选(自定义函数)
4. 数据质量检查
5. 数据划分train/val/test
6. 数据预处理fit_transform/transform
"""
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import polars as pl
import numpy as np
from src.factors import FactorEngine
from src.training.factor_manager import FactorManager
from src.training.components.base import BaseProcessor
from src.training.core.stock_pool_manager import StockPoolManager
class DataPipeline:
"""数据流水线
执行完整的数据处理流程,返回标准化的数据字典。
Attributes:
factor_manager: 因子管理器
filters: 类形式的过滤器列表(如 STFilter
stock_pool_filter_func: 函数形式的股票池筛选器
processor_configs: 数据处理器配置列表(类+参数)
stock_pool_required_columns: 股票池筛选所需的额外列
fitted_processors: 已拟合的处理器列表(训练后填充)
"""
def __init__(
self,
factor_manager: FactorManager,
processor_configs: List[Tuple[Type[BaseProcessor], Dict[str, Any]]],
filters: Optional[List[Any]] = None,
stock_pool_filter_func: Optional[Callable] = None,
stock_pool_required_columns: Optional[List[str]] = None,
):
"""初始化数据流水线
Args:
factor_manager: 因子管理器实例
processor_configs: 数据处理器配置列表,每个元素为 (ProcessorClass, kwargs)
例如:[(NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99})]
filters: 类形式的过滤器列表(如 [STFilter]
stock_pool_filter_func: 函数形式的股票池筛选器
stock_pool_required_columns: 股票池筛选所需的额外列
"""
self.factor_manager = factor_manager
self.processor_configs = processor_configs or []
self.filters = filters or []
self.stock_pool_filter_func = stock_pool_filter_func
self.stock_pool_required_columns = stock_pool_required_columns or []
self.fitted_processors: List[BaseProcessor] = []
def prepare_data(
self,
engine: FactorEngine,
date_range: Dict[str, Tuple[str, str]],
label_name: str,
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""执行完整数据流程
流程:
1. 注册因子并准备数据
2. 应用类过滤器STFilter
3. 应用股票池筛选(函数形式)
4. 数据质量检查
5. 数据划分
6. 数据预处理
Args:
engine: FactorEngine 实例
date_range: 日期范围字典 {"train": (start, end), "val": ..., "test": ...}
label_name: Label 列名
verbose: 是否打印处理信息
Returns:
标准化的数据字典
"""
if verbose:
print("\n" + "=" * 80)
print("数据流水线")
print("=" * 80)
# Step 1: 注册因子并准备数据
if verbose:
print("\n[1/6] 注册因子并准备数据...")
feature_cols = self.factor_manager.register_to_engine(engine, verbose=verbose)
# 计算完整日期范围
all_start = min(
date_range["train"][0], date_range["val"][0], date_range["test"][0]
)
all_end = max(
date_range["train"][1], date_range["val"][1], date_range["test"][1]
)
# 准备数据
data = engine.compute(
factor_names=feature_cols + [label_name],
start_date=all_start,
end_date=all_end,
)
if verbose:
print(f" 原始数据规模: {data.shape}")
print(f" 特征数: {len(feature_cols)}")
# Step 2: 应用类过滤器STFilter
if self.filters:
if verbose:
print(f"\n[2/6] 应用过滤器({len(self.filters)}个)...")
for filter_obj in self.filters:
data_before = len(data)
data = filter_obj.filter(data)
data_after = len(data)
if verbose:
print(f" {filter_obj.__class__.__name__}:")
print(f" 过滤前: {data_before}, 过滤后: {data_after}")
print(f" 删除: {data_before - data_after}")
# Step 3: 应用股票池筛选(函数形式)
if self.stock_pool_filter_func:
if verbose:
print(f"\n[3/6] 股票池筛选...")
data_before = len(data)
# 创建 StockPoolManager
pool_manager = StockPoolManager(
filter_func=self.stock_pool_filter_func,
required_columns=self.stock_pool_required_columns,
data_router=engine.router,
)
data = pool_manager.filter_and_select_daily(data)
data_after = len(data)
if verbose:
print(f" 筛选前: {data_before}, 筛选后: {data_after}")
print(f" 删除: {data_before - data_after}")
# Step 4: 数据质量检查
if verbose:
print(f"\n[4/6] 数据质量检查...")
self._check_data_quality(data, feature_cols, verbose=verbose)
# Step 5: 数据划分
if verbose:
print(f"\n[5/6] 数据划分...")
split_data = self._split_data(
data, date_range, feature_cols, label_name, verbose=verbose
)
# Step 6: 数据预处理
if verbose:
print(f"\n[6/6] 数据预处理...")
split_data = self._preprocess(split_data, feature_cols, verbose=verbose)
if verbose:
print("\n" + "=" * 80)
print("数据流水线完成")
print("=" * 80)
return split_data
def _check_data_quality(
self,
data: pl.DataFrame,
feature_cols: List[str],
verbose: bool = True,
) -> None:
"""检查数据质量
Args:
data: 数据框
feature_cols: 特征列名列表
verbose: 是否打印信息
"""
# 检查缺失值
null_counts = {}
for col in feature_cols[:10]: # 只检查前10个特征
null_count = data[col].null_count()
if null_count > 0:
null_counts[col] = null_count
if null_counts and verbose:
print(f" [警告] 发现缺失值仅显示前10个特征:")
for col, count in list(null_counts.items())[:5]:
pct = count / len(data) * 100
print(f" {col}: {count} ({pct:.2f}%)")
def _split_data(
self,
data: pl.DataFrame,
date_range: Dict[str, Tuple[str, str]],
feature_cols: List[str],
label_name: str,
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""划分数据集
Args:
data: 完整数据
date_range: 日期范围字典
feature_cols: 特征列名
label_name: Label 列名
verbose: 是否打印信息
Returns:
划分后的数据字典
"""
result = {}
for split_name, (start, end) in date_range.items():
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
split_df = data.filter(mask)
result[split_name] = {
"X": split_df.select(feature_cols),
"y": split_df[label_name],
"raw_data": split_df,
"feature_cols": feature_cols,
}
if verbose:
print(f" {split_name}: {len(split_df)} 条记录")
return result
def _preprocess(
self,
split_data: Dict[str, Dict[str, Any]],
feature_cols: List[str],
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""预处理数据
训练集使用 fit_transform验证集和测试集使用 transform
Args:
split_data: 划分后的数据字典
feature_cols: 特征列名列表
verbose: 是否打印信息
Returns:
预处理后的数据字典
"""
if not self.processor_configs:
return split_data
self.fitted_processors = []
# 实例化 processors传入 feature_cols
processors = []
for proc_class, proc_kwargs in self.processor_configs:
proc_kwargs_with_cols = {**proc_kwargs, "feature_cols": feature_cols}
processors.append(proc_class(**proc_kwargs_with_cols))
# 训练集fit_transform
if verbose:
print(f" 训练集预处理fit_transform...")
train_data = split_data["train"]["raw_data"]
for processor in processors:
train_data = processor.fit_transform(train_data)
self.fitted_processors.append(processor)
# 更新训练集
split_data["train"]["raw_data"] = train_data
split_data["train"]["X"] = train_data.select(feature_cols)
split_data["train"]["y"] = train_data[split_data["train"]["y"].name]
# 验证集和测试集transform
for split_name in ["val", "test"]:
if split_name in split_data:
if verbose:
print(f" {split_name}集预处理transform...")
split_df = split_data[split_name]["raw_data"]
for processor in self.fitted_processors:
split_df = processor.transform(split_df)
split_data[split_name]["raw_data"] = split_df
split_data[split_name]["X"] = split_df.select(feature_cols)
split_data[split_name]["y"] = split_df[split_data[split_name]["y"].name]
return split_data
def get_fitted_processors(self) -> List[BaseProcessor]:
"""获取已拟合的处理器列表
Returns:
已拟合的处理器列表(用于模型保存)
"""
return self.fitted_processors