feat(training): 实现 Trainer 模块化重构 (Trainer V2)
- 新增 FactorManager 组件:统一管理多种来源因子 - 新增 DataPipeline 组件:完整数据处理流程(注册、过滤、划分、预处理) - 新增 Task 策略组件:BaseTask 抽象基类、RegressionTask、RankTask - 新增 ResultAnalyzer 组件:特征重要性分析和结果组装 - 新增 TrainerV2:作为纯调度引擎协调各组件 - 支持回归和排序学习两种训练模式 - 采用组合模式解耦训练流程,消除代码重复
This commit is contained in:
309
src/training/pipeline.py
Normal file
309
src/training/pipeline.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""数据流水线
|
||||
|
||||
完整的数据处理流程:
|
||||
1. 因子注册和数据准备
|
||||
2. 应用过滤器(STFilter 等)
|
||||
3. 股票池筛选(自定义函数)
|
||||
4. 数据质量检查
|
||||
5. 数据划分(train/val/test)
|
||||
6. 数据预处理(fit_transform/transform)
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training.factor_manager import FactorManager
|
||||
from src.training.components.base import BaseProcessor
|
||||
from src.training.core.stock_pool_manager import StockPoolManager
|
||||
|
||||
|
||||
class DataPipeline:
|
||||
"""数据流水线
|
||||
|
||||
执行完整的数据处理流程,返回标准化的数据字典。
|
||||
|
||||
Attributes:
|
||||
factor_manager: 因子管理器
|
||||
filters: 类形式的过滤器列表(如 STFilter)
|
||||
stock_pool_filter_func: 函数形式的股票池筛选器
|
||||
processor_configs: 数据处理器配置列表(类+参数)
|
||||
stock_pool_required_columns: 股票池筛选所需的额外列
|
||||
fitted_processors: 已拟合的处理器列表(训练后填充)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factor_manager: FactorManager,
|
||||
processor_configs: List[Tuple[Type[BaseProcessor], Dict[str, Any]]],
|
||||
filters: Optional[List[Any]] = None,
|
||||
stock_pool_filter_func: Optional[Callable] = None,
|
||||
stock_pool_required_columns: Optional[List[str]] = None,
|
||||
):
|
||||
"""初始化数据流水线
|
||||
|
||||
Args:
|
||||
factor_manager: 因子管理器实例
|
||||
processor_configs: 数据处理器配置列表,每个元素为 (ProcessorClass, kwargs)
|
||||
例如:[(NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99})]
|
||||
filters: 类形式的过滤器列表(如 [STFilter])
|
||||
stock_pool_filter_func: 函数形式的股票池筛选器
|
||||
stock_pool_required_columns: 股票池筛选所需的额外列
|
||||
"""
|
||||
self.factor_manager = factor_manager
|
||||
self.processor_configs = processor_configs or []
|
||||
self.filters = filters or []
|
||||
self.stock_pool_filter_func = stock_pool_filter_func
|
||||
self.stock_pool_required_columns = stock_pool_required_columns or []
|
||||
self.fitted_processors: List[BaseProcessor] = []
|
||||
|
||||
def prepare_data(
|
||||
self,
|
||||
engine: FactorEngine,
|
||||
date_range: Dict[str, Tuple[str, str]],
|
||||
label_name: str,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""执行完整数据流程
|
||||
|
||||
流程:
|
||||
1. 注册因子并准备数据
|
||||
2. 应用类过滤器(STFilter)
|
||||
3. 应用股票池筛选(函数形式)
|
||||
4. 数据质量检查
|
||||
5. 数据划分
|
||||
6. 数据预处理
|
||||
|
||||
Args:
|
||||
engine: FactorEngine 实例
|
||||
date_range: 日期范围字典 {"train": (start, end), "val": ..., "test": ...}
|
||||
label_name: Label 列名
|
||||
verbose: 是否打印处理信息
|
||||
|
||||
Returns:
|
||||
标准化的数据字典
|
||||
"""
|
||||
if verbose:
|
||||
print("\n" + "=" * 80)
|
||||
print("数据流水线")
|
||||
print("=" * 80)
|
||||
|
||||
# Step 1: 注册因子并准备数据
|
||||
if verbose:
|
||||
print("\n[1/6] 注册因子并准备数据...")
|
||||
|
||||
feature_cols = self.factor_manager.register_to_engine(engine, verbose=verbose)
|
||||
|
||||
# 计算完整日期范围
|
||||
all_start = min(
|
||||
date_range["train"][0], date_range["val"][0], date_range["test"][0]
|
||||
)
|
||||
all_end = max(
|
||||
date_range["train"][1], date_range["val"][1], date_range["test"][1]
|
||||
)
|
||||
|
||||
# 准备数据
|
||||
data = engine.compute(
|
||||
factor_names=feature_cols + [label_name],
|
||||
start_date=all_start,
|
||||
end_date=all_end,
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print(f" 原始数据规模: {data.shape}")
|
||||
print(f" 特征数: {len(feature_cols)}")
|
||||
|
||||
# Step 2: 应用类过滤器(STFilter)
|
||||
if self.filters:
|
||||
if verbose:
|
||||
print(f"\n[2/6] 应用过滤器({len(self.filters)}个)...")
|
||||
|
||||
for filter_obj in self.filters:
|
||||
data_before = len(data)
|
||||
data = filter_obj.filter(data)
|
||||
data_after = len(data)
|
||||
|
||||
if verbose:
|
||||
print(f" {filter_obj.__class__.__name__}:")
|
||||
print(f" 过滤前: {data_before}, 过滤后: {data_after}")
|
||||
print(f" 删除: {data_before - data_after}")
|
||||
|
||||
# Step 3: 应用股票池筛选(函数形式)
|
||||
if self.stock_pool_filter_func:
|
||||
if verbose:
|
||||
print(f"\n[3/6] 股票池筛选...")
|
||||
|
||||
data_before = len(data)
|
||||
|
||||
# 创建 StockPoolManager
|
||||
pool_manager = StockPoolManager(
|
||||
filter_func=self.stock_pool_filter_func,
|
||||
required_columns=self.stock_pool_required_columns,
|
||||
data_router=engine.router,
|
||||
)
|
||||
|
||||
data = pool_manager.filter_and_select_daily(data)
|
||||
data_after = len(data)
|
||||
|
||||
if verbose:
|
||||
print(f" 筛选前: {data_before}, 筛选后: {data_after}")
|
||||
print(f" 删除: {data_before - data_after}")
|
||||
|
||||
# Step 4: 数据质量检查
|
||||
if verbose:
|
||||
print(f"\n[4/6] 数据质量检查...")
|
||||
|
||||
self._check_data_quality(data, feature_cols, verbose=verbose)
|
||||
|
||||
# Step 5: 数据划分
|
||||
if verbose:
|
||||
print(f"\n[5/6] 数据划分...")
|
||||
|
||||
split_data = self._split_data(
|
||||
data, date_range, feature_cols, label_name, verbose=verbose
|
||||
)
|
||||
|
||||
# Step 6: 数据预处理
|
||||
if verbose:
|
||||
print(f"\n[6/6] 数据预处理...")
|
||||
|
||||
split_data = self._preprocess(split_data, feature_cols, verbose=verbose)
|
||||
|
||||
if verbose:
|
||||
print("\n" + "=" * 80)
|
||||
print("数据流水线完成")
|
||||
print("=" * 80)
|
||||
|
||||
return split_data
|
||||
|
||||
def _check_data_quality(
|
||||
self,
|
||||
data: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
"""检查数据质量
|
||||
|
||||
Args:
|
||||
data: 数据框
|
||||
feature_cols: 特征列名列表
|
||||
verbose: 是否打印信息
|
||||
"""
|
||||
# 检查缺失值
|
||||
null_counts = {}
|
||||
for col in feature_cols[:10]: # 只检查前10个特征
|
||||
null_count = data[col].null_count()
|
||||
if null_count > 0:
|
||||
null_counts[col] = null_count
|
||||
|
||||
if null_counts and verbose:
|
||||
print(f" [警告] 发现缺失值(仅显示前10个特征):")
|
||||
for col, count in list(null_counts.items())[:5]:
|
||||
pct = count / len(data) * 100
|
||||
print(f" {col}: {count} ({pct:.2f}%)")
|
||||
|
||||
def _split_data(
|
||||
self,
|
||||
data: pl.DataFrame,
|
||||
date_range: Dict[str, Tuple[str, str]],
|
||||
feature_cols: List[str],
|
||||
label_name: str,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""划分数据集
|
||||
|
||||
Args:
|
||||
data: 完整数据
|
||||
date_range: 日期范围字典
|
||||
feature_cols: 特征列名
|
||||
label_name: Label 列名
|
||||
verbose: 是否打印信息
|
||||
|
||||
Returns:
|
||||
划分后的数据字典
|
||||
"""
|
||||
result = {}
|
||||
|
||||
for split_name, (start, end) in date_range.items():
|
||||
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
|
||||
split_df = data.filter(mask)
|
||||
|
||||
result[split_name] = {
|
||||
"X": split_df.select(feature_cols),
|
||||
"y": split_df[label_name],
|
||||
"raw_data": split_df,
|
||||
"feature_cols": feature_cols,
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f" {split_name}: {len(split_df)} 条记录")
|
||||
|
||||
return result
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
split_data: Dict[str, Dict[str, Any]],
|
||||
feature_cols: List[str],
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""预处理数据
|
||||
|
||||
训练集使用 fit_transform,验证集和测试集使用 transform
|
||||
|
||||
Args:
|
||||
split_data: 划分后的数据字典
|
||||
feature_cols: 特征列名列表
|
||||
verbose: 是否打印信息
|
||||
|
||||
Returns:
|
||||
预处理后的数据字典
|
||||
"""
|
||||
if not self.processor_configs:
|
||||
return split_data
|
||||
|
||||
self.fitted_processors = []
|
||||
|
||||
# 实例化 processors(传入 feature_cols)
|
||||
processors = []
|
||||
for proc_class, proc_kwargs in self.processor_configs:
|
||||
proc_kwargs_with_cols = {**proc_kwargs, "feature_cols": feature_cols}
|
||||
processors.append(proc_class(**proc_kwargs_with_cols))
|
||||
|
||||
# 训练集:fit_transform
|
||||
if verbose:
|
||||
print(f" 训练集预处理(fit_transform)...")
|
||||
|
||||
train_data = split_data["train"]["raw_data"]
|
||||
for processor in processors:
|
||||
train_data = processor.fit_transform(train_data)
|
||||
self.fitted_processors.append(processor)
|
||||
|
||||
# 更新训练集
|
||||
split_data["train"]["raw_data"] = train_data
|
||||
split_data["train"]["X"] = train_data.select(feature_cols)
|
||||
split_data["train"]["y"] = train_data[split_data["train"]["y"].name]
|
||||
|
||||
# 验证集和测试集:transform
|
||||
for split_name in ["val", "test"]:
|
||||
if split_name in split_data:
|
||||
if verbose:
|
||||
print(f" {split_name}集预处理(transform)...")
|
||||
|
||||
split_df = split_data[split_name]["raw_data"]
|
||||
for processor in self.fitted_processors:
|
||||
split_df = processor.transform(split_df)
|
||||
|
||||
split_data[split_name]["raw_data"] = split_df
|
||||
split_data[split_name]["X"] = split_df.select(feature_cols)
|
||||
split_data[split_name]["y"] = split_df[split_data[split_name]["y"].name]
|
||||
|
||||
return split_data
|
||||
|
||||
def get_fitted_processors(self) -> List[BaseProcessor]:
|
||||
"""获取已拟合的处理器列表
|
||||
|
||||
Returns:
|
||||
已拟合的处理器列表(用于模型保存)
|
||||
"""
|
||||
return self.fitted_processors
|
||||
Reference in New Issue
Block a user