Files
ProStock/src/training/pipeline.py

310 lines
10 KiB
Python
Raw Normal View History

"""数据流水线
完整的数据处理流程
1. 因子注册和数据准备
2. 应用过滤器STFilter
3. 股票池筛选自定义函数
4. 数据质量检查
5. 数据划分train/val/test
6. 数据预处理fit_transform/transform
"""
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import polars as pl
import numpy as np
from src.factors import FactorEngine
from src.training.factor_manager import FactorManager
from src.training.components.base import BaseProcessor
from src.training.core.stock_pool_manager import StockPoolManager
class DataPipeline:
"""数据流水线
执行完整的数据处理流程返回标准化的数据字典
Attributes:
factor_manager: 因子管理器
filters: 类形式的过滤器列表 STFilter
stock_pool_filter_func: 函数形式的股票池筛选器
processor_configs: 数据处理器配置列表+参数
stock_pool_required_columns: 股票池筛选所需的额外列
fitted_processors: 已拟合的处理器列表训练后填充
"""
def __init__(
self,
factor_manager: FactorManager,
processor_configs: List[Tuple[Type[BaseProcessor], Dict[str, Any]]],
filters: Optional[List[Any]] = None,
stock_pool_filter_func: Optional[Callable] = None,
stock_pool_required_columns: Optional[List[str]] = None,
):
"""初始化数据流水线
Args:
factor_manager: 因子管理器实例
processor_configs: 数据处理器配置列表每个元素为 (ProcessorClass, kwargs)
例如[(NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99})]
filters: 类形式的过滤器列表 [STFilter]
stock_pool_filter_func: 函数形式的股票池筛选器
stock_pool_required_columns: 股票池筛选所需的额外列
"""
self.factor_manager = factor_manager
self.processor_configs = processor_configs or []
self.filters = filters or []
self.stock_pool_filter_func = stock_pool_filter_func
self.stock_pool_required_columns = stock_pool_required_columns or []
self.fitted_processors: List[BaseProcessor] = []
def prepare_data(
self,
engine: FactorEngine,
date_range: Dict[str, Tuple[str, str]],
label_name: str,
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""执行完整数据流程
流程
1. 注册因子并准备数据
2. 应用类过滤器STFilter
3. 应用股票池筛选函数形式
4. 数据质量检查
5. 数据划分
6. 数据预处理
Args:
engine: FactorEngine 实例
date_range: 日期范围字典 {"train": (start, end), "val": ..., "test": ...}
label_name: Label 列名
verbose: 是否打印处理信息
Returns:
标准化的数据字典
"""
if verbose:
print("\n" + "=" * 80)
print("数据流水线")
print("=" * 80)
# Step 1: 注册因子并准备数据
if verbose:
print("\n[1/6] 注册因子并准备数据...")
feature_cols = self.factor_manager.register_to_engine(engine, verbose=verbose)
# 计算完整日期范围
all_start = min(
date_range["train"][0], date_range["val"][0], date_range["test"][0]
)
all_end = max(
date_range["train"][1], date_range["val"][1], date_range["test"][1]
)
# 准备数据
data = engine.compute(
factor_names=feature_cols + [label_name],
start_date=all_start,
end_date=all_end,
)
if verbose:
print(f" 原始数据规模: {data.shape}")
print(f" 特征数: {len(feature_cols)}")
# Step 2: 应用类过滤器STFilter
if self.filters:
if verbose:
print(f"\n[2/6] 应用过滤器({len(self.filters)}个)...")
for filter_obj in self.filters:
data_before = len(data)
data = filter_obj.filter(data)
data_after = len(data)
if verbose:
print(f" {filter_obj.__class__.__name__}:")
print(f" 过滤前: {data_before}, 过滤后: {data_after}")
print(f" 删除: {data_before - data_after}")
# Step 3: 应用股票池筛选(函数形式)
if self.stock_pool_filter_func:
if verbose:
print(f"\n[3/6] 股票池筛选...")
data_before = len(data)
# 创建 StockPoolManager
pool_manager = StockPoolManager(
filter_func=self.stock_pool_filter_func,
required_columns=self.stock_pool_required_columns,
data_router=engine.router,
)
data = pool_manager.filter_and_select_daily(data)
data_after = len(data)
if verbose:
print(f" 筛选前: {data_before}, 筛选后: {data_after}")
print(f" 删除: {data_before - data_after}")
# Step 4: 数据质量检查
if verbose:
print(f"\n[4/6] 数据质量检查...")
self._check_data_quality(data, feature_cols, verbose=verbose)
# Step 5: 数据划分
if verbose:
print(f"\n[5/6] 数据划分...")
split_data = self._split_data(
data, date_range, feature_cols, label_name, verbose=verbose
)
# Step 6: 数据预处理
if verbose:
print(f"\n[6/6] 数据预处理...")
split_data = self._preprocess(split_data, feature_cols, verbose=verbose)
if verbose:
print("\n" + "=" * 80)
print("数据流水线完成")
print("=" * 80)
return split_data
def _check_data_quality(
self,
data: pl.DataFrame,
feature_cols: List[str],
verbose: bool = True,
) -> None:
"""检查数据质量
Args:
data: 数据框
feature_cols: 特征列名列表
verbose: 是否打印信息
"""
# 检查缺失值
null_counts = {}
for col in feature_cols[:10]: # 只检查前10个特征
null_count = data[col].null_count()
if null_count > 0:
null_counts[col] = null_count
if null_counts and verbose:
print(f" [警告] 发现缺失值仅显示前10个特征:")
for col, count in list(null_counts.items())[:5]:
pct = count / len(data) * 100
print(f" {col}: {count} ({pct:.2f}%)")
def _split_data(
self,
data: pl.DataFrame,
date_range: Dict[str, Tuple[str, str]],
feature_cols: List[str],
label_name: str,
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""划分数据集
Args:
data: 完整数据
date_range: 日期范围字典
feature_cols: 特征列名
label_name: Label 列名
verbose: 是否打印信息
Returns:
划分后的数据字典
"""
result = {}
for split_name, (start, end) in date_range.items():
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
split_df = data.filter(mask)
result[split_name] = {
"X": split_df.select(feature_cols),
"y": split_df[label_name],
"raw_data": split_df,
"feature_cols": feature_cols,
}
if verbose:
print(f" {split_name}: {len(split_df)} 条记录")
return result
def _preprocess(
self,
split_data: Dict[str, Dict[str, Any]],
feature_cols: List[str],
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""预处理数据
训练集使用 fit_transform验证集和测试集使用 transform
Args:
split_data: 划分后的数据字典
feature_cols: 特征列名列表
verbose: 是否打印信息
Returns:
预处理后的数据字典
"""
if not self.processor_configs:
return split_data
self.fitted_processors = []
# 实例化 processors传入 feature_cols
processors = []
for proc_class, proc_kwargs in self.processor_configs:
proc_kwargs_with_cols = {**proc_kwargs, "feature_cols": feature_cols}
processors.append(proc_class(**proc_kwargs_with_cols))
# 训练集fit_transform
if verbose:
print(f" 训练集预处理fit_transform...")
train_data = split_data["train"]["raw_data"]
for processor in processors:
train_data = processor.fit_transform(train_data)
self.fitted_processors.append(processor)
# 更新训练集
split_data["train"]["raw_data"] = train_data
split_data["train"]["X"] = train_data.select(feature_cols)
split_data["train"]["y"] = train_data[split_data["train"]["y"].name]
# 验证集和测试集transform
for split_name in ["val", "test"]:
if split_name in split_data:
if verbose:
print(f" {split_name}集预处理transform...")
split_df = split_data[split_name]["raw_data"]
for processor in self.fitted_processors:
split_df = processor.transform(split_df)
split_data[split_name]["raw_data"] = split_df
split_data[split_name]["X"] = split_df.select(feature_cols)
split_data[split_name]["y"] = split_df[split_data[split_name]["y"].name]
return split_data
def get_fitted_processors(self) -> List[BaseProcessor]:
"""获取已拟合的处理器列表
Returns:
已拟合的处理器列表用于模型保存
"""
return self.fitted_processors