Files
ProStock/docs/ml_framework_design.md
liaozhaorun 9f95be56a0 feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类
- 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露
- 内置 8 个数据处理器和 3 种时序划分策略
- 支持 LightGBM、CatBoost 模型
- PluginRegistry 装饰器注册,插件式架构
- 22 个单元测试
2026-02-23 01:37:34 +08:00

54 KiB
Raw Blame History

ProStock 模型训练框架设计文档

1. 设计目标与原则

1.1 核心目标

  • 组件化:每个阶段(数据获取、处理、训练、评估)都是独立组件
  • 低耦合:组件间通过标准接口交互,不依赖具体实现
  • 插件式:新功能通过插件注册,无需修改核心代码
  • 阶段感知:数据处理区分训练阶段和测试阶段,防止数据泄露
  • 多模型支持:统一接口支持 LightGBM、CatBoost 等多种模型
  • 多任务支持:分类、回归、排序三种任务类型

1.2 设计原则

原则 说明
单一职责 每个组件只做一件事,做好一件事
开闭原则 对扩展开放(插件),对修改封闭(核心)
依赖倒置 依赖抽象接口,而非具体实现
显式优于隐式 阶段标记、处理逻辑必须显式声明
配置驱动 通过配置文件或代码配置定义流程,减少硬编码

2. 整体架构

2.1 架构概览

┌─────────────────────────────────────────────────────────────────────────┐
│                         ML Pipeline Orchestrator                         │
│                    (流水线编排器 - 配置驱动执行)                          │
└─────────────────────────────────────────────────────────────────────────┘
                                    │
        ┌───────────────────────────┼───────────────────────────┐
        ▼                           ▼                           ▼
┌───────────────┐          ┌───────────────┐          ┌───────────────┐
│  Data Source  │          │  Data Source  │          │  Data Source  │
│   (因子数据)   │          │  (行情数据)    │          │  (标签数据)    │
└───────┬───────┘          └───────┬───────┘          └───────┬───────┘
        │                          │                          │
        └──────────────────────────┼──────────────────────────┘
                                   ▼
┌─────────────────────────────────────────────────────────────────────────┐
│                     Feature Store (特征存储层)                           │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐  │
│  │ FactorLoader │  │  LabelLoader │  │  DataMerger  │  │  CacheMgr    │  │
│  └──────────────┘  └──────────────┘  └──────────────┘  └──────────────┘  │
└─────────────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
┌─────────────────────────────────────────────────────────────────────────┐
│                    Processing Pipeline (处理流水线)                       │
│                                                                         │
│   ┌─────────────┐    ┌─────────────┐    ┌─────────────┐    ┌──────────┐ │
│   │  Processor  │ -> │  Processor  │ -> │  Processor  │ -> │  ...     │ │
│   │  (阶段:ALL)  │    │  (阶段:TRAIN)│    │  (阶段:TEST) │    │          │ │
│   └─────────────┘    └─────────────┘    └─────────────┘    └──────────┘ │
│                                                                         │
│   处理器类型:                                                            │
│   - FeatureEncoder: 特征编码(类别编码、数值缩放等)                       │
│   - FeatureSelector: 特征选择(相关性过滤、重要性筛选等)                   │
│   - OutlierHandler: 异常值处理                                           │
│   - MissingValueHandler: 缺失值处理                                      │
│   - CustomTransformer: 自定义转换器                                      │
└─────────────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
┌─────────────────────────────────────────────────────────────────────────┐
│                    Train/Test Split (数据划分)                           │
│                                                                         │
│   支持多种划分策略:                                                       │
│   - TimeSeriesSplit: 时间序列划分(防止未来泄露)                         │
│   - PurgedKFold: 清除重叠样本的K折交叉验证                               │
│   - EmbargoSplit:  embargo 延迟验证                                     │
│   - CustomSplit: 自定义划分策略                                          │
└─────────────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
┌─────────────────────────────────────────────────────────────────────────┐
│                    Model Training (模型训练层)                           │
│                                                                         │
│   ┌─────────────────────────────────────────────────────────────────┐   │
│   │                     Model Registry                              │   │
│   │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐        │   │
│   │  │ LightGBM │  │CatBoost  │  │ XGBoost  │  │ Custom   │ ...    │   │
│   │  │  Model   │  │  Model   │  │  Model   │  │  Model   │        │   │
│   │  └──────────┘  └──────────┘  └──────────┘  └──────────┘        │   │
│   └─────────────────────────────────────────────────────────────────┘   │
│                                                                         │
│   任务类型:                                                              │
│   - Classification: 分类任务(上涨/下跌预测)                             │
│   - Regression: 回归任务(收益率预测)                                   │
│   - Ranking: 排序任务(股票排序/选股)                                   │
└─────────────────────────────────────────────────────────────────────────┘
                                   │
                                   ▼
┌─────────────────────────────────────────────────────────────────────────┐
│                    Evaluation (评估层)                                   │
│                                                                         │
│   ┌──────────────┐  ┌──────────────┐  ┌──────────────┐  ┌────────────┐  │
│   │  Metric      │  │  Metric      │  │  Metric      │  │  Analyzer  │  │
│   │  (IC/IR)     │  │  (Sharpe)    │  │  (Accuracy)  │  │  (回测)     │  │
│   └──────────────┘  └──────────────┘  └──────────────┘  └────────────┘  │
│                                                                         │
│   ┌──────────────┐  ┌──────────────┐  ┌──────────────┐                  │
│   │  ResultStore │  │  Report      │  │  Visualizer  │                  │
│   │  (模型存储)   │  │  (报告生成)   │  │  (可视化)    │                  │
│   └──────────────┘  └──────────────┘  └──────────────┘                  │
└─────────────────────────────────────────────────────────────────────────┘

2.2 数据流向图

因子DataFrame (Polars)
    │
    ▼
┌──────────────────────┐
│  Feature Store       │  1. 加载并合并因子、标签、辅助数据
│  - 列选择            │  2. 支持按日期/股票过滤
│  - 数据对齐          │  3. 缓存机制避免重复加载
└──────────┬───────────┘
           │
           ▼
┌──────────────────────┐
│  Processing Pipeline │  顺序执行多个处理器
│                      │  每个处理器标记适用阶段 (ALL/TRAIN/TEST)
│  for processor in pipeline:
│      if processor.stage in [current_stage, ALL]:
│          data = processor.transform(data)
└──────────┬───────────┘
           │
           ▼
┌──────────────────────┐
│  Data Splitter       │  时间序列感知的划分策略
│  - X_train, y_train  │  防止未来泄露
│  - X_test, y_test    │
└──────────┬───────────┘
           │
           ▼
┌──────────────────────┐
│  Model Training      │  统一接口,支持多种模型
│  - fit(X_train)      │  任务类型: classification/regression/ranking
│  - predict(X_test)   │
└──────────┬───────────┘
           │
           ▼
┌──────────────────────┐
│  Evaluation          │  多维度评估
│  - 预测指标          │  - IC/IR
│  - 回测指标          │  - 分组收益
│  - 可视化            │  - 累计收益曲线
└──────────────────────┘

3. 核心组件设计

3.1 基础抽象类

3.1.1 PipelineStage (流水线阶段枚举)

from enum import Enum, auto

class PipelineStage(Enum):
    """流水线阶段标记"""
    ALL = auto()      # 适用于所有阶段
    TRAIN = auto()    # 仅训练阶段
    TEST = auto()     # 仅测试阶段
    VALIDATION = auto()  # 仅验证阶段

3.1.2 BaseProcessor (处理器基类)

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import polars as pl

class BaseProcessor(ABC):
    """数据处理器基类
    
    所有数据处理器必须继承此类。
    关键特性:通过 stage 属性控制处理器在哪些阶段生效。
    
    示例:
        >>> class StandardScaler(BaseProcessor):
        ...     stage = PipelineStage.ALL  # 训练和测试都使用
        ...     
        ...     def fit(self, data: pl.DataFrame) -> None:
        ...         self.mean = data[self.columns].mean()
        ...         self.std = data[self.columns].std()
        ...     
        ...     def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        ...         return (data - self.mean) / self.std
    """
    
    # 子类必须定义适用阶段
    stage: PipelineStage = PipelineStage.ALL
    
    def __init__(self, columns: Optional[list] = None, **params):
        """初始化处理器
        
        Args:
            columns: 要处理的列None表示所有数值列
            **params: 处理器特定参数
        """
        self.columns = columns
        self.params = params
        self._is_fitted = False
        self._fitted_params: Dict[str, Any] = {}
    
    @abstractmethod
    def fit(self, data: pl.DataFrame) -> "BaseProcessor":
        """在训练数据上学习参数
        
        此方法只在训练阶段调用一次。
        学习到的参数存储在 self._fitted_params 中。
        
        Args:
            data: 训练数据
            
        Returns:
            self (支持链式调用)
        """
        pass
    
    @abstractmethod
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        """转换数据
        
        在训练和测试阶段都会被调用。
        使用 fit() 阶段学习到的参数进行转换。
        
        Args:
            data: 输入数据
            
        Returns:
            转换后的数据
        """
        pass
    
    def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
        """先fit再transform的便捷方法"""
        return self.fit(data).transform(data)
    
    def get_fitted_params(self) -> Dict[str, Any]:
        """获取学习到的参数(用于保存/加载)"""
        return self._fitted_params.copy()
    
    def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
        """设置学习到的参数用于从checkpoint恢复"""
        self._fitted_params = params.copy()
        self._is_fitted = True
        return self

3.1.3 BaseModel (模型基类)

from abc import ABC, abstractmethod
from typing import Literal, Any, Dict
import polars as pl
import numpy as np

TaskType = Literal["classification", "regression", "ranking"]

class BaseModel(ABC):
    """机器学习模型基类
    
    统一接口支持多种模型LightGBM, CatBoost, XGBoost等
    和多种任务类型(分类、回归、排序)。
    
    示例:
        >>> model = LightGBMModel(
        ...     task_type="classification",
        ...     params={"n_estimators": 100}
        ... )
        >>> model.fit(X_train, y_train)
        >>> predictions = model.predict(X_test)
    """
    
    def __init__(
        self,
        task_type: TaskType,
        params: Optional[Dict[str, Any]] = None,
        name: Optional[str] = None
    ):
        """初始化模型
        
        Args:
            task_type: 任务类型 - "classification", "regression", "ranking"
            params: 模型特定参数
            name: 模型名称(用于日志和报告)
        """
        self.task_type = task_type
        self.params = params or {}
        self.name = name or self.__class__.__name__
        self._model: Any = None
        self._is_fitted = False
    
    @abstractmethod
    def fit(
        self,
        X: pl.DataFrame,
        y: pl.Series,
        X_val: Optional[pl.DataFrame] = None,
        y_val: Optional[pl.Series] = None,
        **fit_params
    ) -> "BaseModel":
        """训练模型
        
        Args:
            X: 特征数据
            y: 目标变量
            X_val: 验证集特征(可选)
            y_val: 验证集目标(可选)
            **fit_params: 额外的fit参数
            
        Returns:
            self (支持链式调用)
        """
        pass
    
    @abstractmethod
    def predict(self, X: pl.DataFrame) -> np.ndarray:
        """预测
        
        Args:
            X: 特征数据
            
        Returns:
            预测结果数组
            - classification: 类别标签或概率
            - regression: 连续值
            - ranking: 排序分数
        """
        pass
    
    def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
        """预测概率(仅分类任务)
        
        Args:
            X: 特征数据
            
        Returns:
            类别概率数组 [n_samples, n_classes]
        """
        raise NotImplementedError("predict_proba only available for classification tasks")
    
    def get_feature_importance(self) -> Optional[pl.DataFrame]:
        """获取特征重要性(如果模型支持)
        
        Returns:
            DataFrame[feature, importance] 或 None
        """
        return None
    
    def save(self, path: str) -> None:
        """保存模型到文件"""
        import pickle
        with open(path, 'wb') as f:
            pickle.dump(self, f)
    
    @classmethod
    def load(cls, path: str) -> "BaseModel":
        """从文件加载模型"""
        import pickle
        with open(path, 'rb') as f:
            return pickle.load(f)

3.1.4 BaseSplitter (数据划分基类)

from abc import ABC, abstractmethod
from typing import Iterator, Tuple, List
import polars as pl

class BaseSplitter(ABC):
    """数据划分策略基类
    
    针对时间序列数据的特殊划分策略,防止未来泄露。
    
    示例:
        >>> splitter = TimeSeriesSplit(n_splits=5, gap=5)
        >>> for train_idx, test_idx in splitter.split(data):
        ...     X_train, X_test = X[train_idx], X[test_idx]
    """
    
    @abstractmethod
    def split(
        self,
        data: pl.DataFrame,
        date_col: str = "trade_date"
    ) -> Iterator[Tuple[List[int], List[int]]]:
        """生成训练/测试索引
        
        Args:
            data: 完整数据集
            date_col: 日期列名
            
        Yields:
            (train_indices, test_indices) 元组
        """
        pass
    
    @abstractmethod
    def get_split_dates(
        self,
        data: pl.DataFrame,
        date_col: str = "trade_date"
    ) -> List[Tuple[str, str, str, str]]:
        """获取划分日期范围
        
        Returns:
            [(train_start, train_end, test_start, test_end), ...]
        """
        pass

3.2 核心组件

3.2.1 FeatureStore (特征存储)

from typing import List, Optional, Dict
import polars as pl
from pathlib import Path

class FeatureStore:
    """特征存储管理器
    
    负责加载、合并、缓存因子数据。
    支持从多个数据源(因子、标签、行情)加载并合并。
    """
    
    def __init__(self, data_dir: str):
        self.data_dir = Path(data_dir)
        self._cache: Dict[str, pl.DataFrame] = {}
    
    def load_factors(
        self,
        factor_names: List[str],
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        stock_codes: Optional[List[str]] = None
    ) -> pl.DataFrame:
        """加载因子数据
        
        Args:
            factor_names: 因子名称列表
            start_date: 开始日期 YYYYMMDD
            end_date: 结束日期 YYYYMMDD
            stock_codes: 股票代码列表(可选)
            
        Returns:
            DataFrame[trade_date, ts_code, factor1, factor2, ...]
        """
        pass
    
    def load_labels(
        self,
        label_name: str,
        forward_period: int = 5,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None
    ) -> pl.DataFrame:
        """加载标签数据(未来收益)
        
        Args:
            label_name: 标签名称(如 "return", "rank"
            forward_period: 前瞻期如5天后收益
            start_date: 开始日期
            end_date: 结束日期
            
        Returns:
            DataFrame[trade_date, ts_code, label]
        """
        pass
    
    def build_dataset(
        self,
        factor_names: List[str],
        label_config: Dict,
        date_range: Tuple[str, str],
        stock_codes: Optional[List[str]] = None,
        additional_cols: Optional[List[str]] = None
    ) -> pl.DataFrame:
        """构建完整数据集
        
        合并因子、标签、辅助列,并对齐数据。
        
        Args:
            factor_names: 因子列表
            label_config: 标签配置 {"name": str, "forward_period": int}
            date_range: (start_date, end_date)
            stock_codes: 限定股票列表
            additional_cols: 额外列(如 industry, market_cap
            
        Returns:
            DataFrame[trade_date, ts_code, factor_cols..., label]
        """
        pass

3.2.2 ProcessingPipeline (处理流水线)

from typing import List
import polars as pl

class ProcessingPipeline:
    """数据处理流水线
    
    按顺序执行多个处理器,自动处理阶段标记。
    关键特性:在测试阶段使用训练阶段学习到的参数。
    """
    
    def __init__(self, processors: List[BaseProcessor]):
        """初始化流水线
        
        Args:
            processors: 处理器列表(按执行顺序)
        """
        self.processors = processors
        self._fitted_processors: Dict[int, BaseProcessor] = {}
    
    def fit_transform(
        self,
        data: pl.DataFrame,
        stage: PipelineStage = PipelineStage.TRAIN
    ) -> pl.DataFrame:
        """在训练数据上fit所有处理器并transform
        
        Args:
            data: 训练数据
            stage: 当前阶段标记
            
        Returns:
            处理后的数据
        """
        result = data
        for i, processor in enumerate(self.processors):
            # 检查处理器是否适用于当前阶段
            if processor.stage in [PipelineStage.ALL, stage]:
                # fit并transform
                result = processor.fit_transform(result)
                self._fitted_processors[i] = processor
            elif stage == PipelineStage.TRAIN:
                # 即使不适用于TRAIN阶段也要fit为TEST阶段准备
                if processor.stage == PipelineStage.TEST:
                    processor.fit(result)
                    self._fitted_processors[i] = processor
        return result
    
    def transform(
        self,
        data: pl.DataFrame,
        stage: PipelineStage = PipelineStage.TEST
    ) -> pl.DataFrame:
        """在测试数据上应用已fit的处理器
        
        使用训练阶段学习到的参数,防止数据泄露。
        
        Args:
            data: 测试数据
            stage: 当前阶段标记
            
        Returns:
            处理后的数据
        """
        result = data
        for i, processor in enumerate(self.processors):
            if processor.stage in [PipelineStage.ALL, stage]:
                if i in self._fitted_processors:
                    # 使用已fit的处理器
                    result = self._fitted_processors[i].transform(result)
                else:
                    # 未fit的处理器ALL阶段但train时没执行到
                    result = processor.transform(result)
        return result
    
    def save_processors(self, path: str) -> None:
        """保存所有已fit的处理器状态"""
        import pickle
        with open(path, 'wb') as f:
            pickle.dump(self._fitted_processors, f)
    
    def load_processors(self, path: str) -> None:
        """加载处理器状态"""
        import pickle
        with open(path, 'rb') as f:
            self._fitted_processors = pickle.load(f)

4. 插件系统

4.1 注册器模式

from typing import Type, Dict, TypeVar
from functools import wraps

T = TypeVar('T')

class PluginRegistry:
    """插件注册中心
    
    提供装饰器方式注册处理器、模型、划分策略等组件。
    实现真正的插件式架构 - 新功能只需注册即可使用。
    """
    
    _processors: Dict[str, Type[BaseProcessor]] = {}
    _models: Dict[str, Type[BaseModel]] = {}
    _splitters: Dict[str, Type[BaseSplitter]] = {}
    _metrics: Dict[str, Type["BaseMetric"]] = {}
    
    @classmethod
    def register_processor(cls, name: Optional[str] = None):
        """注册处理器装饰器
        
        示例:
            >>> @PluginRegistry.register_processor("standard_scaler")
            ... class StandardScaler(BaseProcessor):
            ...     pass
            
            >>> # 使用
            >>> scaler = PluginRegistry.get_processor("standard_scaler")()
        """
        def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
            key = name or processor_class.__name__
            cls._processors[key] = processor_class
            processor_class._registry_name = key
            return processor_class
        return decorator
    
    @classmethod
    def register_model(cls, name: Optional[str] = None):
        """注册模型装饰器"""
        def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
            key = name or model_class.__name__
            cls._models[key] = model_class
            model_class._registry_name = key
            return model_class
        return decorator
    
    @classmethod
    def register_splitter(cls, name: Optional[str] = None):
        """注册划分策略装饰器"""
        def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
            key = name or splitter_class.__name__
            cls._splitters[key] = splitter_class
            return splitter_class
        return decorator
    
    @classmethod
    def get_processor(cls, name: str) -> Type[BaseProcessor]:
        """获取处理器类"""
        if name not in cls._processors:
            raise KeyError(f"Processor '{name}' not found. Available: {list(cls._processors.keys())}")
        return cls._processors[name]
    
    @classmethod
    def get_model(cls, name: str) -> Type[BaseModel]:
        """获取模型类"""
        if name not in cls._models:
            raise KeyError(f"Model '{name}' not found. Available: {list(cls._models.keys())}")
        return cls._models[name]
    
    @classmethod
    def get_splitter(cls, name: str) -> Type[BaseSplitter]:
        """获取划分策略类"""
        if name not in cls._splitters:
            raise KeyError(f"Splitter '{name}' not found. Available: {list(cls._splitters.keys())}")
        return cls._splitters[name]
    
    @classmethod
    def list_processors(cls) -> List[str]:
        """列出所有可用处理器"""
        return list(cls._processors.keys())
    
    @classmethod
    def list_models(cls) -> List[str]:
        """列出所有可用模型"""
        return list(cls._models.keys())

4.2 内置插件

# ========== 内置处理器 ==========

@PluginRegistry.register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
    """标准缩放处理器 - Z-score标准化"""
    stage = PipelineStage.ALL
    
    def fit(self, data: pl.DataFrame) -> "StandardScaler":
        cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
        self._fitted_params = {
            "mean": {c: data[c].mean() for c in cols},
            "std": {c: data[c].std() for c in cols},
            "columns": cols
        }
        return self
    
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        result = data
        for col in self._fitted_params["columns"]:
            mean = self._fitted_params["mean"][col]
            std = self._fitted_params["std"][col]
            if std > 0:
                result = result.with_columns(
                    ((pl.col(col) - mean) / std).alias(col)
                )
        return result


@PluginRegistry.register_processor("winsorizer")
class Winsorizer(BaseProcessor):
    """缩尾处理器 - 防止极端值影响"""
    stage = PipelineStage.TRAIN  # 只在训练阶段计算分位数
    
    def __init__(self, columns=None, lower=0.01, upper=0.99):
        super().__init__(columns)
        self.lower = lower
        self.upper = upper
    
    def fit(self, data: pl.DataFrame) -> "Winsorizer":
        cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
        self._fitted_params = {
            "lower": {c: data[c].quantile(self.lower) for c in cols},
            "upper": {c: data[c].quantile(self.upper) for c in cols},
            "columns": cols
        }
        return self
    
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        result = data
        for col in self._fitted_params["columns"]:
            lower = self._fitted_params["lower"][col]
            upper = self._fitted_params["upper"][col]
            result = result.with_columns(
                pl.col(col).clip(lower, upper).alias(col)
            )
        return result


@PluginRegistry.register_processor("neutralizer")
class Neutralizer(BaseProcessor):
    """行业/市值中性化处理器"""
    stage = PipelineStage.ALL
    
    def __init__(self, columns=None, group_col="industry", exclude_cols=None):
        super().__init__(columns)
        self.group_col = group_col
        self.exclude_cols = exclude_cols or []
    
    def fit(self, data: pl.DataFrame) -> "Neutralizer":
        # 中性化通常在每个截面独立进行不需要全局fit
        return self
    
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        # 按日期分组,对每个截面进行中性化
        result = data
        for col in self.columns or []:
            if col in self.exclude_cols:
                continue
            # 分组去均值
            result = result.with_columns(
                (pl.col(col) - pl.col(col).mean().over(["trade_date", self.group_col]))
                .alias(col)
            )
        return result


@PluginRegistry.register_processor("dropna")
class DropNAProcessor(BaseProcessor):
    """缺失值删除处理器"""
    stage = PipelineStage.ALL
    
    def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
        return self
    
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        cols = self.columns or data.columns
        return data.drop_nulls(subset=cols)


@PluginRegistry.register_processor("fillna")
class FillNAProcessor(BaseProcessor):
    """缺失值填充处理器"""
    stage = PipelineStage.TRAIN
    
    def __init__(self, columns=None, method="median"):
        super().__init__(columns)
        self.method = method
    
    def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
        cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
        fill_values = {}
        for col in cols:
            if self.method == "median":
                fill_values[col] = data[col].median()
            elif self.method == "mean":
                fill_values[col] = data[col].mean()
            elif self.method == "zero":
                fill_values[col] = 0
        self._fitted_params = {"fill_values": fill_values, "columns": cols}
        return self
    
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        result = data
        for col, val in self._fitted_params["fill_values"].items():
            result = result.with_columns(pl.col(col).fill_null(val).alias(col))
        return result


@PluginRegistry.register_processor("rank_transformer")
class RankTransformer(BaseProcessor):
    """排名转换处理器 - 转换为截面排名"""
    stage = PipelineStage.ALL
    
    def fit(self, data: pl.DataFrame) -> "RankTransformer":
        return self
    
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        result = data
        for col in self.columns or []:
            # 按日期分组计算排名
            result = result.with_columns(
                pl.col(col).rank().over("trade_date").alias(col)
            )
        return result


# ========== 内置模型 ==========

@PluginRegistry.register_model("lightgbm")
class LightGBMModel(BaseModel):
    """LightGBM模型包装器"""
    
    def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None):
        super().__init__(task_type, params, name)
        self._model = None
    
    def fit(
        self,
        X: pl.DataFrame,
        y: pl.Series,
        X_val: Optional[pl.DataFrame] = None,
        y_val: Optional[pl.Series] = None,
        **fit_params
    ) -> "LightGBMModel":
        import lightgbm as lgb
        
        # 转换数据格式
        X_arr = X.to_numpy()
        y_arr = y.to_numpy()
        
        # 构建数据集
        train_data = lgb.Dataset(X_arr, label=y_arr)
        valid_sets = [train_data]
        
        if X_val is not None and y_val is not None:
            valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
            valid_sets.append(valid_data)
        
        # 设置默认参数
        default_params = {
            "objective": self._get_objective(),
            "metric": self._get_metric(),
            "boosting_type": "gbdt",
            "num_leaves": 31,
            "learning_rate": 0.05,
            "feature_fraction": 0.9,
            "bagging_fraction": 0.8,
            "bagging_freq": 5,
            "verbose": -1
        }
        default_params.update(self.params)
        
        # 训练
        self._model = lgb.train(
            default_params,
            train_data,
            num_boost_round=fit_params.get("num_boost_round", 100),
            valid_sets=valid_sets,
            callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=False)] if len(valid_sets) > 1 else []
        )
        self._is_fitted = True
        return self
    
    def predict(self, X: pl.DataFrame) -> np.ndarray:
        if not self._is_fitted:
            raise RuntimeError("Model not fitted yet")
        return self._model.predict(X.to_numpy())
    
    def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
        if self.task_type != "classification":
            raise ValueError("predict_proba only for classification")
        probs = self.predict(X)
        if len(probs.shape) == 1:
            return np.vstack([1 - probs, probs]).T
        return probs
    
    def get_feature_importance(self) -> Optional[pl.DataFrame]:
        if self._model is None:
            return None
        importance = self._model.feature_importance(importance_type="gain")
        return pl.DataFrame({
            "feature": self._model.feature_name(),
            "importance": importance
        }).sort("importance", descending=True)
    
    def _get_objective(self) -> str:
        if self.task_type == "classification":
            return "binary"
        elif self.task_type == "regression":
            return "regression"
        elif self.task_type == "ranking":
            return "lambdarank"
        return "regression"
    
    def _get_metric(self) -> str:
        if self.task_type == "classification":
            return "auc"
        elif self.task_type == "regression":
            return "rmse"
        elif self.task_type == "ranking":
            return "ndcg"
        return "rmse"


@PluginRegistry.register_model("catboost")
class CatBoostModel(BaseModel):
    """CatBoost模型包装器"""
    
    def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None):
        super().__init__(task_type, params, name)
        self._model = None
    
    def fit(
        self,
        X: pl.DataFrame,
        y: pl.Series,
        X_val: Optional[pl.DataFrame] = None,
        y_val: Optional[pl.Series] = None,
        **fit_params
    ) -> "CatBoostModel":
        from catboost import CatBoostClassifier, CatBoostRegressor
        
        # 选择模型类型
        if self.task_type == "classification":
            model_class = CatBoostClassifier
            default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
        elif self.task_type == "regression":
            model_class = CatBoostRegressor
            default_params = {"loss_function": "RMSE"}
        else:  # ranking
            model_class = CatBoostRegressor
            default_params = {"loss_function": "QueryRMSE"}
        
        default_params.update(self.params)
        default_params["verbose"] = False
        
        self._model = model_class(**default_params)
        
        # 准备验证集
        eval_set = None
        if X_val is not None and y_val is not None:
            eval_set = (X_val.to_pandas(), y_val.to_pandas())
        
        # 训练
        self._model.fit(
            X.to_pandas(),
            y.to_pandas(),
            eval_set=eval_set,
            early_stopping_rounds=10,
            verbose=False
        )
        self._is_fitted = True
        return self
    
    def predict(self, X: pl.DataFrame) -> np.ndarray:
        if not self._is_fitted:
            raise RuntimeError("Model not fitted yet")
        return self._model.predict(X.to_pandas())
    
    def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
        if self.task_type != "classification":
            raise ValueError("predict_proba only for classification")
        return self._model.predict_proba(X.to_pandas())
    
    def get_feature_importance(self) -> Optional[pl.DataFrame]:
        if self._model is None:
            return None
        return pl.DataFrame({
            "feature": self._model.feature_names_,
            "importance": self._model.feature_importances_
        }).sort("importance", descending=True)


# ========== 内置划分策略 ==========

@PluginRegistry.register_splitter("time_series")
class TimeSeriesSplit(BaseSplitter):
    """时间序列划分 - 确保训练数据在测试数据之前"""
    
    def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
        self.n_splits = n_splits
        self.gap = gap
        self.min_train_size = min_train_size
    
    def split(self, data: pl.DataFrame, date_col: str = "trade_date"):
        dates = data[date_col].unique().sort()
        n_dates = len(dates)
        
        # 计算每个split的测试集大小
        test_size = (n_dates - self.min_train_size) // self.n_splits
        
        for i in range(self.n_splits):
            # 训练集结束位置
            train_end_idx = self.min_train_size + i * test_size
            # 测试集开始位置留gap防止泄露
            test_start_idx = train_end_idx + self.gap
            test_end_idx = test_start_idx + test_size
            
            if test_end_idx > n_dates:
                break
            
            train_dates = dates[:train_end_idx]
            test_dates = dates[test_start_idx:test_end_idx]
            
            train_mask = data[date_col].is_in(train_dates)
            test_mask = data[date_col].is_in(test_dates)
            
            train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list()
            test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list()
            
            yield train_idx, test_idx
    
    def get_split_dates(self, data: pl.DataFrame, date_col: str = "trade_date"):
        dates = data[date_col].unique().sort()
        n_dates = len(dates)
        test_size = (n_dates - self.min_train_size) // self.n_splits
        
        result = []
        for i in range(self.n_splits):
            train_end_idx = self.min_train_size + i * test_size
            test_start_idx = train_end_idx + self.gap
            test_end_idx = test_start_idx + test_size
            
            if test_end_idx > n_dates:
                break
            
            result.append((
                dates[0],
                dates[train_end_idx - 1],
                dates[test_start_idx],
                dates[test_end_idx - 1]
            ))
        return result


@PluginRegistry.register_splitter("walk_forward")
class WalkForwardSplit(BaseSplitter):
    """滚动前向验证 - 训练集逐步扩展"""
    
    def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
        self.train_window = train_window
        self.test_window = test_window
        self.gap = gap
    
    def split(self, data: pl.DataFrame, date_col: str = "trade_date"):
        dates = data[date_col].unique().sort()
        n_dates = len(dates)
        
        start_idx = self.train_window
        while start_idx + self.gap + self.test_window <= n_dates:
            train_start = start_idx - self.train_window
            train_end = start_idx
            test_start = start_idx + self.gap
            test_end = test_start + self.test_window
            
            train_dates = dates[train_start:train_end]
            test_dates = dates[test_start:test_end]
            
            train_mask = data[date_col].is_in(train_dates)
            test_mask = data[date_col].is_in(test_dates)
            
            train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list()
            test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list()
            
            yield train_idx, test_idx
            start_idx += self.test_window

5. 使用示例

5.1 基础用法

from src.models import (
    FeatureStore, ProcessingPipeline, PluginRegistry,
    PipelineStage, MLPipeline
)

# 1. 创建数据存储
store = FeatureStore(data_dir="data")

# 2. 构建数据集
dataset = store.build_dataset(
    factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"],
    label_config={"name": "forward_return", "forward_period": 5},
    date_range=("20200101", "20241231")
)

# 3. 创建处理流水线
processors = [
    # 删除缺失值
    PluginRegistry.get_processor("dropna")(),
    
    # 异常值处理(只在训练阶段计算分位数)
    PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
    
    # 中性化(行业和市值中性化)
    PluginRegistry.get_processor("neutralizer")(group_col="industry"),
    
    # 标准化(训练和测试都使用)
    PluginRegistry.get_processor("standard_scaler")(),
]
pipeline = ProcessingPipeline(processors)

# 4. 创建划分策略
splitter = PluginRegistry.get_splitter("time_series")(
    n_splits=5,
    gap=5,
    min_train_size=252
)

# 5. 创建模型
model = PluginRegistry.get_model("lightgbm")(
    task_type="regression",
    params={"n_estimators": 200, "learning_rate": 0.03}
)

# 6. 运行完整流程
ml_pipeline = MLPipeline(
    feature_store=store,
    processing_pipeline=pipeline,
    splitter=splitter,
    model=model
)

results = ml_pipeline.run(
    factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"],
    label_config={"name": "forward_return", "forward_period": 5},
    date_range=("20200101", "20241231")
)

# 7. 查看结果
print(results.metrics)  # 各折的评估指标
print(results.feature_importance)  # 特征重要性
print(results.predictions)  # 预测结果

5.2 配置驱动用法(推荐)

# config.yaml
experiment:
  name: "momentum_factor_regression"
  
data:
  factor_names: ["momentum_5", "momentum_20", "momentum_60", "volatility_20"]
  label:
    name: "forward_return"
    forward_period: 5
  date_range: ["20200101", "20241231"]
  
processing:
  - name: "dropna"
    params: {}
    stage: "all"
    
  - name: "winsorizer"
    params:
      lower: 0.01
      upper: 0.99
    stage: "train"  # 只在训练阶段计算分位数
    
  - name: "neutralizer"
    params:
      group_col: "industry"
    stage: "all"
    
  - name: "standard_scaler"
    params: {}
    stage: "all"

splitting:
  strategy: "time_series"
  params:
    n_splits: 5
    gap: 5
    min_train_size: 252

model:
  name: "lightgbm"
  task_type: "regression"
  params:
    n_estimators: 200
    learning_rate: 0.03
    max_depth: 6

evaluation:
  metrics: ["ic", "rank_ic", "mse", "mae"]
  output_dir: "results/momentum_experiment"
# 代码中使用配置
from src.models import MLPipeline

pipeline = MLPipeline.from_config("config.yaml")
results = pipeline.run()

# 保存结果
results.save("results/momentum_experiment")

5.3 自定义插件

# 1. 创建自定义处理器
@PluginRegistry.register_processor("my_transformer")
class MyTransformer(BaseProcessor):
    """自定义转换器示例"""
    stage = PipelineStage.ALL
    
    def __init__(self, columns=None, multiplier=2.0):
        super().__init__(columns)
        self.multiplier = multiplier
    
    def fit(self, data: pl.DataFrame) -> "MyTransformer":
        # 学习参数(如有需要)
        return self
    
    def transform(self, data: pl.DataFrame) -> pl.DataFrame:
        result = data
        for col in self.columns or []:
            result = result.with_columns(
                (pl.col(col) * self.multiplier).alias(col)
            )
        return result


# 2. 创建自定义模型
@PluginRegistry.register_model("my_model")
class MyModel(BaseModel):
    """自定义模型示例"""
    
    def fit(self, X, y, X_val=None, y_val=None, **kwargs):
        # 实现训练逻辑
        self._model = ...
        return self
    
    def predict(self, X):
        # 实现预测逻辑
        return self._model.predict(X)


# 3. 在配置中使用
# config.yaml
processing:
  - name: "my_transformer"
    params:
      multiplier: 3.0
    stage: "all"

model:
  name: "my_model"
  task_type: "regression"

6. 目录结构

src/
├── models/                          # 模型训练框架
│   ├── __init__.py                  # 导出主要类
│   ├── core/                        # 核心抽象和基类
│   │   ├── __init__.py
│   │   ├── processor.py             # BaseProcessor, PipelineStage
│   │   ├── model.py                 # BaseModel, TaskType
│   │   ├── splitter.py              # BaseSplitter
│   │   ├── metric.py                # BaseMetric
│   │   └── pipeline.py              # MLPipeline (编排器)
│   │
│   ├── registry.py                  # PluginRegistry 插件注册中心
│   │
│   ├── data/                        # 数据相关
│   │   ├── __init__.py
│   │   ├── feature_store.py         # FeatureStore 特征存储
│   │   ├── label_generator.py       # LabelGenerator 标签生成
│   │   └── dataset.py               # Dataset 数据集包装
│   │
│   ├── processors/                  # 内置处理器
│   │   ├── __init__.py              # 自动注册所有处理器
│   │   ├── scaler.py                # StandardScaler
│   │   ├── winsorizer.py            # Winsorizer
│   │   ├── neutralizer.py           # Neutralizer
│   │   ├── imputer.py               # FillNAProcessor
│   │   ├── selector.py              # FeatureSelector
│   │   └── custom.py                # 其他处理器
│   │
│   ├── models/                      # 内置模型
│   │   ├── __init__.py              # 自动注册所有模型
│   │   ├── lightgbm_model.py        # LightGBMModel
│   │   ├── catboost_model.py        # CatBoostModel
│   │   └── sklearn_model.py         # SklearnModel (LR, RF等)
│   │
│   ├── splitters/                   # 划分策略
│   │   ├── __init__.py
│   │   ├── time_series.py           # TimeSeriesSplit
│   │   ├── walk_forward.py          # WalkForwardSplit
│   │   └── purged.py                # PurgedKFold
│   │
│   ├── metrics/                     # 评估指标
│   │   ├── __init__.py
│   │   ├── ic.py                    # IC, RankIC
│   │   ├── returns.py               # 收益指标
│   │   └── classification.py        # 分类指标
│   │
│   ├── evaluation/                  # 评估和报告
│   │   ├── __init__.py
│   │   ├── evaluator.py             # ModelEvaluator
│   │   ├── report.py                # ReportGenerator
│   │   └── visualizer.py            # ResultVisualizer
│   │
│   └── config/                      # 配置解析
│       ├── __init__.py
│       └── parser.py                # ConfigParser
│
├── factors/                         # 已有因子框架
│   └── ...
│
tests/
├── models/                          # 模型框架测试
│   ├── __init__.py
│   ├── test_processors.py           # 处理器测试
│   ├── test_models.py               # 模型测试
│   ├── test_pipeline.py             # 流水线集成测试
│   └── test_registry.py             # 注册器测试
│
└── factors/                         # 已有因子测试
    └── ...

configs/                             # 配置文件目录
├── momentum_regression.yaml
├── value_classification.yaml
└├── ranking_lambdamart.yaml

experiments/                         # 实验结果目录
└── {experiment_name}/
    ├── config.yaml                  # 实验配置
    ├── model.pkl                    # 保存的模型
    ├── processors.pkl               # 保存的处理器状态
    ├── predictions.parquet          # 预测结果
    ├── metrics.json                 # 评估指标
    ├── feature_importance.csv       # 特征重要性
    └── report.html                  # 可视化报告

7. 开发计划

Phase 1: 核心基础设施 (Week 1-2)

  • 设计并实现 BaseProcessor, BaseModel, BaseSplitter 抽象类
  • 实现 PluginRegistry 注册中心
  • 实现 PipelineStage 阶段管理
  • 编写基础单元测试

Phase 2: 数据层 (Week 2-3)

  • 实现 FeatureStore 特征存储
  • 实现 LabelGenerator 标签生成器
  • 实现 Dataset 数据集包装
  • 集成现有因子框架输出

Phase 3: 处理器 (Week 3-4)

  • 实现 StandardScaler 标准化处理器
  • 实现 Winsorizer 缩尾处理器
  • 实现 Neutralizer 中性化处理器
  • 实现 FillNAProcessor 缺失值处理器
  • 实现 DropNAProcessor 缺失值删除处理器
  • 实现 FeatureSelector 特征选择器
  • 实现 ProcessingPipeline 流水线

Phase 4: 模型层 (Week 4-5)

  • 实现 LightGBMModel LightGBM包装
  • 实现 CatBoostModel CatBoost包装
  • 实现 SklearnModel sklearn模型支持
  • 支持 classification/regression/ranking 三种任务

Phase 5: 划分策略 (Week 5)

  • 实现 TimeSeriesSplit 时间序列划分
  • 实现 WalkForwardSplit 滚动前向验证
  • 实现 PurgedKFold 清除重叠样本

Phase 6: 评估层 (Week 5-6)

  • 实现 IC/RankIC 指标
  • 实现收益分析指标
  • 实现分类指标
  • 实现 ModelEvaluator 评估器
  • 实现 ReportGenerator 报告生成

Phase 7: 配置和编排 (Week 6)

  • 实现配置解析器
  • 实现 MLPipeline 编排器
  • 支持配置驱动执行

Phase 8: 集成测试和文档 (Week 7)

  • 编写完整集成测试
  • 编写使用文档
  • 编写示例代码
  • 性能基准测试

8. 关键设计决策

决策点 选择 理由
数据处理阶段标记 PipelineStage 枚举 显式、类型安全、易于扩展
插件注册方式 装饰器模式 Pythonic、简洁、自动发现
数据格式 Polars DataFrame 与因子框架一致、高性能
模型接口 fit/predict 统一接口 行业标准、易于替换模型
配置格式 YAML 人类可读、支持复杂结构
处理器状态保存 pickle 简单、Python原生、支持大部分对象
特征存储 从因子框架直接读取 避免数据冗余、保持一致性

9. 防数据泄露检查清单

  • 处理器明确标记适用阶段 (stage 属性)
  • TRAIN 阶段处理器只在训练数据上 fit
  • TEST 阶段使用训练阶段学习到的参数
  • 划分策略支持时间序列感知 (TimeSeriesSplit, WalkForwardSplit)
  • 划分时支持 gap 参数防止相邻样本泄露
  • 特征存储从已计算的因子加载(不访问未来数据)
  • 标签生成使用预定义的前瞻期明确的future data

文档版本: v1.0 最后更新: 2026-02-23 设计状态: 草案 - 待评审