Files
ProStock/docs/ml_framework_design.md
liaozhaorun 9f95be56a0 feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类
- 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露
- 内置 8 个数据处理器和 3 种时序划分策略
- 支持 LightGBM、CatBoost 模型
- PluginRegistry 装饰器注册,插件式架构
- 22 个单元测试
2026-02-23 01:37:34 +08:00

1473 lines
54 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ProStock 模型训练框架设计文档
## 1. 设计目标与原则
### 1.1 核心目标
- **组件化**:每个阶段(数据获取、处理、训练、评估)都是独立组件
- **低耦合**:组件间通过标准接口交互,不依赖具体实现
- **插件式**:新功能通过插件注册,无需修改核心代码
- **阶段感知**:数据处理区分训练阶段和测试阶段,防止数据泄露
- **多模型支持**:统一接口支持 LightGBM、CatBoost 等多种模型
- **多任务支持**:分类、回归、排序三种任务类型
### 1.2 设计原则
| 原则 | 说明 |
|------|------|
| **单一职责** | 每个组件只做一件事,做好一件事 |
| **开闭原则** | 对扩展开放(插件),对修改封闭(核心) |
| **依赖倒置** | 依赖抽象接口,而非具体实现 |
| **显式优于隐式** | 阶段标记、处理逻辑必须显式声明 |
| **配置驱动** | 通过配置文件或代码配置定义流程,减少硬编码 |
---
## 2. 整体架构
### 2.1 架构概览
```
┌─────────────────────────────────────────────────────────────────────────┐
│ ML Pipeline Orchestrator │
│ (流水线编排器 - 配置驱动执行) │
└─────────────────────────────────────────────────────────────────────────┘
┌───────────────────────────┼───────────────────────────┐
▼ ▼ ▼
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
│ Data Source │ │ Data Source │ │ Data Source │
│ (因子数据) │ │ (行情数据) │ │ (标签数据) │
└───────┬───────┘ └───────┬───────┘ └───────┬───────┘
│ │ │
└──────────────────────────┼──────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Feature Store (特征存储层) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ FactorLoader │ │ LabelLoader │ │ DataMerger │ │ CacheMgr │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Processing Pipeline (处理流水线) │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │
│ │ Processor │ -> │ Processor │ -> │ Processor │ -> │ ... │ │
│ │ (阶段:ALL) │ │ (阶段:TRAIN)│ │ (阶段:TEST) │ │ │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ └──────────┘ │
│ │
│ 处理器类型: │
│ - FeatureEncoder: 特征编码(类别编码、数值缩放等) │
│ - FeatureSelector: 特征选择(相关性过滤、重要性筛选等) │
│ - OutlierHandler: 异常值处理 │
│ - MissingValueHandler: 缺失值处理 │
│ - CustomTransformer: 自定义转换器 │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Train/Test Split (数据划分) │
│ │
│ 支持多种划分策略: │
│ - TimeSeriesSplit: 时间序列划分(防止未来泄露) │
│ - PurgedKFold: 清除重叠样本的K折交叉验证 │
│ - EmbargoSplit: embargo 延迟验证 │
│ - CustomSplit: 自定义划分策略 │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Model Training (模型训练层) │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Model Registry │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ LightGBM │ │CatBoost │ │ XGBoost │ │ Custom │ ... │ │
│ │ │ Model │ │ Model │ │ Model │ │ Model │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ 任务类型: │
│ - Classification: 分类任务(上涨/下跌预测) │
│ - Regression: 回归任务(收益率预测) │
│ - Ranking: 排序任务(股票排序/选股) │
└─────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────┐
│ Evaluation (评估层) │
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌────────────┐ │
│ │ Metric │ │ Metric │ │ Metric │ │ Analyzer │ │
│ │ (IC/IR) │ │ (Sharpe) │ │ (Accuracy) │ │ (回测) │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ └────────────┘ │
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ ResultStore │ │ Report │ │ Visualizer │ │
│ │ (模型存储) │ │ (报告生成) │ │ (可视化) │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
```
### 2.2 数据流向图
```
因子DataFrame (Polars)
┌──────────────────────┐
│ Feature Store │ 1. 加载并合并因子、标签、辅助数据
│ - 列选择 │ 2. 支持按日期/股票过滤
│ - 数据对齐 │ 3. 缓存机制避免重复加载
└──────────┬───────────┘
┌──────────────────────┐
│ Processing Pipeline │ 顺序执行多个处理器
│ │ 每个处理器标记适用阶段 (ALL/TRAIN/TEST)
│ for processor in pipeline:
│ if processor.stage in [current_stage, ALL]:
│ data = processor.transform(data)
└──────────┬───────────┘
┌──────────────────────┐
│ Data Splitter │ 时间序列感知的划分策略
│ - X_train, y_train │ 防止未来泄露
│ - X_test, y_test │
└──────────┬───────────┘
┌──────────────────────┐
│ Model Training │ 统一接口,支持多种模型
│ - fit(X_train) │ 任务类型: classification/regression/ranking
│ - predict(X_test) │
└──────────┬───────────┘
┌──────────────────────┐
│ Evaluation │ 多维度评估
│ - 预测指标 │ - IC/IR
│ - 回测指标 │ - 分组收益
│ - 可视化 │ - 累计收益曲线
└──────────────────────┘
```
---
## 3. 核心组件设计
### 3.1 基础抽象类
#### 3.1.1 PipelineStage (流水线阶段枚举)
```python
from enum import Enum, auto
class PipelineStage(Enum):
"""流水线阶段标记"""
ALL = auto() # 适用于所有阶段
TRAIN = auto() # 仅训练阶段
TEST = auto() # 仅测试阶段
VALIDATION = auto() # 仅验证阶段
```
#### 3.1.2 BaseProcessor (处理器基类)
```python
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import polars as pl
class BaseProcessor(ABC):
"""数据处理器基类
所有数据处理器必须继承此类。
关键特性:通过 stage 属性控制处理器在哪些阶段生效。
示例:
>>> class StandardScaler(BaseProcessor):
... stage = PipelineStage.ALL # 训练和测试都使用
...
... def fit(self, data: pl.DataFrame) -> None:
... self.mean = data[self.columns].mean()
... self.std = data[self.columns].std()
...
... def transform(self, data: pl.DataFrame) -> pl.DataFrame:
... return (data - self.mean) / self.std
"""
# 子类必须定义适用阶段
stage: PipelineStage = PipelineStage.ALL
def __init__(self, columns: Optional[list] = None, **params):
"""初始化处理器
Args:
columns: 要处理的列None表示所有数值列
**params: 处理器特定参数
"""
self.columns = columns
self.params = params
self._is_fitted = False
self._fitted_params: Dict[str, Any] = {}
@abstractmethod
def fit(self, data: pl.DataFrame) -> "BaseProcessor":
"""在训练数据上学习参数
此方法只在训练阶段调用一次。
学习到的参数存储在 self._fitted_params 中。
Args:
data: 训练数据
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""转换数据
在训练和测试阶段都会被调用。
使用 fit() 阶段学习到的参数进行转换。
Args:
data: 输入数据
Returns:
转换后的数据
"""
pass
def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""先fit再transform的便捷方法"""
return self.fit(data).transform(data)
def get_fitted_params(self) -> Dict[str, Any]:
"""获取学习到的参数(用于保存/加载)"""
return self._fitted_params.copy()
def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
"""设置学习到的参数用于从checkpoint恢复"""
self._fitted_params = params.copy()
self._is_fitted = True
return self
```
#### 3.1.3 BaseModel (模型基类)
```python
from abc import ABC, abstractmethod
from typing import Literal, Any, Dict
import polars as pl
import numpy as np
TaskType = Literal["classification", "regression", "ranking"]
class BaseModel(ABC):
"""机器学习模型基类
统一接口支持多种模型LightGBM, CatBoost, XGBoost等
和多种任务类型(分类、回归、排序)。
示例:
>>> model = LightGBMModel(
... task_type="classification",
... params={"n_estimators": 100}
... )
>>> model.fit(X_train, y_train)
>>> predictions = model.predict(X_test)
"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None
):
"""初始化模型
Args:
task_type: 任务类型 - "classification", "regression", "ranking"
params: 模型特定参数
name: 模型名称(用于日志和报告)
"""
self.task_type = task_type
self.params = params or {}
self.name = name or self.__class__.__name__
self._model: Any = None
self._is_fitted = False
@abstractmethod
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params
) -> "BaseModel":
"""训练模型
Args:
X: 特征数据
y: 目标变量
X_val: 验证集特征(可选)
y_val: 验证集目标(可选)
**fit_params: 额外的fit参数
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征数据
Returns:
预测结果数组
- classification: 类别标签或概率
- regression: 连续值
- ranking: 排序分数
"""
pass
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率(仅分类任务)
Args:
X: 特征数据
Returns:
类别概率数组 [n_samples, n_classes]
"""
raise NotImplementedError("predict_proba only available for classification tasks")
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性(如果模型支持)
Returns:
DataFrame[feature, importance] 或 None
"""
return None
def save(self, path: str) -> None:
"""保存模型到文件"""
import pickle
with open(path, 'wb') as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str) -> "BaseModel":
"""从文件加载模型"""
import pickle
with open(path, 'rb') as f:
return pickle.load(f)
```
#### 3.1.4 BaseSplitter (数据划分基类)
```python
from abc import ABC, abstractmethod
from typing import Iterator, Tuple, List
import polars as pl
class BaseSplitter(ABC):
"""数据划分策略基类
针对时间序列数据的特殊划分策略,防止未来泄露。
示例:
>>> splitter = TimeSeriesSplit(n_splits=5, gap=5)
>>> for train_idx, test_idx in splitter.split(data):
... X_train, X_test = X[train_idx], X[test_idx]
"""
@abstractmethod
def split(
self,
data: pl.DataFrame,
date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引
Args:
data: 完整数据集
date_col: 日期列名
Yields:
(train_indices, test_indices) 元组
"""
pass
@abstractmethod
def get_split_dates(
self,
data: pl.DataFrame,
date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围
Returns:
[(train_start, train_end, test_start, test_end), ...]
"""
pass
```
---
### 3.2 核心组件
#### 3.2.1 FeatureStore (特征存储)
```python
from typing import List, Optional, Dict
import polars as pl
from pathlib import Path
class FeatureStore:
"""特征存储管理器
负责加载、合并、缓存因子数据。
支持从多个数据源(因子、标签、行情)加载并合并。
"""
def __init__(self, data_dir: str):
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
def load_factors(
self,
factor_names: List[str],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
stock_codes: Optional[List[str]] = None
) -> pl.DataFrame:
"""加载因子数据
Args:
factor_names: 因子名称列表
start_date: 开始日期 YYYYMMDD
end_date: 结束日期 YYYYMMDD
stock_codes: 股票代码列表(可选)
Returns:
DataFrame[trade_date, ts_code, factor1, factor2, ...]
"""
pass
def load_labels(
self,
label_name: str,
forward_period: int = 5,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> pl.DataFrame:
"""加载标签数据(未来收益)
Args:
label_name: 标签名称(如 "return", "rank"
forward_period: 前瞻期如5天后收益
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame[trade_date, ts_code, label]
"""
pass
def build_dataset(
self,
factor_names: List[str],
label_config: Dict,
date_range: Tuple[str, str],
stock_codes: Optional[List[str]] = None,
additional_cols: Optional[List[str]] = None
) -> pl.DataFrame:
"""构建完整数据集
合并因子、标签、辅助列,并对齐数据。
Args:
factor_names: 因子列表
label_config: 标签配置 {"name": str, "forward_period": int}
date_range: (start_date, end_date)
stock_codes: 限定股票列表
additional_cols: 额外列(如 industry, market_cap
Returns:
DataFrame[trade_date, ts_code, factor_cols..., label]
"""
pass
```
#### 3.2.2 ProcessingPipeline (处理流水线)
```python
from typing import List
import polars as pl
class ProcessingPipeline:
"""数据处理流水线
按顺序执行多个处理器,自动处理阶段标记。
关键特性:在测试阶段使用训练阶段学习到的参数。
"""
def __init__(self, processors: List[BaseProcessor]):
"""初始化流水线
Args:
processors: 处理器列表(按执行顺序)
"""
self.processors = processors
self._fitted_processors: Dict[int, BaseProcessor] = {}
def fit_transform(
self,
data: pl.DataFrame,
stage: PipelineStage = PipelineStage.TRAIN
) -> pl.DataFrame:
"""在训练数据上fit所有处理器并transform
Args:
data: 训练数据
stage: 当前阶段标记
Returns:
处理后的数据
"""
result = data
for i, processor in enumerate(self.processors):
# 检查处理器是否适用于当前阶段
if processor.stage in [PipelineStage.ALL, stage]:
# fit并transform
result = processor.fit_transform(result)
self._fitted_processors[i] = processor
elif stage == PipelineStage.TRAIN:
# 即使不适用于TRAIN阶段也要fit为TEST阶段准备
if processor.stage == PipelineStage.TEST:
processor.fit(result)
self._fitted_processors[i] = processor
return result
def transform(
self,
data: pl.DataFrame,
stage: PipelineStage = PipelineStage.TEST
) -> pl.DataFrame:
"""在测试数据上应用已fit的处理器
使用训练阶段学习到的参数,防止数据泄露。
Args:
data: 测试数据
stage: 当前阶段标记
Returns:
处理后的数据
"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
if i in self._fitted_processors:
# 使用已fit的处理器
result = self._fitted_processors[i].transform(result)
else:
# 未fit的处理器ALL阶段但train时没执行到
result = processor.transform(result)
return result
def save_processors(self, path: str) -> None:
"""保存所有已fit的处理器状态"""
import pickle
with open(path, 'wb') as f:
pickle.dump(self._fitted_processors, f)
def load_processors(self, path: str) -> None:
"""加载处理器状态"""
import pickle
with open(path, 'rb') as f:
self._fitted_processors = pickle.load(f)
```
---
## 4. 插件系统
### 4.1 注册器模式
```python
from typing import Type, Dict, TypeVar
from functools import wraps
T = TypeVar('T')
class PluginRegistry:
"""插件注册中心
提供装饰器方式注册处理器、模型、划分策略等组件。
实现真正的插件式架构 - 新功能只需注册即可使用。
"""
_processors: Dict[str, Type[BaseProcessor]] = {}
_models: Dict[str, Type[BaseModel]] = {}
_splitters: Dict[str, Type[BaseSplitter]] = {}
_metrics: Dict[str, Type["BaseMetric"]] = {}
@classmethod
def register_processor(cls, name: Optional[str] = None):
"""注册处理器装饰器
示例:
>>> @PluginRegistry.register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>> # 使用
>>> scaler = PluginRegistry.get_processor("standard_scaler")()
"""
def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
key = name or processor_class.__name__
cls._processors[key] = processor_class
processor_class._registry_name = key
return processor_class
return decorator
@classmethod
def register_model(cls, name: Optional[str] = None):
"""注册模型装饰器"""
def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
key = name or model_class.__name__
cls._models[key] = model_class
model_class._registry_name = key
return model_class
return decorator
@classmethod
def register_splitter(cls, name: Optional[str] = None):
"""注册划分策略装饰器"""
def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
key = name or splitter_class.__name__
cls._splitters[key] = splitter_class
return splitter_class
return decorator
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类"""
if name not in cls._processors:
raise KeyError(f"Processor '{name}' not found. Available: {list(cls._processors.keys())}")
return cls._processors[name]
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类"""
if name not in cls._models:
raise KeyError(f"Model '{name}' not found. Available: {list(cls._models.keys())}")
return cls._models[name]
@classmethod
def get_splitter(cls, name: str) -> Type[BaseSplitter]:
"""获取划分策略类"""
if name not in cls._splitters:
raise KeyError(f"Splitter '{name}' not found. Available: {list(cls._splitters.keys())}")
return cls._splitters[name]
@classmethod
def list_processors(cls) -> List[str]:
"""列出所有可用处理器"""
return list(cls._processors.keys())
@classmethod
def list_models(cls) -> List[str]:
"""列出所有可用模型"""
return list(cls._models.keys())
```
### 4.2 内置插件
```python
# ========== 内置处理器 ==========
@PluginRegistry.register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
"""标准缩放处理器 - Z-score标准化"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "StandardScaler":
cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
self._fitted_params = {
"mean": {c: data[c].mean() for c in cols},
"std": {c: data[c].std() for c in cols},
"columns": cols
}
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self._fitted_params["columns"]:
mean = self._fitted_params["mean"][col]
std = self._fitted_params["std"][col]
if std > 0:
result = result.with_columns(
((pl.col(col) - mean) / std).alias(col)
)
return result
@PluginRegistry.register_processor("winsorizer")
class Winsorizer(BaseProcessor):
"""缩尾处理器 - 防止极端值影响"""
stage = PipelineStage.TRAIN # 只在训练阶段计算分位数
def __init__(self, columns=None, lower=0.01, upper=0.99):
super().__init__(columns)
self.lower = lower
self.upper = upper
def fit(self, data: pl.DataFrame) -> "Winsorizer":
cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
self._fitted_params = {
"lower": {c: data[c].quantile(self.lower) for c in cols},
"upper": {c: data[c].quantile(self.upper) for c in cols},
"columns": cols
}
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self._fitted_params["columns"]:
lower = self._fitted_params["lower"][col]
upper = self._fitted_params["upper"][col]
result = result.with_columns(
pl.col(col).clip(lower, upper).alias(col)
)
return result
@PluginRegistry.register_processor("neutralizer")
class Neutralizer(BaseProcessor):
"""行业/市值中性化处理器"""
stage = PipelineStage.ALL
def __init__(self, columns=None, group_col="industry", exclude_cols=None):
super().__init__(columns)
self.group_col = group_col
self.exclude_cols = exclude_cols or []
def fit(self, data: pl.DataFrame) -> "Neutralizer":
# 中性化通常在每个截面独立进行不需要全局fit
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
# 按日期分组,对每个截面进行中性化
result = data
for col in self.columns or []:
if col in self.exclude_cols:
continue
# 分组去均值
result = result.with_columns(
(pl.col(col) - pl.col(col).mean().over(["trade_date", self.group_col]))
.alias(col)
)
return result
@PluginRegistry.register_processor("dropna")
class DropNAProcessor(BaseProcessor):
"""缺失值删除处理器"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
cols = self.columns or data.columns
return data.drop_nulls(subset=cols)
@PluginRegistry.register_processor("fillna")
class FillNAProcessor(BaseProcessor):
"""缺失值填充处理器"""
stage = PipelineStage.TRAIN
def __init__(self, columns=None, method="median"):
super().__init__(columns)
self.method = method
def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
fill_values = {}
for col in cols:
if self.method == "median":
fill_values[col] = data[col].median()
elif self.method == "mean":
fill_values[col] = data[col].mean()
elif self.method == "zero":
fill_values[col] = 0
self._fitted_params = {"fill_values": fill_values, "columns": cols}
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, val in self._fitted_params["fill_values"].items():
result = result.with_columns(pl.col(col).fill_null(val).alias(col))
return result
@PluginRegistry.register_processor("rank_transformer")
class RankTransformer(BaseProcessor):
"""排名转换处理器 - 转换为截面排名"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "RankTransformer":
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self.columns or []:
# 按日期分组计算排名
result = result.with_columns(
pl.col(col).rank().over("trade_date").alias(col)
)
return result
# ========== 内置模型 ==========
@PluginRegistry.register_model("lightgbm")
class LightGBMModel(BaseModel):
"""LightGBM模型包装器"""
def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params
) -> "LightGBMModel":
import lightgbm as lgb
# 转换数据格式
X_arr = X.to_numpy()
y_arr = y.to_numpy()
# 构建数据集
train_data = lgb.Dataset(X_arr, label=y_arr)
valid_sets = [train_data]
if X_val is not None and y_val is not None:
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
valid_sets.append(valid_data)
# 设置默认参数
default_params = {
"objective": self._get_objective(),
"metric": self._get_metric(),
"boosting_type": "gbdt",
"num_leaves": 31,
"learning_rate": 0.05,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbose": -1
}
default_params.update(self.params)
# 训练
self._model = lgb.train(
default_params,
train_data,
num_boost_round=fit_params.get("num_boost_round", 100),
valid_sets=valid_sets,
callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=False)] if len(valid_sets) > 1 else []
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_numpy())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
probs = self.predict(X)
if len(probs.shape) == 1:
return np.vstack([1 - probs, probs]).T
return probs
def get_feature_importance(self) -> Optional[pl.DataFrame]:
if self._model is None:
return None
importance = self._model.feature_importance(importance_type="gain")
return pl.DataFrame({
"feature": self._model.feature_name(),
"importance": importance
}).sort("importance", descending=True)
def _get_objective(self) -> str:
if self.task_type == "classification":
return "binary"
elif self.task_type == "regression":
return "regression"
elif self.task_type == "ranking":
return "lambdarank"
return "regression"
def _get_metric(self) -> str:
if self.task_type == "classification":
return "auc"
elif self.task_type == "regression":
return "rmse"
elif self.task_type == "ranking":
return "ndcg"
return "rmse"
@PluginRegistry.register_model("catboost")
class CatBoostModel(BaseModel):
"""CatBoost模型包装器"""
def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params
) -> "CatBoostModel":
from catboost import CatBoostClassifier, CatBoostRegressor
# 选择模型类型
if self.task_type == "classification":
model_class = CatBoostClassifier
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
elif self.task_type == "regression":
model_class = CatBoostRegressor
default_params = {"loss_function": "RMSE"}
else: # ranking
model_class = CatBoostRegressor
default_params = {"loss_function": "QueryRMSE"}
default_params.update(self.params)
default_params["verbose"] = False
self._model = model_class(**default_params)
# 准备验证集
eval_set = None
if X_val is not None and y_val is not None:
eval_set = (X_val.to_pandas(), y_val.to_pandas())
# 训练
self._model.fit(
X.to_pandas(),
y.to_pandas(),
eval_set=eval_set,
early_stopping_rounds=10,
verbose=False
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_pandas())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
return self._model.predict_proba(X.to_pandas())
def get_feature_importance(self) -> Optional[pl.DataFrame]:
if self._model is None:
return None
return pl.DataFrame({
"feature": self._model.feature_names_,
"importance": self._model.feature_importances_
}).sort("importance", descending=True)
# ========== 内置划分策略 ==========
@PluginRegistry.register_splitter("time_series")
class TimeSeriesSplit(BaseSplitter):
"""时间序列划分 - 确保训练数据在测试数据之前"""
def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
self.n_splits = n_splits
self.gap = gap
self.min_train_size = min_train_size
def split(self, data: pl.DataFrame, date_col: str = "trade_date"):
dates = data[date_col].unique().sort()
n_dates = len(dates)
# 计算每个split的测试集大小
test_size = (n_dates - self.min_train_size) // self.n_splits
for i in range(self.n_splits):
# 训练集结束位置
train_end_idx = self.min_train_size + i * test_size
# 测试集开始位置留gap防止泄露
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
train_dates = dates[:train_end_idx]
test_dates = dates[test_start_idx:test_end_idx]
train_mask = data[date_col].is_in(train_dates)
test_mask = data[date_col].is_in(test_dates)
train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list()
test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list()
yield train_idx, test_idx
def get_split_dates(self, data: pl.DataFrame, date_col: str = "trade_date"):
dates = data[date_col].unique().sort()
n_dates = len(dates)
test_size = (n_dates - self.min_train_size) // self.n_splits
result = []
for i in range(self.n_splits):
train_end_idx = self.min_train_size + i * test_size
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
result.append((
dates[0],
dates[train_end_idx - 1],
dates[test_start_idx],
dates[test_end_idx - 1]
))
return result
@PluginRegistry.register_splitter("walk_forward")
class WalkForwardSplit(BaseSplitter):
"""滚动前向验证 - 训练集逐步扩展"""
def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
self.train_window = train_window
self.test_window = test_window
self.gap = gap
def split(self, data: pl.DataFrame, date_col: str = "trade_date"):
dates = data[date_col].unique().sort()
n_dates = len(dates)
start_idx = self.train_window
while start_idx + self.gap + self.test_window <= n_dates:
train_start = start_idx - self.train_window
train_end = start_idx
test_start = start_idx + self.gap
test_end = test_start + self.test_window
train_dates = dates[train_start:train_end]
test_dates = dates[test_start:test_end]
train_mask = data[date_col].is_in(train_dates)
test_mask = data[date_col].is_in(test_dates)
train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list()
test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list()
yield train_idx, test_idx
start_idx += self.test_window
```
---
## 5. 使用示例
### 5.1 基础用法
```python
from src.models import (
FeatureStore, ProcessingPipeline, PluginRegistry,
PipelineStage, MLPipeline
)
# 1. 创建数据存储
store = FeatureStore(data_dir="data")
# 2. 构建数据集
dataset = store.build_dataset(
factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"],
label_config={"name": "forward_return", "forward_period": 5},
date_range=("20200101", "20241231")
)
# 3. 创建处理流水线
processors = [
# 删除缺失值
PluginRegistry.get_processor("dropna")(),
# 异常值处理(只在训练阶段计算分位数)
PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
# 中性化(行业和市值中性化)
PluginRegistry.get_processor("neutralizer")(group_col="industry"),
# 标准化(训练和测试都使用)
PluginRegistry.get_processor("standard_scaler")(),
]
pipeline = ProcessingPipeline(processors)
# 4. 创建划分策略
splitter = PluginRegistry.get_splitter("time_series")(
n_splits=5,
gap=5,
min_train_size=252
)
# 5. 创建模型
model = PluginRegistry.get_model("lightgbm")(
task_type="regression",
params={"n_estimators": 200, "learning_rate": 0.03}
)
# 6. 运行完整流程
ml_pipeline = MLPipeline(
feature_store=store,
processing_pipeline=pipeline,
splitter=splitter,
model=model
)
results = ml_pipeline.run(
factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"],
label_config={"name": "forward_return", "forward_period": 5},
date_range=("20200101", "20241231")
)
# 7. 查看结果
print(results.metrics) # 各折的评估指标
print(results.feature_importance) # 特征重要性
print(results.predictions) # 预测结果
```
### 5.2 配置驱动用法(推荐)
```python
# config.yaml
experiment:
name: "momentum_factor_regression"
data:
factor_names: ["momentum_5", "momentum_20", "momentum_60", "volatility_20"]
label:
name: "forward_return"
forward_period: 5
date_range: ["20200101", "20241231"]
processing:
- name: "dropna"
params: {}
stage: "all"
- name: "winsorizer"
params:
lower: 0.01
upper: 0.99
stage: "train" # 只在训练阶段计算分位数
- name: "neutralizer"
params:
group_col: "industry"
stage: "all"
- name: "standard_scaler"
params: {}
stage: "all"
splitting:
strategy: "time_series"
params:
n_splits: 5
gap: 5
min_train_size: 252
model:
name: "lightgbm"
task_type: "regression"
params:
n_estimators: 200
learning_rate: 0.03
max_depth: 6
evaluation:
metrics: ["ic", "rank_ic", "mse", "mae"]
output_dir: "results/momentum_experiment"
```
```python
# 代码中使用配置
from src.models import MLPipeline
pipeline = MLPipeline.from_config("config.yaml")
results = pipeline.run()
# 保存结果
results.save("results/momentum_experiment")
```
### 5.3 自定义插件
```python
# 1. 创建自定义处理器
@PluginRegistry.register_processor("my_transformer")
class MyTransformer(BaseProcessor):
"""自定义转换器示例"""
stage = PipelineStage.ALL
def __init__(self, columns=None, multiplier=2.0):
super().__init__(columns)
self.multiplier = multiplier
def fit(self, data: pl.DataFrame) -> "MyTransformer":
# 学习参数(如有需要)
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self.columns or []:
result = result.with_columns(
(pl.col(col) * self.multiplier).alias(col)
)
return result
# 2. 创建自定义模型
@PluginRegistry.register_model("my_model")
class MyModel(BaseModel):
"""自定义模型示例"""
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
# 实现训练逻辑
self._model = ...
return self
def predict(self, X):
# 实现预测逻辑
return self._model.predict(X)
# 3. 在配置中使用
# config.yaml
processing:
- name: "my_transformer"
params:
multiplier: 3.0
stage: "all"
model:
name: "my_model"
task_type: "regression"
```
---
## 6. 目录结构
```
src/
├── models/ # 模型训练框架
│ ├── __init__.py # 导出主要类
│ ├── core/ # 核心抽象和基类
│ │ ├── __init__.py
│ │ ├── processor.py # BaseProcessor, PipelineStage
│ │ ├── model.py # BaseModel, TaskType
│ │ ├── splitter.py # BaseSplitter
│ │ ├── metric.py # BaseMetric
│ │ └── pipeline.py # MLPipeline (编排器)
│ │
│ ├── registry.py # PluginRegistry 插件注册中心
│ │
│ ├── data/ # 数据相关
│ │ ├── __init__.py
│ │ ├── feature_store.py # FeatureStore 特征存储
│ │ ├── label_generator.py # LabelGenerator 标签生成
│ │ └── dataset.py # Dataset 数据集包装
│ │
│ ├── processors/ # 内置处理器
│ │ ├── __init__.py # 自动注册所有处理器
│ │ ├── scaler.py # StandardScaler
│ │ ├── winsorizer.py # Winsorizer
│ │ ├── neutralizer.py # Neutralizer
│ │ ├── imputer.py # FillNAProcessor
│ │ ├── selector.py # FeatureSelector
│ │ └── custom.py # 其他处理器
│ │
│ ├── models/ # 内置模型
│ │ ├── __init__.py # 自动注册所有模型
│ │ ├── lightgbm_model.py # LightGBMModel
│ │ ├── catboost_model.py # CatBoostModel
│ │ └── sklearn_model.py # SklearnModel (LR, RF等)
│ │
│ ├── splitters/ # 划分策略
│ │ ├── __init__.py
│ │ ├── time_series.py # TimeSeriesSplit
│ │ ├── walk_forward.py # WalkForwardSplit
│ │ └── purged.py # PurgedKFold
│ │
│ ├── metrics/ # 评估指标
│ │ ├── __init__.py
│ │ ├── ic.py # IC, RankIC
│ │ ├── returns.py # 收益指标
│ │ └── classification.py # 分类指标
│ │
│ ├── evaluation/ # 评估和报告
│ │ ├── __init__.py
│ │ ├── evaluator.py # ModelEvaluator
│ │ ├── report.py # ReportGenerator
│ │ └── visualizer.py # ResultVisualizer
│ │
│ └── config/ # 配置解析
│ ├── __init__.py
│ └── parser.py # ConfigParser
├── factors/ # 已有因子框架
│ └── ...
tests/
├── models/ # 模型框架测试
│ ├── __init__.py
│ ├── test_processors.py # 处理器测试
│ ├── test_models.py # 模型测试
│ ├── test_pipeline.py # 流水线集成测试
│ └── test_registry.py # 注册器测试
└── factors/ # 已有因子测试
└── ...
configs/ # 配置文件目录
├── momentum_regression.yaml
├── value_classification.yaml
└├── ranking_lambdamart.yaml
experiments/ # 实验结果目录
└── {experiment_name}/
├── config.yaml # 实验配置
├── model.pkl # 保存的模型
├── processors.pkl # 保存的处理器状态
├── predictions.parquet # 预测结果
├── metrics.json # 评估指标
├── feature_importance.csv # 特征重要性
└── report.html # 可视化报告
```
---
## 7. 开发计划
### Phase 1: 核心基础设施 (Week 1-2)
- [ ] 设计并实现 `BaseProcessor`, `BaseModel`, `BaseSplitter` 抽象类
- [ ] 实现 `PluginRegistry` 注册中心
- [ ] 实现 `PipelineStage` 阶段管理
- [ ] 编写基础单元测试
### Phase 2: 数据层 (Week 2-3)
- [ ] 实现 `FeatureStore` 特征存储
- [ ] 实现 `LabelGenerator` 标签生成器
- [ ] 实现 `Dataset` 数据集包装
- [ ] 集成现有因子框架输出
### Phase 3: 处理器 (Week 3-4)
- [ ] 实现 `StandardScaler` 标准化处理器
- [ ] 实现 `Winsorizer` 缩尾处理器
- [ ] 实现 `Neutralizer` 中性化处理器
- [ ] 实现 `FillNAProcessor` 缺失值处理器
- [ ] 实现 `DropNAProcessor` 缺失值删除处理器
- [ ] 实现 `FeatureSelector` 特征选择器
- [ ] 实现 `ProcessingPipeline` 流水线
### Phase 4: 模型层 (Week 4-5)
- [ ] 实现 `LightGBMModel` LightGBM包装
- [ ] 实现 `CatBoostModel` CatBoost包装
- [ ] 实现 `SklearnModel` sklearn模型支持
- [ ] 支持 classification/regression/ranking 三种任务
### Phase 5: 划分策略 (Week 5)
- [ ] 实现 `TimeSeriesSplit` 时间序列划分
- [ ] 实现 `WalkForwardSplit` 滚动前向验证
- [ ] 实现 `PurgedKFold` 清除重叠样本
### Phase 6: 评估层 (Week 5-6)
- [ ] 实现 IC/RankIC 指标
- [ ] 实现收益分析指标
- [ ] 实现分类指标
- [ ] 实现 `ModelEvaluator` 评估器
- [ ] 实现 `ReportGenerator` 报告生成
### Phase 7: 配置和编排 (Week 6)
- [ ] 实现配置解析器
- [ ] 实现 `MLPipeline` 编排器
- [ ] 支持配置驱动执行
### Phase 8: 集成测试和文档 (Week 7)
- [ ] 编写完整集成测试
- [ ] 编写使用文档
- [ ] 编写示例代码
- [ ] 性能基准测试
---
## 8. 关键设计决策
| 决策点 | 选择 | 理由 |
|--------|------|------|
| **数据处理阶段标记** | `PipelineStage` 枚举 | 显式、类型安全、易于扩展 |
| **插件注册方式** | 装饰器模式 | Pythonic、简洁、自动发现 |
| **数据格式** | Polars DataFrame | 与因子框架一致、高性能 |
| **模型接口** | `fit/predict` 统一接口 | 行业标准、易于替换模型 |
| **配置格式** | YAML | 人类可读、支持复杂结构 |
| **处理器状态保存** | pickle | 简单、Python原生、支持大部分对象 |
| **特征存储** | 从因子框架直接读取 | 避免数据冗余、保持一致性 |
---
## 9. 防数据泄露检查清单
- [x] 处理器明确标记适用阶段 (`stage` 属性)
- [x] `TRAIN` 阶段处理器只在训练数据上 `fit`
- [x] `TEST` 阶段使用训练阶段学习到的参数
- [x] 划分策略支持时间序列感知 (`TimeSeriesSplit`, `WalkForwardSplit`)
- [x] 划分时支持 `gap` 参数防止相邻样本泄露
- [x] 特征存储从已计算的因子加载(不访问未来数据)
- [x] 标签生成使用预定义的前瞻期明确的future data
---
*文档版本: v1.0*
*最后更新: 2026-02-23*
*设计状态: 草案 - 待评审*