- 添加核心抽象:Processor、Model、Splitter、Metric 基类 - 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露 - 内置 8 个数据处理器和 3 种时序划分策略 - 支持 LightGBM、CatBoost 模型 - PluginRegistry 装饰器注册,插件式架构 - 22 个单元测试
54 KiB
54 KiB
ProStock 模型训练框架设计文档
1. 设计目标与原则
1.1 核心目标
- 组件化:每个阶段(数据获取、处理、训练、评估)都是独立组件
- 低耦合:组件间通过标准接口交互,不依赖具体实现
- 插件式:新功能通过插件注册,无需修改核心代码
- 阶段感知:数据处理区分训练阶段和测试阶段,防止数据泄露
- 多模型支持:统一接口支持 LightGBM、CatBoost 等多种模型
- 多任务支持:分类、回归、排序三种任务类型
1.2 设计原则
| 原则 | 说明 |
|---|---|
| 单一职责 | 每个组件只做一件事,做好一件事 |
| 开闭原则 | 对扩展开放(插件),对修改封闭(核心) |
| 依赖倒置 | 依赖抽象接口,而非具体实现 |
| 显式优于隐式 | 阶段标记、处理逻辑必须显式声明 |
| 配置驱动 | 通过配置文件或代码配置定义流程,减少硬编码 |
2. 整体架构
2.1 架构概览
┌─────────────────────────────────────────────────────────────────────────┐
│ ML Pipeline Orchestrator │
│ (流水线编排器 - 配置驱动执行) │
└─────────────────────────────────────────────────────────────────────────┘
│
┌───────────────────────────┼───────────────────────────┐
▼ ▼ ▼
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
│ Data Source │ │ Data Source │ │ Data Source │
│ (因子数据) │ │ (行情数据) │ │ (标签数据) │
└───────┬───────┘ └───────┬───────┘ └───────┬───────┘
│ │ │
└──────────────────────────┼──────────────────────────┘
▼
┌─────────────────────────────────────────────────────────────────────────┐
│ Feature Store (特征存储层) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ FactorLoader │ │ LabelLoader │ │ DataMerger │ │ CacheMgr │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────┐
│ Processing Pipeline (处理流水线) │
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │
│ │ Processor │ -> │ Processor │ -> │ Processor │ -> │ ... │ │
│ │ (阶段:ALL) │ │ (阶段:TRAIN)│ │ (阶段:TEST) │ │ │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ └──────────┘ │
│ │
│ 处理器类型: │
│ - FeatureEncoder: 特征编码(类别编码、数值缩放等) │
│ - FeatureSelector: 特征选择(相关性过滤、重要性筛选等) │
│ - OutlierHandler: 异常值处理 │
│ - MissingValueHandler: 缺失值处理 │
│ - CustomTransformer: 自定义转换器 │
└─────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────┐
│ Train/Test Split (数据划分) │
│ │
│ 支持多种划分策略: │
│ - TimeSeriesSplit: 时间序列划分(防止未来泄露) │
│ - PurgedKFold: 清除重叠样本的K折交叉验证 │
│ - EmbargoSplit: embargo 延迟验证 │
│ - CustomSplit: 自定义划分策略 │
└─────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────┐
│ Model Training (模型训练层) │
│ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Model Registry │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
│ │ │ LightGBM │ │CatBoost │ │ XGBoost │ │ Custom │ ... │ │
│ │ │ Model │ │ Model │ │ Model │ │ Model │ │ │
│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
│ 任务类型: │
│ - Classification: 分类任务(上涨/下跌预测) │
│ - Regression: 回归任务(收益率预测) │
│ - Ranking: 排序任务(股票排序/选股) │
└─────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────┐
│ Evaluation (评估层) │
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌────────────┐ │
│ │ Metric │ │ Metric │ │ Metric │ │ Analyzer │ │
│ │ (IC/IR) │ │ (Sharpe) │ │ (Accuracy) │ │ (回测) │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ └────────────┘ │
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ ResultStore │ │ Report │ │ Visualizer │ │
│ │ (模型存储) │ │ (报告生成) │ │ (可视化) │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
2.2 数据流向图
因子DataFrame (Polars)
│
▼
┌──────────────────────┐
│ Feature Store │ 1. 加载并合并因子、标签、辅助数据
│ - 列选择 │ 2. 支持按日期/股票过滤
│ - 数据对齐 │ 3. 缓存机制避免重复加载
└──────────┬───────────┘
│
▼
┌──────────────────────┐
│ Processing Pipeline │ 顺序执行多个处理器
│ │ 每个处理器标记适用阶段 (ALL/TRAIN/TEST)
│ for processor in pipeline:
│ if processor.stage in [current_stage, ALL]:
│ data = processor.transform(data)
└──────────┬───────────┘
│
▼
┌──────────────────────┐
│ Data Splitter │ 时间序列感知的划分策略
│ - X_train, y_train │ 防止未来泄露
│ - X_test, y_test │
└──────────┬───────────┘
│
▼
┌──────────────────────┐
│ Model Training │ 统一接口,支持多种模型
│ - fit(X_train) │ 任务类型: classification/regression/ranking
│ - predict(X_test) │
└──────────┬───────────┘
│
▼
┌──────────────────────┐
│ Evaluation │ 多维度评估
│ - 预测指标 │ - IC/IR
│ - 回测指标 │ - 分组收益
│ - 可视化 │ - 累计收益曲线
└──────────────────────┘
3. 核心组件设计
3.1 基础抽象类
3.1.1 PipelineStage (流水线阶段枚举)
from enum import Enum, auto
class PipelineStage(Enum):
"""流水线阶段标记"""
ALL = auto() # 适用于所有阶段
TRAIN = auto() # 仅训练阶段
TEST = auto() # 仅测试阶段
VALIDATION = auto() # 仅验证阶段
3.1.2 BaseProcessor (处理器基类)
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import polars as pl
class BaseProcessor(ABC):
"""数据处理器基类
所有数据处理器必须继承此类。
关键特性:通过 stage 属性控制处理器在哪些阶段生效。
示例:
>>> class StandardScaler(BaseProcessor):
... stage = PipelineStage.ALL # 训练和测试都使用
...
... def fit(self, data: pl.DataFrame) -> None:
... self.mean = data[self.columns].mean()
... self.std = data[self.columns].std()
...
... def transform(self, data: pl.DataFrame) -> pl.DataFrame:
... return (data - self.mean) / self.std
"""
# 子类必须定义适用阶段
stage: PipelineStage = PipelineStage.ALL
def __init__(self, columns: Optional[list] = None, **params):
"""初始化处理器
Args:
columns: 要处理的列,None表示所有数值列
**params: 处理器特定参数
"""
self.columns = columns
self.params = params
self._is_fitted = False
self._fitted_params: Dict[str, Any] = {}
@abstractmethod
def fit(self, data: pl.DataFrame) -> "BaseProcessor":
"""在训练数据上学习参数
此方法只在训练阶段调用一次。
学习到的参数存储在 self._fitted_params 中。
Args:
data: 训练数据
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""转换数据
在训练和测试阶段都会被调用。
使用 fit() 阶段学习到的参数进行转换。
Args:
data: 输入数据
Returns:
转换后的数据
"""
pass
def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""先fit再transform的便捷方法"""
return self.fit(data).transform(data)
def get_fitted_params(self) -> Dict[str, Any]:
"""获取学习到的参数(用于保存/加载)"""
return self._fitted_params.copy()
def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
"""设置学习到的参数(用于从checkpoint恢复)"""
self._fitted_params = params.copy()
self._is_fitted = True
return self
3.1.3 BaseModel (模型基类)
from abc import ABC, abstractmethod
from typing import Literal, Any, Dict
import polars as pl
import numpy as np
TaskType = Literal["classification", "regression", "ranking"]
class BaseModel(ABC):
"""机器学习模型基类
统一接口支持多种模型(LightGBM, CatBoost, XGBoost等)
和多种任务类型(分类、回归、排序)。
示例:
>>> model = LightGBMModel(
... task_type="classification",
... params={"n_estimators": 100}
... )
>>> model.fit(X_train, y_train)
>>> predictions = model.predict(X_test)
"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None
):
"""初始化模型
Args:
task_type: 任务类型 - "classification", "regression", "ranking"
params: 模型特定参数
name: 模型名称(用于日志和报告)
"""
self.task_type = task_type
self.params = params or {}
self.name = name or self.__class__.__name__
self._model: Any = None
self._is_fitted = False
@abstractmethod
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params
) -> "BaseModel":
"""训练模型
Args:
X: 特征数据
y: 目标变量
X_val: 验证集特征(可选)
y_val: 验证集目标(可选)
**fit_params: 额外的fit参数
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征数据
Returns:
预测结果数组
- classification: 类别标签或概率
- regression: 连续值
- ranking: 排序分数
"""
pass
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率(仅分类任务)
Args:
X: 特征数据
Returns:
类别概率数组 [n_samples, n_classes]
"""
raise NotImplementedError("predict_proba only available for classification tasks")
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性(如果模型支持)
Returns:
DataFrame[feature, importance] 或 None
"""
return None
def save(self, path: str) -> None:
"""保存模型到文件"""
import pickle
with open(path, 'wb') as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str) -> "BaseModel":
"""从文件加载模型"""
import pickle
with open(path, 'rb') as f:
return pickle.load(f)
3.1.4 BaseSplitter (数据划分基类)
from abc import ABC, abstractmethod
from typing import Iterator, Tuple, List
import polars as pl
class BaseSplitter(ABC):
"""数据划分策略基类
针对时间序列数据的特殊划分策略,防止未来泄露。
示例:
>>> splitter = TimeSeriesSplit(n_splits=5, gap=5)
>>> for train_idx, test_idx in splitter.split(data):
... X_train, X_test = X[train_idx], X[test_idx]
"""
@abstractmethod
def split(
self,
data: pl.DataFrame,
date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引
Args:
data: 完整数据集
date_col: 日期列名
Yields:
(train_indices, test_indices) 元组
"""
pass
@abstractmethod
def get_split_dates(
self,
data: pl.DataFrame,
date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围
Returns:
[(train_start, train_end, test_start, test_end), ...]
"""
pass
3.2 核心组件
3.2.1 FeatureStore (特征存储)
from typing import List, Optional, Dict
import polars as pl
from pathlib import Path
class FeatureStore:
"""特征存储管理器
负责加载、合并、缓存因子数据。
支持从多个数据源(因子、标签、行情)加载并合并。
"""
def __init__(self, data_dir: str):
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
def load_factors(
self,
factor_names: List[str],
start_date: Optional[str] = None,
end_date: Optional[str] = None,
stock_codes: Optional[List[str]] = None
) -> pl.DataFrame:
"""加载因子数据
Args:
factor_names: 因子名称列表
start_date: 开始日期 YYYYMMDD
end_date: 结束日期 YYYYMMDD
stock_codes: 股票代码列表(可选)
Returns:
DataFrame[trade_date, ts_code, factor1, factor2, ...]
"""
pass
def load_labels(
self,
label_name: str,
forward_period: int = 5,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> pl.DataFrame:
"""加载标签数据(未来收益)
Args:
label_name: 标签名称(如 "return", "rank")
forward_period: 前瞻期(如5天后收益)
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame[trade_date, ts_code, label]
"""
pass
def build_dataset(
self,
factor_names: List[str],
label_config: Dict,
date_range: Tuple[str, str],
stock_codes: Optional[List[str]] = None,
additional_cols: Optional[List[str]] = None
) -> pl.DataFrame:
"""构建完整数据集
合并因子、标签、辅助列,并对齐数据。
Args:
factor_names: 因子列表
label_config: 标签配置 {"name": str, "forward_period": int}
date_range: (start_date, end_date)
stock_codes: 限定股票列表
additional_cols: 额外列(如 industry, market_cap)
Returns:
DataFrame[trade_date, ts_code, factor_cols..., label]
"""
pass
3.2.2 ProcessingPipeline (处理流水线)
from typing import List
import polars as pl
class ProcessingPipeline:
"""数据处理流水线
按顺序执行多个处理器,自动处理阶段标记。
关键特性:在测试阶段使用训练阶段学习到的参数。
"""
def __init__(self, processors: List[BaseProcessor]):
"""初始化流水线
Args:
processors: 处理器列表(按执行顺序)
"""
self.processors = processors
self._fitted_processors: Dict[int, BaseProcessor] = {}
def fit_transform(
self,
data: pl.DataFrame,
stage: PipelineStage = PipelineStage.TRAIN
) -> pl.DataFrame:
"""在训练数据上fit所有处理器并transform
Args:
data: 训练数据
stage: 当前阶段标记
Returns:
处理后的数据
"""
result = data
for i, processor in enumerate(self.processors):
# 检查处理器是否适用于当前阶段
if processor.stage in [PipelineStage.ALL, stage]:
# fit并transform
result = processor.fit_transform(result)
self._fitted_processors[i] = processor
elif stage == PipelineStage.TRAIN:
# 即使不适用于TRAIN阶段,也要fit(为TEST阶段准备)
if processor.stage == PipelineStage.TEST:
processor.fit(result)
self._fitted_processors[i] = processor
return result
def transform(
self,
data: pl.DataFrame,
stage: PipelineStage = PipelineStage.TEST
) -> pl.DataFrame:
"""在测试数据上应用已fit的处理器
使用训练阶段学习到的参数,防止数据泄露。
Args:
data: 测试数据
stage: 当前阶段标记
Returns:
处理后的数据
"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
if i in self._fitted_processors:
# 使用已fit的处理器
result = self._fitted_processors[i].transform(result)
else:
# 未fit的处理器(ALL阶段但train时没执行到)
result = processor.transform(result)
return result
def save_processors(self, path: str) -> None:
"""保存所有已fit的处理器状态"""
import pickle
with open(path, 'wb') as f:
pickle.dump(self._fitted_processors, f)
def load_processors(self, path: str) -> None:
"""加载处理器状态"""
import pickle
with open(path, 'rb') as f:
self._fitted_processors = pickle.load(f)
4. 插件系统
4.1 注册器模式
from typing import Type, Dict, TypeVar
from functools import wraps
T = TypeVar('T')
class PluginRegistry:
"""插件注册中心
提供装饰器方式注册处理器、模型、划分策略等组件。
实现真正的插件式架构 - 新功能只需注册即可使用。
"""
_processors: Dict[str, Type[BaseProcessor]] = {}
_models: Dict[str, Type[BaseModel]] = {}
_splitters: Dict[str, Type[BaseSplitter]] = {}
_metrics: Dict[str, Type["BaseMetric"]] = {}
@classmethod
def register_processor(cls, name: Optional[str] = None):
"""注册处理器装饰器
示例:
>>> @PluginRegistry.register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>> # 使用
>>> scaler = PluginRegistry.get_processor("standard_scaler")()
"""
def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
key = name or processor_class.__name__
cls._processors[key] = processor_class
processor_class._registry_name = key
return processor_class
return decorator
@classmethod
def register_model(cls, name: Optional[str] = None):
"""注册模型装饰器"""
def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
key = name or model_class.__name__
cls._models[key] = model_class
model_class._registry_name = key
return model_class
return decorator
@classmethod
def register_splitter(cls, name: Optional[str] = None):
"""注册划分策略装饰器"""
def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
key = name or splitter_class.__name__
cls._splitters[key] = splitter_class
return splitter_class
return decorator
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类"""
if name not in cls._processors:
raise KeyError(f"Processor '{name}' not found. Available: {list(cls._processors.keys())}")
return cls._processors[name]
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类"""
if name not in cls._models:
raise KeyError(f"Model '{name}' not found. Available: {list(cls._models.keys())}")
return cls._models[name]
@classmethod
def get_splitter(cls, name: str) -> Type[BaseSplitter]:
"""获取划分策略类"""
if name not in cls._splitters:
raise KeyError(f"Splitter '{name}' not found. Available: {list(cls._splitters.keys())}")
return cls._splitters[name]
@classmethod
def list_processors(cls) -> List[str]:
"""列出所有可用处理器"""
return list(cls._processors.keys())
@classmethod
def list_models(cls) -> List[str]:
"""列出所有可用模型"""
return list(cls._models.keys())
4.2 内置插件
# ========== 内置处理器 ==========
@PluginRegistry.register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
"""标准缩放处理器 - Z-score标准化"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "StandardScaler":
cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
self._fitted_params = {
"mean": {c: data[c].mean() for c in cols},
"std": {c: data[c].std() for c in cols},
"columns": cols
}
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self._fitted_params["columns"]:
mean = self._fitted_params["mean"][col]
std = self._fitted_params["std"][col]
if std > 0:
result = result.with_columns(
((pl.col(col) - mean) / std).alias(col)
)
return result
@PluginRegistry.register_processor("winsorizer")
class Winsorizer(BaseProcessor):
"""缩尾处理器 - 防止极端值影响"""
stage = PipelineStage.TRAIN # 只在训练阶段计算分位数
def __init__(self, columns=None, lower=0.01, upper=0.99):
super().__init__(columns)
self.lower = lower
self.upper = upper
def fit(self, data: pl.DataFrame) -> "Winsorizer":
cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
self._fitted_params = {
"lower": {c: data[c].quantile(self.lower) for c in cols},
"upper": {c: data[c].quantile(self.upper) for c in cols},
"columns": cols
}
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self._fitted_params["columns"]:
lower = self._fitted_params["lower"][col]
upper = self._fitted_params["upper"][col]
result = result.with_columns(
pl.col(col).clip(lower, upper).alias(col)
)
return result
@PluginRegistry.register_processor("neutralizer")
class Neutralizer(BaseProcessor):
"""行业/市值中性化处理器"""
stage = PipelineStage.ALL
def __init__(self, columns=None, group_col="industry", exclude_cols=None):
super().__init__(columns)
self.group_col = group_col
self.exclude_cols = exclude_cols or []
def fit(self, data: pl.DataFrame) -> "Neutralizer":
# 中性化通常在每个截面独立进行,不需要全局fit
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
# 按日期分组,对每个截面进行中性化
result = data
for col in self.columns or []:
if col in self.exclude_cols:
continue
# 分组去均值
result = result.with_columns(
(pl.col(col) - pl.col(col).mean().over(["trade_date", self.group_col]))
.alias(col)
)
return result
@PluginRegistry.register_processor("dropna")
class DropNAProcessor(BaseProcessor):
"""缺失值删除处理器"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
cols = self.columns or data.columns
return data.drop_nulls(subset=cols)
@PluginRegistry.register_processor("fillna")
class FillNAProcessor(BaseProcessor):
"""缺失值填充处理器"""
stage = PipelineStage.TRAIN
def __init__(self, columns=None, method="median"):
super().__init__(columns)
self.method = method
def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
fill_values = {}
for col in cols:
if self.method == "median":
fill_values[col] = data[col].median()
elif self.method == "mean":
fill_values[col] = data[col].mean()
elif self.method == "zero":
fill_values[col] = 0
self._fitted_params = {"fill_values": fill_values, "columns": cols}
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, val in self._fitted_params["fill_values"].items():
result = result.with_columns(pl.col(col).fill_null(val).alias(col))
return result
@PluginRegistry.register_processor("rank_transformer")
class RankTransformer(BaseProcessor):
"""排名转换处理器 - 转换为截面排名"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "RankTransformer":
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self.columns or []:
# 按日期分组计算排名
result = result.with_columns(
pl.col(col).rank().over("trade_date").alias(col)
)
return result
# ========== 内置模型 ==========
@PluginRegistry.register_model("lightgbm")
class LightGBMModel(BaseModel):
"""LightGBM模型包装器"""
def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params
) -> "LightGBMModel":
import lightgbm as lgb
# 转换数据格式
X_arr = X.to_numpy()
y_arr = y.to_numpy()
# 构建数据集
train_data = lgb.Dataset(X_arr, label=y_arr)
valid_sets = [train_data]
if X_val is not None and y_val is not None:
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
valid_sets.append(valid_data)
# 设置默认参数
default_params = {
"objective": self._get_objective(),
"metric": self._get_metric(),
"boosting_type": "gbdt",
"num_leaves": 31,
"learning_rate": 0.05,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbose": -1
}
default_params.update(self.params)
# 训练
self._model = lgb.train(
default_params,
train_data,
num_boost_round=fit_params.get("num_boost_round", 100),
valid_sets=valid_sets,
callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=False)] if len(valid_sets) > 1 else []
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_numpy())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
probs = self.predict(X)
if len(probs.shape) == 1:
return np.vstack([1 - probs, probs]).T
return probs
def get_feature_importance(self) -> Optional[pl.DataFrame]:
if self._model is None:
return None
importance = self._model.feature_importance(importance_type="gain")
return pl.DataFrame({
"feature": self._model.feature_name(),
"importance": importance
}).sort("importance", descending=True)
def _get_objective(self) -> str:
if self.task_type == "classification":
return "binary"
elif self.task_type == "regression":
return "regression"
elif self.task_type == "ranking":
return "lambdarank"
return "regression"
def _get_metric(self) -> str:
if self.task_type == "classification":
return "auc"
elif self.task_type == "regression":
return "rmse"
elif self.task_type == "ranking":
return "ndcg"
return "rmse"
@PluginRegistry.register_model("catboost")
class CatBoostModel(BaseModel):
"""CatBoost模型包装器"""
def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params
) -> "CatBoostModel":
from catboost import CatBoostClassifier, CatBoostRegressor
# 选择模型类型
if self.task_type == "classification":
model_class = CatBoostClassifier
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
elif self.task_type == "regression":
model_class = CatBoostRegressor
default_params = {"loss_function": "RMSE"}
else: # ranking
model_class = CatBoostRegressor
default_params = {"loss_function": "QueryRMSE"}
default_params.update(self.params)
default_params["verbose"] = False
self._model = model_class(**default_params)
# 准备验证集
eval_set = None
if X_val is not None and y_val is not None:
eval_set = (X_val.to_pandas(), y_val.to_pandas())
# 训练
self._model.fit(
X.to_pandas(),
y.to_pandas(),
eval_set=eval_set,
early_stopping_rounds=10,
verbose=False
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_pandas())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
return self._model.predict_proba(X.to_pandas())
def get_feature_importance(self) -> Optional[pl.DataFrame]:
if self._model is None:
return None
return pl.DataFrame({
"feature": self._model.feature_names_,
"importance": self._model.feature_importances_
}).sort("importance", descending=True)
# ========== 内置划分策略 ==========
@PluginRegistry.register_splitter("time_series")
class TimeSeriesSplit(BaseSplitter):
"""时间序列划分 - 确保训练数据在测试数据之前"""
def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
self.n_splits = n_splits
self.gap = gap
self.min_train_size = min_train_size
def split(self, data: pl.DataFrame, date_col: str = "trade_date"):
dates = data[date_col].unique().sort()
n_dates = len(dates)
# 计算每个split的测试集大小
test_size = (n_dates - self.min_train_size) // self.n_splits
for i in range(self.n_splits):
# 训练集结束位置
train_end_idx = self.min_train_size + i * test_size
# 测试集开始位置(留gap防止泄露)
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
train_dates = dates[:train_end_idx]
test_dates = dates[test_start_idx:test_end_idx]
train_mask = data[date_col].is_in(train_dates)
test_mask = data[date_col].is_in(test_dates)
train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list()
test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list()
yield train_idx, test_idx
def get_split_dates(self, data: pl.DataFrame, date_col: str = "trade_date"):
dates = data[date_col].unique().sort()
n_dates = len(dates)
test_size = (n_dates - self.min_train_size) // self.n_splits
result = []
for i in range(self.n_splits):
train_end_idx = self.min_train_size + i * test_size
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
result.append((
dates[0],
dates[train_end_idx - 1],
dates[test_start_idx],
dates[test_end_idx - 1]
))
return result
@PluginRegistry.register_splitter("walk_forward")
class WalkForwardSplit(BaseSplitter):
"""滚动前向验证 - 训练集逐步扩展"""
def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
self.train_window = train_window
self.test_window = test_window
self.gap = gap
def split(self, data: pl.DataFrame, date_col: str = "trade_date"):
dates = data[date_col].unique().sort()
n_dates = len(dates)
start_idx = self.train_window
while start_idx + self.gap + self.test_window <= n_dates:
train_start = start_idx - self.train_window
train_end = start_idx
test_start = start_idx + self.gap
test_end = test_start + self.test_window
train_dates = dates[train_start:train_end]
test_dates = dates[test_start:test_end]
train_mask = data[date_col].is_in(train_dates)
test_mask = data[date_col].is_in(test_dates)
train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list()
test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list()
yield train_idx, test_idx
start_idx += self.test_window
5. 使用示例
5.1 基础用法
from src.models import (
FeatureStore, ProcessingPipeline, PluginRegistry,
PipelineStage, MLPipeline
)
# 1. 创建数据存储
store = FeatureStore(data_dir="data")
# 2. 构建数据集
dataset = store.build_dataset(
factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"],
label_config={"name": "forward_return", "forward_period": 5},
date_range=("20200101", "20241231")
)
# 3. 创建处理流水线
processors = [
# 删除缺失值
PluginRegistry.get_processor("dropna")(),
# 异常值处理(只在训练阶段计算分位数)
PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
# 中性化(行业和市值中性化)
PluginRegistry.get_processor("neutralizer")(group_col="industry"),
# 标准化(训练和测试都使用)
PluginRegistry.get_processor("standard_scaler")(),
]
pipeline = ProcessingPipeline(processors)
# 4. 创建划分策略
splitter = PluginRegistry.get_splitter("time_series")(
n_splits=5,
gap=5,
min_train_size=252
)
# 5. 创建模型
model = PluginRegistry.get_model("lightgbm")(
task_type="regression",
params={"n_estimators": 200, "learning_rate": 0.03}
)
# 6. 运行完整流程
ml_pipeline = MLPipeline(
feature_store=store,
processing_pipeline=pipeline,
splitter=splitter,
model=model
)
results = ml_pipeline.run(
factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"],
label_config={"name": "forward_return", "forward_period": 5},
date_range=("20200101", "20241231")
)
# 7. 查看结果
print(results.metrics) # 各折的评估指标
print(results.feature_importance) # 特征重要性
print(results.predictions) # 预测结果
5.2 配置驱动用法(推荐)
# config.yaml
experiment:
name: "momentum_factor_regression"
data:
factor_names: ["momentum_5", "momentum_20", "momentum_60", "volatility_20"]
label:
name: "forward_return"
forward_period: 5
date_range: ["20200101", "20241231"]
processing:
- name: "dropna"
params: {}
stage: "all"
- name: "winsorizer"
params:
lower: 0.01
upper: 0.99
stage: "train" # 只在训练阶段计算分位数
- name: "neutralizer"
params:
group_col: "industry"
stage: "all"
- name: "standard_scaler"
params: {}
stage: "all"
splitting:
strategy: "time_series"
params:
n_splits: 5
gap: 5
min_train_size: 252
model:
name: "lightgbm"
task_type: "regression"
params:
n_estimators: 200
learning_rate: 0.03
max_depth: 6
evaluation:
metrics: ["ic", "rank_ic", "mse", "mae"]
output_dir: "results/momentum_experiment"
# 代码中使用配置
from src.models import MLPipeline
pipeline = MLPipeline.from_config("config.yaml")
results = pipeline.run()
# 保存结果
results.save("results/momentum_experiment")
5.3 自定义插件
# 1. 创建自定义处理器
@PluginRegistry.register_processor("my_transformer")
class MyTransformer(BaseProcessor):
"""自定义转换器示例"""
stage = PipelineStage.ALL
def __init__(self, columns=None, multiplier=2.0):
super().__init__(columns)
self.multiplier = multiplier
def fit(self, data: pl.DataFrame) -> "MyTransformer":
# 学习参数(如有需要)
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col in self.columns or []:
result = result.with_columns(
(pl.col(col) * self.multiplier).alias(col)
)
return result
# 2. 创建自定义模型
@PluginRegistry.register_model("my_model")
class MyModel(BaseModel):
"""自定义模型示例"""
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
# 实现训练逻辑
self._model = ...
return self
def predict(self, X):
# 实现预测逻辑
return self._model.predict(X)
# 3. 在配置中使用
# config.yaml
processing:
- name: "my_transformer"
params:
multiplier: 3.0
stage: "all"
model:
name: "my_model"
task_type: "regression"
6. 目录结构
src/
├── models/ # 模型训练框架
│ ├── __init__.py # 导出主要类
│ ├── core/ # 核心抽象和基类
│ │ ├── __init__.py
│ │ ├── processor.py # BaseProcessor, PipelineStage
│ │ ├── model.py # BaseModel, TaskType
│ │ ├── splitter.py # BaseSplitter
│ │ ├── metric.py # BaseMetric
│ │ └── pipeline.py # MLPipeline (编排器)
│ │
│ ├── registry.py # PluginRegistry 插件注册中心
│ │
│ ├── data/ # 数据相关
│ │ ├── __init__.py
│ │ ├── feature_store.py # FeatureStore 特征存储
│ │ ├── label_generator.py # LabelGenerator 标签生成
│ │ └── dataset.py # Dataset 数据集包装
│ │
│ ├── processors/ # 内置处理器
│ │ ├── __init__.py # 自动注册所有处理器
│ │ ├── scaler.py # StandardScaler
│ │ ├── winsorizer.py # Winsorizer
│ │ ├── neutralizer.py # Neutralizer
│ │ ├── imputer.py # FillNAProcessor
│ │ ├── selector.py # FeatureSelector
│ │ └── custom.py # 其他处理器
│ │
│ ├── models/ # 内置模型
│ │ ├── __init__.py # 自动注册所有模型
│ │ ├── lightgbm_model.py # LightGBMModel
│ │ ├── catboost_model.py # CatBoostModel
│ │ └── sklearn_model.py # SklearnModel (LR, RF等)
│ │
│ ├── splitters/ # 划分策略
│ │ ├── __init__.py
│ │ ├── time_series.py # TimeSeriesSplit
│ │ ├── walk_forward.py # WalkForwardSplit
│ │ └── purged.py # PurgedKFold
│ │
│ ├── metrics/ # 评估指标
│ │ ├── __init__.py
│ │ ├── ic.py # IC, RankIC
│ │ ├── returns.py # 收益指标
│ │ └── classification.py # 分类指标
│ │
│ ├── evaluation/ # 评估和报告
│ │ ├── __init__.py
│ │ ├── evaluator.py # ModelEvaluator
│ │ ├── report.py # ReportGenerator
│ │ └── visualizer.py # ResultVisualizer
│ │
│ └── config/ # 配置解析
│ ├── __init__.py
│ └── parser.py # ConfigParser
│
├── factors/ # 已有因子框架
│ └── ...
│
tests/
├── models/ # 模型框架测试
│ ├── __init__.py
│ ├── test_processors.py # 处理器测试
│ ├── test_models.py # 模型测试
│ ├── test_pipeline.py # 流水线集成测试
│ └── test_registry.py # 注册器测试
│
└── factors/ # 已有因子测试
└── ...
configs/ # 配置文件目录
├── momentum_regression.yaml
├── value_classification.yaml
└├── ranking_lambdamart.yaml
experiments/ # 实验结果目录
└── {experiment_name}/
├── config.yaml # 实验配置
├── model.pkl # 保存的模型
├── processors.pkl # 保存的处理器状态
├── predictions.parquet # 预测结果
├── metrics.json # 评估指标
├── feature_importance.csv # 特征重要性
└── report.html # 可视化报告
7. 开发计划
Phase 1: 核心基础设施 (Week 1-2)
- 设计并实现
BaseProcessor,BaseModel,BaseSplitter抽象类 - 实现
PluginRegistry注册中心 - 实现
PipelineStage阶段管理 - 编写基础单元测试
Phase 2: 数据层 (Week 2-3)
- 实现
FeatureStore特征存储 - 实现
LabelGenerator标签生成器 - 实现
Dataset数据集包装 - 集成现有因子框架输出
Phase 3: 处理器 (Week 3-4)
- 实现
StandardScaler标准化处理器 - 实现
Winsorizer缩尾处理器 - 实现
Neutralizer中性化处理器 - 实现
FillNAProcessor缺失值处理器 - 实现
DropNAProcessor缺失值删除处理器 - 实现
FeatureSelector特征选择器 - 实现
ProcessingPipeline流水线
Phase 4: 模型层 (Week 4-5)
- 实现
LightGBMModelLightGBM包装 - 实现
CatBoostModelCatBoost包装 - 实现
SklearnModelsklearn模型支持 - 支持 classification/regression/ranking 三种任务
Phase 5: 划分策略 (Week 5)
- 实现
TimeSeriesSplit时间序列划分 - 实现
WalkForwardSplit滚动前向验证 - 实现
PurgedKFold清除重叠样本
Phase 6: 评估层 (Week 5-6)
- 实现 IC/RankIC 指标
- 实现收益分析指标
- 实现分类指标
- 实现
ModelEvaluator评估器 - 实现
ReportGenerator报告生成
Phase 7: 配置和编排 (Week 6)
- 实现配置解析器
- 实现
MLPipeline编排器 - 支持配置驱动执行
Phase 8: 集成测试和文档 (Week 7)
- 编写完整集成测试
- 编写使用文档
- 编写示例代码
- 性能基准测试
8. 关键设计决策
| 决策点 | 选择 | 理由 |
|---|---|---|
| 数据处理阶段标记 | PipelineStage 枚举 |
显式、类型安全、易于扩展 |
| 插件注册方式 | 装饰器模式 | Pythonic、简洁、自动发现 |
| 数据格式 | Polars DataFrame | 与因子框架一致、高性能 |
| 模型接口 | fit/predict 统一接口 |
行业标准、易于替换模型 |
| 配置格式 | YAML | 人类可读、支持复杂结构 |
| 处理器状态保存 | pickle | 简单、Python原生、支持大部分对象 |
| 特征存储 | 从因子框架直接读取 | 避免数据冗余、保持一致性 |
9. 防数据泄露检查清单
- 处理器明确标记适用阶段 (
stage属性) TRAIN阶段处理器只在训练数据上fitTEST阶段使用训练阶段学习到的参数- 划分策略支持时间序列感知 (
TimeSeriesSplit,WalkForwardSplit) - 划分时支持
gap参数防止相邻样本泄露 - 特征存储从已计算的因子加载(不访问未来数据)
- 标签生成使用预定义的前瞻期(明确的future data)
文档版本: v1.0 最后更新: 2026-02-23 设计状态: 草案 - 待评审