310 lines
10 KiB
Python
310 lines
10 KiB
Python
|
|
"""数据流水线
|
|||
|
|
|
|||
|
|
完整的数据处理流程:
|
|||
|
|
1. 因子注册和数据准备
|
|||
|
|
2. 应用过滤器(STFilter 等)
|
|||
|
|
3. 股票池筛选(自定义函数)
|
|||
|
|
4. 数据质量检查
|
|||
|
|
5. 数据划分(train/val/test)
|
|||
|
|
6. 数据预处理(fit_transform/transform)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
|||
|
|
import polars as pl
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
from src.factors import FactorEngine
|
|||
|
|
from src.training.factor_manager import FactorManager
|
|||
|
|
from src.training.components.base import BaseProcessor
|
|||
|
|
from src.training.core.stock_pool_manager import StockPoolManager
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DataPipeline:
|
|||
|
|
"""数据流水线
|
|||
|
|
|
|||
|
|
执行完整的数据处理流程,返回标准化的数据字典。
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
factor_manager: 因子管理器
|
|||
|
|
filters: 类形式的过滤器列表(如 STFilter)
|
|||
|
|
stock_pool_filter_func: 函数形式的股票池筛选器
|
|||
|
|
processor_configs: 数据处理器配置列表(类+参数)
|
|||
|
|
stock_pool_required_columns: 股票池筛选所需的额外列
|
|||
|
|
fitted_processors: 已拟合的处理器列表(训练后填充)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
factor_manager: FactorManager,
|
|||
|
|
processor_configs: List[Tuple[Type[BaseProcessor], Dict[str, Any]]],
|
|||
|
|
filters: Optional[List[Any]] = None,
|
|||
|
|
stock_pool_filter_func: Optional[Callable] = None,
|
|||
|
|
stock_pool_required_columns: Optional[List[str]] = None,
|
|||
|
|
):
|
|||
|
|
"""初始化数据流水线
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
factor_manager: 因子管理器实例
|
|||
|
|
processor_configs: 数据处理器配置列表,每个元素为 (ProcessorClass, kwargs)
|
|||
|
|
例如:[(NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99})]
|
|||
|
|
filters: 类形式的过滤器列表(如 [STFilter])
|
|||
|
|
stock_pool_filter_func: 函数形式的股票池筛选器
|
|||
|
|
stock_pool_required_columns: 股票池筛选所需的额外列
|
|||
|
|
"""
|
|||
|
|
self.factor_manager = factor_manager
|
|||
|
|
self.processor_configs = processor_configs or []
|
|||
|
|
self.filters = filters or []
|
|||
|
|
self.stock_pool_filter_func = stock_pool_filter_func
|
|||
|
|
self.stock_pool_required_columns = stock_pool_required_columns or []
|
|||
|
|
self.fitted_processors: List[BaseProcessor] = []
|
|||
|
|
|
|||
|
|
def prepare_data(
|
|||
|
|
self,
|
|||
|
|
engine: FactorEngine,
|
|||
|
|
date_range: Dict[str, Tuple[str, str]],
|
|||
|
|
label_name: str,
|
|||
|
|
verbose: bool = True,
|
|||
|
|
) -> Dict[str, Dict[str, Any]]:
|
|||
|
|
"""执行完整数据流程
|
|||
|
|
|
|||
|
|
流程:
|
|||
|
|
1. 注册因子并准备数据
|
|||
|
|
2. 应用类过滤器(STFilter)
|
|||
|
|
3. 应用股票池筛选(函数形式)
|
|||
|
|
4. 数据质量检查
|
|||
|
|
5. 数据划分
|
|||
|
|
6. 数据预处理
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
engine: FactorEngine 实例
|
|||
|
|
date_range: 日期范围字典 {"train": (start, end), "val": ..., "test": ...}
|
|||
|
|
label_name: Label 列名
|
|||
|
|
verbose: 是否打印处理信息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
标准化的数据字典
|
|||
|
|
"""
|
|||
|
|
if verbose:
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("数据流水线")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
# Step 1: 注册因子并准备数据
|
|||
|
|
if verbose:
|
|||
|
|
print("\n[1/6] 注册因子并准备数据...")
|
|||
|
|
|
|||
|
|
feature_cols = self.factor_manager.register_to_engine(engine, verbose=verbose)
|
|||
|
|
|
|||
|
|
# 计算完整日期范围
|
|||
|
|
all_start = min(
|
|||
|
|
date_range["train"][0], date_range["val"][0], date_range["test"][0]
|
|||
|
|
)
|
|||
|
|
all_end = max(
|
|||
|
|
date_range["train"][1], date_range["val"][1], date_range["test"][1]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 准备数据
|
|||
|
|
data = engine.compute(
|
|||
|
|
factor_names=feature_cols + [label_name],
|
|||
|
|
start_date=all_start,
|
|||
|
|
end_date=all_end,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print(f" 原始数据规模: {data.shape}")
|
|||
|
|
print(f" 特征数: {len(feature_cols)}")
|
|||
|
|
|
|||
|
|
# Step 2: 应用类过滤器(STFilter)
|
|||
|
|
if self.filters:
|
|||
|
|
if verbose:
|
|||
|
|
print(f"\n[2/6] 应用过滤器({len(self.filters)}个)...")
|
|||
|
|
|
|||
|
|
for filter_obj in self.filters:
|
|||
|
|
data_before = len(data)
|
|||
|
|
data = filter_obj.filter(data)
|
|||
|
|
data_after = len(data)
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print(f" {filter_obj.__class__.__name__}:")
|
|||
|
|
print(f" 过滤前: {data_before}, 过滤后: {data_after}")
|
|||
|
|
print(f" 删除: {data_before - data_after}")
|
|||
|
|
|
|||
|
|
# Step 3: 应用股票池筛选(函数形式)
|
|||
|
|
if self.stock_pool_filter_func:
|
|||
|
|
if verbose:
|
|||
|
|
print(f"\n[3/6] 股票池筛选...")
|
|||
|
|
|
|||
|
|
data_before = len(data)
|
|||
|
|
|
|||
|
|
# 创建 StockPoolManager
|
|||
|
|
pool_manager = StockPoolManager(
|
|||
|
|
filter_func=self.stock_pool_filter_func,
|
|||
|
|
required_columns=self.stock_pool_required_columns,
|
|||
|
|
data_router=engine.router,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
data = pool_manager.filter_and_select_daily(data)
|
|||
|
|
data_after = len(data)
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print(f" 筛选前: {data_before}, 筛选后: {data_after}")
|
|||
|
|
print(f" 删除: {data_before - data_after}")
|
|||
|
|
|
|||
|
|
# Step 4: 数据质量检查
|
|||
|
|
if verbose:
|
|||
|
|
print(f"\n[4/6] 数据质量检查...")
|
|||
|
|
|
|||
|
|
self._check_data_quality(data, feature_cols, verbose=verbose)
|
|||
|
|
|
|||
|
|
# Step 5: 数据划分
|
|||
|
|
if verbose:
|
|||
|
|
print(f"\n[5/6] 数据划分...")
|
|||
|
|
|
|||
|
|
split_data = self._split_data(
|
|||
|
|
data, date_range, feature_cols, label_name, verbose=verbose
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Step 6: 数据预处理
|
|||
|
|
if verbose:
|
|||
|
|
print(f"\n[6/6] 数据预处理...")
|
|||
|
|
|
|||
|
|
split_data = self._preprocess(split_data, feature_cols, verbose=verbose)
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print("\n" + "=" * 80)
|
|||
|
|
print("数据流水线完成")
|
|||
|
|
print("=" * 80)
|
|||
|
|
|
|||
|
|
return split_data
|
|||
|
|
|
|||
|
|
def _check_data_quality(
|
|||
|
|
self,
|
|||
|
|
data: pl.DataFrame,
|
|||
|
|
feature_cols: List[str],
|
|||
|
|
verbose: bool = True,
|
|||
|
|
) -> None:
|
|||
|
|
"""检查数据质量
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data: 数据框
|
|||
|
|
feature_cols: 特征列名列表
|
|||
|
|
verbose: 是否打印信息
|
|||
|
|
"""
|
|||
|
|
# 检查缺失值
|
|||
|
|
null_counts = {}
|
|||
|
|
for col in feature_cols[:10]: # 只检查前10个特征
|
|||
|
|
null_count = data[col].null_count()
|
|||
|
|
if null_count > 0:
|
|||
|
|
null_counts[col] = null_count
|
|||
|
|
|
|||
|
|
if null_counts and verbose:
|
|||
|
|
print(f" [警告] 发现缺失值(仅显示前10个特征):")
|
|||
|
|
for col, count in list(null_counts.items())[:5]:
|
|||
|
|
pct = count / len(data) * 100
|
|||
|
|
print(f" {col}: {count} ({pct:.2f}%)")
|
|||
|
|
|
|||
|
|
def _split_data(
|
|||
|
|
self,
|
|||
|
|
data: pl.DataFrame,
|
|||
|
|
date_range: Dict[str, Tuple[str, str]],
|
|||
|
|
feature_cols: List[str],
|
|||
|
|
label_name: str,
|
|||
|
|
verbose: bool = True,
|
|||
|
|
) -> Dict[str, Dict[str, Any]]:
|
|||
|
|
"""划分数据集
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data: 完整数据
|
|||
|
|
date_range: 日期范围字典
|
|||
|
|
feature_cols: 特征列名
|
|||
|
|
label_name: Label 列名
|
|||
|
|
verbose: 是否打印信息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
划分后的数据字典
|
|||
|
|
"""
|
|||
|
|
result = {}
|
|||
|
|
|
|||
|
|
for split_name, (start, end) in date_range.items():
|
|||
|
|
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
|
|||
|
|
split_df = data.filter(mask)
|
|||
|
|
|
|||
|
|
result[split_name] = {
|
|||
|
|
"X": split_df.select(feature_cols),
|
|||
|
|
"y": split_df[label_name],
|
|||
|
|
"raw_data": split_df,
|
|||
|
|
"feature_cols": feature_cols,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if verbose:
|
|||
|
|
print(f" {split_name}: {len(split_df)} 条记录")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def _preprocess(
|
|||
|
|
self,
|
|||
|
|
split_data: Dict[str, Dict[str, Any]],
|
|||
|
|
feature_cols: List[str],
|
|||
|
|
verbose: bool = True,
|
|||
|
|
) -> Dict[str, Dict[str, Any]]:
|
|||
|
|
"""预处理数据
|
|||
|
|
|
|||
|
|
训练集使用 fit_transform,验证集和测试集使用 transform
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
split_data: 划分后的数据字典
|
|||
|
|
feature_cols: 特征列名列表
|
|||
|
|
verbose: 是否打印信息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
预处理后的数据字典
|
|||
|
|
"""
|
|||
|
|
if not self.processor_configs:
|
|||
|
|
return split_data
|
|||
|
|
|
|||
|
|
self.fitted_processors = []
|
|||
|
|
|
|||
|
|
# 实例化 processors(传入 feature_cols)
|
|||
|
|
processors = []
|
|||
|
|
for proc_class, proc_kwargs in self.processor_configs:
|
|||
|
|
proc_kwargs_with_cols = {**proc_kwargs, "feature_cols": feature_cols}
|
|||
|
|
processors.append(proc_class(**proc_kwargs_with_cols))
|
|||
|
|
|
|||
|
|
# 训练集:fit_transform
|
|||
|
|
if verbose:
|
|||
|
|
print(f" 训练集预处理(fit_transform)...")
|
|||
|
|
|
|||
|
|
train_data = split_data["train"]["raw_data"]
|
|||
|
|
for processor in processors:
|
|||
|
|
train_data = processor.fit_transform(train_data)
|
|||
|
|
self.fitted_processors.append(processor)
|
|||
|
|
|
|||
|
|
# 更新训练集
|
|||
|
|
split_data["train"]["raw_data"] = train_data
|
|||
|
|
split_data["train"]["X"] = train_data.select(feature_cols)
|
|||
|
|
split_data["train"]["y"] = train_data[split_data["train"]["y"].name]
|
|||
|
|
|
|||
|
|
# 验证集和测试集:transform
|
|||
|
|
for split_name in ["val", "test"]:
|
|||
|
|
if split_name in split_data:
|
|||
|
|
if verbose:
|
|||
|
|
print(f" {split_name}集预处理(transform)...")
|
|||
|
|
|
|||
|
|
split_df = split_data[split_name]["raw_data"]
|
|||
|
|
for processor in self.fitted_processors:
|
|||
|
|
split_df = processor.transform(split_df)
|
|||
|
|
|
|||
|
|
split_data[split_name]["raw_data"] = split_df
|
|||
|
|
split_data[split_name]["X"] = split_df.select(feature_cols)
|
|||
|
|
split_data[split_name]["y"] = split_df[split_data[split_name]["y"].name]
|
|||
|
|
|
|||
|
|
return split_data
|
|||
|
|
|
|||
|
|
def get_fitted_processors(self) -> List[BaseProcessor]:
|
|||
|
|
"""获取已拟合的处理器列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
已拟合的处理器列表(用于模型保存)
|
|||
|
|
"""
|
|||
|
|
return self.fitted_processors
|