"""数据流水线 完整的数据处理流程: 1. 因子注册和数据准备 2. 应用过滤器(STFilter 等) 3. 股票池筛选(自定义函数) 4. 数据质量检查 5. 数据划分(train/val/test) 6. 数据预处理(fit_transform/transform) """ from typing import Any, Callable, Dict, List, Optional, Tuple, Type import polars as pl import numpy as np from src.factors import FactorEngine from src.training.factor_manager import FactorManager from src.training.components.base import BaseProcessor from src.training.core.stock_pool_manager import StockPoolManager class DataPipeline: """数据流水线 执行完整的数据处理流程,返回标准化的数据字典。 Attributes: factor_manager: 因子管理器 filters: 类形式的过滤器列表(如 STFilter) stock_pool_filter_func: 函数形式的股票池筛选器 processor_configs: 数据处理器配置列表(类+参数) stock_pool_required_columns: 股票池筛选所需的额外列 fitted_processors: 已拟合的处理器列表(训练后填充) """ def __init__( self, factor_manager: FactorManager, processor_configs: List[Tuple[Type[BaseProcessor], Dict[str, Any]]], filters: Optional[List[Any]] = None, stock_pool_filter_func: Optional[Callable] = None, stock_pool_required_columns: Optional[List[str]] = None, ): """初始化数据流水线 Args: factor_manager: 因子管理器实例 processor_configs: 数据处理器配置列表,每个元素为 (ProcessorClass, kwargs) 例如:[(NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99})] filters: 类形式的过滤器列表(如 [STFilter]) stock_pool_filter_func: 函数形式的股票池筛选器 stock_pool_required_columns: 股票池筛选所需的额外列 """ self.factor_manager = factor_manager self.processor_configs = processor_configs or [] self.filters = filters or [] self.stock_pool_filter_func = stock_pool_filter_func self.stock_pool_required_columns = stock_pool_required_columns or [] self.fitted_processors: List[BaseProcessor] = [] def prepare_data( self, engine: FactorEngine, date_range: Dict[str, Tuple[str, str]], label_name: str, verbose: bool = True, ) -> Dict[str, Dict[str, Any]]: """执行完整数据流程 流程: 1. 注册因子并准备数据 2. 应用类过滤器(STFilter) 3. 应用股票池筛选(函数形式) 4. 数据质量检查 5. 数据划分 6. 数据预处理 Args: engine: FactorEngine 实例 date_range: 日期范围字典 {"train": (start, end), "val": ..., "test": ...} label_name: Label 列名 verbose: 是否打印处理信息 Returns: 标准化的数据字典 """ if verbose: print("\n" + "=" * 80) print("数据流水线") print("=" * 80) # Step 1: 注册因子并准备数据 if verbose: print("\n[1/6] 注册因子并准备数据...") feature_cols = self.factor_manager.register_to_engine(engine, verbose=verbose) # 计算完整日期范围 all_start = min( date_range["train"][0], date_range["val"][0], date_range["test"][0] ) all_end = max( date_range["train"][1], date_range["val"][1], date_range["test"][1] ) # 准备数据 data = engine.compute( factor_names=feature_cols + [label_name], start_date=all_start, end_date=all_end, ) if verbose: print(f" 原始数据规模: {data.shape}") print(f" 特征数: {len(feature_cols)}") # Step 2: 应用类过滤器(STFilter) if self.filters: if verbose: print(f"\n[2/6] 应用过滤器({len(self.filters)}个)...") for filter_obj in self.filters: data_before = len(data) data = filter_obj.filter(data) data_after = len(data) if verbose: print(f" {filter_obj.__class__.__name__}:") print(f" 过滤前: {data_before}, 过滤后: {data_after}") print(f" 删除: {data_before - data_after}") # Step 3: 应用股票池筛选(函数形式) if self.stock_pool_filter_func: if verbose: print(f"\n[3/6] 股票池筛选...") data_before = len(data) # 创建 StockPoolManager pool_manager = StockPoolManager( filter_func=self.stock_pool_filter_func, required_columns=self.stock_pool_required_columns, data_router=engine.router, ) data = pool_manager.filter_and_select_daily(data) data_after = len(data) if verbose: print(f" 筛选前: {data_before}, 筛选后: {data_after}") print(f" 删除: {data_before - data_after}") # Step 4: 数据质量检查 if verbose: print(f"\n[4/6] 数据质量检查...") self._check_data_quality(data, feature_cols, verbose=verbose) # Step 5: 数据划分 if verbose: print(f"\n[5/6] 数据划分...") split_data = self._split_data( data, date_range, feature_cols, label_name, verbose=verbose ) # Step 6: 数据预处理 if verbose: print(f"\n[6/6] 数据预处理...") split_data = self._preprocess(split_data, feature_cols, verbose=verbose) if verbose: print("\n" + "=" * 80) print("数据流水线完成") print("=" * 80) return split_data def _check_data_quality( self, data: pl.DataFrame, feature_cols: List[str], verbose: bool = True, ) -> None: """检查数据质量 Args: data: 数据框 feature_cols: 特征列名列表 verbose: 是否打印信息 """ # 检查缺失值 null_counts = {} for col in feature_cols[:10]: # 只检查前10个特征 null_count = data[col].null_count() if null_count > 0: null_counts[col] = null_count if null_counts and verbose: print(f" [警告] 发现缺失值(仅显示前10个特征):") for col, count in list(null_counts.items())[:5]: pct = count / len(data) * 100 print(f" {col}: {count} ({pct:.2f}%)") def _split_data( self, data: pl.DataFrame, date_range: Dict[str, Tuple[str, str]], feature_cols: List[str], label_name: str, verbose: bool = True, ) -> Dict[str, Dict[str, Any]]: """划分数据集 Args: data: 完整数据 date_range: 日期范围字典 feature_cols: 特征列名 label_name: Label 列名 verbose: 是否打印信息 Returns: 划分后的数据字典 """ result = {} for split_name, (start, end) in date_range.items(): mask = (data["trade_date"] >= start) & (data["trade_date"] <= end) split_df = data.filter(mask) result[split_name] = { "X": split_df.select(feature_cols), "y": split_df[label_name], "raw_data": split_df, "feature_cols": feature_cols, } if verbose: print(f" {split_name}: {len(split_df)} 条记录") return result def _preprocess( self, split_data: Dict[str, Dict[str, Any]], feature_cols: List[str], verbose: bool = True, ) -> Dict[str, Dict[str, Any]]: """预处理数据 训练集使用 fit_transform,验证集和测试集使用 transform Args: split_data: 划分后的数据字典 feature_cols: 特征列名列表 verbose: 是否打印信息 Returns: 预处理后的数据字典 """ if not self.processor_configs: return split_data self.fitted_processors = [] # 实例化 processors(传入 feature_cols) processors = [] for proc_class, proc_kwargs in self.processor_configs: proc_kwargs_with_cols = {**proc_kwargs, "feature_cols": feature_cols} processors.append(proc_class(**proc_kwargs_with_cols)) # 训练集:fit_transform if verbose: print(f" 训练集预处理(fit_transform)...") train_data = split_data["train"]["raw_data"] for processor in processors: train_data = processor.fit_transform(train_data) self.fitted_processors.append(processor) # 更新训练集 split_data["train"]["raw_data"] = train_data split_data["train"]["X"] = train_data.select(feature_cols) split_data["train"]["y"] = train_data[split_data["train"]["y"].name] # 验证集和测试集:transform for split_name in ["val", "test"]: if split_name in split_data: if verbose: print(f" {split_name}集预处理(transform)...") split_df = split_data[split_name]["raw_data"] for processor in self.fitted_processors: split_df = processor.transform(split_df) split_data[split_name]["raw_data"] = split_df split_data[split_name]["X"] = split_df.select(feature_cols) split_data[split_name]["y"] = split_df[split_data[split_name]["y"].name] return split_data def get_fitted_processors(self) -> List[BaseProcessor]: """获取已拟合的处理器列表 Returns: 已拟合的处理器列表(用于模型保存) """ return self.fitted_processors