# ProStock 模型训练框架设计文档 ## 1. 设计目标与原则 ### 1.1 核心目标 - **组件化**:每个阶段(数据获取、处理、训练、评估)都是独立组件 - **低耦合**:组件间通过标准接口交互,不依赖具体实现 - **插件式**:新功能通过插件注册,无需修改核心代码 - **阶段感知**:数据处理区分训练阶段和测试阶段,防止数据泄露 - **多模型支持**:统一接口支持 LightGBM、CatBoost 等多种模型 - **多任务支持**:分类、回归、排序三种任务类型 ### 1.2 设计原则 | 原则 | 说明 | |------|------| | **单一职责** | 每个组件只做一件事,做好一件事 | | **开闭原则** | 对扩展开放(插件),对修改封闭(核心) | | **依赖倒置** | 依赖抽象接口,而非具体实现 | | **显式优于隐式** | 阶段标记、处理逻辑必须显式声明 | | **配置驱动** | 通过配置文件或代码配置定义流程,减少硬编码 | --- ## 2. 整体架构 ### 2.1 架构概览 ``` ┌─────────────────────────────────────────────────────────────────────────┐ │ ML Pipeline Orchestrator │ │ (流水线编排器 - 配置驱动执行) │ └─────────────────────────────────────────────────────────────────────────┘ │ ┌───────────────────────────┼───────────────────────────┐ ▼ ▼ ▼ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │ Data Source │ │ Data Source │ │ Data Source │ │ (因子数据) │ │ (行情数据) │ │ (标签数据) │ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ │ │ │ └──────────────────────────┼──────────────────────────┘ ▼ ┌─────────────────────────────────────────────────────────────────────────┐ │ Feature Store (特征存储层) │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ │ │ FactorLoader │ │ LabelLoader │ │ DataMerger │ │ CacheMgr │ │ │ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │ └─────────────────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────┐ │ Processing Pipeline (处理流水线) │ │ │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │ │ Processor │ -> │ Processor │ -> │ Processor │ -> │ ... │ │ │ │ (阶段:ALL) │ │ (阶段:TRAIN)│ │ (阶段:TEST) │ │ │ │ │ └─────────────┘ └─────────────┘ └─────────────┘ └──────────┘ │ │ │ │ 处理器类型: │ │ - FeatureEncoder: 特征编码(类别编码、数值缩放等) │ │ - FeatureSelector: 特征选择(相关性过滤、重要性筛选等) │ │ - OutlierHandler: 异常值处理 │ │ - MissingValueHandler: 缺失值处理 │ │ - CustomTransformer: 自定义转换器 │ └─────────────────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────┐ │ Train/Test Split (数据划分) │ │ │ │ 支持多种划分策略: │ │ - TimeSeriesSplit: 时间序列划分(防止未来泄露) │ │ - PurgedKFold: 清除重叠样本的K折交叉验证 │ │ - EmbargoSplit: embargo 延迟验证 │ │ - CustomSplit: 自定义划分策略 │ └─────────────────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────┐ │ Model Training (模型训练层) │ │ │ │ ┌─────────────────────────────────────────────────────────────────┐ │ │ │ Model Registry │ │ │ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │ │ │ │ LightGBM │ │CatBoost │ │ XGBoost │ │ Custom │ ... │ │ │ │ │ Model │ │ Model │ │ Model │ │ Model │ │ │ │ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │ │ └─────────────────────────────────────────────────────────────────┘ │ │ │ │ 任务类型: │ │ - Classification: 分类任务(上涨/下跌预测) │ │ - Regression: 回归任务(收益率预测) │ │ - Ranking: 排序任务(股票排序/选股) │ └─────────────────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────┐ │ Evaluation (评估层) │ │ │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌────────────┐ │ │ │ Metric │ │ Metric │ │ Metric │ │ Analyzer │ │ │ │ (IC/IR) │ │ (Sharpe) │ │ (Accuracy) │ │ (回测) │ │ │ └──────────────┘ └──────────────┘ └──────────────┘ └────────────┘ │ │ │ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ │ │ ResultStore │ │ Report │ │ Visualizer │ │ │ │ (模型存储) │ │ (报告生成) │ │ (可视化) │ │ │ └──────────────┘ └──────────────┘ └──────────────┘ │ └─────────────────────────────────────────────────────────────────────────┘ ``` ### 2.2 数据流向图 ``` 因子DataFrame (Polars) │ ▼ ┌──────────────────────┐ │ Feature Store │ 1. 加载并合并因子、标签、辅助数据 │ - 列选择 │ 2. 支持按日期/股票过滤 │ - 数据对齐 │ 3. 缓存机制避免重复加载 └──────────┬───────────┘ │ ▼ ┌──────────────────────┐ │ Processing Pipeline │ 顺序执行多个处理器 │ │ 每个处理器标记适用阶段 (ALL/TRAIN/TEST) │ for processor in pipeline: │ if processor.stage in [current_stage, ALL]: │ data = processor.transform(data) └──────────┬───────────┘ │ ▼ ┌──────────────────────┐ │ Data Splitter │ 时间序列感知的划分策略 │ - X_train, y_train │ 防止未来泄露 │ - X_test, y_test │ └──────────┬───────────┘ │ ▼ ┌──────────────────────┐ │ Model Training │ 统一接口,支持多种模型 │ - fit(X_train) │ 任务类型: classification/regression/ranking │ - predict(X_test) │ └──────────┬───────────┘ │ ▼ ┌──────────────────────┐ │ Evaluation │ 多维度评估 │ - 预测指标 │ - IC/IR │ - 回测指标 │ - 分组收益 │ - 可视化 │ - 累计收益曲线 └──────────────────────┘ ``` --- ## 3. 核心组件设计 ### 3.1 基础抽象类 #### 3.1.1 PipelineStage (流水线阶段枚举) ```python from enum import Enum, auto class PipelineStage(Enum): """流水线阶段标记""" ALL = auto() # 适用于所有阶段 TRAIN = auto() # 仅训练阶段 TEST = auto() # 仅测试阶段 VALIDATION = auto() # 仅验证阶段 ``` #### 3.1.2 BaseProcessor (处理器基类) ```python from abc import ABC, abstractmethod from typing import Any, Dict, Optional import polars as pl class BaseProcessor(ABC): """数据处理器基类 所有数据处理器必须继承此类。 关键特性:通过 stage 属性控制处理器在哪些阶段生效。 示例: >>> class StandardScaler(BaseProcessor): ... stage = PipelineStage.ALL # 训练和测试都使用 ... ... def fit(self, data: pl.DataFrame) -> None: ... self.mean = data[self.columns].mean() ... self.std = data[self.columns].std() ... ... def transform(self, data: pl.DataFrame) -> pl.DataFrame: ... return (data - self.mean) / self.std """ # 子类必须定义适用阶段 stage: PipelineStage = PipelineStage.ALL def __init__(self, columns: Optional[list] = None, **params): """初始化处理器 Args: columns: 要处理的列,None表示所有数值列 **params: 处理器特定参数 """ self.columns = columns self.params = params self._is_fitted = False self._fitted_params: Dict[str, Any] = {} @abstractmethod def fit(self, data: pl.DataFrame) -> "BaseProcessor": """在训练数据上学习参数 此方法只在训练阶段调用一次。 学习到的参数存储在 self._fitted_params 中。 Args: data: 训练数据 Returns: self (支持链式调用) """ pass @abstractmethod def transform(self, data: pl.DataFrame) -> pl.DataFrame: """转换数据 在训练和测试阶段都会被调用。 使用 fit() 阶段学习到的参数进行转换。 Args: data: 输入数据 Returns: 转换后的数据 """ pass def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame: """先fit再transform的便捷方法""" return self.fit(data).transform(data) def get_fitted_params(self) -> Dict[str, Any]: """获取学习到的参数(用于保存/加载)""" return self._fitted_params.copy() def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor": """设置学习到的参数(用于从checkpoint恢复)""" self._fitted_params = params.copy() self._is_fitted = True return self ``` #### 3.1.3 BaseModel (模型基类) ```python from abc import ABC, abstractmethod from typing import Literal, Any, Dict import polars as pl import numpy as np TaskType = Literal["classification", "regression", "ranking"] class BaseModel(ABC): """机器学习模型基类 统一接口支持多种模型(LightGBM, CatBoost, XGBoost等) 和多种任务类型(分类、回归、排序)。 示例: >>> model = LightGBMModel( ... task_type="classification", ... params={"n_estimators": 100} ... ) >>> model.fit(X_train, y_train) >>> predictions = model.predict(X_test) """ def __init__( self, task_type: TaskType, params: Optional[Dict[str, Any]] = None, name: Optional[str] = None ): """初始化模型 Args: task_type: 任务类型 - "classification", "regression", "ranking" params: 模型特定参数 name: 模型名称(用于日志和报告) """ self.task_type = task_type self.params = params or {} self.name = name or self.__class__.__name__ self._model: Any = None self._is_fitted = False @abstractmethod def fit( self, X: pl.DataFrame, y: pl.Series, X_val: Optional[pl.DataFrame] = None, y_val: Optional[pl.Series] = None, **fit_params ) -> "BaseModel": """训练模型 Args: X: 特征数据 y: 目标变量 X_val: 验证集特征(可选) y_val: 验证集目标(可选) **fit_params: 额外的fit参数 Returns: self (支持链式调用) """ pass @abstractmethod def predict(self, X: pl.DataFrame) -> np.ndarray: """预测 Args: X: 特征数据 Returns: 预测结果数组 - classification: 类别标签或概率 - regression: 连续值 - ranking: 排序分数 """ pass def predict_proba(self, X: pl.DataFrame) -> np.ndarray: """预测概率(仅分类任务) Args: X: 特征数据 Returns: 类别概率数组 [n_samples, n_classes] """ raise NotImplementedError("predict_proba only available for classification tasks") def get_feature_importance(self) -> Optional[pl.DataFrame]: """获取特征重要性(如果模型支持) Returns: DataFrame[feature, importance] 或 None """ return None def save(self, path: str) -> None: """保存模型到文件""" import pickle with open(path, 'wb') as f: pickle.dump(self, f) @classmethod def load(cls, path: str) -> "BaseModel": """从文件加载模型""" import pickle with open(path, 'rb') as f: return pickle.load(f) ``` #### 3.1.4 BaseSplitter (数据划分基类) ```python from abc import ABC, abstractmethod from typing import Iterator, Tuple, List import polars as pl class BaseSplitter(ABC): """数据划分策略基类 针对时间序列数据的特殊划分策略,防止未来泄露。 示例: >>> splitter = TimeSeriesSplit(n_splits=5, gap=5) >>> for train_idx, test_idx in splitter.split(data): ... X_train, X_test = X[train_idx], X[test_idx] """ @abstractmethod def split( self, data: pl.DataFrame, date_col: str = "trade_date" ) -> Iterator[Tuple[List[int], List[int]]]: """生成训练/测试索引 Args: data: 完整数据集 date_col: 日期列名 Yields: (train_indices, test_indices) 元组 """ pass @abstractmethod def get_split_dates( self, data: pl.DataFrame, date_col: str = "trade_date" ) -> List[Tuple[str, str, str, str]]: """获取划分日期范围 Returns: [(train_start, train_end, test_start, test_end), ...] """ pass ``` --- ### 3.2 核心组件 #### 3.2.1 FeatureStore (特征存储) ```python from typing import List, Optional, Dict import polars as pl from pathlib import Path class FeatureStore: """特征存储管理器 负责加载、合并、缓存因子数据。 支持从多个数据源(因子、标签、行情)加载并合并。 """ def __init__(self, data_dir: str): self.data_dir = Path(data_dir) self._cache: Dict[str, pl.DataFrame] = {} def load_factors( self, factor_names: List[str], start_date: Optional[str] = None, end_date: Optional[str] = None, stock_codes: Optional[List[str]] = None ) -> pl.DataFrame: """加载因子数据 Args: factor_names: 因子名称列表 start_date: 开始日期 YYYYMMDD end_date: 结束日期 YYYYMMDD stock_codes: 股票代码列表(可选) Returns: DataFrame[trade_date, ts_code, factor1, factor2, ...] """ pass def load_labels( self, label_name: str, forward_period: int = 5, start_date: Optional[str] = None, end_date: Optional[str] = None ) -> pl.DataFrame: """加载标签数据(未来收益) Args: label_name: 标签名称(如 "return", "rank") forward_period: 前瞻期(如5天后收益) start_date: 开始日期 end_date: 结束日期 Returns: DataFrame[trade_date, ts_code, label] """ pass def build_dataset( self, factor_names: List[str], label_config: Dict, date_range: Tuple[str, str], stock_codes: Optional[List[str]] = None, additional_cols: Optional[List[str]] = None ) -> pl.DataFrame: """构建完整数据集 合并因子、标签、辅助列,并对齐数据。 Args: factor_names: 因子列表 label_config: 标签配置 {"name": str, "forward_period": int} date_range: (start_date, end_date) stock_codes: 限定股票列表 additional_cols: 额外列(如 industry, market_cap) Returns: DataFrame[trade_date, ts_code, factor_cols..., label] """ pass ``` #### 3.2.2 ProcessingPipeline (处理流水线) ```python from typing import List import polars as pl class ProcessingPipeline: """数据处理流水线 按顺序执行多个处理器,自动处理阶段标记。 关键特性:在测试阶段使用训练阶段学习到的参数。 """ def __init__(self, processors: List[BaseProcessor]): """初始化流水线 Args: processors: 处理器列表(按执行顺序) """ self.processors = processors self._fitted_processors: Dict[int, BaseProcessor] = {} def fit_transform( self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN ) -> pl.DataFrame: """在训练数据上fit所有处理器并transform Args: data: 训练数据 stage: 当前阶段标记 Returns: 处理后的数据 """ result = data for i, processor in enumerate(self.processors): # 检查处理器是否适用于当前阶段 if processor.stage in [PipelineStage.ALL, stage]: # fit并transform result = processor.fit_transform(result) self._fitted_processors[i] = processor elif stage == PipelineStage.TRAIN: # 即使不适用于TRAIN阶段,也要fit(为TEST阶段准备) if processor.stage == PipelineStage.TEST: processor.fit(result) self._fitted_processors[i] = processor return result def transform( self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST ) -> pl.DataFrame: """在测试数据上应用已fit的处理器 使用训练阶段学习到的参数,防止数据泄露。 Args: data: 测试数据 stage: 当前阶段标记 Returns: 处理后的数据 """ result = data for i, processor in enumerate(self.processors): if processor.stage in [PipelineStage.ALL, stage]: if i in self._fitted_processors: # 使用已fit的处理器 result = self._fitted_processors[i].transform(result) else: # 未fit的处理器(ALL阶段但train时没执行到) result = processor.transform(result) return result def save_processors(self, path: str) -> None: """保存所有已fit的处理器状态""" import pickle with open(path, 'wb') as f: pickle.dump(self._fitted_processors, f) def load_processors(self, path: str) -> None: """加载处理器状态""" import pickle with open(path, 'rb') as f: self._fitted_processors = pickle.load(f) ``` --- ## 4. 插件系统 ### 4.1 注册器模式 ```python from typing import Type, Dict, TypeVar from functools import wraps T = TypeVar('T') class PluginRegistry: """插件注册中心 提供装饰器方式注册处理器、模型、划分策略等组件。 实现真正的插件式架构 - 新功能只需注册即可使用。 """ _processors: Dict[str, Type[BaseProcessor]] = {} _models: Dict[str, Type[BaseModel]] = {} _splitters: Dict[str, Type[BaseSplitter]] = {} _metrics: Dict[str, Type["BaseMetric"]] = {} @classmethod def register_processor(cls, name: Optional[str] = None): """注册处理器装饰器 示例: >>> @PluginRegistry.register_processor("standard_scaler") ... class StandardScaler(BaseProcessor): ... pass >>> # 使用 >>> scaler = PluginRegistry.get_processor("standard_scaler")() """ def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]: key = name or processor_class.__name__ cls._processors[key] = processor_class processor_class._registry_name = key return processor_class return decorator @classmethod def register_model(cls, name: Optional[str] = None): """注册模型装饰器""" def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]: key = name or model_class.__name__ cls._models[key] = model_class model_class._registry_name = key return model_class return decorator @classmethod def register_splitter(cls, name: Optional[str] = None): """注册划分策略装饰器""" def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]: key = name or splitter_class.__name__ cls._splitters[key] = splitter_class return splitter_class return decorator @classmethod def get_processor(cls, name: str) -> Type[BaseProcessor]: """获取处理器类""" if name not in cls._processors: raise KeyError(f"Processor '{name}' not found. Available: {list(cls._processors.keys())}") return cls._processors[name] @classmethod def get_model(cls, name: str) -> Type[BaseModel]: """获取模型类""" if name not in cls._models: raise KeyError(f"Model '{name}' not found. Available: {list(cls._models.keys())}") return cls._models[name] @classmethod def get_splitter(cls, name: str) -> Type[BaseSplitter]: """获取划分策略类""" if name not in cls._splitters: raise KeyError(f"Splitter '{name}' not found. Available: {list(cls._splitters.keys())}") return cls._splitters[name] @classmethod def list_processors(cls) -> List[str]: """列出所有可用处理器""" return list(cls._processors.keys()) @classmethod def list_models(cls) -> List[str]: """列出所有可用模型""" return list(cls._models.keys()) ``` ### 4.2 内置插件 ```python # ========== 内置处理器 ========== @PluginRegistry.register_processor("standard_scaler") class StandardScaler(BaseProcessor): """标准缩放处理器 - Z-score标准化""" stage = PipelineStage.ALL def fit(self, data: pl.DataFrame) -> "StandardScaler": cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] self._fitted_params = { "mean": {c: data[c].mean() for c in cols}, "std": {c: data[c].std() for c in cols}, "columns": cols } return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: result = data for col in self._fitted_params["columns"]: mean = self._fitted_params["mean"][col] std = self._fitted_params["std"][col] if std > 0: result = result.with_columns( ((pl.col(col) - mean) / std).alias(col) ) return result @PluginRegistry.register_processor("winsorizer") class Winsorizer(BaseProcessor): """缩尾处理器 - 防止极端值影响""" stage = PipelineStage.TRAIN # 只在训练阶段计算分位数 def __init__(self, columns=None, lower=0.01, upper=0.99): super().__init__(columns) self.lower = lower self.upper = upper def fit(self, data: pl.DataFrame) -> "Winsorizer": cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] self._fitted_params = { "lower": {c: data[c].quantile(self.lower) for c in cols}, "upper": {c: data[c].quantile(self.upper) for c in cols}, "columns": cols } return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: result = data for col in self._fitted_params["columns"]: lower = self._fitted_params["lower"][col] upper = self._fitted_params["upper"][col] result = result.with_columns( pl.col(col).clip(lower, upper).alias(col) ) return result @PluginRegistry.register_processor("neutralizer") class Neutralizer(BaseProcessor): """行业/市值中性化处理器""" stage = PipelineStage.ALL def __init__(self, columns=None, group_col="industry", exclude_cols=None): super().__init__(columns) self.group_col = group_col self.exclude_cols = exclude_cols or [] def fit(self, data: pl.DataFrame) -> "Neutralizer": # 中性化通常在每个截面独立进行,不需要全局fit return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: # 按日期分组,对每个截面进行中性化 result = data for col in self.columns or []: if col in self.exclude_cols: continue # 分组去均值 result = result.with_columns( (pl.col(col) - pl.col(col).mean().over(["trade_date", self.group_col])) .alias(col) ) return result @PluginRegistry.register_processor("dropna") class DropNAProcessor(BaseProcessor): """缺失值删除处理器""" stage = PipelineStage.ALL def fit(self, data: pl.DataFrame) -> "DropNAProcessor": return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: cols = self.columns or data.columns return data.drop_nulls(subset=cols) @PluginRegistry.register_processor("fillna") class FillNAProcessor(BaseProcessor): """缺失值填充处理器""" stage = PipelineStage.TRAIN def __init__(self, columns=None, method="median"): super().__init__(columns) self.method = method def fit(self, data: pl.DataFrame) -> "FillNAProcessor": cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] fill_values = {} for col in cols: if self.method == "median": fill_values[col] = data[col].median() elif self.method == "mean": fill_values[col] = data[col].mean() elif self.method == "zero": fill_values[col] = 0 self._fitted_params = {"fill_values": fill_values, "columns": cols} return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: result = data for col, val in self._fitted_params["fill_values"].items(): result = result.with_columns(pl.col(col).fill_null(val).alias(col)) return result @PluginRegistry.register_processor("rank_transformer") class RankTransformer(BaseProcessor): """排名转换处理器 - 转换为截面排名""" stage = PipelineStage.ALL def fit(self, data: pl.DataFrame) -> "RankTransformer": return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: result = data for col in self.columns or []: # 按日期分组计算排名 result = result.with_columns( pl.col(col).rank().over("trade_date").alias(col) ) return result # ========== 内置模型 ========== @PluginRegistry.register_model("lightgbm") class LightGBMModel(BaseModel): """LightGBM模型包装器""" def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None): super().__init__(task_type, params, name) self._model = None def fit( self, X: pl.DataFrame, y: pl.Series, X_val: Optional[pl.DataFrame] = None, y_val: Optional[pl.Series] = None, **fit_params ) -> "LightGBMModel": import lightgbm as lgb # 转换数据格式 X_arr = X.to_numpy() y_arr = y.to_numpy() # 构建数据集 train_data = lgb.Dataset(X_arr, label=y_arr) valid_sets = [train_data] if X_val is not None and y_val is not None: valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy()) valid_sets.append(valid_data) # 设置默认参数 default_params = { "objective": self._get_objective(), "metric": self._get_metric(), "boosting_type": "gbdt", "num_leaves": 31, "learning_rate": 0.05, "feature_fraction": 0.9, "bagging_fraction": 0.8, "bagging_freq": 5, "verbose": -1 } default_params.update(self.params) # 训练 self._model = lgb.train( default_params, train_data, num_boost_round=fit_params.get("num_boost_round", 100), valid_sets=valid_sets, callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=False)] if len(valid_sets) > 1 else [] ) self._is_fitted = True return self def predict(self, X: pl.DataFrame) -> np.ndarray: if not self._is_fitted: raise RuntimeError("Model not fitted yet") return self._model.predict(X.to_numpy()) def predict_proba(self, X: pl.DataFrame) -> np.ndarray: if self.task_type != "classification": raise ValueError("predict_proba only for classification") probs = self.predict(X) if len(probs.shape) == 1: return np.vstack([1 - probs, probs]).T return probs def get_feature_importance(self) -> Optional[pl.DataFrame]: if self._model is None: return None importance = self._model.feature_importance(importance_type="gain") return pl.DataFrame({ "feature": self._model.feature_name(), "importance": importance }).sort("importance", descending=True) def _get_objective(self) -> str: if self.task_type == "classification": return "binary" elif self.task_type == "regression": return "regression" elif self.task_type == "ranking": return "lambdarank" return "regression" def _get_metric(self) -> str: if self.task_type == "classification": return "auc" elif self.task_type == "regression": return "rmse" elif self.task_type == "ranking": return "ndcg" return "rmse" @PluginRegistry.register_model("catboost") class CatBoostModel(BaseModel): """CatBoost模型包装器""" def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None): super().__init__(task_type, params, name) self._model = None def fit( self, X: pl.DataFrame, y: pl.Series, X_val: Optional[pl.DataFrame] = None, y_val: Optional[pl.Series] = None, **fit_params ) -> "CatBoostModel": from catboost import CatBoostClassifier, CatBoostRegressor # 选择模型类型 if self.task_type == "classification": model_class = CatBoostClassifier default_params = {"loss_function": "Logloss", "eval_metric": "AUC"} elif self.task_type == "regression": model_class = CatBoostRegressor default_params = {"loss_function": "RMSE"} else: # ranking model_class = CatBoostRegressor default_params = {"loss_function": "QueryRMSE"} default_params.update(self.params) default_params["verbose"] = False self._model = model_class(**default_params) # 准备验证集 eval_set = None if X_val is not None and y_val is not None: eval_set = (X_val.to_pandas(), y_val.to_pandas()) # 训练 self._model.fit( X.to_pandas(), y.to_pandas(), eval_set=eval_set, early_stopping_rounds=10, verbose=False ) self._is_fitted = True return self def predict(self, X: pl.DataFrame) -> np.ndarray: if not self._is_fitted: raise RuntimeError("Model not fitted yet") return self._model.predict(X.to_pandas()) def predict_proba(self, X: pl.DataFrame) -> np.ndarray: if self.task_type != "classification": raise ValueError("predict_proba only for classification") return self._model.predict_proba(X.to_pandas()) def get_feature_importance(self) -> Optional[pl.DataFrame]: if self._model is None: return None return pl.DataFrame({ "feature": self._model.feature_names_, "importance": self._model.feature_importances_ }).sort("importance", descending=True) # ========== 内置划分策略 ========== @PluginRegistry.register_splitter("time_series") class TimeSeriesSplit(BaseSplitter): """时间序列划分 - 确保训练数据在测试数据之前""" def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252): self.n_splits = n_splits self.gap = gap self.min_train_size = min_train_size def split(self, data: pl.DataFrame, date_col: str = "trade_date"): dates = data[date_col].unique().sort() n_dates = len(dates) # 计算每个split的测试集大小 test_size = (n_dates - self.min_train_size) // self.n_splits for i in range(self.n_splits): # 训练集结束位置 train_end_idx = self.min_train_size + i * test_size # 测试集开始位置(留gap防止泄露) test_start_idx = train_end_idx + self.gap test_end_idx = test_start_idx + test_size if test_end_idx > n_dates: break train_dates = dates[:train_end_idx] test_dates = dates[test_start_idx:test_end_idx] train_mask = data[date_col].is_in(train_dates) test_mask = data[date_col].is_in(test_dates) train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list() test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list() yield train_idx, test_idx def get_split_dates(self, data: pl.DataFrame, date_col: str = "trade_date"): dates = data[date_col].unique().sort() n_dates = len(dates) test_size = (n_dates - self.min_train_size) // self.n_splits result = [] for i in range(self.n_splits): train_end_idx = self.min_train_size + i * test_size test_start_idx = train_end_idx + self.gap test_end_idx = test_start_idx + test_size if test_end_idx > n_dates: break result.append(( dates[0], dates[train_end_idx - 1], dates[test_start_idx], dates[test_end_idx - 1] )) return result @PluginRegistry.register_splitter("walk_forward") class WalkForwardSplit(BaseSplitter): """滚动前向验证 - 训练集逐步扩展""" def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5): self.train_window = train_window self.test_window = test_window self.gap = gap def split(self, data: pl.DataFrame, date_col: str = "trade_date"): dates = data[date_col].unique().sort() n_dates = len(dates) start_idx = self.train_window while start_idx + self.gap + self.test_window <= n_dates: train_start = start_idx - self.train_window train_end = start_idx test_start = start_idx + self.gap test_end = test_start + self.test_window train_dates = dates[train_start:train_end] test_dates = dates[test_start:test_end] train_mask = data[date_col].is_in(train_dates) test_mask = data[date_col].is_in(test_dates) train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list() test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list() yield train_idx, test_idx start_idx += self.test_window ``` --- ## 5. 使用示例 ### 5.1 基础用法 ```python from src.models import ( FeatureStore, ProcessingPipeline, PluginRegistry, PipelineStage, MLPipeline ) # 1. 创建数据存储 store = FeatureStore(data_dir="data") # 2. 构建数据集 dataset = store.build_dataset( factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"], label_config={"name": "forward_return", "forward_period": 5}, date_range=("20200101", "20241231") ) # 3. 创建处理流水线 processors = [ # 删除缺失值 PluginRegistry.get_processor("dropna")(), # 异常值处理(只在训练阶段计算分位数) PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99), # 中性化(行业和市值中性化) PluginRegistry.get_processor("neutralizer")(group_col="industry"), # 标准化(训练和测试都使用) PluginRegistry.get_processor("standard_scaler")(), ] pipeline = ProcessingPipeline(processors) # 4. 创建划分策略 splitter = PluginRegistry.get_splitter("time_series")( n_splits=5, gap=5, min_train_size=252 ) # 5. 创建模型 model = PluginRegistry.get_model("lightgbm")( task_type="regression", params={"n_estimators": 200, "learning_rate": 0.03} ) # 6. 运行完整流程 ml_pipeline = MLPipeline( feature_store=store, processing_pipeline=pipeline, splitter=splitter, model=model ) results = ml_pipeline.run( factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"], label_config={"name": "forward_return", "forward_period": 5}, date_range=("20200101", "20241231") ) # 7. 查看结果 print(results.metrics) # 各折的评估指标 print(results.feature_importance) # 特征重要性 print(results.predictions) # 预测结果 ``` ### 5.2 配置驱动用法(推荐) ```python # config.yaml experiment: name: "momentum_factor_regression" data: factor_names: ["momentum_5", "momentum_20", "momentum_60", "volatility_20"] label: name: "forward_return" forward_period: 5 date_range: ["20200101", "20241231"] processing: - name: "dropna" params: {} stage: "all" - name: "winsorizer" params: lower: 0.01 upper: 0.99 stage: "train" # 只在训练阶段计算分位数 - name: "neutralizer" params: group_col: "industry" stage: "all" - name: "standard_scaler" params: {} stage: "all" splitting: strategy: "time_series" params: n_splits: 5 gap: 5 min_train_size: 252 model: name: "lightgbm" task_type: "regression" params: n_estimators: 200 learning_rate: 0.03 max_depth: 6 evaluation: metrics: ["ic", "rank_ic", "mse", "mae"] output_dir: "results/momentum_experiment" ``` ```python # 代码中使用配置 from src.models import MLPipeline pipeline = MLPipeline.from_config("config.yaml") results = pipeline.run() # 保存结果 results.save("results/momentum_experiment") ``` ### 5.3 自定义插件 ```python # 1. 创建自定义处理器 @PluginRegistry.register_processor("my_transformer") class MyTransformer(BaseProcessor): """自定义转换器示例""" stage = PipelineStage.ALL def __init__(self, columns=None, multiplier=2.0): super().__init__(columns) self.multiplier = multiplier def fit(self, data: pl.DataFrame) -> "MyTransformer": # 学习参数(如有需要) return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: result = data for col in self.columns or []: result = result.with_columns( (pl.col(col) * self.multiplier).alias(col) ) return result # 2. 创建自定义模型 @PluginRegistry.register_model("my_model") class MyModel(BaseModel): """自定义模型示例""" def fit(self, X, y, X_val=None, y_val=None, **kwargs): # 实现训练逻辑 self._model = ... return self def predict(self, X): # 实现预测逻辑 return self._model.predict(X) # 3. 在配置中使用 # config.yaml processing: - name: "my_transformer" params: multiplier: 3.0 stage: "all" model: name: "my_model" task_type: "regression" ``` --- ## 6. 目录结构 ``` src/ ├── models/ # 模型训练框架 │ ├── __init__.py # 导出主要类 │ ├── core/ # 核心抽象和基类 │ │ ├── __init__.py │ │ ├── processor.py # BaseProcessor, PipelineStage │ │ ├── model.py # BaseModel, TaskType │ │ ├── splitter.py # BaseSplitter │ │ ├── metric.py # BaseMetric │ │ └── pipeline.py # MLPipeline (编排器) │ │ │ ├── registry.py # PluginRegistry 插件注册中心 │ │ │ ├── data/ # 数据相关 │ │ ├── __init__.py │ │ ├── feature_store.py # FeatureStore 特征存储 │ │ ├── label_generator.py # LabelGenerator 标签生成 │ │ └── dataset.py # Dataset 数据集包装 │ │ │ ├── processors/ # 内置处理器 │ │ ├── __init__.py # 自动注册所有处理器 │ │ ├── scaler.py # StandardScaler │ │ ├── winsorizer.py # Winsorizer │ │ ├── neutralizer.py # Neutralizer │ │ ├── imputer.py # FillNAProcessor │ │ ├── selector.py # FeatureSelector │ │ └── custom.py # 其他处理器 │ │ │ ├── models/ # 内置模型 │ │ ├── __init__.py # 自动注册所有模型 │ │ ├── lightgbm_model.py # LightGBMModel │ │ ├── catboost_model.py # CatBoostModel │ │ └── sklearn_model.py # SklearnModel (LR, RF等) │ │ │ ├── splitters/ # 划分策略 │ │ ├── __init__.py │ │ ├── time_series.py # TimeSeriesSplit │ │ ├── walk_forward.py # WalkForwardSplit │ │ └── purged.py # PurgedKFold │ │ │ ├── metrics/ # 评估指标 │ │ ├── __init__.py │ │ ├── ic.py # IC, RankIC │ │ ├── returns.py # 收益指标 │ │ └── classification.py # 分类指标 │ │ │ ├── evaluation/ # 评估和报告 │ │ ├── __init__.py │ │ ├── evaluator.py # ModelEvaluator │ │ ├── report.py # ReportGenerator │ │ └── visualizer.py # ResultVisualizer │ │ │ └── config/ # 配置解析 │ ├── __init__.py │ └── parser.py # ConfigParser │ ├── factors/ # 已有因子框架 │ └── ... │ tests/ ├── models/ # 模型框架测试 │ ├── __init__.py │ ├── test_processors.py # 处理器测试 │ ├── test_models.py # 模型测试 │ ├── test_pipeline.py # 流水线集成测试 │ └── test_registry.py # 注册器测试 │ └── factors/ # 已有因子测试 └── ... configs/ # 配置文件目录 ├── momentum_regression.yaml ├── value_classification.yaml └├── ranking_lambdamart.yaml experiments/ # 实验结果目录 └── {experiment_name}/ ├── config.yaml # 实验配置 ├── model.pkl # 保存的模型 ├── processors.pkl # 保存的处理器状态 ├── predictions.parquet # 预测结果 ├── metrics.json # 评估指标 ├── feature_importance.csv # 特征重要性 └── report.html # 可视化报告 ``` --- ## 7. 开发计划 ### Phase 1: 核心基础设施 (Week 1-2) - [ ] 设计并实现 `BaseProcessor`, `BaseModel`, `BaseSplitter` 抽象类 - [ ] 实现 `PluginRegistry` 注册中心 - [ ] 实现 `PipelineStage` 阶段管理 - [ ] 编写基础单元测试 ### Phase 2: 数据层 (Week 2-3) - [ ] 实现 `FeatureStore` 特征存储 - [ ] 实现 `LabelGenerator` 标签生成器 - [ ] 实现 `Dataset` 数据集包装 - [ ] 集成现有因子框架输出 ### Phase 3: 处理器 (Week 3-4) - [ ] 实现 `StandardScaler` 标准化处理器 - [ ] 实现 `Winsorizer` 缩尾处理器 - [ ] 实现 `Neutralizer` 中性化处理器 - [ ] 实现 `FillNAProcessor` 缺失值处理器 - [ ] 实现 `DropNAProcessor` 缺失值删除处理器 - [ ] 实现 `FeatureSelector` 特征选择器 - [ ] 实现 `ProcessingPipeline` 流水线 ### Phase 4: 模型层 (Week 4-5) - [ ] 实现 `LightGBMModel` LightGBM包装 - [ ] 实现 `CatBoostModel` CatBoost包装 - [ ] 实现 `SklearnModel` sklearn模型支持 - [ ] 支持 classification/regression/ranking 三种任务 ### Phase 5: 划分策略 (Week 5) - [ ] 实现 `TimeSeriesSplit` 时间序列划分 - [ ] 实现 `WalkForwardSplit` 滚动前向验证 - [ ] 实现 `PurgedKFold` 清除重叠样本 ### Phase 6: 评估层 (Week 5-6) - [ ] 实现 IC/RankIC 指标 - [ ] 实现收益分析指标 - [ ] 实现分类指标 - [ ] 实现 `ModelEvaluator` 评估器 - [ ] 实现 `ReportGenerator` 报告生成 ### Phase 7: 配置和编排 (Week 6) - [ ] 实现配置解析器 - [ ] 实现 `MLPipeline` 编排器 - [ ] 支持配置驱动执行 ### Phase 8: 集成测试和文档 (Week 7) - [ ] 编写完整集成测试 - [ ] 编写使用文档 - [ ] 编写示例代码 - [ ] 性能基准测试 --- ## 8. 关键设计决策 | 决策点 | 选择 | 理由 | |--------|------|------| | **数据处理阶段标记** | `PipelineStage` 枚举 | 显式、类型安全、易于扩展 | | **插件注册方式** | 装饰器模式 | Pythonic、简洁、自动发现 | | **数据格式** | Polars DataFrame | 与因子框架一致、高性能 | | **模型接口** | `fit/predict` 统一接口 | 行业标准、易于替换模型 | | **配置格式** | YAML | 人类可读、支持复杂结构 | | **处理器状态保存** | pickle | 简单、Python原生、支持大部分对象 | | **特征存储** | 从因子框架直接读取 | 避免数据冗余、保持一致性 | --- ## 9. 防数据泄露检查清单 - [x] 处理器明确标记适用阶段 (`stage` 属性) - [x] `TRAIN` 阶段处理器只在训练数据上 `fit` - [x] `TEST` 阶段使用训练阶段学习到的参数 - [x] 划分策略支持时间序列感知 (`TimeSeriesSplit`, `WalkForwardSplit`) - [x] 划分时支持 `gap` 参数防止相邻样本泄露 - [x] 特征存储从已计算的因子加载(不访问未来数据) - [x] 标签生成使用预定义的前瞻期(明确的future data) --- *文档版本: v1.0* *最后更新: 2026-02-23* *设计状态: 草案 - 待评审*