feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类 - 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露 - 内置 8 个数据处理器和 3 种时序划分策略 - 支持 LightGBM、CatBoost 模型 - PluginRegistry 装饰器注册,插件式架构 - 22 个单元测试
This commit is contained in:
292
README.md
292
README.md
@@ -1,40 +1,300 @@
|
|||||||
# ProStock
|
# ProStock
|
||||||
|
|
||||||
A股量化投资框架
|
A股量化投资框架 - 从数据获取到模型训练的完整解决方案
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
### 1. 数据层 (src/data/)
|
||||||
|
- **多源数据接入**: Tushare API 集成,支持日线、股票基础信息、交易日历
|
||||||
|
- **DuckDB 存储**: 高性能嵌入式数据库,支持 SQL 查询下推
|
||||||
|
- **智能同步**: 增量/全量同步策略,自动检测数据更新需求
|
||||||
|
- **速率控制**: 令牌桶算法实现 API 限流
|
||||||
|
- **并发优化**: ThreadPoolExecutor 多线程数据获取
|
||||||
|
|
||||||
|
### 2. 因子层 (src/factors/)
|
||||||
|
- **类型安全**: 严格的截面因子 vs 时序因子区分
|
||||||
|
- **防泄露机制**: 框架层面防止未来数据和跨股票数据泄露
|
||||||
|
- **因子组合**: 支持因子加减乘除和标量运算
|
||||||
|
- **高性能计算**: Polars 向量化操作,零拷贝数据导出
|
||||||
|
- **灵活扩展**: 基类抽象便于自定义因子
|
||||||
|
|
||||||
|
### 3. 模型层 (src/models/)
|
||||||
|
- **插件架构**: 装饰器注册机制,新模型即插即用
|
||||||
|
- **阶段感知**: 训练/测试阶段区分,防止数据泄露
|
||||||
|
- **多模型支持**: LightGBM、CatBoost 等模型统一接口
|
||||||
|
- **数据处理**: 缺失值处理、缩尾、标准化、中性化等
|
||||||
|
- **时序划分**: WalkForward、ExpandingWindow 等时间序列划分策略
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
ProStock/
|
||||||
|
├── src/
|
||||||
|
│ ├── config/ # 配置管理
|
||||||
|
│ │ ├── settings.py # pydantic-settings 配置
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ │
|
||||||
|
│ ├── data/ # 数据获取与存储
|
||||||
|
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||||
|
│ │ │ ├── api_daily.py # 日线数据接口
|
||||||
|
│ │ │ ├── api_stock_basic.py # 股票基础信息
|
||||||
|
│ │ │ └── api_trade_cal.py # 交易日历
|
||||||
|
│ │ ├── client.py # Tushare 客户端(含限流)
|
||||||
|
│ │ ├── config.py # 数据模块配置
|
||||||
|
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||||
|
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||||
|
│ │ ├── rate_limiter.py # 令牌桶限流器
|
||||||
|
│ │ ├── storage.py # DuckDB 存储核心
|
||||||
|
│ │ ├── sync.py # 数据同步主逻辑
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ │
|
||||||
|
│ ├── factors/ # 因子计算框架
|
||||||
|
│ │ ├── base.py # 因子基类(截面/时序)
|
||||||
|
│ │ ├── composite.py # 组合因子和标量运算
|
||||||
|
│ │ ├── data_loader.py # DuckDB 数据加载器
|
||||||
|
│ │ ├── data_spec.py # 数据规格定义
|
||||||
|
│ │ ├── engine.py # 因子执行引擎
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ │
|
||||||
|
│ ├── models/ # 模型训练框架
|
||||||
|
│ │ ├── core/ # 核心抽象
|
||||||
|
│ │ │ ├── base.py # 处理器/模型/划分基类
|
||||||
|
│ │ │ └── splitter.py # 时间序列划分策略
|
||||||
|
│ │ ├── models/ # 模型实现
|
||||||
|
│ │ │ └── models.py # LightGBM、CatBoost
|
||||||
|
│ │ ├── processors/ # 数据处理器
|
||||||
|
│ │ │ └── processors.py # 标准化、缩尾、中性化等
|
||||||
|
│ │ ├── pipeline.py # 处理流水线
|
||||||
|
│ │ ├── registry.py # 插件注册中心
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ │
|
||||||
|
│ └── __init__.py
|
||||||
|
│
|
||||||
|
├── docs/ # 文档
|
||||||
|
│ ├── factor_framework_design.md # 因子框架设计
|
||||||
|
│ ├── ml_framework_design.md # 模型框架设计
|
||||||
|
│ ├── db_sync_guide.md # 数据同步指南
|
||||||
|
│ └── ...
|
||||||
|
│
|
||||||
|
├── data/ # 数据存储(DuckDB)
|
||||||
|
│ ├── prostock.db # 主数据库文件
|
||||||
|
│ └── stock_basic.csv # 股票基础信息缓存
|
||||||
|
│
|
||||||
|
├── config/ # 配置文件
|
||||||
|
│ └── .env.local # 环境变量(API Token等)
|
||||||
|
│
|
||||||
|
└── tests/ # 测试文件
|
||||||
|
├── test_sync.py
|
||||||
|
└── factors/
|
||||||
|
```
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
||||||
### 安装依赖
|
### 1. 安装依赖
|
||||||
|
|
||||||
**⚠️ 本项目强制使用 uv 作为 Python 包管理器,禁止直接使用 `python` 或 `pip` 命令。**
|
**⚠️ 本项目强制使用 uv 作为 Python 包管理器**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 使用 uv 安装(必须)
|
# 安装 uv (如果尚未安装)
|
||||||
|
pip install uv
|
||||||
|
|
||||||
|
# 安装项目依赖
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### 数据同步
|
### 2. 配置环境变量
|
||||||
|
|
||||||
|
创建 `config/.env.local` 文件:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 增量同步(自动从最新日期开始)
|
TUSHARE_TOKEN=your_tushare_token_here
|
||||||
|
DATA_PATH=data
|
||||||
|
RATE_LIMIT=100
|
||||||
|
THREADS=10
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 数据同步
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 首次同步 - 全量同步(从20180101开始)
|
||||||
|
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||||
|
|
||||||
|
# 日常同步 - 增量同步(自动从最新日期开始)
|
||||||
uv run python -c "from src.data.sync import sync_all; sync_all()"
|
uv run python -c "from src.data.sync import sync_all; sync_all()"
|
||||||
|
|
||||||
# 全量同步(从 20180101 开始)
|
# 预览同步(检查需要同步的数据量)
|
||||||
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
uv run python -c "from src.data.sync import preview_sync; preview_sync()"
|
||||||
|
|
||||||
# 自定义线程数
|
# 自定义线程数
|
||||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 4. 查看数据库状态
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
### 因子计算
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factors import FactorEngine, DataLoader, DataSpec
|
||||||
|
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
# 自定义截面因子:PE排名
|
||||||
|
class PERankFactor(CrossSectionalFactor):
|
||||||
|
name = "pe_rank"
|
||||||
|
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)]
|
||||||
|
|
||||||
|
def compute(self, data) -> pl.Series:
|
||||||
|
cs = data.get_cross_section()
|
||||||
|
return cs["pe"].rank()
|
||||||
|
|
||||||
|
# 自定义时序因子:20日移动平均
|
||||||
|
class MA20Factor(TimeSeriesFactor):
|
||||||
|
name = "ma20"
|
||||||
|
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
|
||||||
|
|
||||||
|
def compute(self, data) -> pl.Series:
|
||||||
|
return data.get_column("close").rolling_mean(window_size=20)
|
||||||
|
|
||||||
|
# 执行计算
|
||||||
|
loader = DataLoader(data_dir="data")
|
||||||
|
engine = FactorEngine(loader)
|
||||||
|
|
||||||
|
# 计算截面因子
|
||||||
|
pe_rank = PERankFactor()
|
||||||
|
result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131")
|
||||||
|
|
||||||
|
# 计算时序因子
|
||||||
|
ma20 = MA20Factor()
|
||||||
|
result2 = engine.compute(ma20, stock_codes=["000001.SZ"],
|
||||||
|
start_date="20240101", end_date="20240131")
|
||||||
|
|
||||||
|
# 因子组合
|
||||||
|
combined = 0.5 * pe_rank + 0.3 * ma20
|
||||||
|
```
|
||||||
|
|
||||||
|
### 模型训练
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.models import PluginRegistry, ProcessingPipeline
|
||||||
|
from src.models.core import PipelineStage
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
# 创建处理流水线
|
||||||
|
pipeline = ProcessingPipeline([
|
||||||
|
PluginRegistry.get_processor("dropna")(),
|
||||||
|
PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
|
||||||
|
PluginRegistry.get_processor("standard_scaler")(),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
data = pl.read_csv("features.csv") # 包含特征和标签
|
||||||
|
|
||||||
|
# 划分训练/测试集
|
||||||
|
from src.models.core import WalkForwardSplit
|
||||||
|
splitter = WalkForwardSplit(train_window=252, test_window=21)
|
||||||
|
|
||||||
|
# 获取 LightGBM 模型
|
||||||
|
ModelClass = PluginRegistry.get_model("lightgbm")
|
||||||
|
model = ModelClass(task_type="regression", params={"n_estimators": 100})
|
||||||
|
|
||||||
|
# 训练循环
|
||||||
|
for train_idx, test_idx in splitter.split(data):
|
||||||
|
train_data = data[train_idx]
|
||||||
|
test_data = data[test_idx]
|
||||||
|
|
||||||
|
# 数据处理
|
||||||
|
X_train = pipeline.fit_transform(train_data.drop("target"))
|
||||||
|
X_test = pipeline.transform(test_data.drop("target"))
|
||||||
|
y_train = train_data["target"]
|
||||||
|
y_test = test_data["target"]
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
predictions = model.predict(X_test)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心设计
|
||||||
|
|
||||||
|
### 1. 数据防泄露机制
|
||||||
|
|
||||||
|
**截面因子 (CrossSectionalFactor)**:
|
||||||
|
- 防止日期泄露:每天只传入 `[T-lookback+1, T]` 数据
|
||||||
|
- 允许股票间比较:传入当天所有股票数据
|
||||||
|
- 典型应用:PE排名、市值分位数、当日收益率排名
|
||||||
|
|
||||||
|
**时序因子 (TimeSeriesFactor)**:
|
||||||
|
- 防止股票泄露:每只股票单独计算
|
||||||
|
- 允许历史数据访问:传入完整时间序列
|
||||||
|
- 典型应用:移动平均线、RSI、历史波动率
|
||||||
|
|
||||||
|
### 2. 插件注册机制
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.models.registry import PluginRegistry
|
||||||
|
|
||||||
|
# 注册自定义处理器
|
||||||
|
@PluginRegistry.register_processor("my_processor")
|
||||||
|
class MyProcessor(BaseProcessor):
|
||||||
|
stage = PipelineStage.TRAIN
|
||||||
|
|
||||||
|
def fit(self, data):
|
||||||
|
# 学习参数
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
# 转换数据
|
||||||
|
return data
|
||||||
|
|
||||||
|
# 使用
|
||||||
|
processor_class = PluginRegistry.get_processor("my_processor")
|
||||||
|
processor = processor_class()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 数据同步策略
|
||||||
|
|
||||||
|
**智能增量同步**:
|
||||||
|
```python
|
||||||
|
from src.data.db_manager import SyncManager
|
||||||
|
|
||||||
|
manager = SyncManager()
|
||||||
|
result = manager.sync(
|
||||||
|
table_name="daily",
|
||||||
|
fetch_func=get_daily,
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131"
|
||||||
|
)
|
||||||
|
# 自动检测:表不存在→全量,表存在→增量
|
||||||
|
```
|
||||||
|
|
||||||
## 文档
|
## 文档
|
||||||
|
|
||||||
- [数据同步模块](docs/data_sync.md) - 详细的数据同步使用说明
|
- [因子框架设计](docs/factor_framework_design.md) - 因子计算架构详解
|
||||||
|
- [模型框架设计](docs/ml_framework_design.md) - 模型训练架构详解
|
||||||
|
- [数据同步指南](docs/db_sync_guide.md) - DuckDB 数据同步 API 说明
|
||||||
|
- [代码审查报告](docs/code_review_factors_20260222.md) - 因子框架代码审查
|
||||||
|
|
||||||
## 模块
|
## 开发规范
|
||||||
|
|
||||||
- `data/` - 数据获取
|
- **Python 版本**: 3.10+
|
||||||
- `factors/` - 因子生成
|
- **代码风格**: Google 风格文档字符串
|
||||||
- `models/` - 模型训练
|
- **类型提示**: 强制类型注解
|
||||||
- `backtest/` - 回测分析
|
- **测试**: pytest 框架
|
||||||
- `utils/` - 工具函数
|
- **包管理**: uv (禁止直接使用 pip/python)
|
||||||
- `scripts/` - 运行脚本
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
- **数据处理**: Polars, Pandas, NumPy
|
||||||
|
- **数据存储**: DuckDB (嵌入式 OLAP 数据库)
|
||||||
|
- **API 接口**: Tushare Pro
|
||||||
|
- **机器学习**: LightGBM, CatBoost, scikit-learn
|
||||||
|
- **配置管理**: pydantic-settings
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|||||||
@@ -1,846 +0,0 @@
|
|||||||
# ProStock 因子框架实现计划
|
|
||||||
|
|
||||||
## 目录结构
|
|
||||||
|
|
||||||
```
|
|
||||||
src/factors/
|
|
||||||
├── __init__.py # 导出主要类
|
|
||||||
├── data_spec.py # Phase 1: 数据类型定义
|
|
||||||
├── base.py # Phase 2: 因子基类
|
|
||||||
├── composite.py # Phase 2: 组合因子
|
|
||||||
├── data_loader.py # Phase 3: 数据加载
|
|
||||||
├── engine.py # Phase 4: 执行引擎
|
|
||||||
└── builtin/ # Phase 5: 内置因子库
|
|
||||||
├── __init__.py
|
|
||||||
├── momentum.py # 截面动量因子
|
|
||||||
├── technical.py # 时序技术指标
|
|
||||||
└── value.py # 截面估值因子
|
|
||||||
|
|
||||||
tests/factors/ # Phase 6-7: 测试
|
|
||||||
├── __init__.py
|
|
||||||
├── test_data_spec.py # 数据类型测试
|
|
||||||
├── test_base.py # 因子基类测试
|
|
||||||
├── test_composite.py # 组合因子测试
|
|
||||||
├── test_data_loader.py # 数据加载测试
|
|
||||||
├── test_engine.py # 引擎测试
|
|
||||||
├── test_builtin.py # 内置因子测试
|
|
||||||
└── test_integration.py # 集成测试
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 1: 数据类型定义 (data_spec.py)
|
|
||||||
|
|
||||||
### 1.1 DataSpec - 数据需求规格
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class DataSpec:
|
|
||||||
"""
|
|
||||||
数据需求规格说明
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source: H5 文件名(如 "daily", "fundamental")
|
|
||||||
columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date"
|
|
||||||
lookback_days: 需要回看的天数(包含当日)
|
|
||||||
- 1 表示只需要当日数据 [T]
|
|
||||||
- 5 表示需要 [T-4, T] 共5天
|
|
||||||
- 20 表示需要 [T-19, T] 共20天
|
|
||||||
"""
|
|
||||||
source: str
|
|
||||||
columns: List[str]
|
|
||||||
lookback_days: int = 1
|
|
||||||
```
|
|
||||||
|
|
||||||
**约束验证:**
|
|
||||||
- `lookback_days >= 1`(至少包含当日)
|
|
||||||
- `columns` 必须包含 `ts_code` 和 `trade_date`
|
|
||||||
- `source` 不能为空字符串
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试有效 DataSpec 创建
|
|
||||||
- [ ] 测试 `lookback_days < 1` 时抛出 ValueError
|
|
||||||
- [ ] 测试缺少 `ts_code` 或 `trade_date` 时抛出 ValueError
|
|
||||||
- [ ] 测试空 `source` 时抛出 ValueError
|
|
||||||
- [ ] 测试 frozen 特性(创建后不可修改)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.2 FactorContext - 计算上下文
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
@dataclass
|
|
||||||
class FactorContext:
|
|
||||||
"""
|
|
||||||
因子计算上下文
|
|
||||||
|
|
||||||
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
current_date: 当前计算日期 YYYYMMDD(截面因子使用)
|
|
||||||
current_stock: 当前计算股票代码(时序因子使用)
|
|
||||||
trade_dates: 交易日历列表(可选,用于对齐)
|
|
||||||
"""
|
|
||||||
current_date: Optional[str] = None
|
|
||||||
current_stock: Optional[str] = None
|
|
||||||
trade_dates: Optional[List[str]] = None
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试默认值创建
|
|
||||||
- [ ] 测试完整参数创建
|
|
||||||
- [ ] 测试 dataclass 自动生成的方法
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.3 FactorData - 数据容器
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class FactorData:
|
|
||||||
"""
|
|
||||||
提供给因子的数据容器
|
|
||||||
|
|
||||||
封装底层 Polars DataFrame,提供安全的数据访问接口
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, df: pl.DataFrame, context: FactorContext):
|
|
||||||
self._df = df
|
|
||||||
self._context = context
|
|
||||||
|
|
||||||
def get_column(self, col: str) -> pl.Series:
|
|
||||||
"""
|
|
||||||
获取指定列的数据
|
|
||||||
|
|
||||||
- 截面因子:获取当天所有股票的该列值
|
|
||||||
- 时序因子:获取该股票时间序列的该列值
|
|
||||||
|
|
||||||
Args:
|
|
||||||
col: 列名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Polars Series
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: 列不存在
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def filter_by_date(self, date: str) -> "FactorData":
|
|
||||||
"""
|
|
||||||
按日期过滤数据,返回新的 FactorData
|
|
||||||
|
|
||||||
主要用于截面因子获取特定日期的数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
date: YYYYMMDD 格式的日期
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
过滤后的 FactorData
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_cross_section(self) -> pl.DataFrame:
|
|
||||||
"""
|
|
||||||
获取当前日期的截面数据
|
|
||||||
|
|
||||||
仅适用于截面因子,返回 current_date 当天的所有股票数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame 包含当前日期的所有股票
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: current_date 未设置(非截面因子场景)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def to_polars(self) -> pl.DataFrame:
|
|
||||||
"""获取底层的 Polars DataFrame(高级用法)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def context(self) -> FactorContext:
|
|
||||||
"""获取计算上下文"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""返回数据行数"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试 `get_column()` 返回正确 Series
|
|
||||||
- [ ] 测试 `get_column()` 列不存在时抛出 KeyError
|
|
||||||
- [ ] 测试 `filter_by_date()` 返回正确过滤结果
|
|
||||||
- [ ] 测试 `filter_by_date()` 日期不存在时返回空 DataFrame
|
|
||||||
- [ ] 测试 `get_cross_section()` 返回 current_date 当天的数据
|
|
||||||
- [ ] 测试 `get_cross_section()` current_date 为 None 时抛出 ValueError
|
|
||||||
- [ ] 测试 `to_polars()` 返回原始 DataFrame
|
|
||||||
- [ ] 测试 `context` 属性返回正确上下文
|
|
||||||
- [ ] 测试 `__len__()` 返回正确行数
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 2: 因子基类 (base.py, composite.py)
|
|
||||||
|
|
||||||
### 2.1 BaseFactor - 抽象基类
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class BaseFactor(ABC):
|
|
||||||
"""
|
|
||||||
因子基类 - 定义通用接口
|
|
||||||
|
|
||||||
所有因子必须继承此类,并声明以下类属性:
|
|
||||||
- name: 因子唯一标识(snake_case)
|
|
||||||
- factor_type: "cross_sectional" 或 "time_series"
|
|
||||||
- data_specs: List[DataSpec] 数据需求列表
|
|
||||||
|
|
||||||
可选声明:
|
|
||||||
- category: 因子分类(默认 "default")
|
|
||||||
- description: 因子描述
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 必须声明的类属性
|
|
||||||
name: str = ""
|
|
||||||
factor_type: str = "" # "cross_sectional" | "time_series"
|
|
||||||
data_specs: List[DataSpec] = field(default_factory=list)
|
|
||||||
|
|
||||||
# 可选声明的类属性
|
|
||||||
category: str = "default"
|
|
||||||
description: str = ""
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
|
||||||
"""
|
|
||||||
子类创建时验证必须属性
|
|
||||||
|
|
||||||
验证项:
|
|
||||||
1. name 必须是非空字符串
|
|
||||||
2. factor_type 必须是 "cross_sectional" 或 "time_series"
|
|
||||||
3. data_specs 必须是非空列表
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __init__(self, **params):
|
|
||||||
"""
|
|
||||||
初始化因子参数
|
|
||||||
|
|
||||||
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
|
|
||||||
"""
|
|
||||||
self.params = params
|
|
||||||
self._validate_params()
|
|
||||||
|
|
||||||
def _validate_params(self):
|
|
||||||
"""
|
|
||||||
验证参数有效性
|
|
||||||
|
|
||||||
子类可覆盖此方法进行自定义验证
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute(self, data: FactorData) -> pl.Series:
|
|
||||||
"""
|
|
||||||
核心计算逻辑 - 子类必须实现
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 安全的数据容器,已根据因子类型裁剪
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
计算得到的因子值 Series
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ========== 因子组合运算符 ==========
|
|
||||||
|
|
||||||
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
|
|
||||||
"""因子相加:f1 + f2(要求同类型)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
|
|
||||||
"""因子相减:f1 - f2(要求同类型)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __mul__(self, other: "BaseFactor") -> "CompositeFactor":
|
|
||||||
"""因子相乘:f1 * f2(要求同类型)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
|
|
||||||
"""因子相除:f1 / f2(要求同类型)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __rmul__(self, scalar: float) -> "ScalarFactor":
|
|
||||||
"""标量乘法:0.5 * f1"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试有效子类创建通过验证
|
|
||||||
- [ ] 测试缺少 `name` 时抛出 ValueError
|
|
||||||
- [ ] 测试 `name` 为空字符串时抛出 ValueError
|
|
||||||
- [ ] 测试缺少 `factor_type` 时抛出 ValueError
|
|
||||||
- [ ] 测试无效的 `factor_type`(非 cs/ts)时抛出 ValueError
|
|
||||||
- [ ] 测试缺少 `data_specs` 时抛出 ValueError
|
|
||||||
- [ ] 测试 `data_specs` 为空列表时抛出 ValueError
|
|
||||||
- [ ] 测试 `compute()` 抽象方法强制子类实现
|
|
||||||
- [ ] 测试参数化初始化 `params` 正确存储
|
|
||||||
- [ ] 测试 `_validate_params()` 被调用
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2.2 CrossSectionalFactor - 日期截面因子
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class CrossSectionalFactor(BaseFactor):
|
|
||||||
"""
|
|
||||||
日期截面因子基类
|
|
||||||
|
|
||||||
计算逻辑:在每个交易日,对所有股票进行横向计算
|
|
||||||
|
|
||||||
防泄露边界:
|
|
||||||
- ❌ 禁止访问未来日期的数据(日期泄露)
|
|
||||||
- ✅ 允许访问当前日期的所有股票数据
|
|
||||||
|
|
||||||
数据传入:
|
|
||||||
- compute() 接收的是 [T-lookback+1, T] 的数据
|
|
||||||
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
|
|
||||||
"""
|
|
||||||
|
|
||||||
factor_type: str = "cross_sectional"
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute(self, data: FactorData) -> pl.Series:
|
|
||||||
"""
|
|
||||||
计算截面因子值
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: FactorData,包含 [T-lookback+1, T] 的截面数据
|
|
||||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
|
|
||||||
|
|
||||||
示例:
|
|
||||||
def compute(self, data):
|
|
||||||
# 获取当前日期的截面
|
|
||||||
cs = data.get_cross_section()
|
|
||||||
# 计算市值排名
|
|
||||||
return cs['market_cap'].rank()
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试 `factor_type` 自动设置为 "cross_sectional"
|
|
||||||
- [ ] 测试子类必须实现 `compute()`
|
|
||||||
- [ ] 测试 `compute()` 返回类型为 pl.Series
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2.3 TimeSeriesFactor - 时间序列因子
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class TimeSeriesFactor(BaseFactor):
|
|
||||||
"""
|
|
||||||
时间序列因子基类(股票截面)
|
|
||||||
|
|
||||||
计算逻辑:对每只股票,在其时间序列上进行纵向计算
|
|
||||||
|
|
||||||
防泄露边界:
|
|
||||||
- ❌ 禁止访问其他股票的数据(股票泄露)
|
|
||||||
- ✅ 允许访问该股票的完整历史数据
|
|
||||||
|
|
||||||
数据传入:
|
|
||||||
- compute() 接收的是单只股票的完整时间序列
|
|
||||||
- 包含该股票在 [start_date, end_date] 范围内的所有数据
|
|
||||||
"""
|
|
||||||
|
|
||||||
factor_type: str = "time_series"
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute(self, data: FactorData) -> pl.Series:
|
|
||||||
"""
|
|
||||||
计算时间序列因子值
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: FactorData,包含单只股票的完整时间序列
|
|
||||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
|
|
||||||
|
|
||||||
示例:
|
|
||||||
def compute(self, data):
|
|
||||||
series = data.get_column("close")
|
|
||||||
return series.rolling_mean(window_size=self.params['period'])
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试 `factor_type` 自动设置为 "time_series"
|
|
||||||
- [ ] 测试子类必须实现 `compute()`
|
|
||||||
- [ ] 测试 `compute()` 返回类型为 pl.Series
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2.4 CompositeFactor - 组合因子 (composite.py)
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class CompositeFactor(BaseFactor):
|
|
||||||
"""
|
|
||||||
组合因子 - 用于实现因子间的数学运算
|
|
||||||
|
|
||||||
约束:左右因子必须是同类型(同为截面或同为时序)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
|
|
||||||
"""
|
|
||||||
创建组合因子
|
|
||||||
|
|
||||||
Args:
|
|
||||||
left: 左操作数因子
|
|
||||||
right: 右操作数因子
|
|
||||||
op: 运算符,支持 '+', '-', '*', '/'
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 左右因子类型不一致
|
|
||||||
ValueError: 不支持的运算符
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _merge_data_specs(self) -> List[DataSpec]:
|
|
||||||
"""
|
|
||||||
合并左右因子的数据需求
|
|
||||||
|
|
||||||
策略:
|
|
||||||
1. 相同 source 和 columns 的 DataSpec 合并
|
|
||||||
2. lookback_days 取最大值
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def compute(self, data: FactorData) -> pl.Series:
|
|
||||||
"""
|
|
||||||
执行组合运算
|
|
||||||
|
|
||||||
流程:
|
|
||||||
1. 分别计算 left 和 right 的值
|
|
||||||
2. 根据 op 执行运算
|
|
||||||
3. 返回结果
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试同类型因子组合成功(cs + cs)
|
|
||||||
- [ ] 测试同类型因子组合成功(ts + ts)
|
|
||||||
- [ ] 测试不同类型因子组合抛出 ValueError(cs + ts)
|
|
||||||
- [ ] 测试无效运算符抛出 ValueError
|
|
||||||
- [ ] 测试 `_merge_data_specs()` 正确合并(相同 source)
|
|
||||||
- [ ] 测试 `_merge_data_specs()` 正确合并(不同 source)
|
|
||||||
- [ ] 测试 `_merge_data_specs()` lookback 取最大值
|
|
||||||
- [ ] 测试 `compute()` 执行正确的数学运算
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 2.5 ScalarFactor - 标量运算因子 (composite.py)
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class ScalarFactor(BaseFactor):
|
|
||||||
"""
|
|
||||||
标量运算因子
|
|
||||||
|
|
||||||
支持:scalar * factor, factor * scalar(通过 __rmul__)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, factor: BaseFactor, scalar: float, op: str):
|
|
||||||
"""
|
|
||||||
创建标量运算因子
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factor: 基础因子
|
|
||||||
scalar: 标量值
|
|
||||||
op: 运算符,支持 '*', '+'
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def compute(self, data: FactorData) -> pl.Series:
|
|
||||||
"""执行标量运算"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试标量乘法 `0.5 * factor`
|
|
||||||
- [ ] 测试标量乘法 `factor * 0.5`
|
|
||||||
- [ ] 测试标量加法(如支持)
|
|
||||||
- [ ] 测试继承基础因子的 data_specs
|
|
||||||
- [ ] 测试 `compute()` 返回正确缩放后的值
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3: 数据加载 (data_loader.py)
|
|
||||||
|
|
||||||
### 3.1 DataLoader - 数据加载器
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class DataLoader:
|
|
||||||
"""
|
|
||||||
数据加载器 - 负责从 HDF5 安全加载数据
|
|
||||||
|
|
||||||
功能:
|
|
||||||
1. 多文件聚合:合并多个 H5 文件的数据
|
|
||||||
2. 列选择:只加载需要的列
|
|
||||||
3. 原始数据缓存:避免重复读取
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_dir: str):
|
|
||||||
"""
|
|
||||||
初始化 DataLoader
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_dir: HDF5 文件所在目录
|
|
||||||
"""
|
|
||||||
self.data_dir = Path(data_dir)
|
|
||||||
self._cache: Dict[str, pl.DataFrame] = {}
|
|
||||||
|
|
||||||
def load(
|
|
||||||
self,
|
|
||||||
specs: List[DataSpec],
|
|
||||||
date_range: Optional[Tuple[str, str]] = None
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""
|
|
||||||
加载并聚合多个 H5 文件的数据
|
|
||||||
|
|
||||||
流程:
|
|
||||||
1. 对每个 DataSpec:
|
|
||||||
a. 检查缓存,命中则直接使用
|
|
||||||
b. 未命中则读取 HDF5(通过 pandas)
|
|
||||||
c. 转换为 Polars DataFrame
|
|
||||||
d. 按 date_range 过滤
|
|
||||||
e. 存入缓存
|
|
||||||
2. 合并多个 DataFrame(按 trade_date 和 ts_code join)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
specs: 数据需求规格列表
|
|
||||||
date_range: 日期范围限制 (start_date, end_date),可选
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
合并后的 Polars DataFrame
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: H5 文件不存在
|
|
||||||
KeyError: 列不存在于文件中
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
"""清空缓存"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _read_h5(self, source: str) -> pl.DataFrame:
|
|
||||||
"""
|
|
||||||
读取单个 H5 文件
|
|
||||||
|
|
||||||
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试从单个 H5 文件加载数据
|
|
||||||
- [ ] 测试从多个 H5 文件加载并合并
|
|
||||||
- [ ] 测试列选择(只加载需要的列)
|
|
||||||
- [ ] 测试缓存机制(第二次加载更快)
|
|
||||||
- [ ] 测试 `clear_cache()` 清空缓存
|
|
||||||
- [ ] 测试按 date_range 过滤
|
|
||||||
- [ ] 测试文件不存在时抛出 FileNotFoundError
|
|
||||||
- [ ] 测试列不存在时抛出 KeyError
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 4: 执行引擎 (engine.py)
|
|
||||||
|
|
||||||
### 4.1 FactorEngine - 因子执行引擎
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
class FactorEngine:
|
|
||||||
"""
|
|
||||||
因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
|
|
||||||
|
|
||||||
核心职责:
|
|
||||||
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
|
||||||
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_loader: DataLoader):
|
|
||||||
"""
|
|
||||||
初始化引擎
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_loader: 数据加载器实例
|
|
||||||
"""
|
|
||||||
self.data_loader = data_loader
|
|
||||||
|
|
||||||
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
|
|
||||||
"""
|
|
||||||
统一的计算入口
|
|
||||||
|
|
||||||
根据 factor_type 分发到具体方法:
|
|
||||||
- "cross_sectional" -> _compute_cross_sectional()
|
|
||||||
- "time_series" -> _compute_time_series()
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factor: 要计算的因子
|
|
||||||
**kwargs: 额外参数,根据因子类型不同:
|
|
||||||
- 截面因子: start_date, end_date
|
|
||||||
- 时序因子: stock_codes, start_date, end_date
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DataFrame[trade_date, ts_code, factor_name]
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试 `compute()` 正确分发给截面计算
|
|
||||||
- [ ] 测试 `compute()` 正确分发给时序计算
|
|
||||||
- [ ] 测试无效 factor_type 时抛出 ValueError
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 4.2 截面计算(防止日期泄露)
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
def _compute_cross_sectional(
|
|
||||||
self,
|
|
||||||
factor: CrossSectionalFactor,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""
|
|
||||||
执行日期截面计算
|
|
||||||
|
|
||||||
防泄露策略:
|
|
||||||
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
|
|
||||||
- 允许股票间比较:传入当天所有股票的数据
|
|
||||||
|
|
||||||
计算流程:
|
|
||||||
1. 计算 max_lookback,确定数据起始日期
|
|
||||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
|
||||||
3. 对每个日期 T in [start_date, end_date]:
|
|
||||||
a. 裁剪数据到 [T-lookback+1, T]
|
|
||||||
b. 创建 FactorData(current_date=T)
|
|
||||||
c. 调用 factor.compute()
|
|
||||||
d. 收集结果
|
|
||||||
4. 合并所有日期的结果
|
|
||||||
|
|
||||||
返回 DataFrame 格式:
|
|
||||||
┌────────────┬──────────┬──────────────┐
|
|
||||||
│ trade_date │ ts_code │ factor_name │
|
|
||||||
├────────────┼──────────┼──────────────┤
|
|
||||||
│ 20240101 │ 000001.SZ│ 0.5 │
|
|
||||||
│ 20240101 │ 000002.SZ│ 0.3 │
|
|
||||||
└────────────┴──────────┴──────────────┘
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求(防泄露验证):**
|
|
||||||
- [ ] 测试数据裁剪正确(传入 [T-lookback+1, T])
|
|
||||||
- [ ] 测试不包含未来日期 T+1 的数据
|
|
||||||
- [ ] 测试每个日期独立计算
|
|
||||||
- [ ] 测试结果包含所有日期和所有股票
|
|
||||||
- [ ] 测试结果 DataFrame 格式正确
|
|
||||||
- [ ] 测试多个 DataSpec 时 lookback 取最大值
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 4.3 时序计算(防止股票泄露)
|
|
||||||
|
|
||||||
**实现要求:**
|
|
||||||
```python
|
|
||||||
def _compute_time_series(
|
|
||||||
self,
|
|
||||||
factor: TimeSeriesFactor,
|
|
||||||
stock_codes: List[str],
|
|
||||||
start_date: str,
|
|
||||||
end_date: str
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""
|
|
||||||
执行时间序列计算
|
|
||||||
|
|
||||||
防泄露策略:
|
|
||||||
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
|
|
||||||
- 允许访问历史数据:时序计算需要历史数据
|
|
||||||
|
|
||||||
计算流程:
|
|
||||||
1. 计算 max_lookback,确定数据起始日期
|
|
||||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
|
||||||
3. 对每只股票 S in stock_codes:
|
|
||||||
a. 过滤出 S 的数据(防止股票泄露)
|
|
||||||
b. 创建 FactorData(current_stock=S)
|
|
||||||
c. 调用 factor.compute()(向量化计算整个序列)
|
|
||||||
d. 收集结果
|
|
||||||
4. 合并所有股票的结果
|
|
||||||
|
|
||||||
性能优势:
|
|
||||||
- 使用 Polars 的 rolling_mean 等向量化操作
|
|
||||||
- 每只股票只计算一次,无重复计算
|
|
||||||
|
|
||||||
返回 DataFrame 格式:
|
|
||||||
┌────────────┬──────────┬──────────────┐
|
|
||||||
│ trade_date │ ts_code │ factor_name │
|
|
||||||
├────────────┼──────────┼──────────────┤
|
|
||||||
│ 20240101 │ 000001.SZ│ 10.5 │
|
|
||||||
│ 20240102 │ 000001.SZ│ 10.6 │
|
|
||||||
└────────────┴──────────┴──────────────┘
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求(防泄露验证):**
|
|
||||||
- [ ] 测试每只股票只看到自己的数据
|
|
||||||
- [ ] 测试不包含其他股票的数据
|
|
||||||
- [ ] 测试传入的是完整时间序列(向量化计算)
|
|
||||||
- [ ] 测试结果包含所有股票和所有日期
|
|
||||||
- [ ] 测试结果 DataFrame 格式正确
|
|
||||||
- [ ] 测试股票不在数据中时跳过(或填充 null)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 5: 内置因子库 (builtin/)
|
|
||||||
|
|
||||||
### 5.1 momentum.py - 截面动量因子
|
|
||||||
|
|
||||||
**实现因子:**
|
|
||||||
|
|
||||||
1. **ReturnRankFactor** - 当日收益率排名
|
|
||||||
```python
|
|
||||||
class ReturnRankFactor(CrossSectionalFactor):
|
|
||||||
"""当日收益率排名因子"""
|
|
||||||
name = "return_rank"
|
|
||||||
data_specs = [DataSpec("daily", ["close"], lookback_days=2)] # 需要2天计算收益率
|
|
||||||
|
|
||||||
def compute(self, data):
|
|
||||||
# 获取当前日期截面
|
|
||||||
cs = data.get_cross_section()
|
|
||||||
# 需要前1天和当天的收盘价,lookback=2 保证数据包含 [T-1, T]
|
|
||||||
# 这里假设 data 已经包含历史,实际计算需要 groupby 处理
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试收益率计算正确
|
|
||||||
- [ ] 测试排名计算正确
|
|
||||||
- [ ] 测试无数据时返回 null
|
|
||||||
|
|
||||||
2. **MomentumFactor** - 过去 N 日涨幅排名
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 5.2 technical.py - 时序技术指标
|
|
||||||
|
|
||||||
**实现因子:**
|
|
||||||
|
|
||||||
1. **MovingAverageFactor** - 移动平均线
|
|
||||||
```python
|
|
||||||
class MovingAverageFactor(TimeSeriesFactor):
|
|
||||||
"""移动平均线因子"""
|
|
||||||
name = "ma"
|
|
||||||
|
|
||||||
def __init__(self, period: int = 20):
|
|
||||||
super().__init__(period=period)
|
|
||||||
self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
|
|
||||||
|
|
||||||
def compute(self, data):
|
|
||||||
return data.get_column("close").rolling_mean(self.params["period"])
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试需求:**
|
|
||||||
- [ ] 测试 MA20 计算正确
|
|
||||||
- [ ] 测试前19天返回 null(Polars 默认行为)
|
|
||||||
- [ ] 测试参数 period 生效
|
|
||||||
|
|
||||||
2. **RSIFactor** - RSI 指标
|
|
||||||
3. **MACDFactor** - MACD 指标
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 5.3 value.py - 截面估值因子
|
|
||||||
|
|
||||||
**实现因子:**
|
|
||||||
1. **PERankFactor** - PE 行业分位数
|
|
||||||
2. **PBFactor** - PB 排名
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 6-7: 测试策略
|
|
||||||
|
|
||||||
### 测试金字塔
|
|
||||||
|
|
||||||
```
|
|
||||||
/\
|
|
||||||
/ \
|
|
||||||
/ 集成\ tests/factors/test_integration.py
|
|
||||||
/────────\
|
|
||||||
/ 引擎 \ tests/factors/test_engine.py
|
|
||||||
/────────────\
|
|
||||||
/ 基类/组合因子 \ tests/factors/test_base.py, test_composite.py
|
|
||||||
/────────────────\
|
|
||||||
/ 数据加载/类型 \ tests/factors/test_data_loader.py, test_data_spec.py
|
|
||||||
/──────────────────────\
|
|
||||||
```
|
|
||||||
|
|
||||||
### 测试数据准备
|
|
||||||
|
|
||||||
创建 `tests/fixtures/` 目录,包含:
|
|
||||||
- `sample_daily.h5`: 少量股票的日线数据(用于测试)
|
|
||||||
- `sample_fundamental.h5`: 基本面数据
|
|
||||||
|
|
||||||
### 关键测试场景
|
|
||||||
|
|
||||||
1. **防泄露测试(核心)**
|
|
||||||
- 截面因子:验证 compute() 中无法访问未来日期
|
|
||||||
- 时序因子:验证 compute() 中无法访问其他股票
|
|
||||||
|
|
||||||
2. **边界测试**
|
|
||||||
- lookback_days = 1(最小值)
|
|
||||||
- 数据起始点(前 N 天为 null)
|
|
||||||
- 空数据/停牌处理
|
|
||||||
|
|
||||||
3. **性能测试(可选)**
|
|
||||||
- 大数据量下的内存占用
|
|
||||||
- 缓存命中率
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 实现状态
|
|
||||||
|
|
||||||
| Phase | 状态 | 完成日期 | 测试覆盖 |
|
|
||||||
|-------|------|----------|----------|
|
|
||||||
| Phase 1: 数据类型定义 | ✅ 已完成 | 2026-02-21 | 27 tests passed |
|
|
||||||
| Phase 2: 因子基类 | ✅ 已完成 | 2026-02-21 | 49 tests passed |
|
|
||||||
| Phase 3: 数据加载 | ✅ 已完成 | 2026-02-21 | 11 tests passed |
|
|
||||||
| Phase 4: 执行引擎 | ✅ 已完成 | 2026-02-22 | 10 tests passed |
|
|
||||||
| Phase 5: 内置因子库 | 📝 待开发 | - | - |
|
|
||||||
| Phase 6-7: 测试文档 | ✅ 已完成 | 2026-02-22 | 76 tests total |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 实现顺序建议
|
|
||||||
|
|
||||||
1. **Week 1**: Phase 1-2(数据类型 + 基类)
|
|
||||||
2. **Week 2**: Phase 3-4(DataLoader + Engine)✅ **已完成**
|
|
||||||
3. **Week 3**: Phase 5(内置因子)
|
|
||||||
4. **Week 4**: Phase 6-7(测试 + 文档)
|
|
||||||
|
|
||||||
每个 Phase 完成后运行对应测试,确保质量。
|
|
||||||
@@ -1,10 +1,17 @@
|
|||||||
# ProStock HDF5 到 DuckDB 迁移方案与计划
|
# ProStock HDF5 到 DuckDB 迁移方案
|
||||||
|
|
||||||
**文档版本**: v1.0
|
**文档版本**: v1.1
|
||||||
**创建日期**: 2026-02-22
|
**创建日期**: 2026-02-22
|
||||||
**状态**: 待审批
|
**完成日期**: 2026-02-22
|
||||||
|
**状态**: ✅ 已完成
|
||||||
**影响范围**: data 模块、factors 模块、相关文档
|
**影响范围**: data 模块、factors 模块、相关文档
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
[DuckDB 数据同步指南](./db_sync_guide.md) - 同步 API 使用说明
|
||||||
|
[迁移测试报告](./test_report_duckdb_migration.md) - 测试验证结果
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|||||||
1472
docs/ml_framework_design.md
Normal file
1472
docs/ml_framework_design.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,8 @@
|
|||||||
# ProStock HDF5 到 DuckDB 迁移测试报告
|
# ProStock HDF5 到 DuckDB 迁移测试报告
|
||||||
|
|
||||||
**报告生成时间**: 2026-02-22
|
**报告生成时间**: 2026-02-22
|
||||||
|
**完成时间**: 2026-02-22
|
||||||
|
**状态**: ✅ 已完成
|
||||||
**迁移文档**: [hdf5_to_duckdb_migration.md](./hdf5_to_duckdb_migration.md)
|
**迁移文档**: [hdf5_to_duckdb_migration.md](./hdf5_to_duckdb_migration.md)
|
||||||
**测试数据范围**: 2024年1月-3月(3个月)
|
**测试数据范围**: 2024年1月-3月(3个月)
|
||||||
|
|
||||||
|
|||||||
86
src/models/__init__.py
Normal file
86
src/models/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""ProStock 模型训练框架
|
||||||
|
|
||||||
|
组件化、低耦合、插件式的机器学习训练框架。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> from src.models import (
|
||||||
|
... PluginRegistry, ProcessingPipeline,
|
||||||
|
... PipelineStage, BaseProcessor
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # 获取注册的处理器
|
||||||
|
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
|
||||||
|
>>> scaler = scaler_class()
|
||||||
|
|
||||||
|
>>> # 创建处理流水线
|
||||||
|
>>> pipeline = ProcessingPipeline([
|
||||||
|
... PluginRegistry.get_processor("dropna")(),
|
||||||
|
... PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
|
||||||
|
... PluginRegistry.get_processor("standard_scaler")(),
|
||||||
|
... ])
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 导入核心抽象类和划分策略
|
||||||
|
from src.models.core import (
|
||||||
|
PipelineStage,
|
||||||
|
TaskType,
|
||||||
|
BaseProcessor,
|
||||||
|
BaseModel,
|
||||||
|
BaseSplitter,
|
||||||
|
BaseMetric,
|
||||||
|
TimeSeriesSplit,
|
||||||
|
WalkForwardSplit,
|
||||||
|
ExpandingWindowSplit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入注册中心
|
||||||
|
from src.models.registry import PluginRegistry
|
||||||
|
|
||||||
|
# 导入处理流水线
|
||||||
|
from src.models.pipeline import ProcessingPipeline
|
||||||
|
|
||||||
|
# 导入并注册内置处理器
|
||||||
|
from src.models.processors.processors import (
|
||||||
|
DropNAProcessor,
|
||||||
|
FillNAProcessor,
|
||||||
|
Winsorizer,
|
||||||
|
StandardScaler,
|
||||||
|
MinMaxScaler,
|
||||||
|
RankTransformer,
|
||||||
|
Neutralizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入并注册内置模型
|
||||||
|
from src.models.models.models import (
|
||||||
|
LightGBMModel,
|
||||||
|
CatBoostModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 核心抽象
|
||||||
|
"PipelineStage",
|
||||||
|
"TaskType",
|
||||||
|
"BaseProcessor",
|
||||||
|
"BaseModel",
|
||||||
|
"BaseSplitter",
|
||||||
|
"BaseMetric",
|
||||||
|
# 划分策略
|
||||||
|
"TimeSeriesSplit",
|
||||||
|
"WalkForwardSplit",
|
||||||
|
"ExpandingWindowSplit",
|
||||||
|
# 注册中心
|
||||||
|
"PluginRegistry",
|
||||||
|
# 处理流水线
|
||||||
|
"ProcessingPipeline",
|
||||||
|
# 处理器
|
||||||
|
"DropNAProcessor",
|
||||||
|
"FillNAProcessor",
|
||||||
|
"Winsorizer",
|
||||||
|
"StandardScaler",
|
||||||
|
"MinMaxScaler",
|
||||||
|
"RankTransformer",
|
||||||
|
"Neutralizer",
|
||||||
|
# 模型
|
||||||
|
"LightGBMModel",
|
||||||
|
"CatBoostModel",
|
||||||
|
]
|
||||||
30
src/models/core/__init__.py
Normal file
30
src/models/core/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""核心模块导出"""
|
||||||
|
|
||||||
|
from src.models.core.base import (
|
||||||
|
PipelineStage,
|
||||||
|
TaskType,
|
||||||
|
BaseProcessor,
|
||||||
|
BaseModel,
|
||||||
|
BaseSplitter,
|
||||||
|
BaseMetric,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.models.core.splitter import (
|
||||||
|
TimeSeriesSplit,
|
||||||
|
WalkForwardSplit,
|
||||||
|
ExpandingWindowSplit,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 基础抽象
|
||||||
|
"PipelineStage",
|
||||||
|
"TaskType",
|
||||||
|
"BaseProcessor",
|
||||||
|
"BaseModel",
|
||||||
|
"BaseSplitter",
|
||||||
|
"BaseMetric",
|
||||||
|
# 划分策略
|
||||||
|
"TimeSeriesSplit",
|
||||||
|
"WalkForwardSplit",
|
||||||
|
"ExpandingWindowSplit",
|
||||||
|
]
|
||||||
351
src/models/core/base.py
Normal file
351
src/models/core/base.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
"""模型训练框架核心抽象类
|
||||||
|
|
||||||
|
提供处理器、模型、划分策略和评估指标的基类定义。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Literal
|
||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 任务类型
|
||||||
|
TaskType = Literal["classification", "regression", "ranking"]
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStage(Enum):
|
||||||
|
"""流水线阶段标记
|
||||||
|
|
||||||
|
用于标记处理器在哪些阶段生效,防止数据泄露。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ALL: 适用于所有阶段(训练、测试、验证)
|
||||||
|
TRAIN: 仅训练阶段
|
||||||
|
TEST: 仅测试阶段
|
||||||
|
VALIDATION: 仅验证阶段
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALL = auto()
|
||||||
|
TRAIN = auto()
|
||||||
|
TEST = auto()
|
||||||
|
VALIDATION = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProcessor(ABC):
|
||||||
|
"""数据处理器基类
|
||||||
|
|
||||||
|
所有数据处理器必须继承此类。关键特性是通过 stage 属性控制处理器在哪些阶段生效。
|
||||||
|
|
||||||
|
阶段标记规则:
|
||||||
|
- ALL: 训练和测试阶段都使用相同的参数
|
||||||
|
- TRAIN: 只在训练阶段计算参数(如分位数、均值等),测试阶段使用训练阶段学到的参数
|
||||||
|
- TEST: 只在测试阶段执行
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 子类必须定义适用阶段
|
||||||
|
stage: PipelineStage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def __init__(self, columns: Optional[List[str]] = None, **params):
|
||||||
|
"""初始化处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
columns: 要处理的列,None表示所有数值列
|
||||||
|
**params: 处理器特定参数
|
||||||
|
"""
|
||||||
|
self.columns = columns
|
||||||
|
self.params = params
|
||||||
|
self._is_fitted = False
|
||||||
|
self._fitted_params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fit(self, data: pl.DataFrame) -> "BaseProcessor":
|
||||||
|
"""在训练数据上学习参数
|
||||||
|
|
||||||
|
此方法只在训练阶段调用一次。学习到的参数存储在 self._fitted_params 中。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""转换数据
|
||||||
|
|
||||||
|
在训练和测试阶段都会被调用。使用 fit() 阶段学习到的参数进行转换。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转换后的数据
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""先fit再transform的便捷方法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转换后的数据
|
||||||
|
"""
|
||||||
|
return self.fit(data).transform(data)
|
||||||
|
|
||||||
|
def get_fitted_params(self) -> Dict[str, Any]:
|
||||||
|
"""获取学习到的参数(用于保存/加载)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
学习到的参数字典
|
||||||
|
"""
|
||||||
|
return self._fitted_params.copy()
|
||||||
|
|
||||||
|
def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
|
||||||
|
"""设置学习到的参数(用于从checkpoint恢复)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: 参数字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
"""
|
||||||
|
self._fitted_params = params.copy()
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(ABC):
|
||||||
|
"""机器学习模型基类
|
||||||
|
|
||||||
|
统一接口支持多种模型(LightGBM, CatBoost, XGBoost等)
|
||||||
|
和多种任务类型(分类、回归、排序)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_type: TaskType,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""初始化模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: 任务类型 - "classification", "regression", "ranking"
|
||||||
|
params: 模型特定参数
|
||||||
|
name: 模型名称(用于日志和报告)
|
||||||
|
"""
|
||||||
|
self.task_type = task_type
|
||||||
|
self.params = params or {}
|
||||||
|
self.name = name or self.__class__.__name__
|
||||||
|
self._model: Any = None
|
||||||
|
self._is_fitted = False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
X: pl.DataFrame,
|
||||||
|
y: pl.Series,
|
||||||
|
X_val: Optional[pl.DataFrame] = None,
|
||||||
|
y_val: Optional[pl.Series] = None,
|
||||||
|
**fit_params,
|
||||||
|
) -> "BaseModel":
|
||||||
|
"""训练模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征数据
|
||||||
|
y: 目标变量
|
||||||
|
X_val: 验证集特征(可选)
|
||||||
|
y_val: 验证集目标(可选)
|
||||||
|
**fit_params: 额外的fit参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预测结果数组
|
||||||
|
- classification: 类别标签或概率
|
||||||
|
- regression: 连续值
|
||||||
|
- ranking: 排序分数
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测概率(仅分类任务)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
类别概率数组 [n_samples, n_classes]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: 非分类任务时抛出
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"predict_proba only available for classification tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||||
|
"""获取特征重要性(如果模型支持)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame[feature, importance] 或 None
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save(self, path: str) -> None:
|
||||||
|
"""保存模型到文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 保存路径
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
pickle.dump(self, f)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: str) -> "BaseModel":
|
||||||
|
"""从文件加载模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 模型文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载的模型实例
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSplitter(ABC):
|
||||||
|
"""数据划分策略基类
|
||||||
|
|
||||||
|
针对时间序列数据的特殊划分策略,防止未来泄露。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def split(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||||
|
"""生成训练/测试索引
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 完整数据集
|
||||||
|
date_col: 日期列名
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
(train_indices, test_indices) 元组
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_split_dates(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> List[Tuple[str, str, str, str]]:
|
||||||
|
"""获取划分日期范围
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 完整数据集
|
||||||
|
date_col: 日期列名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[(train_start, train_end, test_start, test_end), ...]
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMetric(ABC):
|
||||||
|
"""评估指标基类
|
||||||
|
|
||||||
|
所有评估指标必须继承此类。支持单次计算和累积计算两种模式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: Optional[str] = None):
|
||||||
|
"""初始化指标
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 指标名称
|
||||||
|
"""
|
||||||
|
self.name = name or self.__class__.__name__
|
||||||
|
self._values: List[float] = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||||
|
"""计算指标值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: 真实值
|
||||||
|
y_pred: 预测值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
指标值
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update(self, y_true: np.ndarray, y_pred: np.ndarray) -> "BaseMetric":
|
||||||
|
"""更新累积值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: 真实值
|
||||||
|
y_pred: 预测值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
"""
|
||||||
|
self._values.append(self.compute(y_true, y_pred))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_mean(self) -> float:
|
||||||
|
"""获取累积值的均值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
均值
|
||||||
|
"""
|
||||||
|
if not self._values:
|
||||||
|
return 0.0
|
||||||
|
return float(np.mean(self._values))
|
||||||
|
|
||||||
|
def get_std(self) -> float:
|
||||||
|
"""获取累积值的标准差
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准差
|
||||||
|
"""
|
||||||
|
if not self._values:
|
||||||
|
return 0.0
|
||||||
|
return float(np.std(self._values))
|
||||||
|
|
||||||
|
def reset(self) -> "BaseMetric":
|
||||||
|
"""重置累积值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
"""
|
||||||
|
self._values = []
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PipelineStage",
|
||||||
|
"TaskType",
|
||||||
|
"BaseProcessor",
|
||||||
|
"BaseModel",
|
||||||
|
"BaseSplitter",
|
||||||
|
"BaseMetric",
|
||||||
|
]
|
||||||
222
src/models/core/splitter.py
Normal file
222
src/models/core/splitter.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""时间序列数据划分策略
|
||||||
|
|
||||||
|
提供针对金融时间序列的特殊划分策略,防止未来泄露。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Iterator, List, Tuple
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.models.core.base import BaseSplitter
|
||||||
|
|
||||||
|
|
||||||
|
class TimeSeriesSplit(BaseSplitter):
|
||||||
|
"""时间序列划分 - 确保训练数据在测试数据之前
|
||||||
|
|
||||||
|
按照时间顺序进行K折划分,每折的训练数据都在测试数据之前。
|
||||||
|
通过 gap 参数防止训练集和测试集之间的数据泄露。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_splits: 划分折数
|
||||||
|
gap: 训练集和测试集之间的间隔天数(防止泄露)
|
||||||
|
min_train_size: 最小训练集大小(天数)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
|
||||||
|
self.n_splits = n_splits
|
||||||
|
self.gap = gap
|
||||||
|
self.min_train_size = min_train_size
|
||||||
|
|
||||||
|
def split(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||||
|
"""生成训练/测试索引"""
|
||||||
|
dates = data[date_col].unique().sort()
|
||||||
|
n_dates = len(dates)
|
||||||
|
|
||||||
|
test_size = (n_dates - self.min_train_size) // self.n_splits
|
||||||
|
|
||||||
|
for i in range(self.n_splits):
|
||||||
|
train_end_idx = self.min_train_size + i * test_size
|
||||||
|
test_start_idx = train_end_idx + self.gap
|
||||||
|
test_end_idx = test_start_idx + test_size
|
||||||
|
|
||||||
|
if test_end_idx > n_dates:
|
||||||
|
break
|
||||||
|
|
||||||
|
train_dates = dates[:train_end_idx]
|
||||||
|
test_dates = dates[test_start_idx:test_end_idx]
|
||||||
|
|
||||||
|
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||||
|
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||||
|
|
||||||
|
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||||
|
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||||
|
|
||||||
|
yield train_idx, test_idx
|
||||||
|
|
||||||
|
def get_split_dates(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> List[Tuple[str, str, str, str]]:
|
||||||
|
"""获取划分日期范围"""
|
||||||
|
dates = data[date_col].unique().sort()
|
||||||
|
n_dates = len(dates)
|
||||||
|
test_size = (n_dates - self.min_train_size) // self.n_splits
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for i in range(self.n_splits):
|
||||||
|
train_end_idx = self.min_train_size + i * test_size
|
||||||
|
test_start_idx = train_end_idx + self.gap
|
||||||
|
test_end_idx = test_start_idx + test_size
|
||||||
|
|
||||||
|
if test_end_idx > n_dates:
|
||||||
|
break
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
(
|
||||||
|
str(dates[0]),
|
||||||
|
str(dates[train_end_idx - 1]),
|
||||||
|
str(dates[test_start_idx]),
|
||||||
|
str(dates[test_end_idx - 1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class WalkForwardSplit(BaseSplitter):
|
||||||
|
"""滚动前向验证 - 训练集逐步扩展
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_window: 训练集窗口大小(天数)
|
||||||
|
test_window: 测试集窗口大小(天数)
|
||||||
|
gap: 训练集和测试集之间的间隔天数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
|
||||||
|
self.train_window = train_window
|
||||||
|
self.test_window = test_window
|
||||||
|
self.gap = gap
|
||||||
|
|
||||||
|
def split(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||||
|
"""生成训练/测试索引"""
|
||||||
|
dates = data[date_col].unique().sort()
|
||||||
|
n_dates = len(dates)
|
||||||
|
|
||||||
|
start_idx = self.train_window
|
||||||
|
while start_idx + self.gap + self.test_window <= n_dates:
|
||||||
|
train_start = start_idx - self.train_window
|
||||||
|
train_end = start_idx
|
||||||
|
test_start = start_idx + self.gap
|
||||||
|
test_end = test_start + self.test_window
|
||||||
|
|
||||||
|
train_dates = dates[train_start:train_end]
|
||||||
|
test_dates = dates[test_start:test_end]
|
||||||
|
|
||||||
|
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||||
|
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||||
|
|
||||||
|
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||||
|
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||||
|
|
||||||
|
yield train_idx, test_idx
|
||||||
|
start_idx += self.test_window
|
||||||
|
|
||||||
|
def get_split_dates(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> List[Tuple[str, str, str, str]]:
|
||||||
|
"""获取划分日期范围"""
|
||||||
|
dates = data[date_col].unique().sort()
|
||||||
|
n_dates = len(dates)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
start_idx = self.train_window
|
||||||
|
while start_idx + self.gap + self.test_window <= n_dates:
|
||||||
|
train_start = start_idx - self.train_window
|
||||||
|
train_end = start_idx
|
||||||
|
test_start = start_idx + self.gap
|
||||||
|
test_end = test_start + self.test_window
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
(
|
||||||
|
str(dates[train_start]),
|
||||||
|
str(dates[train_end - 1]),
|
||||||
|
str(dates[test_start]),
|
||||||
|
str(dates[test_end - 1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
start_idx += self.test_window
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ExpandingWindowSplit(BaseSplitter):
|
||||||
|
"""扩展窗口划分 - 训练集不断扩大
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_train_size: 初始训练集大小(天数)
|
||||||
|
test_window: 测试集窗口大小(天数)
|
||||||
|
gap: 训练集和测试集之间的间隔天数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, initial_train_size: int = 252, test_window: int = 21, gap: int = 5
|
||||||
|
):
|
||||||
|
self.initial_train_size = initial_train_size
|
||||||
|
self.test_window = test_window
|
||||||
|
self.gap = gap
|
||||||
|
|
||||||
|
def split(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||||
|
"""生成训练/测试索引"""
|
||||||
|
dates = data[date_col].unique().sort()
|
||||||
|
n_dates = len(dates)
|
||||||
|
|
||||||
|
train_end_idx = self.initial_train_size
|
||||||
|
while train_end_idx + self.gap + self.test_window <= n_dates:
|
||||||
|
train_dates = dates[:train_end_idx]
|
||||||
|
test_start = train_end_idx + self.gap
|
||||||
|
test_end = test_start + self.test_window
|
||||||
|
test_dates = dates[test_start:test_end]
|
||||||
|
|
||||||
|
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||||
|
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||||
|
|
||||||
|
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||||
|
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||||
|
|
||||||
|
yield train_idx, test_idx
|
||||||
|
train_end_idx += self.test_window
|
||||||
|
|
||||||
|
def get_split_dates(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> List[Tuple[str, str, str, str]]:
|
||||||
|
"""获取划分日期范围"""
|
||||||
|
dates = data[date_col].unique().sort()
|
||||||
|
n_dates = len(dates)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
train_end_idx = self.initial_train_size
|
||||||
|
while train_end_idx + self.gap + self.test_window <= n_dates:
|
||||||
|
test_start = train_end_idx + self.gap
|
||||||
|
test_end = test_start + self.test_window
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
(
|
||||||
|
str(dates[0]),
|
||||||
|
str(dates[train_end_idx - 1]),
|
||||||
|
str(dates[test_start]),
|
||||||
|
str(dates[test_end - 1]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
train_end_idx += self.test_window
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TimeSeriesSplit",
|
||||||
|
"WalkForwardSplit",
|
||||||
|
"ExpandingWindowSplit",
|
||||||
|
]
|
||||||
11
src/models/models/__init__.py
Normal file
11
src/models/models/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""模型模块"""
|
||||||
|
|
||||||
|
from src.models.models.models import (
|
||||||
|
LightGBMModel,
|
||||||
|
CatBoostModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LightGBMModel",
|
||||||
|
"CatBoostModel",
|
||||||
|
]
|
||||||
210
src/models/models/models.py
Normal file
210
src/models/models/models.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""内置机器学习模型
|
||||||
|
|
||||||
|
提供 LightGBM、CatBoost 等模型的统一接口包装器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.models.core import BaseModel, TaskType
|
||||||
|
from src.models.registry import PluginRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_model("lightgbm")
|
||||||
|
class LightGBMModel(BaseModel):
|
||||||
|
"""LightGBM 模型包装器
|
||||||
|
|
||||||
|
支持分类、回归、排序三种任务类型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_type: TaskType,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(task_type, params, name)
|
||||||
|
self._model = None
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
X: pl.DataFrame,
|
||||||
|
y: pl.Series,
|
||||||
|
X_val: Optional[pl.DataFrame] = None,
|
||||||
|
y_val: Optional[pl.Series] = None,
|
||||||
|
**fit_params,
|
||||||
|
) -> "LightGBMModel":
|
||||||
|
"""训练模型"""
|
||||||
|
try:
|
||||||
|
import lightgbm as lgb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"lightgbm is required. Install with: uv pip install lightgbm"
|
||||||
|
)
|
||||||
|
|
||||||
|
X_arr = X.to_numpy()
|
||||||
|
y_arr = y.to_numpy()
|
||||||
|
|
||||||
|
train_data = lgb.Dataset(X_arr, label=y_arr)
|
||||||
|
valid_sets = [train_data]
|
||||||
|
valid_names = ["train"]
|
||||||
|
|
||||||
|
if X_val is not None and y_val is not None:
|
||||||
|
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
|
||||||
|
valid_sets.append(valid_data)
|
||||||
|
valid_names.append("valid")
|
||||||
|
|
||||||
|
default_params = {
|
||||||
|
"objective": self._get_objective(),
|
||||||
|
"metric": self._get_metric(),
|
||||||
|
"boosting_type": "gbdt",
|
||||||
|
"num_leaves": 31,
|
||||||
|
"learning_rate": 0.05,
|
||||||
|
"feature_fraction": 0.9,
|
||||||
|
"bagging_fraction": 0.8,
|
||||||
|
"bagging_freq": 5,
|
||||||
|
"verbose": -1,
|
||||||
|
}
|
||||||
|
default_params.update(self.params)
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
if len(valid_sets) > 1:
|
||||||
|
callbacks.append(lgb.early_stopping(stopping_rounds=10, verbose=False))
|
||||||
|
|
||||||
|
self._model = lgb.train(
|
||||||
|
default_params,
|
||||||
|
train_data,
|
||||||
|
num_boost_round=fit_params.get("num_boost_round", 100),
|
||||||
|
valid_sets=valid_sets,
|
||||||
|
valid_names=valid_names,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测"""
|
||||||
|
if not self._is_fitted:
|
||||||
|
raise RuntimeError("Model not fitted yet")
|
||||||
|
return self._model.predict(X.to_numpy())
|
||||||
|
|
||||||
|
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测概率(仅分类任务)"""
|
||||||
|
if self.task_type != "classification":
|
||||||
|
raise ValueError("predict_proba only for classification")
|
||||||
|
probs = self.predict(X)
|
||||||
|
if len(probs.shape) == 1:
|
||||||
|
return np.vstack([1 - probs, probs]).T
|
||||||
|
return probs
|
||||||
|
|
||||||
|
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||||
|
"""获取特征重要性"""
|
||||||
|
if self._model is None:
|
||||||
|
return None
|
||||||
|
importance = self._model.feature_importance(importance_type="gain")
|
||||||
|
feature_names = getattr(
|
||||||
|
self._model,
|
||||||
|
"feature_name",
|
||||||
|
lambda: [f"feature_{i}" for i in range(len(importance))],
|
||||||
|
)()
|
||||||
|
return pl.DataFrame({"feature": feature_names, "importance": importance}).sort(
|
||||||
|
"importance", descending=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_objective(self) -> str:
|
||||||
|
objectives = {
|
||||||
|
"classification": "binary",
|
||||||
|
"regression": "regression",
|
||||||
|
"ranking": "lambdarank",
|
||||||
|
}
|
||||||
|
return objectives.get(self.task_type, "regression")
|
||||||
|
|
||||||
|
def _get_metric(self) -> str:
|
||||||
|
metrics = {"classification": "auc", "regression": "rmse", "ranking": "ndcg"}
|
||||||
|
return metrics.get(self.task_type, "rmse")
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_model("catboost")
|
||||||
|
class CatBoostModel(BaseModel):
|
||||||
|
"""CatBoost 模型包装器"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_type: TaskType,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(task_type, params, name)
|
||||||
|
self._model = None
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
X: pl.DataFrame,
|
||||||
|
y: pl.Series,
|
||||||
|
X_val: Optional[pl.DataFrame] = None,
|
||||||
|
y_val: Optional[pl.Series] = None,
|
||||||
|
**fit_params,
|
||||||
|
) -> "CatBoostModel":
|
||||||
|
"""训练模型"""
|
||||||
|
try:
|
||||||
|
from catboost import CatBoostClassifier, CatBoostRegressor
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"catboost is required. Install with: uv pip install catboost"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.task_type == "classification":
|
||||||
|
model_class = CatBoostClassifier
|
||||||
|
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
|
||||||
|
elif self.task_type == "regression":
|
||||||
|
model_class = CatBoostRegressor
|
||||||
|
default_params = {"loss_function": "RMSE"}
|
||||||
|
else:
|
||||||
|
model_class = CatBoostRegressor
|
||||||
|
default_params = {"loss_function": "QueryRMSE"}
|
||||||
|
|
||||||
|
default_params.update(self.params)
|
||||||
|
default_params["verbose"] = False
|
||||||
|
|
||||||
|
self._model = model_class(**default_params)
|
||||||
|
|
||||||
|
eval_set = None
|
||||||
|
if X_val is not None and y_val is not None:
|
||||||
|
eval_set = (X_val.to_pandas(), y_val.to_pandas())
|
||||||
|
|
||||||
|
self._model.fit(
|
||||||
|
X.to_pandas(),
|
||||||
|
y.to_pandas(),
|
||||||
|
eval_set=eval_set,
|
||||||
|
early_stopping_rounds=fit_params.get("early_stopping_rounds", 10),
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测"""
|
||||||
|
if not self._is_fitted:
|
||||||
|
raise RuntimeError("Model not fitted yet")
|
||||||
|
return self._model.predict(X.to_pandas())
|
||||||
|
|
||||||
|
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测概率"""
|
||||||
|
if self.task_type != "classification":
|
||||||
|
raise ValueError("predict_proba only for classification")
|
||||||
|
return self._model.predict_proba(X.to_pandas())
|
||||||
|
|
||||||
|
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||||
|
"""获取特征重要性"""
|
||||||
|
if self._model is None:
|
||||||
|
return None
|
||||||
|
return pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature": self._model.feature_names_,
|
||||||
|
"importance": self._model.feature_importances_,
|
||||||
|
}
|
||||||
|
).sort("importance", descending=True)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["LightGBMModel", "CatBoostModel"]
|
||||||
70
src/models/pipeline.py
Normal file
70
src/models/pipeline.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""数据处理流水线
|
||||||
|
|
||||||
|
管理多个处理器的顺序执行,支持阶段感知处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.models.core import BaseProcessor, PipelineStage
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingPipeline:
|
||||||
|
"""数据处理流水线
|
||||||
|
|
||||||
|
按顺序执行多个处理器,自动处理阶段标记。
|
||||||
|
关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, processors: List[BaseProcessor]):
|
||||||
|
"""初始化流水线
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processors: 处理器列表(按执行顺序)
|
||||||
|
"""
|
||||||
|
self.processors = processors
|
||||||
|
self._fitted_processors: Dict[int, BaseProcessor] = {}
|
||||||
|
|
||||||
|
def fit_transform(
|
||||||
|
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""在训练数据上fit所有处理器并transform"""
|
||||||
|
result = data
|
||||||
|
for i, processor in enumerate(self.processors):
|
||||||
|
if processor.stage in [PipelineStage.ALL, stage]:
|
||||||
|
result = processor.fit_transform(result)
|
||||||
|
self._fitted_processors[i] = processor
|
||||||
|
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
|
||||||
|
processor.fit(result)
|
||||||
|
self._fitted_processors[i] = processor
|
||||||
|
return result
|
||||||
|
|
||||||
|
def transform(
|
||||||
|
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""在测试数据上应用已fit的处理器"""
|
||||||
|
result = data
|
||||||
|
for i, processor in enumerate(self.processors):
|
||||||
|
if processor.stage in [PipelineStage.ALL, stage]:
|
||||||
|
if i in self._fitted_processors:
|
||||||
|
result = self._fitted_processors[i].transform(result)
|
||||||
|
else:
|
||||||
|
result = processor.transform(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def save_processors(self, path: str) -> None:
|
||||||
|
"""保存所有已fit的处理器状态"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
pickle.dump(self._fitted_processors, f)
|
||||||
|
|
||||||
|
def load_processors(self, path: str) -> None:
|
||||||
|
"""加载处理器状态"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
self._fitted_processors = pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ProcessingPipeline"]
|
||||||
21
src/models/processors/__init__.py
Normal file
21
src/models/processors/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""处理器模块"""
|
||||||
|
|
||||||
|
from src.models.processors.processors import (
|
||||||
|
DropNAProcessor,
|
||||||
|
FillNAProcessor,
|
||||||
|
Winsorizer,
|
||||||
|
StandardScaler,
|
||||||
|
MinMaxScaler,
|
||||||
|
RankTransformer,
|
||||||
|
Neutralizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DropNAProcessor",
|
||||||
|
"FillNAProcessor",
|
||||||
|
"Winsorizer",
|
||||||
|
"StandardScaler",
|
||||||
|
"MinMaxScaler",
|
||||||
|
"RankTransformer",
|
||||||
|
"Neutralizer",
|
||||||
|
]
|
||||||
238
src/models/processors/processors.py
Normal file
238
src/models/processors/processors.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""内置数据处理器
|
||||||
|
|
||||||
|
提供常用的数据预处理和转换处理器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.models.core import BaseProcessor, PipelineStage
|
||||||
|
from src.models.registry import PluginRegistry
|
||||||
|
|
||||||
|
# 数值类型列表
|
||||||
|
FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_numeric_columns(
|
||||||
|
data: pl.DataFrame, columns: Optional[List[str]] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""获取数值列"""
|
||||||
|
if columns is not None:
|
||||||
|
return columns
|
||||||
|
return [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("dropna")
|
||||||
|
class DropNAProcessor(BaseProcessor):
|
||||||
|
"""缺失值删除处理器"""
|
||||||
|
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
cols = self.columns or data.columns
|
||||||
|
return data.drop_nulls(subset=cols)
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("fillna")
|
||||||
|
class FillNAProcessor(BaseProcessor):
|
||||||
|
"""缺失值填充处理器(只在训练阶段计算填充值)"""
|
||||||
|
|
||||||
|
stage = PipelineStage.TRAIN
|
||||||
|
|
||||||
|
def __init__(self, columns: Optional[List[str]] = None, method: str = "median"):
|
||||||
|
super().__init__(columns)
|
||||||
|
if method not in ["median", "mean", "zero"]:
|
||||||
|
raise ValueError(f"Unknown fill method: {method}")
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
|
||||||
|
cols = _get_numeric_columns(data, self.columns)
|
||||||
|
fill_values = {}
|
||||||
|
|
||||||
|
for col in cols:
|
||||||
|
if self.method == "median":
|
||||||
|
fill_values[col] = data[col].median()
|
||||||
|
elif self.method == "mean":
|
||||||
|
fill_values[col] = data[col].mean()
|
||||||
|
elif self.method == "zero":
|
||||||
|
fill_values[col] = 0.0
|
||||||
|
|
||||||
|
self._fitted_params = {"fill_values": fill_values, "columns": cols}
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
result = data
|
||||||
|
for col, val in self._fitted_params.get("fill_values", {}).items():
|
||||||
|
if col in result.columns:
|
||||||
|
result = result.with_columns(pl.col(col).fill_null(val).alias(col))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("winsorizer")
|
||||||
|
class Winsorizer(BaseProcessor):
|
||||||
|
"""缩尾处理器 - 防止极端值影响(只在训练阶段计算分位数)"""
|
||||||
|
|
||||||
|
stage = PipelineStage.TRAIN
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
lower: float = 0.01,
|
||||||
|
upper: float = 0.99,
|
||||||
|
):
|
||||||
|
super().__init__(columns)
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "Winsorizer":
|
||||||
|
cols = _get_numeric_columns(data, self.columns)
|
||||||
|
bounds = {}
|
||||||
|
|
||||||
|
for col in cols:
|
||||||
|
bounds[col] = {
|
||||||
|
"lower": data[col].quantile(self.lower),
|
||||||
|
"upper": data[col].quantile(self.upper),
|
||||||
|
}
|
||||||
|
|
||||||
|
self._fitted_params = {"bounds": bounds, "columns": cols}
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
result = data
|
||||||
|
for col, bounds in self._fitted_params.get("bounds", {}).items():
|
||||||
|
if col in result.columns:
|
||||||
|
result = result.with_columns(
|
||||||
|
pl.col(col).clip(bounds["lower"], bounds["upper"]).alias(col)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("standard_scaler")
|
||||||
|
class StandardScaler(BaseProcessor):
|
||||||
|
"""标准化处理器 - Z-score标准化"""
|
||||||
|
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "StandardScaler":
|
||||||
|
cols = _get_numeric_columns(data, self.columns)
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
for col in cols:
|
||||||
|
stats[col] = {"mean": data[col].mean(), "std": data[col].std()}
|
||||||
|
|
||||||
|
self._fitted_params = {"stats": stats, "columns": cols}
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
result = data
|
||||||
|
for col, stats in self._fitted_params.get("stats", {}).items():
|
||||||
|
if col in result.columns and stats["std"] is not None and stats["std"] > 0:
|
||||||
|
result = result.with_columns(
|
||||||
|
((pl.col(col) - stats["mean"]) / stats["std"]).alias(col)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("minmax_scaler")
|
||||||
|
class MinMaxScaler(BaseProcessor):
|
||||||
|
"""归一化处理器 - 缩放到[0, 1]范围"""
|
||||||
|
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "MinMaxScaler":
|
||||||
|
cols = _get_numeric_columns(data, self.columns)
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
for col in cols:
|
||||||
|
stats[col] = {"min": data[col].min(), "max": data[col].max()}
|
||||||
|
|
||||||
|
self._fitted_params = {"stats": stats, "columns": cols}
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
result = data
|
||||||
|
for col, stats in self._fitted_params.get("stats", {}).items():
|
||||||
|
if col in result.columns:
|
||||||
|
range_val = stats["max"] - stats["min"]
|
||||||
|
if range_val is not None and range_val > 0:
|
||||||
|
result = result.with_columns(
|
||||||
|
((pl.col(col) - stats["min"]) / range_val).alias(col)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("rank_transformer")
|
||||||
|
class RankTransformer(BaseProcessor):
|
||||||
|
"""排名转换处理器 - 转换为截面排名"""
|
||||||
|
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "RankTransformer":
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
result = data
|
||||||
|
cols = self.columns or _get_numeric_columns(data)
|
||||||
|
|
||||||
|
for col in cols:
|
||||||
|
if col in result.columns:
|
||||||
|
result = result.with_columns(
|
||||||
|
pl.col(col).rank().over("trade_date").alias(col)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("neutralizer")
|
||||||
|
class Neutralizer(BaseProcessor):
|
||||||
|
"""中性化处理器 - 行业/市值中性化"""
|
||||||
|
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
group_col: str = "industry",
|
||||||
|
exclude_cols: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(columns)
|
||||||
|
self.group_col = group_col
|
||||||
|
self.exclude_cols = exclude_cols or []
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "Neutralizer":
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
result = data
|
||||||
|
cols = self.columns or _get_numeric_columns(data)
|
||||||
|
|
||||||
|
for col in cols:
|
||||||
|
if col in result.columns and col not in self.exclude_cols:
|
||||||
|
result = result.with_columns(
|
||||||
|
(
|
||||||
|
pl.col(col)
|
||||||
|
- pl.col(col).mean().over(["trade_date", self.group_col])
|
||||||
|
).alias(col)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DropNAProcessor",
|
||||||
|
"FillNAProcessor",
|
||||||
|
"Winsorizer",
|
||||||
|
"StandardScaler",
|
||||||
|
"MinMaxScaler",
|
||||||
|
"RankTransformer",
|
||||||
|
"Neutralizer",
|
||||||
|
]
|
||||||
297
src/models/registry.py
Normal file
297
src/models/registry.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""插件注册中心
|
||||||
|
|
||||||
|
提供装饰器方式注册处理器、模型、划分策略等组件。
|
||||||
|
实现真正的插件式架构 - 新功能只需注册即可使用。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> @PluginRegistry.register_processor("standard_scaler")
|
||||||
|
... class StandardScaler(BaseProcessor):
|
||||||
|
... pass
|
||||||
|
|
||||||
|
>>> # 使用
|
||||||
|
>>> scaler = PluginRegistry.get_processor("standard_scaler")()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Type, Dict, List, TypeVar, Optional
|
||||||
|
from functools import wraps
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginRegistry:
|
||||||
|
"""插件注册中心
|
||||||
|
|
||||||
|
管理所有组件的注册和获取。使用装饰器方式注册新组件。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_processors: 已注册的处理器字典
|
||||||
|
_models: 已注册的模型字典
|
||||||
|
_splitters: 已注册的划分策略字典
|
||||||
|
_metrics: 已注册的评估指标字典
|
||||||
|
"""
|
||||||
|
|
||||||
|
_processors: Dict[str, Type[BaseProcessor]] = {}
|
||||||
|
_models: Dict[str, Type[BaseModel]] = {}
|
||||||
|
_splitters: Dict[str, Type[BaseSplitter]] = {}
|
||||||
|
_metrics: Dict[str, Type[BaseMetric]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def temp_registry(cls):
|
||||||
|
"""临时注册上下文管理器
|
||||||
|
|
||||||
|
在上下文管理器内部注册的组件会在退出时自动清理,
|
||||||
|
避免测试之间的状态污染。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> with PluginRegistry.temp_registry():
|
||||||
|
... @PluginRegistry.register_processor("temp_processor")
|
||||||
|
... class TempProcessor(BaseProcessor):
|
||||||
|
... pass
|
||||||
|
... # 在此处可以使用 temp_processor
|
||||||
|
... # 退出后自动清理
|
||||||
|
"""
|
||||||
|
original_state = {
|
||||||
|
"_processors": cls._processors.copy(),
|
||||||
|
"_models": cls._models.copy(),
|
||||||
|
"_splitters": cls._splitters.copy(),
|
||||||
|
"_metrics": cls._metrics.copy(),
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
yield cls
|
||||||
|
finally:
|
||||||
|
cls._processors = original_state["_processors"]
|
||||||
|
cls._models = original_state["_models"]
|
||||||
|
cls._splitters = original_state["_splitters"]
|
||||||
|
cls._metrics = original_state["_metrics"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_processor(cls, name: Optional[str] = None):
|
||||||
|
"""注册处理器装饰器
|
||||||
|
|
||||||
|
用于装饰器方式注册数据处理器。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> @PluginRegistry.register_processor("standard_scaler")
|
||||||
|
... class StandardScaler(BaseProcessor):
|
||||||
|
... pass
|
||||||
|
|
||||||
|
>>> # 获取并使用
|
||||||
|
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
|
||||||
|
>>> scaler = scaler_class()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 注册名称,默认为类名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
装饰器函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
|
||||||
|
key = name or processor_class.__name__
|
||||||
|
cls._processors[key] = processor_class
|
||||||
|
processor_class._registry_name = key
|
||||||
|
return processor_class
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_model(cls, name: Optional[str] = None):
|
||||||
|
"""注册模型装饰器
|
||||||
|
|
||||||
|
用于装饰器方式注册机器学习模型。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> @PluginRegistry.register_model("lightgbm")
|
||||||
|
... class LightGBMModel(BaseModel):
|
||||||
|
... pass
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 注册名称,默认为类名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
装饰器函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
|
||||||
|
key = name or model_class.__name__
|
||||||
|
cls._models[key] = model_class
|
||||||
|
model_class._registry_name = key
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_splitter(cls, name: Optional[str] = None):
|
||||||
|
"""注册划分策略装饰器
|
||||||
|
|
||||||
|
用于装饰器方式注册数据划分策略。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> @PluginRegistry.register_splitter("time_series")
|
||||||
|
... class TimeSeriesSplit(BaseSplitter):
|
||||||
|
... pass
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 注册名称,默认为类名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
装饰器函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
|
||||||
|
key = name or splitter_class.__name__
|
||||||
|
cls._splitters[key] = splitter_class
|
||||||
|
splitter_class._registry_name = key
|
||||||
|
return splitter_class
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_metric(cls, name: Optional[str] = None):
|
||||||
|
"""注册评估指标装饰器
|
||||||
|
|
||||||
|
用于装饰器方式注册评估指标。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> @PluginRegistry.register_metric("ic")
|
||||||
|
... class ICMetric(BaseMetric):
|
||||||
|
... pass
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 注册名称,默认为类名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
装饰器函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(metric_class: Type[BaseMetric]) -> Type[BaseMetric]:
|
||||||
|
key = name or metric_class.__name__
|
||||||
|
cls._metrics[key] = metric_class
|
||||||
|
metric_class._registry_name = key
|
||||||
|
return metric_class
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_processor(cls, name: str) -> Type[BaseProcessor]:
|
||||||
|
"""获取处理器类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 处理器注册名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理器类
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 处理器不存在时抛出
|
||||||
|
"""
|
||||||
|
if name not in cls._processors:
|
||||||
|
available = list(cls._processors.keys())
|
||||||
|
raise KeyError(f"Processor '{name}' not found. Available: {available}")
|
||||||
|
return cls._processors[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model(cls, name: str) -> Type[BaseModel]:
|
||||||
|
"""获取模型类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 模型注册名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型类
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 模型不存在时抛出
|
||||||
|
"""
|
||||||
|
if name not in cls._models:
|
||||||
|
available = list(cls._models.keys())
|
||||||
|
raise KeyError(f"Model '{name}' not found. Available: {available}")
|
||||||
|
return cls._models[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_splitter(cls, name: str) -> Type[BaseSplitter]:
|
||||||
|
"""获取划分策略类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 划分策略注册名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
划分策略类
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 划分策略不存在时抛出
|
||||||
|
"""
|
||||||
|
if name not in cls._splitters:
|
||||||
|
available = list(cls._splitters.keys())
|
||||||
|
raise KeyError(f"Splitter '{name}' not found. Available: {available}")
|
||||||
|
return cls._splitters[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_metric(cls, name: str) -> Type[BaseMetric]:
|
||||||
|
"""获取评估指标类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 评估指标注册名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估指标类
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 评估指标不存在时抛出
|
||||||
|
"""
|
||||||
|
if name not in cls._metrics:
|
||||||
|
available = list(cls._metrics.keys())
|
||||||
|
raise KeyError(f"Metric '{name}' not found. Available: {available}")
|
||||||
|
return cls._metrics[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_processors(cls) -> List[str]:
|
||||||
|
"""列出所有可用处理器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理器名称列表
|
||||||
|
"""
|
||||||
|
return list(cls._processors.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_models(cls) -> List[str]:
|
||||||
|
"""列出所有可用模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型名称列表
|
||||||
|
"""
|
||||||
|
return list(cls._models.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_splitters(cls) -> List[str]:
|
||||||
|
"""列出所有可用划分策略
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
划分策略名称列表
|
||||||
|
"""
|
||||||
|
return list(cls._splitters.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_metrics(cls) -> List[str]:
|
||||||
|
"""列出所有可用评估指标
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估指标名称列表
|
||||||
|
"""
|
||||||
|
return list(cls._metrics.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_all(cls) -> None:
|
||||||
|
"""清除所有注册(主要用于测试)"""
|
||||||
|
cls._processors.clear()
|
||||||
|
cls._models.clear()
|
||||||
|
cls._splitters.clear()
|
||||||
|
cls._metrics.clear()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PluginRegistry"]
|
||||||
478
tests/models/test_core.py
Normal file
478
tests/models/test_core.py
Normal file
@@ -0,0 +1,478 @@
|
|||||||
|
"""模型框架核心测试
|
||||||
|
|
||||||
|
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
# 确保导入时注册所有组件
|
||||||
|
from src.models import (
|
||||||
|
PluginRegistry,
|
||||||
|
PipelineStage,
|
||||||
|
BaseProcessor,
|
||||||
|
BaseModel,
|
||||||
|
BaseSplitter,
|
||||||
|
ProcessingPipeline,
|
||||||
|
)
|
||||||
|
from src.models.core import TaskType
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 测试核心抽象类 ==========
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineStage:
|
||||||
|
"""测试阶段枚举"""
|
||||||
|
|
||||||
|
def test_stage_values(self):
|
||||||
|
assert PipelineStage.ALL.name == "ALL"
|
||||||
|
assert PipelineStage.TRAIN.name == "TRAIN"
|
||||||
|
assert PipelineStage.TEST.name == "TEST"
|
||||||
|
assert PipelineStage.VALIDATION.name == "VALIDATION"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseProcessor:
|
||||||
|
"""测试处理器基类"""
|
||||||
|
|
||||||
|
def test_processor_initialization(self):
|
||||||
|
"""测试处理器初始化"""
|
||||||
|
|
||||||
|
class DummyProcessor(BaseProcessor):
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "DummyProcessor":
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
return data
|
||||||
|
|
||||||
|
processor = DummyProcessor(columns=["col1", "col2"])
|
||||||
|
assert processor.columns == ["col1", "col2"]
|
||||||
|
assert processor.stage == PipelineStage.ALL
|
||||||
|
assert not processor._is_fitted
|
||||||
|
|
||||||
|
def test_processor_fit_transform(self):
|
||||||
|
"""测试 fit_transform 方法"""
|
||||||
|
|
||||||
|
class AddOneProcessor(BaseProcessor):
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data: pl.DataFrame) -> "AddOneProcessor":
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
result = data.clone()
|
||||||
|
for col in self.columns or []:
|
||||||
|
result = result.with_columns((pl.col(col) + 1).alias(col))
|
||||||
|
return result
|
||||||
|
|
||||||
|
processor = AddOneProcessor(columns=["value"])
|
||||||
|
df = pl.DataFrame({"value": [1, 2, 3]})
|
||||||
|
|
||||||
|
result = processor.fit_transform(df)
|
||||||
|
|
||||||
|
assert processor._is_fitted
|
||||||
|
assert result["value"].to_list() == [2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseModel:
|
||||||
|
"""测试模型基类"""
|
||||||
|
|
||||||
|
def test_model_initialization(self):
|
||||||
|
"""测试模型初始化"""
|
||||||
|
|
||||||
|
class DummyModel(BaseModel):
|
||||||
|
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||||||
|
self._is_fitted = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
return np.zeros(len(X))
|
||||||
|
|
||||||
|
model = DummyModel(
|
||||||
|
task_type="regression", params={"lr": 0.01}, name="test_model"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model.task_type == "regression"
|
||||||
|
assert model.params == {"lr": 0.01}
|
||||||
|
assert model.name == "test_model"
|
||||||
|
assert not model._is_fitted
|
||||||
|
|
||||||
|
def test_predict_proba_not_implemented(self):
|
||||||
|
"""测试未实现 predict_proba 时抛出异常"""
|
||||||
|
|
||||||
|
class DummyModel(BaseModel):
|
||||||
|
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
return np.zeros(len(X))
|
||||||
|
|
||||||
|
model = DummyModel(task_type="regression")
|
||||||
|
df = pl.DataFrame({"feature": [1, 2, 3]})
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
model.predict_proba(df)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseSplitter:
|
||||||
|
"""测试划分策略基类"""
|
||||||
|
|
||||||
|
def test_splitter_interface(self):
|
||||||
|
"""测试划分策略接口"""
|
||||||
|
|
||||||
|
class DummySplitter(BaseSplitter):
|
||||||
|
def split(self, data, date_col="trade_date"):
|
||||||
|
yield [0, 1], [2, 3]
|
||||||
|
|
||||||
|
def get_split_dates(self, data, date_col="trade_date"):
|
||||||
|
return [("20200101", "20201231", "20210101", "20211231")]
|
||||||
|
|
||||||
|
splitter = DummySplitter()
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{"trade_date": ["20200101", "20200601", "20210101", "20210601"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
splits = list(splitter.split(df))
|
||||||
|
assert len(splits) == 1
|
||||||
|
assert splits[0] == ([0, 1], [2, 3])
|
||||||
|
|
||||||
|
dates = splitter.get_split_dates(df)
|
||||||
|
assert dates == [("20200101", "20201231", "20210101", "20211231")]
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 测试插件注册中心 ==========
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRegistry:
|
||||||
|
"""测试插件注册中心"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""每个测试前清除注册"""
|
||||||
|
PluginRegistry.clear_all()
|
||||||
|
|
||||||
|
def test_register_and_get_processor(self):
|
||||||
|
"""测试注册和获取处理器"""
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor("test_processor")
|
||||||
|
class TestProcessor(BaseProcessor):
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
processor_class = PluginRegistry.get_processor("test_processor")
|
||||||
|
assert processor_class == TestProcessor
|
||||||
|
assert "test_processor" in PluginRegistry.list_processors()
|
||||||
|
|
||||||
|
def test_register_and_get_model(self):
|
||||||
|
"""测试注册和获取模型"""
|
||||||
|
|
||||||
|
@PluginRegistry.register_model("test_model")
|
||||||
|
class TestModel(BaseModel):
|
||||||
|
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X):
|
||||||
|
return np.zeros(len(X))
|
||||||
|
|
||||||
|
model_class = PluginRegistry.get_model("test_model")
|
||||||
|
assert model_class == TestModel
|
||||||
|
assert "test_model" in PluginRegistry.list_models()
|
||||||
|
|
||||||
|
def test_register_and_get_splitter(self):
|
||||||
|
"""测试注册和获取划分策略"""
|
||||||
|
|
||||||
|
@PluginRegistry.register_splitter("test_splitter")
|
||||||
|
class TestSplitter(BaseSplitter):
|
||||||
|
def split(self, data, date_col="trade_date"):
|
||||||
|
yield [], []
|
||||||
|
|
||||||
|
def get_split_dates(self, data, date_col="trade_date"):
|
||||||
|
return []
|
||||||
|
|
||||||
|
splitter_class = PluginRegistry.get_splitter("test_splitter")
|
||||||
|
assert splitter_class == TestSplitter
|
||||||
|
assert "test_splitter" in PluginRegistry.list_splitters()
|
||||||
|
|
||||||
|
def test_get_nonexistent_processor(self):
|
||||||
|
"""测试获取不存在的处理器时抛出异常"""
|
||||||
|
with pytest.raises(KeyError) as exc_info:
|
||||||
|
PluginRegistry.get_processor("nonexistent")
|
||||||
|
assert "nonexistent" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_register_with_default_name(self):
|
||||||
|
"""测试使用默认名称注册"""
|
||||||
|
|
||||||
|
@PluginRegistry.register_processor()
|
||||||
|
class MyCustomProcessor(BaseProcessor):
|
||||||
|
stage = PipelineStage.ALL
|
||||||
|
|
||||||
|
def fit(self, data):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
assert "MyCustomProcessor" in PluginRegistry.list_processors()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 测试内置处理器 ==========
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuiltInProcessors:
|
||||||
|
"""测试内置处理器"""
|
||||||
|
|
||||||
|
def test_dropna_processor(self):
|
||||||
|
"""测试缺失值删除处理器"""
|
||||||
|
from src.models.processors import DropNAProcessor
|
||||||
|
|
||||||
|
processor = DropNAProcessor(columns=["a", "b"])
|
||||||
|
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
|
||||||
|
|
||||||
|
result = processor.fit_transform(df)
|
||||||
|
|
||||||
|
# 只有第一行没有缺失值
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result["a"].to_list() == [1]
|
||||||
|
assert result["b"].to_list() == [4]
|
||||||
|
|
||||||
|
def test_fillna_processor(self):
|
||||||
|
"""测试缺失值填充处理器"""
|
||||||
|
from src.models.processors import FillNAProcessor
|
||||||
|
|
||||||
|
processor = FillNAProcessor(columns=["a"], method="mean")
|
||||||
|
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
|
||||||
|
|
||||||
|
result = processor.fit_transform(df)
|
||||||
|
|
||||||
|
# 均值 = (1+2+4)/3 = 2.333...
|
||||||
|
assert result["a"][2] == pytest.approx(2.333, rel=0.01)
|
||||||
|
|
||||||
|
def test_standard_scaler(self):
|
||||||
|
"""测试标准化处理器"""
|
||||||
|
from src.models.processors import StandardScaler
|
||||||
|
|
||||||
|
processor = StandardScaler(columns=["value"])
|
||||||
|
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
|
||||||
|
|
||||||
|
result = processor.fit_transform(df)
|
||||||
|
|
||||||
|
# Z-score 标准化后均值为0,标准差为1
|
||||||
|
assert result["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||||
|
assert result["value"].std() == pytest.approx(1.0, rel=0.01)
|
||||||
|
|
||||||
|
def test_winsorizer(self):
|
||||||
|
"""测试缩尾处理器"""
|
||||||
|
from src.models.processors import Winsorizer
|
||||||
|
|
||||||
|
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"value": list(range(100)) # 0-99
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = processor.fit_transform(df)
|
||||||
|
|
||||||
|
# 10%和90%分位数应该是10和89(Polars的quantile行为)
|
||||||
|
assert result["value"].min() == 10
|
||||||
|
assert result["value"].max() == 89
|
||||||
|
|
||||||
|
def test_rank_transformer(self):
|
||||||
|
"""测试排名转换处理器"""
|
||||||
|
from src.models.processors import RankTransformer
|
||||||
|
|
||||||
|
processor = RankTransformer(columns=["value"])
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = processor.fit_transform(df)
|
||||||
|
|
||||||
|
# 排名应该是 1, 3, 2, 5, 4
|
||||||
|
assert result["value"].to_list() == [1, 3, 2, 5, 4]
|
||||||
|
|
||||||
|
def test_neutralizer(self):
|
||||||
|
"""测试中性化处理器"""
|
||||||
|
from src.models.processors import Neutralizer
|
||||||
|
|
||||||
|
processor = Neutralizer(columns=["value"], group_col="industry")
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"trade_date": ["20200101", "20200101", "20200101", "20200101"],
|
||||||
|
"industry": ["A", "A", "B", "B"],
|
||||||
|
"value": [10, 20, 30, 50],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = processor.fit_transform(df)
|
||||||
|
|
||||||
|
# 分组去均值后,每组的均值为0
|
||||||
|
group_a = result.filter(pl.col("industry") == "A")
|
||||||
|
group_b = result.filter(pl.col("industry") == "B")
|
||||||
|
|
||||||
|
assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||||
|
assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 测试处理流水线 ==========
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessingPipeline:
|
||||||
|
"""测试处理流水线"""
|
||||||
|
|
||||||
|
def test_pipeline_fit_transform(self):
|
||||||
|
"""测试流水线的 fit_transform"""
|
||||||
|
from src.models.processors import StandardScaler
|
||||||
|
|
||||||
|
scaler1 = StandardScaler(columns=["a"])
|
||||||
|
scaler2 = StandardScaler(columns=["b"])
|
||||||
|
|
||||||
|
pipeline = ProcessingPipeline([scaler1, scaler2])
|
||||||
|
|
||||||
|
df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]})
|
||||||
|
|
||||||
|
result = pipeline.fit_transform(df)
|
||||||
|
|
||||||
|
# 两个列都应该被标准化
|
||||||
|
assert result["a"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||||
|
assert result["b"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||||
|
|
||||||
|
def test_pipeline_transform_uses_fitted_params(self):
|
||||||
|
"""测试 transform 使用已 fit 的参数"""
|
||||||
|
from src.models.processors import StandardScaler
|
||||||
|
|
||||||
|
scaler = StandardScaler(columns=["value"])
|
||||||
|
pipeline = ProcessingPipeline([scaler])
|
||||||
|
|
||||||
|
# 训练数据
|
||||||
|
train_df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"value": [1.0, 2.0, 3.0] # 均值=2,标准差=1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试数据(不同的分布)
|
||||||
|
test_df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline.fit_transform(train_df)
|
||||||
|
result = pipeline.transform(test_df)
|
||||||
|
|
||||||
|
# 使用训练数据的均值=2和标准差=1进行标准化
|
||||||
|
# 4 -> (4-2)/1 = 2
|
||||||
|
assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 测试划分策略 ==========
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitters:
|
||||||
|
"""测试划分策略"""
|
||||||
|
|
||||||
|
def test_time_series_split(self):
|
||||||
|
"""测试时间序列划分"""
|
||||||
|
from src.models.core import TimeSeriesSplit
|
||||||
|
|
||||||
|
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
|
||||||
|
|
||||||
|
# 10天的数据
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"trade_date": [f"202001{i:02d}" for i in range(1, 11)],
|
||||||
|
"value": list(range(10)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splits = list(splitter.split(df))
|
||||||
|
|
||||||
|
# 应该有两折
|
||||||
|
assert len(splits) == 2
|
||||||
|
|
||||||
|
# 检查每折训练集在测试集之前
|
||||||
|
for train_idx, test_idx in splits:
|
||||||
|
assert max(train_idx) < min(test_idx)
|
||||||
|
|
||||||
|
def test_walk_forward_split(self):
|
||||||
|
"""测试滚动前向划分"""
|
||||||
|
from src.models.core import WalkForwardSplit
|
||||||
|
|
||||||
|
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
|
||||||
|
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"trade_date": [f"202001{i:02d}" for i in range(1, 13)],
|
||||||
|
"value": list(range(12)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splits = list(splitter.split(df))
|
||||||
|
|
||||||
|
# 检查训练集大小固定
|
||||||
|
for train_idx, test_idx in splits:
|
||||||
|
assert len(train_idx) == 5
|
||||||
|
assert len(test_idx) == 2
|
||||||
|
|
||||||
|
def test_expanding_window_split(self):
|
||||||
|
"""测试扩展窗口划分"""
|
||||||
|
from src.models.core import ExpandingWindowSplit
|
||||||
|
|
||||||
|
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
|
||||||
|
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"trade_date": [f"202001{i:02d}" for i in range(1, 15)],
|
||||||
|
"value": list(range(14)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splits = list(splitter.split(df))
|
||||||
|
|
||||||
|
# 训练集应该逐渐增大
|
||||||
|
train_sizes = [len(train_idx) for train_idx, _ in splits]
|
||||||
|
assert train_sizes[0] == 3
|
||||||
|
assert train_sizes[1] == 5 # 3 + 2
|
||||||
|
assert train_sizes[2] == 7 # 5 + 2
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 测试内置模型(可选,需要安装依赖) ==========
|
||||||
|
|
||||||
|
|
||||||
|
class TestModels:
|
||||||
|
"""测试内置模型(标记为跳过如果依赖未安装)"""
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="需要安装 lightgbm")
|
||||||
|
def test_lightgbm_model(self):
|
||||||
|
"""测试 LightGBM 模型"""
|
||||||
|
from src.models.models import LightGBMModel
|
||||||
|
|
||||||
|
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
|
||||||
|
|
||||||
|
X = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10,
|
||||||
|
"feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10)
|
||||||
|
|
||||||
|
model.fit(X, y)
|
||||||
|
predictions = model.predict(X)
|
||||||
|
|
||||||
|
assert len(predictions) == len(X)
|
||||||
|
assert model._is_fitted
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user