feat(models): 实现机器学习模型训练框架

- 添加核心抽象:Processor、Model、Splitter、Metric 基类
- 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露
- 内置 8 个数据处理器和 3 种时序划分策略
- 支持 LightGBM、CatBoost 模型
- PluginRegistry 装饰器注册,插件式架构
- 22 个单元测试
This commit is contained in:
2026-02-23 01:37:34 +08:00
parent e58b39970c
commit 9f95be56a0
16 changed files with 3774 additions and 865 deletions

292
README.md
View File

@@ -1,40 +1,300 @@
# ProStock
A股量化投资框架
A股量化投资框架 - 从数据获取到模型训练的完整解决方案
## 功能特性
### 1. 数据层 (src/data/)
- **多源数据接入**: Tushare API 集成,支持日线、股票基础信息、交易日历
- **DuckDB 存储**: 高性能嵌入式数据库,支持 SQL 查询下推
- **智能同步**: 增量/全量同步策略,自动检测数据更新需求
- **速率控制**: 令牌桶算法实现 API 限流
- **并发优化**: ThreadPoolExecutor 多线程数据获取
### 2. 因子层 (src/factors/)
- **类型安全**: 严格的截面因子 vs 时序因子区分
- **防泄露机制**: 框架层面防止未来数据和跨股票数据泄露
- **因子组合**: 支持因子加减乘除和标量运算
- **高性能计算**: Polars 向量化操作,零拷贝数据导出
- **灵活扩展**: 基类抽象便于自定义因子
### 3. 模型层 (src/models/)
- **插件架构**: 装饰器注册机制,新模型即插即用
- **阶段感知**: 训练/测试阶段区分,防止数据泄露
- **多模型支持**: LightGBM、CatBoost 等模型统一接口
- **数据处理**: 缺失值处理、缩尾、标准化、中性化等
- **时序划分**: WalkForward、ExpandingWindow 等时间序列划分策略
## 项目结构
```
ProStock/
├── src/
│ ├── config/ # 配置管理
│ │ ├── settings.py # pydantic-settings 配置
│ │ └── __init__.py
│ │
│ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── api_daily.py # 日线数据接口
│ │ │ ├── api_stock_basic.py # 股票基础信息
│ │ │ └── api_trade_cal.py # 交易日历
│ │ ├── client.py # Tushare 客户端(含限流)
│ │ ├── config.py # 数据模块配置
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── rate_limiter.py # 令牌桶限流器
│ │ ├── storage.py # DuckDB 存储核心
│ │ ├── sync.py # 数据同步主逻辑
│ │ └── __init__.py
│ │
│ ├── factors/ # 因子计算框架
│ │ ├── base.py # 因子基类(截面/时序)
│ │ ├── composite.py # 组合因子和标量运算
│ │ ├── data_loader.py # DuckDB 数据加载器
│ │ ├── data_spec.py # 数据规格定义
│ │ ├── engine.py # 因子执行引擎
│ │ └── __init__.py
│ │
│ ├── models/ # 模型训练框架
│ │ ├── core/ # 核心抽象
│ │ │ ├── base.py # 处理器/模型/划分基类
│ │ │ └── splitter.py # 时间序列划分策略
│ │ ├── models/ # 模型实现
│ │ │ └── models.py # LightGBM、CatBoost
│ │ ├── processors/ # 数据处理器
│ │ │ └── processors.py # 标准化、缩尾、中性化等
│ │ ├── pipeline.py # 处理流水线
│ │ ├── registry.py # 插件注册中心
│ │ └── __init__.py
│ │
│ └── __init__.py
├── docs/ # 文档
│ ├── factor_framework_design.md # 因子框架设计
│ ├── ml_framework_design.md # 模型框架设计
│ ├── db_sync_guide.md # 数据同步指南
│ └── ...
├── data/ # 数据存储DuckDB
│ ├── prostock.db # 主数据库文件
│ └── stock_basic.csv # 股票基础信息缓存
├── config/ # 配置文件
│ └── .env.local # 环境变量API Token等
└── tests/ # 测试文件
├── test_sync.py
└── factors/
```
## 快速开始
### 安装依赖
### 1. 安装依赖
**⚠️ 本项目强制使用 uv 作为 Python 包管理器,禁止直接使用 `python``pip` 命令。**
**⚠️ 本项目强制使用 uv 作为 Python 包管理器**
```bash
# 使用 uv 安装(必须)
# 安装 uv (如果尚未安装)
pip install uv
# 安装项目依赖
uv pip install -e .
```
### 数据同步
### 2. 配置环境变量
创建 `config/.env.local` 文件:
```bash
# 增量同步(自动从最新日期开始)
TUSHARE_TOKEN=your_tushare_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10
```
### 3. 数据同步
```bash
# 首次同步 - 全量同步从20180101开始
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 日常同步 - 增量同步(自动从最新日期开始)
uv run python -c "from src.data.sync import sync_all; sync_all()"
# 全量同步(从 20180101 开始
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 预览同步(检查需要同步的数据量
uv run python -c "from src.data.sync import preview_sync; preview_sync()"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
```
### 4. 查看数据库状态
```bash
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
```
## 使用示例
### 因子计算
```python
from src.factors import FactorEngine, DataLoader, DataSpec
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
import polars as pl
# 自定义截面因子PE排名
class PERankFactor(CrossSectionalFactor):
name = "pe_rank"
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)]
def compute(self, data) -> pl.Series:
cs = data.get_cross_section()
return cs["pe"].rank()
# 自定义时序因子20日移动平均
class MA20Factor(TimeSeriesFactor):
name = "ma20"
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
def compute(self, data) -> pl.Series:
return data.get_column("close").rolling_mean(window_size=20)
# 执行计算
loader = DataLoader(data_dir="data")
engine = FactorEngine(loader)
# 计算截面因子
pe_rank = PERankFactor()
result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131")
# 计算时序因子
ma20 = MA20Factor()
result2 = engine.compute(ma20, stock_codes=["000001.SZ"],
start_date="20240101", end_date="20240131")
# 因子组合
combined = 0.5 * pe_rank + 0.3 * ma20
```
### 模型训练
```python
from src.models import PluginRegistry, ProcessingPipeline
from src.models.core import PipelineStage
import polars as pl
# 创建处理流水线
pipeline = ProcessingPipeline([
PluginRegistry.get_processor("dropna")(),
PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
PluginRegistry.get_processor("standard_scaler")(),
])
# 准备数据
data = pl.read_csv("features.csv") # 包含特征和标签
# 划分训练/测试集
from src.models.core import WalkForwardSplit
splitter = WalkForwardSplit(train_window=252, test_window=21)
# 获取 LightGBM 模型
ModelClass = PluginRegistry.get_model("lightgbm")
model = ModelClass(task_type="regression", params={"n_estimators": 100})
# 训练循环
for train_idx, test_idx in splitter.split(data):
train_data = data[train_idx]
test_data = data[test_idx]
# 数据处理
X_train = pipeline.fit_transform(train_data.drop("target"))
X_test = pipeline.transform(test_data.drop("target"))
y_train = train_data["target"]
y_test = test_data["target"]
# 训练模型
model.fit(X_train, y_train)
predictions = model.predict(X_test)
```
## 核心设计
### 1. 数据防泄露机制
**截面因子 (CrossSectionalFactor)**:
- 防止日期泄露:每天只传入 `[T-lookback+1, T]` 数据
- 允许股票间比较:传入当天所有股票数据
- 典型应用PE排名、市值分位数、当日收益率排名
**时序因子 (TimeSeriesFactor)**:
- 防止股票泄露:每只股票单独计算
- 允许历史数据访问:传入完整时间序列
- 典型应用移动平均线、RSI、历史波动率
### 2. 插件注册机制
```python
from src.models.registry import PluginRegistry
# 注册自定义处理器
@PluginRegistry.register_processor("my_processor")
class MyProcessor(BaseProcessor):
stage = PipelineStage.TRAIN
def fit(self, data):
# 学习参数
return self
def transform(self, data):
# 转换数据
return data
# 使用
processor_class = PluginRegistry.get_processor("my_processor")
processor = processor_class()
```
### 3. 数据同步策略
**智能增量同步**:
```python
from src.data.db_manager import SyncManager
manager = SyncManager()
result = manager.sync(
table_name="daily",
fetch_func=get_daily,
start_date="20240101",
end_date="20240131"
)
# 自动检测:表不存在→全量,表存在→增量
```
## 文档
- [数据同步模块](docs/data_sync.md) - 详细的数据同步使用说明
- [因子框架设计](docs/factor_framework_design.md) - 因子计算架构详解
- [模型框架设计](docs/ml_framework_design.md) - 模型训练架构详解
- [数据同步指南](docs/db_sync_guide.md) - DuckDB 数据同步 API 说明
- [代码审查报告](docs/code_review_factors_20260222.md) - 因子框架代码审查
## 模块
## 开发规范
- `data/` - 数据获取
- `factors/` - 因子生成
- `models/` - 模型训练
- `backtest/` - 回测分析
- `utils/` - 工具函数
- `scripts/` - 运行脚本
- **Python 版本**: 3.10+
- **代码风格**: Google 风格文档字符串
- **类型提示**: 强制类型注解
- **测试**: pytest 框架
- **包管理**: uv (禁止直接使用 pip/python)
## 技术栈
- **数据处理**: Polars, Pandas, NumPy
- **数据存储**: DuckDB (嵌入式 OLAP 数据库)
- **API 接口**: Tushare Pro
- **机器学习**: LightGBM, CatBoost, scikit-learn
- **配置管理**: pydantic-settings
## 许可证
MIT License

View File

@@ -1,846 +0,0 @@
# ProStock 因子框架实现计划
## 目录结构
```
src/factors/
├── __init__.py # 导出主要类
├── data_spec.py # Phase 1: 数据类型定义
├── base.py # Phase 2: 因子基类
├── composite.py # Phase 2: 组合因子
├── data_loader.py # Phase 3: 数据加载
├── engine.py # Phase 4: 执行引擎
└── builtin/ # Phase 5: 内置因子库
├── __init__.py
├── momentum.py # 截面动量因子
├── technical.py # 时序技术指标
└── value.py # 截面估值因子
tests/factors/ # Phase 6-7: 测试
├── __init__.py
├── test_data_spec.py # 数据类型测试
├── test_base.py # 因子基类测试
├── test_composite.py # 组合因子测试
├── test_data_loader.py # 数据加载测试
├── test_engine.py # 引擎测试
├── test_builtin.py # 内置因子测试
└── test_integration.py # 集成测试
```
---
## Phase 1: 数据类型定义 (data_spec.py)
### 1.1 DataSpec - 数据需求规格
**实现要求:**
```python
@dataclass(frozen=True)
class DataSpec:
"""
数据需求规格说明
Args:
source: H5 文件名(如 "daily", "fundamental"
columns: 需要的列名列表,必须包含 "ts_code""trade_date"
lookback_days: 需要回看的天数(包含当日)
- 1 表示只需要当日数据 [T]
- 5 表示需要 [T-4, T] 共5天
- 20 表示需要 [T-19, T] 共20天
"""
source: str
columns: List[str]
lookback_days: int = 1
```
**约束验证:**
- `lookback_days >= 1`(至少包含当日)
- `columns` 必须包含 `ts_code``trade_date`
- `source` 不能为空字符串
**测试需求:**
- [ ] 测试有效 DataSpec 创建
- [ ] 测试 `lookback_days < 1` 时抛出 ValueError
- [ ] 测试缺少 `ts_code``trade_date` 时抛出 ValueError
- [ ] 测试空 `source` 时抛出 ValueError
- [ ] 测试 frozen 特性(创建后不可修改)
---
### 1.2 FactorContext - 计算上下文
**实现要求:**
```python
@dataclass
class FactorContext:
"""
因子计算上下文
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问
Attributes:
current_date: 当前计算日期 YYYYMMDD截面因子使用
current_stock: 当前计算股票代码(时序因子使用)
trade_dates: 交易日历列表(可选,用于对齐)
"""
current_date: Optional[str] = None
current_stock: Optional[str] = None
trade_dates: Optional[List[str]] = None
```
**测试需求:**
- [ ] 测试默认值创建
- [ ] 测试完整参数创建
- [ ] 测试 dataclass 自动生成的方法
---
### 1.3 FactorData - 数据容器
**实现要求:**
```python
class FactorData:
"""
提供给因子的数据容器
封装底层 Polars DataFrame提供安全的数据访问接口
"""
def __init__(self, df: pl.DataFrame, context: FactorContext):
self._df = df
self._context = context
def get_column(self, col: str) -> pl.Series:
"""
获取指定列的数据
- 截面因子:获取当天所有股票的该列值
- 时序因子:获取该股票时间序列的该列值
Args:
col: 列名
Returns:
Polars Series
Raises:
KeyError: 列不存在
"""
pass
def filter_by_date(self, date: str) -> "FactorData":
"""
按日期过滤数据,返回新的 FactorData
主要用于截面因子获取特定日期的数据
Args:
date: YYYYMMDD 格式的日期
Returns:
过滤后的 FactorData
"""
pass
def get_cross_section(self) -> pl.DataFrame:
"""
获取当前日期的截面数据
仅适用于截面因子,返回 current_date 当天的所有股票数据
Returns:
DataFrame 包含当前日期的所有股票
Raises:
ValueError: current_date 未设置(非截面因子场景)
"""
pass
def to_polars(self) -> pl.DataFrame:
"""获取底层的 Polars DataFrame高级用法"""
pass
@property
def context(self) -> FactorContext:
"""获取计算上下文"""
pass
def __len__(self) -> int:
"""返回数据行数"""
pass
```
**测试需求:**
- [ ] 测试 `get_column()` 返回正确 Series
- [ ] 测试 `get_column()` 列不存在时抛出 KeyError
- [ ] 测试 `filter_by_date()` 返回正确过滤结果
- [ ] 测试 `filter_by_date()` 日期不存在时返回空 DataFrame
- [ ] 测试 `get_cross_section()` 返回 current_date 当天的数据
- [ ] 测试 `get_cross_section()` current_date 为 None 时抛出 ValueError
- [ ] 测试 `to_polars()` 返回原始 DataFrame
- [ ] 测试 `context` 属性返回正确上下文
- [ ] 测试 `__len__()` 返回正确行数
---
## Phase 2: 因子基类 (base.py, composite.py)
### 2.1 BaseFactor - 抽象基类
**实现要求:**
```python
class BaseFactor(ABC):
"""
因子基类 - 定义通用接口
所有因子必须继承此类,并声明以下类属性:
- name: 因子唯一标识snake_case
- factor_type: "cross_sectional""time_series"
- data_specs: List[DataSpec] 数据需求列表
可选声明:
- category: 因子分类(默认 "default"
- description: 因子描述
"""
# 必须声明的类属性
name: str = ""
factor_type: str = "" # "cross_sectional" | "time_series"
data_specs: List[DataSpec] = field(default_factory=list)
# 可选声明的类属性
category: str = "default"
description: str = ""
def __init_subclass__(cls, **kwargs):
"""
子类创建时验证必须属性
验证项:
1. name 必须是非空字符串
2. factor_type 必须是 "cross_sectional""time_series"
3. data_specs 必须是非空列表
"""
pass
def __init__(self, **params):
"""
初始化因子参数
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
"""
self.params = params
self._validate_params()
def _validate_params(self):
"""
验证参数有效性
子类可覆盖此方法进行自定义验证
"""
pass
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
核心计算逻辑 - 子类必须实现
Args:
data: 安全的数据容器,已根据因子类型裁剪
Returns:
计算得到的因子值 Series
"""
pass
# ========== 因子组合运算符 ==========
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相加f1 + f2要求同类型"""
pass
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相减f1 - f2要求同类型"""
pass
def __mul__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相乘f1 * f2要求同类型"""
pass
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相除f1 / f2要求同类型"""
pass
def __rmul__(self, scalar: float) -> "ScalarFactor":
"""标量乘法0.5 * f1"""
pass
```
**测试需求:**
- [ ] 测试有效子类创建通过验证
- [ ] 测试缺少 `name` 时抛出 ValueError
- [ ] 测试 `name` 为空字符串时抛出 ValueError
- [ ] 测试缺少 `factor_type` 时抛出 ValueError
- [ ] 测试无效的 `factor_type`(非 cs/ts时抛出 ValueError
- [ ] 测试缺少 `data_specs` 时抛出 ValueError
- [ ] 测试 `data_specs` 为空列表时抛出 ValueError
- [ ] 测试 `compute()` 抽象方法强制子类实现
- [ ] 测试参数化初始化 `params` 正确存储
- [ ] 测试 `_validate_params()` 被调用
---
### 2.2 CrossSectionalFactor - 日期截面因子
**实现要求:**
```python
class CrossSectionalFactor(BaseFactor):
"""
日期截面因子基类
计算逻辑:在每个交易日,对所有股票进行横向计算
防泄露边界:
- ❌ 禁止访问未来日期的数据(日期泄露)
- ✅ 允许访问当前日期的所有股票数据
数据传入:
- compute() 接收的是 [T-lookback+1, T] 的数据
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
"""
factor_type: str = "cross_sectional"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算截面因子值
Args:
data: FactorData包含 [T-lookback+1, T] 的截面数据
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
示例:
def compute(self, data):
# 获取当前日期的截面
cs = data.get_cross_section()
# 计算市值排名
return cs['market_cap'].rank()
"""
pass
```
**测试需求:**
- [ ] 测试 `factor_type` 自动设置为 "cross_sectional"
- [ ] 测试子类必须实现 `compute()`
- [ ] 测试 `compute()` 返回类型为 pl.Series
---
### 2.3 TimeSeriesFactor - 时间序列因子
**实现要求:**
```python
class TimeSeriesFactor(BaseFactor):
"""
时间序列因子基类(股票截面)
计算逻辑:对每只股票,在其时间序列上进行纵向计算
防泄露边界:
- ❌ 禁止访问其他股票的数据(股票泄露)
- ✅ 允许访问该股票的完整历史数据
数据传入:
- compute() 接收的是单只股票的完整时间序列
- 包含该股票在 [start_date, end_date] 范围内的所有数据
"""
factor_type: str = "time_series"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算时间序列因子值
Args:
data: FactorData包含单只股票的完整时间序列
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
示例:
def compute(self, data):
series = data.get_column("close")
return series.rolling_mean(window_size=self.params['period'])
"""
pass
```
**测试需求:**
- [ ] 测试 `factor_type` 自动设置为 "time_series"
- [ ] 测试子类必须实现 `compute()`
- [ ] 测试 `compute()` 返回类型为 pl.Series
---
### 2.4 CompositeFactor - 组合因子 (composite.py)
**实现要求:**
```python
class CompositeFactor(BaseFactor):
"""
组合因子 - 用于实现因子间的数学运算
约束:左右因子必须是同类型(同为截面或同为时序)
"""
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
"""
创建组合因子
Args:
left: 左操作数因子
right: 右操作数因子
op: 运算符,支持 '+', '-', '*', '/'
Raises:
ValueError: 左右因子类型不一致
ValueError: 不支持的运算符
"""
pass
def _merge_data_specs(self) -> List[DataSpec]:
"""
合并左右因子的数据需求
策略:
1. 相同 source 和 columns 的 DataSpec 合并
2. lookback_days 取最大值
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""
执行组合运算
流程:
1. 分别计算 left 和 right 的值
2. 根据 op 执行运算
3. 返回结果
"""
pass
```
**测试需求:**
- [ ] 测试同类型因子组合成功cs + cs
- [ ] 测试同类型因子组合成功ts + ts
- [ ] 测试不同类型因子组合抛出 ValueErrorcs + ts
- [ ] 测试无效运算符抛出 ValueError
- [ ] 测试 `_merge_data_specs()` 正确合并(相同 source
- [ ] 测试 `_merge_data_specs()` 正确合并(不同 source
- [ ] 测试 `_merge_data_specs()` lookback 取最大值
- [ ] 测试 `compute()` 执行正确的数学运算
---
### 2.5 ScalarFactor - 标量运算因子 (composite.py)
**实现要求:**
```python
class ScalarFactor(BaseFactor):
"""
标量运算因子
支持scalar * factor, factor * scalar通过 __rmul__
"""
def __init__(self, factor: BaseFactor, scalar: float, op: str):
"""
创建标量运算因子
Args:
factor: 基础因子
scalar: 标量值
op: 运算符,支持 '*', '+'
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""执行标量运算"""
pass
```
**测试需求:**
- [ ] 测试标量乘法 `0.5 * factor`
- [ ] 测试标量乘法 `factor * 0.5`
- [ ] 测试标量加法(如支持)
- [ ] 测试继承基础因子的 data_specs
- [ ] 测试 `compute()` 返回正确缩放后的值
---
## Phase 3: 数据加载 (data_loader.py)
### 3.1 DataLoader - 数据加载器
**实现要求:**
```python
class DataLoader:
"""
数据加载器 - 负责从 HDF5 安全加载数据
功能:
1. 多文件聚合:合并多个 H5 文件的数据
2. 列选择:只加载需要的列
3. 原始数据缓存:避免重复读取
"""
def __init__(self, data_dir: str):
"""
初始化 DataLoader
Args:
data_dir: HDF5 文件所在目录
"""
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
def load(
self,
specs: List[DataSpec],
date_range: Optional[Tuple[str, str]] = None
) -> pl.DataFrame:
"""
加载并聚合多个 H5 文件的数据
流程:
1. 对每个 DataSpec
a. 检查缓存,命中则直接使用
b. 未命中则读取 HDF5通过 pandas
c. 转换为 Polars DataFrame
d. 按 date_range 过滤
e. 存入缓存
2. 合并多个 DataFrame按 trade_date 和 ts_code join
Args:
specs: 数据需求规格列表
date_range: 日期范围限制 (start_date, end_date),可选
Returns:
合并后的 Polars DataFrame
Raises:
FileNotFoundError: H5 文件不存在
KeyError: 列不存在于文件中
"""
pass
def clear_cache(self):
"""清空缓存"""
pass
def _read_h5(self, source: str) -> pl.DataFrame:
"""
读取单个 H5 文件
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
"""
pass
```
**测试需求:**
- [ ] 测试从单个 H5 文件加载数据
- [ ] 测试从多个 H5 文件加载并合并
- [ ] 测试列选择(只加载需要的列)
- [ ] 测试缓存机制(第二次加载更快)
- [ ] 测试 `clear_cache()` 清空缓存
- [ ] 测试按 date_range 过滤
- [ ] 测试文件不存在时抛出 FileNotFoundError
- [ ] 测试列不存在时抛出 KeyError
---
## Phase 4: 执行引擎 (engine.py)
### 4.1 FactorEngine - 因子执行引擎
**实现要求:**
```python
class FactorEngine:
"""
因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
核心职责:
1. CrossSectionalFactor防止日期泄露每天传入 [T-lookback+1, T] 数据
2. TimeSeriesFactor防止股票泄露每只股票传入完整序列
"""
def __init__(self, data_loader: DataLoader):
"""
初始化引擎
Args:
data_loader: 数据加载器实例
"""
self.data_loader = data_loader
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
"""
统一的计算入口
根据 factor_type 分发到具体方法:
- "cross_sectional" -> _compute_cross_sectional()
- "time_series" -> _compute_time_series()
Args:
factor: 要计算的因子
**kwargs: 额外参数,根据因子类型不同:
- 截面因子: start_date, end_date
- 时序因子: stock_codes, start_date, end_date
Returns:
DataFrame[trade_date, ts_code, factor_name]
"""
pass
```
**测试需求:**
- [ ] 测试 `compute()` 正确分发给截面计算
- [ ] 测试 `compute()` 正确分发给时序计算
- [ ] 测试无效 factor_type 时抛出 ValueError
---
### 4.2 截面计算(防止日期泄露)
**实现要求:**
```python
def _compute_cross_sectional(
self,
factor: CrossSectionalFactor,
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行日期截面计算
防泄露策略:
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
- 允许股票间比较:传入当天所有股票的数据
计算流程:
1. 计算 max_lookback确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每个日期 T in [start_date, end_date]
a. 裁剪数据到 [T-lookback+1, T]
b. 创建 FactorDatacurrent_date=T
c. 调用 factor.compute()
d. 收集结果
4. 合并所有日期的结果
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 0.5 │
│ 20240101 │ 000002.SZ│ 0.3 │
└────────────┴──────────┴──────────────┘
"""
pass
```
**测试需求(防泄露验证):**
- [ ] 测试数据裁剪正确(传入 [T-lookback+1, T]
- [ ] 测试不包含未来日期 T+1 的数据
- [ ] 测试每个日期独立计算
- [ ] 测试结果包含所有日期和所有股票
- [ ] 测试结果 DataFrame 格式正确
- [ ] 测试多个 DataSpec 时 lookback 取最大值
---
### 4.3 时序计算(防止股票泄露)
**实现要求:**
```python
def _compute_time_series(
self,
factor: TimeSeriesFactor,
stock_codes: List[str],
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行时间序列计算
防泄露策略:
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
- 允许访问历史数据:时序计算需要历史数据
计算流程:
1. 计算 max_lookback确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每只股票 S in stock_codes
a. 过滤出 S 的数据(防止股票泄露)
b. 创建 FactorDatacurrent_stock=S
c. 调用 factor.compute()(向量化计算整个序列)
d. 收集结果
4. 合并所有股票的结果
性能优势:
- 使用 Polars 的 rolling_mean 等向量化操作
- 每只股票只计算一次,无重复计算
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 10.5 │
│ 20240102 │ 000001.SZ│ 10.6 │
└────────────┴──────────┴──────────────┘
"""
pass
```
**测试需求(防泄露验证):**
- [ ] 测试每只股票只看到自己的数据
- [ ] 测试不包含其他股票的数据
- [ ] 测试传入的是完整时间序列(向量化计算)
- [ ] 测试结果包含所有股票和所有日期
- [ ] 测试结果 DataFrame 格式正确
- [ ] 测试股票不在数据中时跳过(或填充 null
---
## Phase 5: 内置因子库 (builtin/)
### 5.1 momentum.py - 截面动量因子
**实现因子:**
1. **ReturnRankFactor** - 当日收益率排名
```python
class ReturnRankFactor(CrossSectionalFactor):
"""当日收益率排名因子"""
name = "return_rank"
data_specs = [DataSpec("daily", ["close"], lookback_days=2)] # 需要2天计算收益率
def compute(self, data):
# 获取当前日期截面
cs = data.get_cross_section()
# 需要前1天和当天的收盘价lookback=2 保证数据包含 [T-1, T]
# 这里假设 data 已经包含历史,实际计算需要 groupby 处理
pass
```
**测试需求:**
- [ ] 测试收益率计算正确
- [ ] 测试排名计算正确
- [ ] 测试无数据时返回 null
2. **MomentumFactor** - 过去 N 日涨幅排名
---
### 5.2 technical.py - 时序技术指标
**实现因子:**
1. **MovingAverageFactor** - 移动平均线
```python
class MovingAverageFactor(TimeSeriesFactor):
"""移动平均线因子"""
name = "ma"
def __init__(self, period: int = 20):
super().__init__(period=period)
self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
def compute(self, data):
return data.get_column("close").rolling_mean(self.params["period"])
```
**测试需求:**
- [ ] 测试 MA20 计算正确
- [ ] 测试前19天返回 nullPolars 默认行为)
- [ ] 测试参数 period 生效
2. **RSIFactor** - RSI 指标
3. **MACDFactor** - MACD 指标
---
### 5.3 value.py - 截面估值因子
**实现因子:**
1. **PERankFactor** - PE 行业分位数
2. **PBFactor** - PB 排名
---
## Phase 6-7: 测试策略
### 测试金字塔
```
/\
/ \
/ 集成\ tests/factors/test_integration.py
/────────\
/ 引擎 \ tests/factors/test_engine.py
/────────────\
/ 基类/组合因子 \ tests/factors/test_base.py, test_composite.py
/────────────────\
/ 数据加载/类型 \ tests/factors/test_data_loader.py, test_data_spec.py
/──────────────────────\
```
### 测试数据准备
创建 `tests/fixtures/` 目录,包含:
- `sample_daily.h5`: 少量股票的日线数据(用于测试)
- `sample_fundamental.h5`: 基本面数据
### 关键测试场景
1. **防泄露测试(核心)**
- 截面因子:验证 compute() 中无法访问未来日期
- 时序因子:验证 compute() 中无法访问其他股票
2. **边界测试**
- lookback_days = 1最小值
- 数据起始点(前 N 天为 null
- 空数据/停牌处理
3. **性能测试(可选)**
- 大数据量下的内存占用
- 缓存命中率
---
## 实现状态
| Phase | 状态 | 完成日期 | 测试覆盖 |
|-------|------|----------|----------|
| Phase 1: 数据类型定义 | ✅ 已完成 | 2026-02-21 | 27 tests passed |
| Phase 2: 因子基类 | ✅ 已完成 | 2026-02-21 | 49 tests passed |
| Phase 3: 数据加载 | ✅ 已完成 | 2026-02-21 | 11 tests passed |
| Phase 4: 执行引擎 | ✅ 已完成 | 2026-02-22 | 10 tests passed |
| Phase 5: 内置因子库 | 📝 待开发 | - | - |
| Phase 6-7: 测试文档 | ✅ 已完成 | 2026-02-22 | 76 tests total |
---
## 实现顺序建议
1. **Week 1**: Phase 1-2数据类型 + 基类)
2. **Week 2**: Phase 3-4DataLoader + Engine**已完成**
3. **Week 3**: Phase 5内置因子
4. **Week 4**: Phase 6-7测试 + 文档)
每个 Phase 完成后运行对应测试,确保质量。

View File

@@ -1,10 +1,17 @@
# ProStock HDF5 到 DuckDB 迁移方案与计划
# ProStock HDF5 到 DuckDB 迁移方案
**文档版本**: v1.0
**文档版本**: v1.1
**创建日期**: 2026-02-22
**状态**: 待审批
**完成日期**: 2026-02-22
**状态**: ✅ 已完成
**影响范围**: data 模块、factors 模块、相关文档
## 相关文档
[DuckDB 数据同步指南](./db_sync_guide.md) - 同步 API 使用说明
[迁移测试报告](./test_report_duckdb_migration.md) - 测试验证结果
---
## 目录

1472
docs/ml_framework_design.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,8 @@
# ProStock HDF5 到 DuckDB 迁移测试报告
**报告生成时间**: 2026-02-22
**完成时间**: 2026-02-22
**状态**: ✅ 已完成
**迁移文档**: [hdf5_to_duckdb_migration.md](./hdf5_to_duckdb_migration.md)
**测试数据范围**: 2024年1月-3月3个月

86
src/models/__init__.py Normal file
View File

@@ -0,0 +1,86 @@
"""ProStock 模型训练框架
组件化、低耦合、插件式的机器学习训练框架。
示例:
>>> from src.models import (
... PluginRegistry, ProcessingPipeline,
... PipelineStage, BaseProcessor
... )
>>> # 获取注册的处理器
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
>>> scaler = scaler_class()
>>> # 创建处理流水线
>>> pipeline = ProcessingPipeline([
... PluginRegistry.get_processor("dropna")(),
... PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
... PluginRegistry.get_processor("standard_scaler")(),
... ])
"""
# 导入核心抽象类和划分策略
from src.models.core import (
PipelineStage,
TaskType,
BaseProcessor,
BaseModel,
BaseSplitter,
BaseMetric,
TimeSeriesSplit,
WalkForwardSplit,
ExpandingWindowSplit,
)
# 导入注册中心
from src.models.registry import PluginRegistry
# 导入处理流水线
from src.models.pipeline import ProcessingPipeline
# 导入并注册内置处理器
from src.models.processors.processors import (
DropNAProcessor,
FillNAProcessor,
Winsorizer,
StandardScaler,
MinMaxScaler,
RankTransformer,
Neutralizer,
)
# 导入并注册内置模型
from src.models.models.models import (
LightGBMModel,
CatBoostModel,
)
__all__ = [
# 核心抽象
"PipelineStage",
"TaskType",
"BaseProcessor",
"BaseModel",
"BaseSplitter",
"BaseMetric",
# 划分策略
"TimeSeriesSplit",
"WalkForwardSplit",
"ExpandingWindowSplit",
# 注册中心
"PluginRegistry",
# 处理流水线
"ProcessingPipeline",
# 处理器
"DropNAProcessor",
"FillNAProcessor",
"Winsorizer",
"StandardScaler",
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
# 模型
"LightGBMModel",
"CatBoostModel",
]

View File

@@ -0,0 +1,30 @@
"""核心模块导出"""
from src.models.core.base import (
PipelineStage,
TaskType,
BaseProcessor,
BaseModel,
BaseSplitter,
BaseMetric,
)
from src.models.core.splitter import (
TimeSeriesSplit,
WalkForwardSplit,
ExpandingWindowSplit,
)
__all__ = [
# 基础抽象
"PipelineStage",
"TaskType",
"BaseProcessor",
"BaseModel",
"BaseSplitter",
"BaseMetric",
# 划分策略
"TimeSeriesSplit",
"WalkForwardSplit",
"ExpandingWindowSplit",
]

351
src/models/core/base.py Normal file
View File

@@ -0,0 +1,351 @@
"""模型训练框架核心抽象类
提供处理器、模型、划分策略和评估指标的基类定义。
"""
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Any, Dict, Iterator, List, Optional, Tuple, Literal
import polars as pl
import numpy as np
# 任务类型
TaskType = Literal["classification", "regression", "ranking"]
class PipelineStage(Enum):
"""流水线阶段标记
用于标记处理器在哪些阶段生效,防止数据泄露。
Attributes:
ALL: 适用于所有阶段(训练、测试、验证)
TRAIN: 仅训练阶段
TEST: 仅测试阶段
VALIDATION: 仅验证阶段
"""
ALL = auto()
TRAIN = auto()
TEST = auto()
VALIDATION = auto()
class BaseProcessor(ABC):
"""数据处理器基类
所有数据处理器必须继承此类。关键特性是通过 stage 属性控制处理器在哪些阶段生效。
阶段标记规则:
- ALL: 训练和测试阶段都使用相同的参数
- TRAIN: 只在训练阶段计算参数(如分位数、均值等),测试阶段使用训练阶段学到的参数
- TEST: 只在测试阶段执行
"""
# 子类必须定义适用阶段
stage: PipelineStage = PipelineStage.ALL
def __init__(self, columns: Optional[List[str]] = None, **params):
"""初始化处理器
Args:
columns: 要处理的列None表示所有数值列
**params: 处理器特定参数
"""
self.columns = columns
self.params = params
self._is_fitted = False
self._fitted_params: Dict[str, Any] = {}
@abstractmethod
def fit(self, data: pl.DataFrame) -> "BaseProcessor":
"""在训练数据上学习参数
此方法只在训练阶段调用一次。学习到的参数存储在 self._fitted_params 中。
Args:
data: 训练数据
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""转换数据
在训练和测试阶段都会被调用。使用 fit() 阶段学习到的参数进行转换。
Args:
data: 输入数据
Returns:
转换后的数据
"""
pass
def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""先fit再transform的便捷方法
Args:
data: 训练数据
Returns:
转换后的数据
"""
return self.fit(data).transform(data)
def get_fitted_params(self) -> Dict[str, Any]:
"""获取学习到的参数(用于保存/加载)
Returns:
学习到的参数字典
"""
return self._fitted_params.copy()
def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
"""设置学习到的参数用于从checkpoint恢复
Args:
params: 参数字典
Returns:
self (支持链式调用)
"""
self._fitted_params = params.copy()
self._is_fitted = True
return self
class BaseModel(ABC):
"""机器学习模型基类
统一接口支持多种模型LightGBM, CatBoost, XGBoost等
和多种任务类型(分类、回归、排序)。
"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
):
"""初始化模型
Args:
task_type: 任务类型 - "classification", "regression", "ranking"
params: 模型特定参数
name: 模型名称(用于日志和报告)
"""
self.task_type = task_type
self.params = params or {}
self.name = name or self.__class__.__name__
self._model: Any = None
self._is_fitted = False
@abstractmethod
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params,
) -> "BaseModel":
"""训练模型
Args:
X: 特征数据
y: 目标变量
X_val: 验证集特征(可选)
y_val: 验证集目标(可选)
**fit_params: 额外的fit参数
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征数据
Returns:
预测结果数组
- classification: 类别标签或概率
- regression: 连续值
- ranking: 排序分数
"""
pass
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率(仅分类任务)
Args:
X: 特征数据
Returns:
类别概率数组 [n_samples, n_classes]
Raises:
NotImplementedError: 非分类任务时抛出
"""
raise NotImplementedError(
"predict_proba only available for classification tasks"
)
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性(如果模型支持)
Returns:
DataFrame[feature, importance] 或 None
"""
return None
def save(self, path: str) -> None:
"""保存模型到文件
Args:
path: 保存路径
"""
import pickle
with open(path, "wb") as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str) -> "BaseModel":
"""从文件加载模型
Args:
path: 模型文件路径
Returns:
加载的模型实例
"""
import pickle
with open(path, "rb") as f:
return pickle.load(f)
class BaseSplitter(ABC):
"""数据划分策略基类
针对时间序列数据的特殊划分策略,防止未来泄露。
"""
@abstractmethod
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引
Args:
data: 完整数据集
date_col: 日期列名
Yields:
(train_indices, test_indices) 元组
"""
pass
@abstractmethod
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围
Args:
data: 完整数据集
date_col: 日期列名
Returns:
[(train_start, train_end, test_start, test_end), ...]
"""
pass
class BaseMetric(ABC):
"""评估指标基类
所有评估指标必须继承此类。支持单次计算和累积计算两种模式。
"""
def __init__(self, name: Optional[str] = None):
"""初始化指标
Args:
name: 指标名称
"""
self.name = name or self.__class__.__name__
self._values: List[float] = []
@abstractmethod
def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""计算指标值
Args:
y_true: 真实值
y_pred: 预测值
Returns:
指标值
"""
pass
def update(self, y_true: np.ndarray, y_pred: np.ndarray) -> "BaseMetric":
"""更新累积值
Args:
y_true: 真实值
y_pred: 预测值
Returns:
self (支持链式调用)
"""
self._values.append(self.compute(y_true, y_pred))
return self
def get_mean(self) -> float:
"""获取累积值的均值
Returns:
均值
"""
if not self._values:
return 0.0
return float(np.mean(self._values))
def get_std(self) -> float:
"""获取累积值的标准差
Returns:
标准差
"""
if not self._values:
return 0.0
return float(np.std(self._values))
def reset(self) -> "BaseMetric":
"""重置累积值
Returns:
self (支持链式调用)
"""
self._values = []
return self
__all__ = [
"PipelineStage",
"TaskType",
"BaseProcessor",
"BaseModel",
"BaseSplitter",
"BaseMetric",
]

222
src/models/core/splitter.py Normal file
View File

@@ -0,0 +1,222 @@
"""时间序列数据划分策略
提供针对金融时间序列的特殊划分策略,防止未来泄露。
"""
from typing import Iterator, List, Tuple
import polars as pl
from src.models.core.base import BaseSplitter
class TimeSeriesSplit(BaseSplitter):
"""时间序列划分 - 确保训练数据在测试数据之前
按照时间顺序进行K折划分每折的训练数据都在测试数据之前。
通过 gap 参数防止训练集和测试集之间的数据泄露。
Args:
n_splits: 划分折数
gap: 训练集和测试集之间的间隔天数(防止泄露)
min_train_size: 最小训练集大小(天数)
"""
def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
self.n_splits = n_splits
self.gap = gap
self.min_train_size = min_train_size
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
test_size = (n_dates - self.min_train_size) // self.n_splits
for i in range(self.n_splits):
train_end_idx = self.min_train_size + i * test_size
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
train_dates = dates[:train_end_idx]
test_dates = dates[test_start_idx:test_end_idx]
train_mask = data[date_col].is_in(train_dates.to_list())
test_mask = data[date_col].is_in(test_dates.to_list())
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
yield train_idx, test_idx
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
test_size = (n_dates - self.min_train_size) // self.n_splits
result = []
for i in range(self.n_splits):
train_end_idx = self.min_train_size + i * test_size
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
result.append(
(
str(dates[0]),
str(dates[train_end_idx - 1]),
str(dates[test_start_idx]),
str(dates[test_end_idx - 1]),
)
)
return result
class WalkForwardSplit(BaseSplitter):
"""滚动前向验证 - 训练集逐步扩展
Args:
train_window: 训练集窗口大小(天数)
test_window: 测试集窗口大小(天数)
gap: 训练集和测试集之间的间隔天数
"""
def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
self.train_window = train_window
self.test_window = test_window
self.gap = gap
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
start_idx = self.train_window
while start_idx + self.gap + self.test_window <= n_dates:
train_start = start_idx - self.train_window
train_end = start_idx
test_start = start_idx + self.gap
test_end = test_start + self.test_window
train_dates = dates[train_start:train_end]
test_dates = dates[test_start:test_end]
train_mask = data[date_col].is_in(train_dates.to_list())
test_mask = data[date_col].is_in(test_dates.to_list())
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
yield train_idx, test_idx
start_idx += self.test_window
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
result = []
start_idx = self.train_window
while start_idx + self.gap + self.test_window <= n_dates:
train_start = start_idx - self.train_window
train_end = start_idx
test_start = start_idx + self.gap
test_end = test_start + self.test_window
result.append(
(
str(dates[train_start]),
str(dates[train_end - 1]),
str(dates[test_start]),
str(dates[test_end - 1]),
)
)
start_idx += self.test_window
return result
class ExpandingWindowSplit(BaseSplitter):
"""扩展窗口划分 - 训练集不断扩大
Args:
initial_train_size: 初始训练集大小(天数)
test_window: 测试集窗口大小(天数)
gap: 训练集和测试集之间的间隔天数
"""
def __init__(
self, initial_train_size: int = 252, test_window: int = 21, gap: int = 5
):
self.initial_train_size = initial_train_size
self.test_window = test_window
self.gap = gap
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
train_end_idx = self.initial_train_size
while train_end_idx + self.gap + self.test_window <= n_dates:
train_dates = dates[:train_end_idx]
test_start = train_end_idx + self.gap
test_end = test_start + self.test_window
test_dates = dates[test_start:test_end]
train_mask = data[date_col].is_in(train_dates.to_list())
test_mask = data[date_col].is_in(test_dates.to_list())
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
yield train_idx, test_idx
train_end_idx += self.test_window
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
result = []
train_end_idx = self.initial_train_size
while train_end_idx + self.gap + self.test_window <= n_dates:
test_start = train_end_idx + self.gap
test_end = test_start + self.test_window
result.append(
(
str(dates[0]),
str(dates[train_end_idx - 1]),
str(dates[test_start]),
str(dates[test_end - 1]),
)
)
train_end_idx += self.test_window
return result
__all__ = [
"TimeSeriesSplit",
"WalkForwardSplit",
"ExpandingWindowSplit",
]

View File

@@ -0,0 +1,11 @@
"""模型模块"""
from src.models.models.models import (
LightGBMModel,
CatBoostModel,
)
__all__ = [
"LightGBMModel",
"CatBoostModel",
]

210
src/models/models/models.py Normal file
View File

@@ -0,0 +1,210 @@
"""内置机器学习模型
提供 LightGBM、CatBoost 等模型的统一接口包装器。
"""
from typing import Optional, Dict, Any
import polars as pl
import numpy as np
from src.models.core import BaseModel, TaskType
from src.models.registry import PluginRegistry
@PluginRegistry.register_model("lightgbm")
class LightGBMModel(BaseModel):
"""LightGBM 模型包装器
支持分类、回归、排序三种任务类型。
"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params,
) -> "LightGBMModel":
"""训练模型"""
try:
import lightgbm as lgb
except ImportError:
raise ImportError(
"lightgbm is required. Install with: uv pip install lightgbm"
)
X_arr = X.to_numpy()
y_arr = y.to_numpy()
train_data = lgb.Dataset(X_arr, label=y_arr)
valid_sets = [train_data]
valid_names = ["train"]
if X_val is not None and y_val is not None:
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
valid_sets.append(valid_data)
valid_names.append("valid")
default_params = {
"objective": self._get_objective(),
"metric": self._get_metric(),
"boosting_type": "gbdt",
"num_leaves": 31,
"learning_rate": 0.05,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbose": -1,
}
default_params.update(self.params)
callbacks = []
if len(valid_sets) > 1:
callbacks.append(lgb.early_stopping(stopping_rounds=10, verbose=False))
self._model = lgb.train(
default_params,
train_data,
num_boost_round=fit_params.get("num_boost_round", 100),
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=callbacks,
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测"""
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_numpy())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率(仅分类任务)"""
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
probs = self.predict(X)
if len(probs.shape) == 1:
return np.vstack([1 - probs, probs]).T
return probs
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性"""
if self._model is None:
return None
importance = self._model.feature_importance(importance_type="gain")
feature_names = getattr(
self._model,
"feature_name",
lambda: [f"feature_{i}" for i in range(len(importance))],
)()
return pl.DataFrame({"feature": feature_names, "importance": importance}).sort(
"importance", descending=True
)
def _get_objective(self) -> str:
objectives = {
"classification": "binary",
"regression": "regression",
"ranking": "lambdarank",
}
return objectives.get(self.task_type, "regression")
def _get_metric(self) -> str:
metrics = {"classification": "auc", "regression": "rmse", "ranking": "ndcg"}
return metrics.get(self.task_type, "rmse")
@PluginRegistry.register_model("catboost")
class CatBoostModel(BaseModel):
"""CatBoost 模型包装器"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params,
) -> "CatBoostModel":
"""训练模型"""
try:
from catboost import CatBoostClassifier, CatBoostRegressor
except ImportError:
raise ImportError(
"catboost is required. Install with: uv pip install catboost"
)
if self.task_type == "classification":
model_class = CatBoostClassifier
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
elif self.task_type == "regression":
model_class = CatBoostRegressor
default_params = {"loss_function": "RMSE"}
else:
model_class = CatBoostRegressor
default_params = {"loss_function": "QueryRMSE"}
default_params.update(self.params)
default_params["verbose"] = False
self._model = model_class(**default_params)
eval_set = None
if X_val is not None and y_val is not None:
eval_set = (X_val.to_pandas(), y_val.to_pandas())
self._model.fit(
X.to_pandas(),
y.to_pandas(),
eval_set=eval_set,
early_stopping_rounds=fit_params.get("early_stopping_rounds", 10),
verbose=False,
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测"""
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_pandas())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率"""
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
return self._model.predict_proba(X.to_pandas())
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性"""
if self._model is None:
return None
return pl.DataFrame(
{
"feature": self._model.feature_names_,
"importance": self._model.feature_importances_,
}
).sort("importance", descending=True)
__all__ = ["LightGBMModel", "CatBoostModel"]

70
src/models/pipeline.py Normal file
View File

@@ -0,0 +1,70 @@
"""数据处理流水线
管理多个处理器的顺序执行,支持阶段感知处理。
"""
from typing import List, Dict
import polars as pl
from src.models.core import BaseProcessor, PipelineStage
class ProcessingPipeline:
"""数据处理流水线
按顺序执行多个处理器,自动处理阶段标记。
关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。
"""
def __init__(self, processors: List[BaseProcessor]):
"""初始化流水线
Args:
processors: 处理器列表(按执行顺序)
"""
self.processors = processors
self._fitted_processors: Dict[int, BaseProcessor] = {}
def fit_transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
) -> pl.DataFrame:
"""在训练数据上fit所有处理器并transform"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
result = processor.fit_transform(result)
self._fitted_processors[i] = processor
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
processor.fit(result)
self._fitted_processors[i] = processor
return result
def transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
) -> pl.DataFrame:
"""在测试数据上应用已fit的处理器"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
if i in self._fitted_processors:
result = self._fitted_processors[i].transform(result)
else:
result = processor.transform(result)
return result
def save_processors(self, path: str) -> None:
"""保存所有已fit的处理器状态"""
import pickle
with open(path, "wb") as f:
pickle.dump(self._fitted_processors, f)
def load_processors(self, path: str) -> None:
"""加载处理器状态"""
import pickle
with open(path, "rb") as f:
self._fitted_processors = pickle.load(f)
__all__ = ["ProcessingPipeline"]

View File

@@ -0,0 +1,21 @@
"""处理器模块"""
from src.models.processors.processors import (
DropNAProcessor,
FillNAProcessor,
Winsorizer,
StandardScaler,
MinMaxScaler,
RankTransformer,
Neutralizer,
)
__all__ = [
"DropNAProcessor",
"FillNAProcessor",
"Winsorizer",
"StandardScaler",
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
]

View File

@@ -0,0 +1,238 @@
"""内置数据处理器
提供常用的数据预处理和转换处理器。
"""
from typing import List, Optional, Dict, Any
import polars as pl
import numpy as np
from src.models.core import BaseProcessor, PipelineStage
from src.models.registry import PluginRegistry
# 数值类型列表
FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]
def _get_numeric_columns(
data: pl.DataFrame, columns: Optional[List[str]] = None
) -> List[str]:
"""获取数值列"""
if columns is not None:
return columns
return [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
@PluginRegistry.register_processor("dropna")
class DropNAProcessor(BaseProcessor):
"""缺失值删除处理器"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
cols = self.columns or data.columns
return data.drop_nulls(subset=cols)
@PluginRegistry.register_processor("fillna")
class FillNAProcessor(BaseProcessor):
"""缺失值填充处理器(只在训练阶段计算填充值)"""
stage = PipelineStage.TRAIN
def __init__(self, columns: Optional[List[str]] = None, method: str = "median"):
super().__init__(columns)
if method not in ["median", "mean", "zero"]:
raise ValueError(f"Unknown fill method: {method}")
self.method = method
def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
cols = _get_numeric_columns(data, self.columns)
fill_values = {}
for col in cols:
if self.method == "median":
fill_values[col] = data[col].median()
elif self.method == "mean":
fill_values[col] = data[col].mean()
elif self.method == "zero":
fill_values[col] = 0.0
self._fitted_params = {"fill_values": fill_values, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, val in self._fitted_params.get("fill_values", {}).items():
if col in result.columns:
result = result.with_columns(pl.col(col).fill_null(val).alias(col))
return result
@PluginRegistry.register_processor("winsorizer")
class Winsorizer(BaseProcessor):
"""缩尾处理器 - 防止极端值影响(只在训练阶段计算分位数)"""
stage = PipelineStage.TRAIN
def __init__(
self,
columns: Optional[List[str]] = None,
lower: float = 0.01,
upper: float = 0.99,
):
super().__init__(columns)
self.lower = lower
self.upper = upper
def fit(self, data: pl.DataFrame) -> "Winsorizer":
cols = _get_numeric_columns(data, self.columns)
bounds = {}
for col in cols:
bounds[col] = {
"lower": data[col].quantile(self.lower),
"upper": data[col].quantile(self.upper),
}
self._fitted_params = {"bounds": bounds, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, bounds in self._fitted_params.get("bounds", {}).items():
if col in result.columns:
result = result.with_columns(
pl.col(col).clip(bounds["lower"], bounds["upper"]).alias(col)
)
return result
@PluginRegistry.register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
"""标准化处理器 - Z-score标准化"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "StandardScaler":
cols = _get_numeric_columns(data, self.columns)
stats = {}
for col in cols:
stats[col] = {"mean": data[col].mean(), "std": data[col].std()}
self._fitted_params = {"stats": stats, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, stats in self._fitted_params.get("stats", {}).items():
if col in result.columns and stats["std"] is not None and stats["std"] > 0:
result = result.with_columns(
((pl.col(col) - stats["mean"]) / stats["std"]).alias(col)
)
return result
@PluginRegistry.register_processor("minmax_scaler")
class MinMaxScaler(BaseProcessor):
"""归一化处理器 - 缩放到[0, 1]范围"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "MinMaxScaler":
cols = _get_numeric_columns(data, self.columns)
stats = {}
for col in cols:
stats[col] = {"min": data[col].min(), "max": data[col].max()}
self._fitted_params = {"stats": stats, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, stats in self._fitted_params.get("stats", {}).items():
if col in result.columns:
range_val = stats["max"] - stats["min"]
if range_val is not None and range_val > 0:
result = result.with_columns(
((pl.col(col) - stats["min"]) / range_val).alias(col)
)
return result
@PluginRegistry.register_processor("rank_transformer")
class RankTransformer(BaseProcessor):
"""排名转换处理器 - 转换为截面排名"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "RankTransformer":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
cols = self.columns or _get_numeric_columns(data)
for col in cols:
if col in result.columns:
result = result.with_columns(
pl.col(col).rank().over("trade_date").alias(col)
)
return result
@PluginRegistry.register_processor("neutralizer")
class Neutralizer(BaseProcessor):
"""中性化处理器 - 行业/市值中性化"""
stage = PipelineStage.ALL
def __init__(
self,
columns: Optional[List[str]] = None,
group_col: str = "industry",
exclude_cols: Optional[List[str]] = None,
):
super().__init__(columns)
self.group_col = group_col
self.exclude_cols = exclude_cols or []
def fit(self, data: pl.DataFrame) -> "Neutralizer":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
cols = self.columns or _get_numeric_columns(data)
for col in cols:
if col in result.columns and col not in self.exclude_cols:
result = result.with_columns(
(
pl.col(col)
- pl.col(col).mean().over(["trade_date", self.group_col])
).alias(col)
)
return result
__all__ = [
"DropNAProcessor",
"FillNAProcessor",
"Winsorizer",
"StandardScaler",
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
]

297
src/models/registry.py Normal file
View File

@@ -0,0 +1,297 @@
"""插件注册中心
提供装饰器方式注册处理器、模型、划分策略等组件。
实现真正的插件式架构 - 新功能只需注册即可使用。
示例:
>>> @PluginRegistry.register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>> # 使用
>>> scaler = PluginRegistry.get_processor("standard_scaler")()
"""
from typing import Type, Dict, List, TypeVar, Optional
from functools import wraps
from weakref import WeakValueDictionary
import contextlib
from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
T = TypeVar("T")
class PluginRegistry:
"""插件注册中心
管理所有组件的注册和获取。使用装饰器方式注册新组件。
Attributes:
_processors: 已注册的处理器字典
_models: 已注册的模型字典
_splitters: 已注册的划分策略字典
_metrics: 已注册的评估指标字典
"""
_processors: Dict[str, Type[BaseProcessor]] = {}
_models: Dict[str, Type[BaseModel]] = {}
_splitters: Dict[str, Type[BaseSplitter]] = {}
_metrics: Dict[str, Type[BaseMetric]] = {}
@classmethod
@contextlib.contextmanager
def temp_registry(cls):
"""临时注册上下文管理器
在上下文管理器内部注册的组件会在退出时自动清理,
避免测试之间的状态污染。
示例:
>>> with PluginRegistry.temp_registry():
... @PluginRegistry.register_processor("temp_processor")
... class TempProcessor(BaseProcessor):
... pass
... # 在此处可以使用 temp_processor
... # 退出后自动清理
"""
original_state = {
"_processors": cls._processors.copy(),
"_models": cls._models.copy(),
"_splitters": cls._splitters.copy(),
"_metrics": cls._metrics.copy(),
}
try:
yield cls
finally:
cls._processors = original_state["_processors"]
cls._models = original_state["_models"]
cls._splitters = original_state["_splitters"]
cls._metrics = original_state["_metrics"]
@classmethod
def register_processor(cls, name: Optional[str] = None):
"""注册处理器装饰器
用于装饰器方式注册数据处理器。
示例:
>>> @PluginRegistry.register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>> # 获取并使用
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
>>> scaler = scaler_class()
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
key = name or processor_class.__name__
cls._processors[key] = processor_class
processor_class._registry_name = key
return processor_class
return decorator
@classmethod
def register_model(cls, name: Optional[str] = None):
"""注册模型装饰器
用于装饰器方式注册机器学习模型。
示例:
>>> @PluginRegistry.register_model("lightgbm")
... class LightGBMModel(BaseModel):
... pass
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
key = name or model_class.__name__
cls._models[key] = model_class
model_class._registry_name = key
return model_class
return decorator
@classmethod
def register_splitter(cls, name: Optional[str] = None):
"""注册划分策略装饰器
用于装饰器方式注册数据划分策略。
示例:
>>> @PluginRegistry.register_splitter("time_series")
... class TimeSeriesSplit(BaseSplitter):
... pass
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
key = name or splitter_class.__name__
cls._splitters[key] = splitter_class
splitter_class._registry_name = key
return splitter_class
return decorator
@classmethod
def register_metric(cls, name: Optional[str] = None):
"""注册评估指标装饰器
用于装饰器方式注册评估指标。
示例:
>>> @PluginRegistry.register_metric("ic")
... class ICMetric(BaseMetric):
... pass
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(metric_class: Type[BaseMetric]) -> Type[BaseMetric]:
key = name or metric_class.__name__
cls._metrics[key] = metric_class
metric_class._registry_name = key
return metric_class
return decorator
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类
Args:
name: 处理器注册名称
Returns:
处理器类
Raises:
KeyError: 处理器不存在时抛出
"""
if name not in cls._processors:
available = list(cls._processors.keys())
raise KeyError(f"Processor '{name}' not found. Available: {available}")
return cls._processors[name]
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类
Args:
name: 模型注册名称
Returns:
模型类
Raises:
KeyError: 模型不存在时抛出
"""
if name not in cls._models:
available = list(cls._models.keys())
raise KeyError(f"Model '{name}' not found. Available: {available}")
return cls._models[name]
@classmethod
def get_splitter(cls, name: str) -> Type[BaseSplitter]:
"""获取划分策略类
Args:
name: 划分策略注册名称
Returns:
划分策略类
Raises:
KeyError: 划分策略不存在时抛出
"""
if name not in cls._splitters:
available = list(cls._splitters.keys())
raise KeyError(f"Splitter '{name}' not found. Available: {available}")
return cls._splitters[name]
@classmethod
def get_metric(cls, name: str) -> Type[BaseMetric]:
"""获取评估指标类
Args:
name: 评估指标注册名称
Returns:
评估指标类
Raises:
KeyError: 评估指标不存在时抛出
"""
if name not in cls._metrics:
available = list(cls._metrics.keys())
raise KeyError(f"Metric '{name}' not found. Available: {available}")
return cls._metrics[name]
@classmethod
def list_processors(cls) -> List[str]:
"""列出所有可用处理器
Returns:
处理器名称列表
"""
return list(cls._processors.keys())
@classmethod
def list_models(cls) -> List[str]:
"""列出所有可用模型
Returns:
模型名称列表
"""
return list(cls._models.keys())
@classmethod
def list_splitters(cls) -> List[str]:
"""列出所有可用划分策略
Returns:
划分策略名称列表
"""
return list(cls._splitters.keys())
@classmethod
def list_metrics(cls) -> List[str]:
"""列出所有可用评估指标
Returns:
评估指标名称列表
"""
return list(cls._metrics.keys())
@classmethod
def clear_all(cls) -> None:
"""清除所有注册(主要用于测试)"""
cls._processors.clear()
cls._models.clear()
cls._splitters.clear()
cls._metrics.clear()
__all__ = ["PluginRegistry"]

478
tests/models/test_core.py Normal file
View File

@@ -0,0 +1,478 @@
"""模型框架核心测试
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
"""
import pytest
import polars as pl
import numpy as np
from typing import List, Optional
# 确保导入时注册所有组件
from src.models import (
PluginRegistry,
PipelineStage,
BaseProcessor,
BaseModel,
BaseSplitter,
ProcessingPipeline,
)
from src.models.core import TaskType
# ========== 测试核心抽象类 ==========
class TestPipelineStage:
"""测试阶段枚举"""
def test_stage_values(self):
assert PipelineStage.ALL.name == "ALL"
assert PipelineStage.TRAIN.name == "TRAIN"
assert PipelineStage.TEST.name == "TEST"
assert PipelineStage.VALIDATION.name == "VALIDATION"
class TestBaseProcessor:
"""测试处理器基类"""
def test_processor_initialization(self):
"""测试处理器初始化"""
class DummyProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "DummyProcessor":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
return data
processor = DummyProcessor(columns=["col1", "col2"])
assert processor.columns == ["col1", "col2"]
assert processor.stage == PipelineStage.ALL
assert not processor._is_fitted
def test_processor_fit_transform(self):
"""测试 fit_transform 方法"""
class AddOneProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "AddOneProcessor":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data.clone()
for col in self.columns or []:
result = result.with_columns((pl.col(col) + 1).alias(col))
return result
processor = AddOneProcessor(columns=["value"])
df = pl.DataFrame({"value": [1, 2, 3]})
result = processor.fit_transform(df)
assert processor._is_fitted
assert result["value"].to_list() == [2, 3, 4]
class TestBaseModel:
"""测试模型基类"""
def test_model_initialization(self):
"""测试模型初始化"""
class DummyModel(BaseModel):
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
self._is_fitted = True
return self
def predict(self, X):
return np.zeros(len(X))
model = DummyModel(
task_type="regression", params={"lr": 0.01}, name="test_model"
)
assert model.task_type == "regression"
assert model.params == {"lr": 0.01}
assert model.name == "test_model"
assert not model._is_fitted
def test_predict_proba_not_implemented(self):
"""测试未实现 predict_proba 时抛出异常"""
class DummyModel(BaseModel):
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
return self
def predict(self, X):
return np.zeros(len(X))
model = DummyModel(task_type="regression")
df = pl.DataFrame({"feature": [1, 2, 3]})
with pytest.raises(NotImplementedError):
model.predict_proba(df)
class TestBaseSplitter:
"""测试划分策略基类"""
def test_splitter_interface(self):
"""测试划分策略接口"""
class DummySplitter(BaseSplitter):
def split(self, data, date_col="trade_date"):
yield [0, 1], [2, 3]
def get_split_dates(self, data, date_col="trade_date"):
return [("20200101", "20201231", "20210101", "20211231")]
splitter = DummySplitter()
df = pl.DataFrame(
{"trade_date": ["20200101", "20200601", "20210101", "20210601"]}
)
splits = list(splitter.split(df))
assert len(splits) == 1
assert splits[0] == ([0, 1], [2, 3])
dates = splitter.get_split_dates(df)
assert dates == [("20200101", "20201231", "20210101", "20211231")]
# ========== 测试插件注册中心 ==========
class TestPluginRegistry:
"""测试插件注册中心"""
def setup_method(self):
"""每个测试前清除注册"""
PluginRegistry.clear_all()
def test_register_and_get_processor(self):
"""测试注册和获取处理器"""
@PluginRegistry.register_processor("test_processor")
class TestProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data):
return self
def transform(self, data):
return data
processor_class = PluginRegistry.get_processor("test_processor")
assert processor_class == TestProcessor
assert "test_processor" in PluginRegistry.list_processors()
def test_register_and_get_model(self):
"""测试注册和获取模型"""
@PluginRegistry.register_model("test_model")
class TestModel(BaseModel):
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
return self
def predict(self, X):
return np.zeros(len(X))
model_class = PluginRegistry.get_model("test_model")
assert model_class == TestModel
assert "test_model" in PluginRegistry.list_models()
def test_register_and_get_splitter(self):
"""测试注册和获取划分策略"""
@PluginRegistry.register_splitter("test_splitter")
class TestSplitter(BaseSplitter):
def split(self, data, date_col="trade_date"):
yield [], []
def get_split_dates(self, data, date_col="trade_date"):
return []
splitter_class = PluginRegistry.get_splitter("test_splitter")
assert splitter_class == TestSplitter
assert "test_splitter" in PluginRegistry.list_splitters()
def test_get_nonexistent_processor(self):
"""测试获取不存在的处理器时抛出异常"""
with pytest.raises(KeyError) as exc_info:
PluginRegistry.get_processor("nonexistent")
assert "nonexistent" in str(exc_info.value)
def test_register_with_default_name(self):
"""测试使用默认名称注册"""
@PluginRegistry.register_processor()
class MyCustomProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data):
return self
def transform(self, data):
return data
assert "MyCustomProcessor" in PluginRegistry.list_processors()
# ========== 测试内置处理器 ==========
class TestBuiltInProcessors:
"""测试内置处理器"""
def test_dropna_processor(self):
"""测试缺失值删除处理器"""
from src.models.processors import DropNAProcessor
processor = DropNAProcessor(columns=["a", "b"])
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
result = processor.fit_transform(df)
# 只有第一行没有缺失值
assert len(result) == 1
assert result["a"].to_list() == [1]
assert result["b"].to_list() == [4]
def test_fillna_processor(self):
"""测试缺失值填充处理器"""
from src.models.processors import FillNAProcessor
processor = FillNAProcessor(columns=["a"], method="mean")
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
result = processor.fit_transform(df)
# 均值 = (1+2+4)/3 = 2.333...
assert result["a"][2] == pytest.approx(2.333, rel=0.01)
def test_standard_scaler(self):
"""测试标准化处理器"""
from src.models.processors import StandardScaler
processor = StandardScaler(columns=["value"])
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
result = processor.fit_transform(df)
# Z-score 标准化后均值为0标准差为1
assert result["value"].mean() == pytest.approx(0.0, abs=1e-10)
assert result["value"].std() == pytest.approx(1.0, rel=0.01)
def test_winsorizer(self):
"""测试缩尾处理器"""
from src.models.processors import Winsorizer
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
df = pl.DataFrame(
{
"value": list(range(100)) # 0-99
}
)
result = processor.fit_transform(df)
# 10%和90%分位数应该是10和89Polars的quantile行为
assert result["value"].min() == 10
assert result["value"].max() == 89
def test_rank_transformer(self):
"""测试排名转换处理器"""
from src.models.processors import RankTransformer
processor = RankTransformer(columns=["value"])
df = pl.DataFrame(
{"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]}
)
result = processor.fit_transform(df)
# 排名应该是 1, 3, 2, 5, 4
assert result["value"].to_list() == [1, 3, 2, 5, 4]
def test_neutralizer(self):
"""测试中性化处理器"""
from src.models.processors import Neutralizer
processor = Neutralizer(columns=["value"], group_col="industry")
df = pl.DataFrame(
{
"trade_date": ["20200101", "20200101", "20200101", "20200101"],
"industry": ["A", "A", "B", "B"],
"value": [10, 20, 30, 50],
}
)
result = processor.fit_transform(df)
# 分组去均值后每组的均值为0
group_a = result.filter(pl.col("industry") == "A")
group_b = result.filter(pl.col("industry") == "B")
assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10)
assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10)
# ========== 测试处理流水线 ==========
class TestProcessingPipeline:
"""测试处理流水线"""
def test_pipeline_fit_transform(self):
"""测试流水线的 fit_transform"""
from src.models.processors import StandardScaler
scaler1 = StandardScaler(columns=["a"])
scaler2 = StandardScaler(columns=["b"])
pipeline = ProcessingPipeline([scaler1, scaler2])
df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]})
result = pipeline.fit_transform(df)
# 两个列都应该被标准化
assert result["a"].mean() == pytest.approx(0.0, abs=1e-10)
assert result["b"].mean() == pytest.approx(0.0, abs=1e-10)
def test_pipeline_transform_uses_fitted_params(self):
"""测试 transform 使用已 fit 的参数"""
from src.models.processors import StandardScaler
scaler = StandardScaler(columns=["value"])
pipeline = ProcessingPipeline([scaler])
# 训练数据
train_df = pl.DataFrame(
{
"value": [1.0, 2.0, 3.0] # 均值=2标准差=1
}
)
# 测试数据(不同的分布)
test_df = pl.DataFrame(
{
"value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5
}
)
pipeline.fit_transform(train_df)
result = pipeline.transform(test_df)
# 使用训练数据的均值=2和标准差=1进行标准化
# 4 -> (4-2)/1 = 2
assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10)
# ========== 测试划分策略 ==========
class TestSplitters:
"""测试划分策略"""
def test_time_series_split(self):
"""测试时间序列划分"""
from src.models.core import TimeSeriesSplit
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
# 10天的数据
df = pl.DataFrame(
{
"trade_date": [f"202001{i:02d}" for i in range(1, 11)],
"value": list(range(10)),
}
)
splits = list(splitter.split(df))
# 应该有两折
assert len(splits) == 2
# 检查每折训练集在测试集之前
for train_idx, test_idx in splits:
assert max(train_idx) < min(test_idx)
def test_walk_forward_split(self):
"""测试滚动前向划分"""
from src.models.core import WalkForwardSplit
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
df = pl.DataFrame(
{
"trade_date": [f"202001{i:02d}" for i in range(1, 13)],
"value": list(range(12)),
}
)
splits = list(splitter.split(df))
# 检查训练集大小固定
for train_idx, test_idx in splits:
assert len(train_idx) == 5
assert len(test_idx) == 2
def test_expanding_window_split(self):
"""测试扩展窗口划分"""
from src.models.core import ExpandingWindowSplit
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
df = pl.DataFrame(
{
"trade_date": [f"202001{i:02d}" for i in range(1, 15)],
"value": list(range(14)),
}
)
splits = list(splitter.split(df))
# 训练集应该逐渐增大
train_sizes = [len(train_idx) for train_idx, _ in splits]
assert train_sizes[0] == 3
assert train_sizes[1] == 5 # 3 + 2
assert train_sizes[2] == 7 # 5 + 2
# ========== 测试内置模型(可选,需要安装依赖) ==========
class TestModels:
"""测试内置模型(标记为跳过如果依赖未安装)"""
@pytest.mark.skip(reason="需要安装 lightgbm")
def test_lightgbm_model(self):
"""测试 LightGBM 模型"""
from src.models.models import LightGBMModel
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10,
"feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10,
}
)
y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10)
model.fit(X, y)
predictions = model.predict(X)
assert len(predictions) == len(X)
assert model._is_fitted
if __name__ == "__main__":
pytest.main([__file__, "-v"])