- 新增 FactorManager 组件:统一管理多种来源因子 - 新增 DataPipeline 组件:完整数据处理流程(注册、过滤、划分、预处理) - 新增 Task 策略组件:BaseTask 抽象基类、RegressionTask、RankTask - 新增 ResultAnalyzer 组件:特征重要性分析和结果组装 - 新增 TrainerV2:作为纯调度引擎协调各组件 - 支持回归和排序学习两种训练模式 - 采用组合模式解耦训练流程,消除代码重复
310 lines
10 KiB
Python
310 lines
10 KiB
Python
"""数据流水线
|
||
|
||
完整的数据处理流程:
|
||
1. 因子注册和数据准备
|
||
2. 应用过滤器(STFilter 等)
|
||
3. 股票池筛选(自定义函数)
|
||
4. 数据质量检查
|
||
5. 数据划分(train/val/test)
|
||
6. 数据预处理(fit_transform/transform)
|
||
"""
|
||
|
||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||
import polars as pl
|
||
import numpy as np
|
||
|
||
from src.factors import FactorEngine
|
||
from src.training.factor_manager import FactorManager
|
||
from src.training.components.base import BaseProcessor
|
||
from src.training.core.stock_pool_manager import StockPoolManager
|
||
|
||
|
||
class DataPipeline:
|
||
"""数据流水线
|
||
|
||
执行完整的数据处理流程,返回标准化的数据字典。
|
||
|
||
Attributes:
|
||
factor_manager: 因子管理器
|
||
filters: 类形式的过滤器列表(如 STFilter)
|
||
stock_pool_filter_func: 函数形式的股票池筛选器
|
||
processor_configs: 数据处理器配置列表(类+参数)
|
||
stock_pool_required_columns: 股票池筛选所需的额外列
|
||
fitted_processors: 已拟合的处理器列表(训练后填充)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
factor_manager: FactorManager,
|
||
processor_configs: List[Tuple[Type[BaseProcessor], Dict[str, Any]]],
|
||
filters: Optional[List[Any]] = None,
|
||
stock_pool_filter_func: Optional[Callable] = None,
|
||
stock_pool_required_columns: Optional[List[str]] = None,
|
||
):
|
||
"""初始化数据流水线
|
||
|
||
Args:
|
||
factor_manager: 因子管理器实例
|
||
processor_configs: 数据处理器配置列表,每个元素为 (ProcessorClass, kwargs)
|
||
例如:[(NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99})]
|
||
filters: 类形式的过滤器列表(如 [STFilter])
|
||
stock_pool_filter_func: 函数形式的股票池筛选器
|
||
stock_pool_required_columns: 股票池筛选所需的额外列
|
||
"""
|
||
self.factor_manager = factor_manager
|
||
self.processor_configs = processor_configs or []
|
||
self.filters = filters or []
|
||
self.stock_pool_filter_func = stock_pool_filter_func
|
||
self.stock_pool_required_columns = stock_pool_required_columns or []
|
||
self.fitted_processors: List[BaseProcessor] = []
|
||
|
||
def prepare_data(
|
||
self,
|
||
engine: FactorEngine,
|
||
date_range: Dict[str, Tuple[str, str]],
|
||
label_name: str,
|
||
verbose: bool = True,
|
||
) -> Dict[str, Dict[str, Any]]:
|
||
"""执行完整数据流程
|
||
|
||
流程:
|
||
1. 注册因子并准备数据
|
||
2. 应用类过滤器(STFilter)
|
||
3. 应用股票池筛选(函数形式)
|
||
4. 数据质量检查
|
||
5. 数据划分
|
||
6. 数据预处理
|
||
|
||
Args:
|
||
engine: FactorEngine 实例
|
||
date_range: 日期范围字典 {"train": (start, end), "val": ..., "test": ...}
|
||
label_name: Label 列名
|
||
verbose: 是否打印处理信息
|
||
|
||
Returns:
|
||
标准化的数据字典
|
||
"""
|
||
if verbose:
|
||
print("\n" + "=" * 80)
|
||
print("数据流水线")
|
||
print("=" * 80)
|
||
|
||
# Step 1: 注册因子并准备数据
|
||
if verbose:
|
||
print("\n[1/6] 注册因子并准备数据...")
|
||
|
||
feature_cols = self.factor_manager.register_to_engine(engine, verbose=verbose)
|
||
|
||
# 计算完整日期范围
|
||
all_start = min(
|
||
date_range["train"][0], date_range["val"][0], date_range["test"][0]
|
||
)
|
||
all_end = max(
|
||
date_range["train"][1], date_range["val"][1], date_range["test"][1]
|
||
)
|
||
|
||
# 准备数据
|
||
data = engine.compute(
|
||
factor_names=feature_cols + [label_name],
|
||
start_date=all_start,
|
||
end_date=all_end,
|
||
)
|
||
|
||
if verbose:
|
||
print(f" 原始数据规模: {data.shape}")
|
||
print(f" 特征数: {len(feature_cols)}")
|
||
|
||
# Step 2: 应用类过滤器(STFilter)
|
||
if self.filters:
|
||
if verbose:
|
||
print(f"\n[2/6] 应用过滤器({len(self.filters)}个)...")
|
||
|
||
for filter_obj in self.filters:
|
||
data_before = len(data)
|
||
data = filter_obj.filter(data)
|
||
data_after = len(data)
|
||
|
||
if verbose:
|
||
print(f" {filter_obj.__class__.__name__}:")
|
||
print(f" 过滤前: {data_before}, 过滤后: {data_after}")
|
||
print(f" 删除: {data_before - data_after}")
|
||
|
||
# Step 3: 应用股票池筛选(函数形式)
|
||
if self.stock_pool_filter_func:
|
||
if verbose:
|
||
print(f"\n[3/6] 股票池筛选...")
|
||
|
||
data_before = len(data)
|
||
|
||
# 创建 StockPoolManager
|
||
pool_manager = StockPoolManager(
|
||
filter_func=self.stock_pool_filter_func,
|
||
required_columns=self.stock_pool_required_columns,
|
||
data_router=engine.router,
|
||
)
|
||
|
||
data = pool_manager.filter_and_select_daily(data)
|
||
data_after = len(data)
|
||
|
||
if verbose:
|
||
print(f" 筛选前: {data_before}, 筛选后: {data_after}")
|
||
print(f" 删除: {data_before - data_after}")
|
||
|
||
# Step 4: 数据质量检查
|
||
if verbose:
|
||
print(f"\n[4/6] 数据质量检查...")
|
||
|
||
self._check_data_quality(data, feature_cols, verbose=verbose)
|
||
|
||
# Step 5: 数据划分
|
||
if verbose:
|
||
print(f"\n[5/6] 数据划分...")
|
||
|
||
split_data = self._split_data(
|
||
data, date_range, feature_cols, label_name, verbose=verbose
|
||
)
|
||
|
||
# Step 6: 数据预处理
|
||
if verbose:
|
||
print(f"\n[6/6] 数据预处理...")
|
||
|
||
split_data = self._preprocess(split_data, feature_cols, verbose=verbose)
|
||
|
||
if verbose:
|
||
print("\n" + "=" * 80)
|
||
print("数据流水线完成")
|
||
print("=" * 80)
|
||
|
||
return split_data
|
||
|
||
def _check_data_quality(
|
||
self,
|
||
data: pl.DataFrame,
|
||
feature_cols: List[str],
|
||
verbose: bool = True,
|
||
) -> None:
|
||
"""检查数据质量
|
||
|
||
Args:
|
||
data: 数据框
|
||
feature_cols: 特征列名列表
|
||
verbose: 是否打印信息
|
||
"""
|
||
# 检查缺失值
|
||
null_counts = {}
|
||
for col in feature_cols[:10]: # 只检查前10个特征
|
||
null_count = data[col].null_count()
|
||
if null_count > 0:
|
||
null_counts[col] = null_count
|
||
|
||
if null_counts and verbose:
|
||
print(f" [警告] 发现缺失值(仅显示前10个特征):")
|
||
for col, count in list(null_counts.items())[:5]:
|
||
pct = count / len(data) * 100
|
||
print(f" {col}: {count} ({pct:.2f}%)")
|
||
|
||
def _split_data(
|
||
self,
|
||
data: pl.DataFrame,
|
||
date_range: Dict[str, Tuple[str, str]],
|
||
feature_cols: List[str],
|
||
label_name: str,
|
||
verbose: bool = True,
|
||
) -> Dict[str, Dict[str, Any]]:
|
||
"""划分数据集
|
||
|
||
Args:
|
||
data: 完整数据
|
||
date_range: 日期范围字典
|
||
feature_cols: 特征列名
|
||
label_name: Label 列名
|
||
verbose: 是否打印信息
|
||
|
||
Returns:
|
||
划分后的数据字典
|
||
"""
|
||
result = {}
|
||
|
||
for split_name, (start, end) in date_range.items():
|
||
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
|
||
split_df = data.filter(mask)
|
||
|
||
result[split_name] = {
|
||
"X": split_df.select(feature_cols),
|
||
"y": split_df[label_name],
|
||
"raw_data": split_df,
|
||
"feature_cols": feature_cols,
|
||
}
|
||
|
||
if verbose:
|
||
print(f" {split_name}: {len(split_df)} 条记录")
|
||
|
||
return result
|
||
|
||
def _preprocess(
|
||
self,
|
||
split_data: Dict[str, Dict[str, Any]],
|
||
feature_cols: List[str],
|
||
verbose: bool = True,
|
||
) -> Dict[str, Dict[str, Any]]:
|
||
"""预处理数据
|
||
|
||
训练集使用 fit_transform,验证集和测试集使用 transform
|
||
|
||
Args:
|
||
split_data: 划分后的数据字典
|
||
feature_cols: 特征列名列表
|
||
verbose: 是否打印信息
|
||
|
||
Returns:
|
||
预处理后的数据字典
|
||
"""
|
||
if not self.processor_configs:
|
||
return split_data
|
||
|
||
self.fitted_processors = []
|
||
|
||
# 实例化 processors(传入 feature_cols)
|
||
processors = []
|
||
for proc_class, proc_kwargs in self.processor_configs:
|
||
proc_kwargs_with_cols = {**proc_kwargs, "feature_cols": feature_cols}
|
||
processors.append(proc_class(**proc_kwargs_with_cols))
|
||
|
||
# 训练集:fit_transform
|
||
if verbose:
|
||
print(f" 训练集预处理(fit_transform)...")
|
||
|
||
train_data = split_data["train"]["raw_data"]
|
||
for processor in processors:
|
||
train_data = processor.fit_transform(train_data)
|
||
self.fitted_processors.append(processor)
|
||
|
||
# 更新训练集
|
||
split_data["train"]["raw_data"] = train_data
|
||
split_data["train"]["X"] = train_data.select(feature_cols)
|
||
split_data["train"]["y"] = train_data[split_data["train"]["y"].name]
|
||
|
||
# 验证集和测试集:transform
|
||
for split_name in ["val", "test"]:
|
||
if split_name in split_data:
|
||
if verbose:
|
||
print(f" {split_name}集预处理(transform)...")
|
||
|
||
split_df = split_data[split_name]["raw_data"]
|
||
for processor in self.fitted_processors:
|
||
split_df = processor.transform(split_df)
|
||
|
||
split_data[split_name]["raw_data"] = split_df
|
||
split_data[split_name]["X"] = split_df.select(feature_cols)
|
||
split_data[split_name]["y"] = split_df[split_data[split_name]["y"].name]
|
||
|
||
return split_data
|
||
|
||
def get_fitted_processors(self) -> List[BaseProcessor]:
|
||
"""获取已拟合的处理器列表
|
||
|
||
Returns:
|
||
已拟合的处理器列表(用于模型保存)
|
||
"""
|
||
return self.fitted_processors
|