feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类 - 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露 - 内置 8 个数据处理器和 3 种时序划分策略 - 支持 LightGBM、CatBoost 模型 - PluginRegistry 装饰器注册,插件式架构 - 22 个单元测试
This commit is contained in:
292
README.md
292
README.md
@@ -1,40 +1,300 @@
|
||||
# ProStock
|
||||
|
||||
A股量化投资框架
|
||||
A股量化投资框架 - 从数据获取到模型训练的完整解决方案
|
||||
|
||||
## 功能特性
|
||||
|
||||
### 1. 数据层 (src/data/)
|
||||
- **多源数据接入**: Tushare API 集成,支持日线、股票基础信息、交易日历
|
||||
- **DuckDB 存储**: 高性能嵌入式数据库,支持 SQL 查询下推
|
||||
- **智能同步**: 增量/全量同步策略,自动检测数据更新需求
|
||||
- **速率控制**: 令牌桶算法实现 API 限流
|
||||
- **并发优化**: ThreadPoolExecutor 多线程数据获取
|
||||
|
||||
### 2. 因子层 (src/factors/)
|
||||
- **类型安全**: 严格的截面因子 vs 时序因子区分
|
||||
- **防泄露机制**: 框架层面防止未来数据和跨股票数据泄露
|
||||
- **因子组合**: 支持因子加减乘除和标量运算
|
||||
- **高性能计算**: Polars 向量化操作,零拷贝数据导出
|
||||
- **灵活扩展**: 基类抽象便于自定义因子
|
||||
|
||||
### 3. 模型层 (src/models/)
|
||||
- **插件架构**: 装饰器注册机制,新模型即插即用
|
||||
- **阶段感知**: 训练/测试阶段区分,防止数据泄露
|
||||
- **多模型支持**: LightGBM、CatBoost 等模型统一接口
|
||||
- **数据处理**: 缺失值处理、缩尾、标准化、中性化等
|
||||
- **时序划分**: WalkForward、ExpandingWindow 等时间序列划分策略
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
ProStock/
|
||||
├── src/
|
||||
│ ├── config/ # 配置管理
|
||||
│ │ ├── settings.py # pydantic-settings 配置
|
||||
│ │ └── __init__.py
|
||||
│ │
|
||||
│ ├── data/ # 数据获取与存储
|
||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||
│ │ │ ├── api_daily.py # 日线数据接口
|
||||
│ │ │ ├── api_stock_basic.py # 股票基础信息
|
||||
│ │ │ └── api_trade_cal.py # 交易日历
|
||||
│ │ ├── client.py # Tushare 客户端(含限流)
|
||||
│ │ ├── config.py # 数据模块配置
|
||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||
│ │ ├── rate_limiter.py # 令牌桶限流器
|
||||
│ │ ├── storage.py # DuckDB 存储核心
|
||||
│ │ ├── sync.py # 数据同步主逻辑
|
||||
│ │ └── __init__.py
|
||||
│ │
|
||||
│ ├── factors/ # 因子计算框架
|
||||
│ │ ├── base.py # 因子基类(截面/时序)
|
||||
│ │ ├── composite.py # 组合因子和标量运算
|
||||
│ │ ├── data_loader.py # DuckDB 数据加载器
|
||||
│ │ ├── data_spec.py # 数据规格定义
|
||||
│ │ ├── engine.py # 因子执行引擎
|
||||
│ │ └── __init__.py
|
||||
│ │
|
||||
│ ├── models/ # 模型训练框架
|
||||
│ │ ├── core/ # 核心抽象
|
||||
│ │ │ ├── base.py # 处理器/模型/划分基类
|
||||
│ │ │ └── splitter.py # 时间序列划分策略
|
||||
│ │ ├── models/ # 模型实现
|
||||
│ │ │ └── models.py # LightGBM、CatBoost
|
||||
│ │ ├── processors/ # 数据处理器
|
||||
│ │ │ └── processors.py # 标准化、缩尾、中性化等
|
||||
│ │ ├── pipeline.py # 处理流水线
|
||||
│ │ ├── registry.py # 插件注册中心
|
||||
│ │ └── __init__.py
|
||||
│ │
|
||||
│ └── __init__.py
|
||||
│
|
||||
├── docs/ # 文档
|
||||
│ ├── factor_framework_design.md # 因子框架设计
|
||||
│ ├── ml_framework_design.md # 模型框架设计
|
||||
│ ├── db_sync_guide.md # 数据同步指南
|
||||
│ └── ...
|
||||
│
|
||||
├── data/ # 数据存储(DuckDB)
|
||||
│ ├── prostock.db # 主数据库文件
|
||||
│ └── stock_basic.csv # 股票基础信息缓存
|
||||
│
|
||||
├── config/ # 配置文件
|
||||
│ └── .env.local # 环境变量(API Token等)
|
||||
│
|
||||
└── tests/ # 测试文件
|
||||
├── test_sync.py
|
||||
└── factors/
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 安装依赖
|
||||
### 1. 安装依赖
|
||||
|
||||
**⚠️ 本项目强制使用 uv 作为 Python 包管理器,禁止直接使用 `python` 或 `pip` 命令。**
|
||||
**⚠️ 本项目强制使用 uv 作为 Python 包管理器**
|
||||
|
||||
```bash
|
||||
# 使用 uv 安装(必须)
|
||||
# 安装 uv (如果尚未安装)
|
||||
pip install uv
|
||||
|
||||
# 安装项目依赖
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### 数据同步
|
||||
### 2. 配置环境变量
|
||||
|
||||
创建 `config/.env.local` 文件:
|
||||
|
||||
```bash
|
||||
# 增量同步(自动从最新日期开始)
|
||||
TUSHARE_TOKEN=your_tushare_token_here
|
||||
DATA_PATH=data
|
||||
RATE_LIMIT=100
|
||||
THREADS=10
|
||||
```
|
||||
|
||||
### 3. 数据同步
|
||||
|
||||
```bash
|
||||
# 首次同步 - 全量同步(从20180101开始)
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||
|
||||
# 日常同步 - 增量同步(自动从最新日期开始)
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all()"
|
||||
|
||||
# 全量同步(从 20180101 开始)
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||
# 预览同步(检查需要同步的数据量)
|
||||
uv run python -c "from src.data.sync import preview_sync; preview_sync()"
|
||||
|
||||
# 自定义线程数
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
```
|
||||
|
||||
### 4. 查看数据库状态
|
||||
|
||||
```bash
|
||||
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 因子计算
|
||||
|
||||
```python
|
||||
from src.factors import FactorEngine, DataLoader, DataSpec
|
||||
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
|
||||
import polars as pl
|
||||
|
||||
# 自定义截面因子:PE排名
|
||||
class PERankFactor(CrossSectionalFactor):
|
||||
name = "pe_rank"
|
||||
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)]
|
||||
|
||||
def compute(self, data) -> pl.Series:
|
||||
cs = data.get_cross_section()
|
||||
return cs["pe"].rank()
|
||||
|
||||
# 自定义时序因子:20日移动平均
|
||||
class MA20Factor(TimeSeriesFactor):
|
||||
name = "ma20"
|
||||
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
|
||||
|
||||
def compute(self, data) -> pl.Series:
|
||||
return data.get_column("close").rolling_mean(window_size=20)
|
||||
|
||||
# 执行计算
|
||||
loader = DataLoader(data_dir="data")
|
||||
engine = FactorEngine(loader)
|
||||
|
||||
# 计算截面因子
|
||||
pe_rank = PERankFactor()
|
||||
result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131")
|
||||
|
||||
# 计算时序因子
|
||||
ma20 = MA20Factor()
|
||||
result2 = engine.compute(ma20, stock_codes=["000001.SZ"],
|
||||
start_date="20240101", end_date="20240131")
|
||||
|
||||
# 因子组合
|
||||
combined = 0.5 * pe_rank + 0.3 * ma20
|
||||
```
|
||||
|
||||
### 模型训练
|
||||
|
||||
```python
|
||||
from src.models import PluginRegistry, ProcessingPipeline
|
||||
from src.models.core import PipelineStage
|
||||
import polars as pl
|
||||
|
||||
# 创建处理流水线
|
||||
pipeline = ProcessingPipeline([
|
||||
PluginRegistry.get_processor("dropna")(),
|
||||
PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
|
||||
PluginRegistry.get_processor("standard_scaler")(),
|
||||
])
|
||||
|
||||
# 准备数据
|
||||
data = pl.read_csv("features.csv") # 包含特征和标签
|
||||
|
||||
# 划分训练/测试集
|
||||
from src.models.core import WalkForwardSplit
|
||||
splitter = WalkForwardSplit(train_window=252, test_window=21)
|
||||
|
||||
# 获取 LightGBM 模型
|
||||
ModelClass = PluginRegistry.get_model("lightgbm")
|
||||
model = ModelClass(task_type="regression", params={"n_estimators": 100})
|
||||
|
||||
# 训练循环
|
||||
for train_idx, test_idx in splitter.split(data):
|
||||
train_data = data[train_idx]
|
||||
test_data = data[test_idx]
|
||||
|
||||
# 数据处理
|
||||
X_train = pipeline.fit_transform(train_data.drop("target"))
|
||||
X_test = pipeline.transform(test_data.drop("target"))
|
||||
y_train = train_data["target"]
|
||||
y_test = test_data["target"]
|
||||
|
||||
# 训练模型
|
||||
model.fit(X_train, y_train)
|
||||
predictions = model.predict(X_test)
|
||||
```
|
||||
|
||||
## 核心设计
|
||||
|
||||
### 1. 数据防泄露机制
|
||||
|
||||
**截面因子 (CrossSectionalFactor)**:
|
||||
- 防止日期泄露:每天只传入 `[T-lookback+1, T]` 数据
|
||||
- 允许股票间比较:传入当天所有股票数据
|
||||
- 典型应用:PE排名、市值分位数、当日收益率排名
|
||||
|
||||
**时序因子 (TimeSeriesFactor)**:
|
||||
- 防止股票泄露:每只股票单独计算
|
||||
- 允许历史数据访问:传入完整时间序列
|
||||
- 典型应用:移动平均线、RSI、历史波动率
|
||||
|
||||
### 2. 插件注册机制
|
||||
|
||||
```python
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
# 注册自定义处理器
|
||||
@PluginRegistry.register_processor("my_processor")
|
||||
class MyProcessor(BaseProcessor):
|
||||
stage = PipelineStage.TRAIN
|
||||
|
||||
def fit(self, data):
|
||||
# 学习参数
|
||||
return self
|
||||
|
||||
def transform(self, data):
|
||||
# 转换数据
|
||||
return data
|
||||
|
||||
# 使用
|
||||
processor_class = PluginRegistry.get_processor("my_processor")
|
||||
processor = processor_class()
|
||||
```
|
||||
|
||||
### 3. 数据同步策略
|
||||
|
||||
**智能增量同步**:
|
||||
```python
|
||||
from src.data.db_manager import SyncManager
|
||||
|
||||
manager = SyncManager()
|
||||
result = manager.sync(
|
||||
table_name="daily",
|
||||
fetch_func=get_daily,
|
||||
start_date="20240101",
|
||||
end_date="20240131"
|
||||
)
|
||||
# 自动检测:表不存在→全量,表存在→增量
|
||||
```
|
||||
|
||||
## 文档
|
||||
|
||||
- [数据同步模块](docs/data_sync.md) - 详细的数据同步使用说明
|
||||
- [因子框架设计](docs/factor_framework_design.md) - 因子计算架构详解
|
||||
- [模型框架设计](docs/ml_framework_design.md) - 模型训练架构详解
|
||||
- [数据同步指南](docs/db_sync_guide.md) - DuckDB 数据同步 API 说明
|
||||
- [代码审查报告](docs/code_review_factors_20260222.md) - 因子框架代码审查
|
||||
|
||||
## 模块
|
||||
## 开发规范
|
||||
|
||||
- `data/` - 数据获取
|
||||
- `factors/` - 因子生成
|
||||
- `models/` - 模型训练
|
||||
- `backtest/` - 回测分析
|
||||
- `utils/` - 工具函数
|
||||
- `scripts/` - 运行脚本
|
||||
- **Python 版本**: 3.10+
|
||||
- **代码风格**: Google 风格文档字符串
|
||||
- **类型提示**: 强制类型注解
|
||||
- **测试**: pytest 框架
|
||||
- **包管理**: uv (禁止直接使用 pip/python)
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **数据处理**: Polars, Pandas, NumPy
|
||||
- **数据存储**: DuckDB (嵌入式 OLAP 数据库)
|
||||
- **API 接口**: Tushare Pro
|
||||
- **机器学习**: LightGBM, CatBoost, scikit-learn
|
||||
- **配置管理**: pydantic-settings
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
@@ -1,846 +0,0 @@
|
||||
# ProStock 因子框架实现计划
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
src/factors/
|
||||
├── __init__.py # 导出主要类
|
||||
├── data_spec.py # Phase 1: 数据类型定义
|
||||
├── base.py # Phase 2: 因子基类
|
||||
├── composite.py # Phase 2: 组合因子
|
||||
├── data_loader.py # Phase 3: 数据加载
|
||||
├── engine.py # Phase 4: 执行引擎
|
||||
└── builtin/ # Phase 5: 内置因子库
|
||||
├── __init__.py
|
||||
├── momentum.py # 截面动量因子
|
||||
├── technical.py # 时序技术指标
|
||||
└── value.py # 截面估值因子
|
||||
|
||||
tests/factors/ # Phase 6-7: 测试
|
||||
├── __init__.py
|
||||
├── test_data_spec.py # 数据类型测试
|
||||
├── test_base.py # 因子基类测试
|
||||
├── test_composite.py # 组合因子测试
|
||||
├── test_data_loader.py # 数据加载测试
|
||||
├── test_engine.py # 引擎测试
|
||||
├── test_builtin.py # 内置因子测试
|
||||
└── test_integration.py # 集成测试
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: 数据类型定义 (data_spec.py)
|
||||
|
||||
### 1.1 DataSpec - 数据需求规格
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
@dataclass(frozen=True)
|
||||
class DataSpec:
|
||||
"""
|
||||
数据需求规格说明
|
||||
|
||||
Args:
|
||||
source: H5 文件名(如 "daily", "fundamental")
|
||||
columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date"
|
||||
lookback_days: 需要回看的天数(包含当日)
|
||||
- 1 表示只需要当日数据 [T]
|
||||
- 5 表示需要 [T-4, T] 共5天
|
||||
- 20 表示需要 [T-19, T] 共20天
|
||||
"""
|
||||
source: str
|
||||
columns: List[str]
|
||||
lookback_days: int = 1
|
||||
```
|
||||
|
||||
**约束验证:**
|
||||
- `lookback_days >= 1`(至少包含当日)
|
||||
- `columns` 必须包含 `ts_code` 和 `trade_date`
|
||||
- `source` 不能为空字符串
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试有效 DataSpec 创建
|
||||
- [ ] 测试 `lookback_days < 1` 时抛出 ValueError
|
||||
- [ ] 测试缺少 `ts_code` 或 `trade_date` 时抛出 ValueError
|
||||
- [ ] 测试空 `source` 时抛出 ValueError
|
||||
- [ ] 测试 frozen 特性(创建后不可修改)
|
||||
|
||||
---
|
||||
|
||||
### 1.2 FactorContext - 计算上下文
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
@dataclass
|
||||
class FactorContext:
|
||||
"""
|
||||
因子计算上下文
|
||||
|
||||
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问
|
||||
|
||||
Attributes:
|
||||
current_date: 当前计算日期 YYYYMMDD(截面因子使用)
|
||||
current_stock: 当前计算股票代码(时序因子使用)
|
||||
trade_dates: 交易日历列表(可选,用于对齐)
|
||||
"""
|
||||
current_date: Optional[str] = None
|
||||
current_stock: Optional[str] = None
|
||||
trade_dates: Optional[List[str]] = None
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试默认值创建
|
||||
- [ ] 测试完整参数创建
|
||||
- [ ] 测试 dataclass 自动生成的方法
|
||||
|
||||
---
|
||||
|
||||
### 1.3 FactorData - 数据容器
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class FactorData:
|
||||
"""
|
||||
提供给因子的数据容器
|
||||
|
||||
封装底层 Polars DataFrame,提供安全的数据访问接口
|
||||
"""
|
||||
|
||||
def __init__(self, df: pl.DataFrame, context: FactorContext):
|
||||
self._df = df
|
||||
self._context = context
|
||||
|
||||
def get_column(self, col: str) -> pl.Series:
|
||||
"""
|
||||
获取指定列的数据
|
||||
|
||||
- 截面因子:获取当天所有股票的该列值
|
||||
- 时序因子:获取该股票时间序列的该列值
|
||||
|
||||
Args:
|
||||
col: 列名
|
||||
|
||||
Returns:
|
||||
Polars Series
|
||||
|
||||
Raises:
|
||||
KeyError: 列不存在
|
||||
"""
|
||||
pass
|
||||
|
||||
def filter_by_date(self, date: str) -> "FactorData":
|
||||
"""
|
||||
按日期过滤数据,返回新的 FactorData
|
||||
|
||||
主要用于截面因子获取特定日期的数据
|
||||
|
||||
Args:
|
||||
date: YYYYMMDD 格式的日期
|
||||
|
||||
Returns:
|
||||
过滤后的 FactorData
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_cross_section(self) -> pl.DataFrame:
|
||||
"""
|
||||
获取当前日期的截面数据
|
||||
|
||||
仅适用于截面因子,返回 current_date 当天的所有股票数据
|
||||
|
||||
Returns:
|
||||
DataFrame 包含当前日期的所有股票
|
||||
|
||||
Raises:
|
||||
ValueError: current_date 未设置(非截面因子场景)
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
"""获取底层的 Polars DataFrame(高级用法)"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def context(self) -> FactorContext:
|
||||
"""获取计算上下文"""
|
||||
pass
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""返回数据行数"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试 `get_column()` 返回正确 Series
|
||||
- [ ] 测试 `get_column()` 列不存在时抛出 KeyError
|
||||
- [ ] 测试 `filter_by_date()` 返回正确过滤结果
|
||||
- [ ] 测试 `filter_by_date()` 日期不存在时返回空 DataFrame
|
||||
- [ ] 测试 `get_cross_section()` 返回 current_date 当天的数据
|
||||
- [ ] 测试 `get_cross_section()` current_date 为 None 时抛出 ValueError
|
||||
- [ ] 测试 `to_polars()` 返回原始 DataFrame
|
||||
- [ ] 测试 `context` 属性返回正确上下文
|
||||
- [ ] 测试 `__len__()` 返回正确行数
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: 因子基类 (base.py, composite.py)
|
||||
|
||||
### 2.1 BaseFactor - 抽象基类
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class BaseFactor(ABC):
|
||||
"""
|
||||
因子基类 - 定义通用接口
|
||||
|
||||
所有因子必须继承此类,并声明以下类属性:
|
||||
- name: 因子唯一标识(snake_case)
|
||||
- factor_type: "cross_sectional" 或 "time_series"
|
||||
- data_specs: List[DataSpec] 数据需求列表
|
||||
|
||||
可选声明:
|
||||
- category: 因子分类(默认 "default")
|
||||
- description: 因子描述
|
||||
"""
|
||||
|
||||
# 必须声明的类属性
|
||||
name: str = ""
|
||||
factor_type: str = "" # "cross_sectional" | "time_series"
|
||||
data_specs: List[DataSpec] = field(default_factory=list)
|
||||
|
||||
# 可选声明的类属性
|
||||
category: str = "default"
|
||||
description: str = ""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
子类创建时验证必须属性
|
||||
|
||||
验证项:
|
||||
1. name 必须是非空字符串
|
||||
2. factor_type 必须是 "cross_sectional" 或 "time_series"
|
||||
3. data_specs 必须是非空列表
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化因子参数
|
||||
|
||||
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
|
||||
"""
|
||||
self.params = params
|
||||
self._validate_params()
|
||||
|
||||
def _validate_params(self):
|
||||
"""
|
||||
验证参数有效性
|
||||
|
||||
子类可覆盖此方法进行自定义验证
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""
|
||||
核心计算逻辑 - 子类必须实现
|
||||
|
||||
Args:
|
||||
data: 安全的数据容器,已根据因子类型裁剪
|
||||
|
||||
Returns:
|
||||
计算得到的因子值 Series
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========== 因子组合运算符 ==========
|
||||
|
||||
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相加:f1 + f2(要求同类型)"""
|
||||
pass
|
||||
|
||||
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相减:f1 - f2(要求同类型)"""
|
||||
pass
|
||||
|
||||
def __mul__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相乘:f1 * f2(要求同类型)"""
|
||||
pass
|
||||
|
||||
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相除:f1 / f2(要求同类型)"""
|
||||
pass
|
||||
|
||||
def __rmul__(self, scalar: float) -> "ScalarFactor":
|
||||
"""标量乘法:0.5 * f1"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试有效子类创建通过验证
|
||||
- [ ] 测试缺少 `name` 时抛出 ValueError
|
||||
- [ ] 测试 `name` 为空字符串时抛出 ValueError
|
||||
- [ ] 测试缺少 `factor_type` 时抛出 ValueError
|
||||
- [ ] 测试无效的 `factor_type`(非 cs/ts)时抛出 ValueError
|
||||
- [ ] 测试缺少 `data_specs` 时抛出 ValueError
|
||||
- [ ] 测试 `data_specs` 为空列表时抛出 ValueError
|
||||
- [ ] 测试 `compute()` 抽象方法强制子类实现
|
||||
- [ ] 测试参数化初始化 `params` 正确存储
|
||||
- [ ] 测试 `_validate_params()` 被调用
|
||||
|
||||
---
|
||||
|
||||
### 2.2 CrossSectionalFactor - 日期截面因子
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class CrossSectionalFactor(BaseFactor):
|
||||
"""
|
||||
日期截面因子基类
|
||||
|
||||
计算逻辑:在每个交易日,对所有股票进行横向计算
|
||||
|
||||
防泄露边界:
|
||||
- ❌ 禁止访问未来日期的数据(日期泄露)
|
||||
- ✅ 允许访问当前日期的所有股票数据
|
||||
|
||||
数据传入:
|
||||
- compute() 接收的是 [T-lookback+1, T] 的数据
|
||||
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
|
||||
"""
|
||||
|
||||
factor_type: str = "cross_sectional"
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""
|
||||
计算截面因子值
|
||||
|
||||
Args:
|
||||
data: FactorData,包含 [T-lookback+1, T] 的截面数据
|
||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
||||
|
||||
Returns:
|
||||
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
|
||||
|
||||
示例:
|
||||
def compute(self, data):
|
||||
# 获取当前日期的截面
|
||||
cs = data.get_cross_section()
|
||||
# 计算市值排名
|
||||
return cs['market_cap'].rank()
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试 `factor_type` 自动设置为 "cross_sectional"
|
||||
- [ ] 测试子类必须实现 `compute()`
|
||||
- [ ] 测试 `compute()` 返回类型为 pl.Series
|
||||
|
||||
---
|
||||
|
||||
### 2.3 TimeSeriesFactor - 时间序列因子
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class TimeSeriesFactor(BaseFactor):
|
||||
"""
|
||||
时间序列因子基类(股票截面)
|
||||
|
||||
计算逻辑:对每只股票,在其时间序列上进行纵向计算
|
||||
|
||||
防泄露边界:
|
||||
- ❌ 禁止访问其他股票的数据(股票泄露)
|
||||
- ✅ 允许访问该股票的完整历史数据
|
||||
|
||||
数据传入:
|
||||
- compute() 接收的是单只股票的完整时间序列
|
||||
- 包含该股票在 [start_date, end_date] 范围内的所有数据
|
||||
"""
|
||||
|
||||
factor_type: str = "time_series"
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""
|
||||
计算时间序列因子值
|
||||
|
||||
Args:
|
||||
data: FactorData,包含单只股票的完整时间序列
|
||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
||||
|
||||
Returns:
|
||||
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
|
||||
|
||||
示例:
|
||||
def compute(self, data):
|
||||
series = data.get_column("close")
|
||||
return series.rolling_mean(window_size=self.params['period'])
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试 `factor_type` 自动设置为 "time_series"
|
||||
- [ ] 测试子类必须实现 `compute()`
|
||||
- [ ] 测试 `compute()` 返回类型为 pl.Series
|
||||
|
||||
---
|
||||
|
||||
### 2.4 CompositeFactor - 组合因子 (composite.py)
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class CompositeFactor(BaseFactor):
|
||||
"""
|
||||
组合因子 - 用于实现因子间的数学运算
|
||||
|
||||
约束:左右因子必须是同类型(同为截面或同为时序)
|
||||
"""
|
||||
|
||||
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
|
||||
"""
|
||||
创建组合因子
|
||||
|
||||
Args:
|
||||
left: 左操作数因子
|
||||
right: 右操作数因子
|
||||
op: 运算符,支持 '+', '-', '*', '/'
|
||||
|
||||
Raises:
|
||||
ValueError: 左右因子类型不一致
|
||||
ValueError: 不支持的运算符
|
||||
"""
|
||||
pass
|
||||
|
||||
def _merge_data_specs(self) -> List[DataSpec]:
|
||||
"""
|
||||
合并左右因子的数据需求
|
||||
|
||||
策略:
|
||||
1. 相同 source 和 columns 的 DataSpec 合并
|
||||
2. lookback_days 取最大值
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""
|
||||
执行组合运算
|
||||
|
||||
流程:
|
||||
1. 分别计算 left 和 right 的值
|
||||
2. 根据 op 执行运算
|
||||
3. 返回结果
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试同类型因子组合成功(cs + cs)
|
||||
- [ ] 测试同类型因子组合成功(ts + ts)
|
||||
- [ ] 测试不同类型因子组合抛出 ValueError(cs + ts)
|
||||
- [ ] 测试无效运算符抛出 ValueError
|
||||
- [ ] 测试 `_merge_data_specs()` 正确合并(相同 source)
|
||||
- [ ] 测试 `_merge_data_specs()` 正确合并(不同 source)
|
||||
- [ ] 测试 `_merge_data_specs()` lookback 取最大值
|
||||
- [ ] 测试 `compute()` 执行正确的数学运算
|
||||
|
||||
---
|
||||
|
||||
### 2.5 ScalarFactor - 标量运算因子 (composite.py)
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class ScalarFactor(BaseFactor):
|
||||
"""
|
||||
标量运算因子
|
||||
|
||||
支持:scalar * factor, factor * scalar(通过 __rmul__)
|
||||
"""
|
||||
|
||||
def __init__(self, factor: BaseFactor, scalar: float, op: str):
|
||||
"""
|
||||
创建标量运算因子
|
||||
|
||||
Args:
|
||||
factor: 基础因子
|
||||
scalar: 标量值
|
||||
op: 运算符,支持 '*', '+'
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""执行标量运算"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试标量乘法 `0.5 * factor`
|
||||
- [ ] 测试标量乘法 `factor * 0.5`
|
||||
- [ ] 测试标量加法(如支持)
|
||||
- [ ] 测试继承基础因子的 data_specs
|
||||
- [ ] 测试 `compute()` 返回正确缩放后的值
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: 数据加载 (data_loader.py)
|
||||
|
||||
### 3.1 DataLoader - 数据加载器
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class DataLoader:
|
||||
"""
|
||||
数据加载器 - 负责从 HDF5 安全加载数据
|
||||
|
||||
功能:
|
||||
1. 多文件聚合:合并多个 H5 文件的数据
|
||||
2. 列选择:只加载需要的列
|
||||
3. 原始数据缓存:避免重复读取
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str):
|
||||
"""
|
||||
初始化 DataLoader
|
||||
|
||||
Args:
|
||||
data_dir: HDF5 文件所在目录
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
|
||||
def load(
|
||||
self,
|
||||
specs: List[DataSpec],
|
||||
date_range: Optional[Tuple[str, str]] = None
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
加载并聚合多个 H5 文件的数据
|
||||
|
||||
流程:
|
||||
1. 对每个 DataSpec:
|
||||
a. 检查缓存,命中则直接使用
|
||||
b. 未命中则读取 HDF5(通过 pandas)
|
||||
c. 转换为 Polars DataFrame
|
||||
d. 按 date_range 过滤
|
||||
e. 存入缓存
|
||||
2. 合并多个 DataFrame(按 trade_date 和 ts_code join)
|
||||
|
||||
Args:
|
||||
specs: 数据需求规格列表
|
||||
date_range: 日期范围限制 (start_date, end_date),可选
|
||||
|
||||
Returns:
|
||||
合并后的 Polars DataFrame
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: H5 文件不存在
|
||||
KeyError: 列不存在于文件中
|
||||
"""
|
||||
pass
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
pass
|
||||
|
||||
def _read_h5(self, source: str) -> pl.DataFrame:
|
||||
"""
|
||||
读取单个 H5 文件
|
||||
|
||||
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试从单个 H5 文件加载数据
|
||||
- [ ] 测试从多个 H5 文件加载并合并
|
||||
- [ ] 测试列选择(只加载需要的列)
|
||||
- [ ] 测试缓存机制(第二次加载更快)
|
||||
- [ ] 测试 `clear_cache()` 清空缓存
|
||||
- [ ] 测试按 date_range 过滤
|
||||
- [ ] 测试文件不存在时抛出 FileNotFoundError
|
||||
- [ ] 测试列不存在时抛出 KeyError
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: 执行引擎 (engine.py)
|
||||
|
||||
### 4.1 FactorEngine - 因子执行引擎
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
class FactorEngine:
|
||||
"""
|
||||
因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
|
||||
|
||||
核心职责:
|
||||
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
||||
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
||||
"""
|
||||
|
||||
def __init__(self, data_loader: DataLoader):
|
||||
"""
|
||||
初始化引擎
|
||||
|
||||
Args:
|
||||
data_loader: 数据加载器实例
|
||||
"""
|
||||
self.data_loader = data_loader
|
||||
|
||||
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
|
||||
"""
|
||||
统一的计算入口
|
||||
|
||||
根据 factor_type 分发到具体方法:
|
||||
- "cross_sectional" -> _compute_cross_sectional()
|
||||
- "time_series" -> _compute_time_series()
|
||||
|
||||
Args:
|
||||
factor: 要计算的因子
|
||||
**kwargs: 额外参数,根据因子类型不同:
|
||||
- 截面因子: start_date, end_date
|
||||
- 时序因子: stock_codes, start_date, end_date
|
||||
|
||||
Returns:
|
||||
DataFrame[trade_date, ts_code, factor_name]
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试 `compute()` 正确分发给截面计算
|
||||
- [ ] 测试 `compute()` 正确分发给时序计算
|
||||
- [ ] 测试无效 factor_type 时抛出 ValueError
|
||||
|
||||
---
|
||||
|
||||
### 4.2 截面计算(防止日期泄露)
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
def _compute_cross_sectional(
|
||||
self,
|
||||
factor: CrossSectionalFactor,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
执行日期截面计算
|
||||
|
||||
防泄露策略:
|
||||
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
|
||||
- 允许股票间比较:传入当天所有股票的数据
|
||||
|
||||
计算流程:
|
||||
1. 计算 max_lookback,确定数据起始日期
|
||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
||||
3. 对每个日期 T in [start_date, end_date]:
|
||||
a. 裁剪数据到 [T-lookback+1, T]
|
||||
b. 创建 FactorData(current_date=T)
|
||||
c. 调用 factor.compute()
|
||||
d. 收集结果
|
||||
4. 合并所有日期的结果
|
||||
|
||||
返回 DataFrame 格式:
|
||||
┌────────────┬──────────┬──────────────┐
|
||||
│ trade_date │ ts_code │ factor_name │
|
||||
├────────────┼──────────┼──────────────┤
|
||||
│ 20240101 │ 000001.SZ│ 0.5 │
|
||||
│ 20240101 │ 000002.SZ│ 0.3 │
|
||||
└────────────┴──────────┴──────────────┘
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求(防泄露验证):**
|
||||
- [ ] 测试数据裁剪正确(传入 [T-lookback+1, T])
|
||||
- [ ] 测试不包含未来日期 T+1 的数据
|
||||
- [ ] 测试每个日期独立计算
|
||||
- [ ] 测试结果包含所有日期和所有股票
|
||||
- [ ] 测试结果 DataFrame 格式正确
|
||||
- [ ] 测试多个 DataSpec 时 lookback 取最大值
|
||||
|
||||
---
|
||||
|
||||
### 4.3 时序计算(防止股票泄露)
|
||||
|
||||
**实现要求:**
|
||||
```python
|
||||
def _compute_time_series(
|
||||
self,
|
||||
factor: TimeSeriesFactor,
|
||||
stock_codes: List[str],
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
执行时间序列计算
|
||||
|
||||
防泄露策略:
|
||||
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
|
||||
- 允许访问历史数据:时序计算需要历史数据
|
||||
|
||||
计算流程:
|
||||
1. 计算 max_lookback,确定数据起始日期
|
||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
||||
3. 对每只股票 S in stock_codes:
|
||||
a. 过滤出 S 的数据(防止股票泄露)
|
||||
b. 创建 FactorData(current_stock=S)
|
||||
c. 调用 factor.compute()(向量化计算整个序列)
|
||||
d. 收集结果
|
||||
4. 合并所有股票的结果
|
||||
|
||||
性能优势:
|
||||
- 使用 Polars 的 rolling_mean 等向量化操作
|
||||
- 每只股票只计算一次,无重复计算
|
||||
|
||||
返回 DataFrame 格式:
|
||||
┌────────────┬──────────┬──────────────┐
|
||||
│ trade_date │ ts_code │ factor_name │
|
||||
├────────────┼──────────┼──────────────┤
|
||||
│ 20240101 │ 000001.SZ│ 10.5 │
|
||||
│ 20240102 │ 000001.SZ│ 10.6 │
|
||||
└────────────┴──────────┴──────────────┘
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求(防泄露验证):**
|
||||
- [ ] 测试每只股票只看到自己的数据
|
||||
- [ ] 测试不包含其他股票的数据
|
||||
- [ ] 测试传入的是完整时间序列(向量化计算)
|
||||
- [ ] 测试结果包含所有股票和所有日期
|
||||
- [ ] 测试结果 DataFrame 格式正确
|
||||
- [ ] 测试股票不在数据中时跳过(或填充 null)
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: 内置因子库 (builtin/)
|
||||
|
||||
### 5.1 momentum.py - 截面动量因子
|
||||
|
||||
**实现因子:**
|
||||
|
||||
1. **ReturnRankFactor** - 当日收益率排名
|
||||
```python
|
||||
class ReturnRankFactor(CrossSectionalFactor):
|
||||
"""当日收益率排名因子"""
|
||||
name = "return_rank"
|
||||
data_specs = [DataSpec("daily", ["close"], lookback_days=2)] # 需要2天计算收益率
|
||||
|
||||
def compute(self, data):
|
||||
# 获取当前日期截面
|
||||
cs = data.get_cross_section()
|
||||
# 需要前1天和当天的收盘价,lookback=2 保证数据包含 [T-1, T]
|
||||
# 这里假设 data 已经包含历史,实际计算需要 groupby 处理
|
||||
pass
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试收益率计算正确
|
||||
- [ ] 测试排名计算正确
|
||||
- [ ] 测试无数据时返回 null
|
||||
|
||||
2. **MomentumFactor** - 过去 N 日涨幅排名
|
||||
|
||||
---
|
||||
|
||||
### 5.2 technical.py - 时序技术指标
|
||||
|
||||
**实现因子:**
|
||||
|
||||
1. **MovingAverageFactor** - 移动平均线
|
||||
```python
|
||||
class MovingAverageFactor(TimeSeriesFactor):
|
||||
"""移动平均线因子"""
|
||||
name = "ma"
|
||||
|
||||
def __init__(self, period: int = 20):
|
||||
super().__init__(period=period)
|
||||
self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
|
||||
|
||||
def compute(self, data):
|
||||
return data.get_column("close").rolling_mean(self.params["period"])
|
||||
```
|
||||
|
||||
**测试需求:**
|
||||
- [ ] 测试 MA20 计算正确
|
||||
- [ ] 测试前19天返回 null(Polars 默认行为)
|
||||
- [ ] 测试参数 period 生效
|
||||
|
||||
2. **RSIFactor** - RSI 指标
|
||||
3. **MACDFactor** - MACD 指标
|
||||
|
||||
---
|
||||
|
||||
### 5.3 value.py - 截面估值因子
|
||||
|
||||
**实现因子:**
|
||||
1. **PERankFactor** - PE 行业分位数
|
||||
2. **PBFactor** - PB 排名
|
||||
|
||||
---
|
||||
|
||||
## Phase 6-7: 测试策略
|
||||
|
||||
### 测试金字塔
|
||||
|
||||
```
|
||||
/\
|
||||
/ \
|
||||
/ 集成\ tests/factors/test_integration.py
|
||||
/────────\
|
||||
/ 引擎 \ tests/factors/test_engine.py
|
||||
/────────────\
|
||||
/ 基类/组合因子 \ tests/factors/test_base.py, test_composite.py
|
||||
/────────────────\
|
||||
/ 数据加载/类型 \ tests/factors/test_data_loader.py, test_data_spec.py
|
||||
/──────────────────────\
|
||||
```
|
||||
|
||||
### 测试数据准备
|
||||
|
||||
创建 `tests/fixtures/` 目录,包含:
|
||||
- `sample_daily.h5`: 少量股票的日线数据(用于测试)
|
||||
- `sample_fundamental.h5`: 基本面数据
|
||||
|
||||
### 关键测试场景
|
||||
|
||||
1. **防泄露测试(核心)**
|
||||
- 截面因子:验证 compute() 中无法访问未来日期
|
||||
- 时序因子:验证 compute() 中无法访问其他股票
|
||||
|
||||
2. **边界测试**
|
||||
- lookback_days = 1(最小值)
|
||||
- 数据起始点(前 N 天为 null)
|
||||
- 空数据/停牌处理
|
||||
|
||||
3. **性能测试(可选)**
|
||||
- 大数据量下的内存占用
|
||||
- 缓存命中率
|
||||
|
||||
---
|
||||
|
||||
## 实现状态
|
||||
|
||||
| Phase | 状态 | 完成日期 | 测试覆盖 |
|
||||
|-------|------|----------|----------|
|
||||
| Phase 1: 数据类型定义 | ✅ 已完成 | 2026-02-21 | 27 tests passed |
|
||||
| Phase 2: 因子基类 | ✅ 已完成 | 2026-02-21 | 49 tests passed |
|
||||
| Phase 3: 数据加载 | ✅ 已完成 | 2026-02-21 | 11 tests passed |
|
||||
| Phase 4: 执行引擎 | ✅ 已完成 | 2026-02-22 | 10 tests passed |
|
||||
| Phase 5: 内置因子库 | 📝 待开发 | - | - |
|
||||
| Phase 6-7: 测试文档 | ✅ 已完成 | 2026-02-22 | 76 tests total |
|
||||
|
||||
---
|
||||
|
||||
## 实现顺序建议
|
||||
|
||||
1. **Week 1**: Phase 1-2(数据类型 + 基类)
|
||||
2. **Week 2**: Phase 3-4(DataLoader + Engine)✅ **已完成**
|
||||
3. **Week 3**: Phase 5(内置因子)
|
||||
4. **Week 4**: Phase 6-7(测试 + 文档)
|
||||
|
||||
每个 Phase 完成后运行对应测试,确保质量。
|
||||
@@ -1,10 +1,17 @@
|
||||
# ProStock HDF5 到 DuckDB 迁移方案与计划
|
||||
# ProStock HDF5 到 DuckDB 迁移方案
|
||||
|
||||
**文档版本**: v1.0
|
||||
**文档版本**: v1.1
|
||||
**创建日期**: 2026-02-22
|
||||
**状态**: 待审批
|
||||
**完成日期**: 2026-02-22
|
||||
**状态**: ✅ 已完成
|
||||
**影响范围**: data 模块、factors 模块、相关文档
|
||||
|
||||
## 相关文档
|
||||
|
||||
[DuckDB 数据同步指南](./db_sync_guide.md) - 同步 API 使用说明
|
||||
[迁移测试报告](./test_report_duckdb_migration.md) - 测试验证结果
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
1472
docs/ml_framework_design.md
Normal file
1472
docs/ml_framework_design.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,8 @@
|
||||
# ProStock HDF5 到 DuckDB 迁移测试报告
|
||||
|
||||
**报告生成时间**: 2026-02-22
|
||||
**完成时间**: 2026-02-22
|
||||
**状态**: ✅ 已完成
|
||||
**迁移文档**: [hdf5_to_duckdb_migration.md](./hdf5_to_duckdb_migration.md)
|
||||
**测试数据范围**: 2024年1月-3月(3个月)
|
||||
|
||||
|
||||
86
src/models/__init__.py
Normal file
86
src/models/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""ProStock 模型训练框架
|
||||
|
||||
组件化、低耦合、插件式的机器学习训练框架。
|
||||
|
||||
示例:
|
||||
>>> from src.models import (
|
||||
... PluginRegistry, ProcessingPipeline,
|
||||
... PipelineStage, BaseProcessor
|
||||
... )
|
||||
|
||||
>>> # 获取注册的处理器
|
||||
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
|
||||
>>> scaler = scaler_class()
|
||||
|
||||
>>> # 创建处理流水线
|
||||
>>> pipeline = ProcessingPipeline([
|
||||
... PluginRegistry.get_processor("dropna")(),
|
||||
... PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
|
||||
... PluginRegistry.get_processor("standard_scaler")(),
|
||||
... ])
|
||||
"""
|
||||
|
||||
# 导入核心抽象类和划分策略
|
||||
from src.models.core import (
|
||||
PipelineStage,
|
||||
TaskType,
|
||||
BaseProcessor,
|
||||
BaseModel,
|
||||
BaseSplitter,
|
||||
BaseMetric,
|
||||
TimeSeriesSplit,
|
||||
WalkForwardSplit,
|
||||
ExpandingWindowSplit,
|
||||
)
|
||||
|
||||
# 导入注册中心
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
# 导入处理流水线
|
||||
from src.models.pipeline import ProcessingPipeline
|
||||
|
||||
# 导入并注册内置处理器
|
||||
from src.models.processors.processors import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
MinMaxScaler,
|
||||
RankTransformer,
|
||||
Neutralizer,
|
||||
)
|
||||
|
||||
# 导入并注册内置模型
|
||||
from src.models.models.models import (
|
||||
LightGBMModel,
|
||||
CatBoostModel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 核心抽象
|
||||
"PipelineStage",
|
||||
"TaskType",
|
||||
"BaseProcessor",
|
||||
"BaseModel",
|
||||
"BaseSplitter",
|
||||
"BaseMetric",
|
||||
# 划分策略
|
||||
"TimeSeriesSplit",
|
||||
"WalkForwardSplit",
|
||||
"ExpandingWindowSplit",
|
||||
# 注册中心
|
||||
"PluginRegistry",
|
||||
# 处理流水线
|
||||
"ProcessingPipeline",
|
||||
# 处理器
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
"Winsorizer",
|
||||
"StandardScaler",
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
# 模型
|
||||
"LightGBMModel",
|
||||
"CatBoostModel",
|
||||
]
|
||||
30
src/models/core/__init__.py
Normal file
30
src/models/core/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""核心模块导出"""
|
||||
|
||||
from src.models.core.base import (
|
||||
PipelineStage,
|
||||
TaskType,
|
||||
BaseProcessor,
|
||||
BaseModel,
|
||||
BaseSplitter,
|
||||
BaseMetric,
|
||||
)
|
||||
|
||||
from src.models.core.splitter import (
|
||||
TimeSeriesSplit,
|
||||
WalkForwardSplit,
|
||||
ExpandingWindowSplit,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 基础抽象
|
||||
"PipelineStage",
|
||||
"TaskType",
|
||||
"BaseProcessor",
|
||||
"BaseModel",
|
||||
"BaseSplitter",
|
||||
"BaseMetric",
|
||||
# 划分策略
|
||||
"TimeSeriesSplit",
|
||||
"WalkForwardSplit",
|
||||
"ExpandingWindowSplit",
|
||||
]
|
||||
351
src/models/core/base.py
Normal file
351
src/models/core/base.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""模型训练框架核心抽象类
|
||||
|
||||
提供处理器、模型、划分策略和评估指标的基类定义。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Literal
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
# 任务类型
|
||||
TaskType = Literal["classification", "regression", "ranking"]
|
||||
|
||||
|
||||
class PipelineStage(Enum):
|
||||
"""流水线阶段标记
|
||||
|
||||
用于标记处理器在哪些阶段生效,防止数据泄露。
|
||||
|
||||
Attributes:
|
||||
ALL: 适用于所有阶段(训练、测试、验证)
|
||||
TRAIN: 仅训练阶段
|
||||
TEST: 仅测试阶段
|
||||
VALIDATION: 仅验证阶段
|
||||
"""
|
||||
|
||||
ALL = auto()
|
||||
TRAIN = auto()
|
||||
TEST = auto()
|
||||
VALIDATION = auto()
|
||||
|
||||
|
||||
class BaseProcessor(ABC):
|
||||
"""数据处理器基类
|
||||
|
||||
所有数据处理器必须继承此类。关键特性是通过 stage 属性控制处理器在哪些阶段生效。
|
||||
|
||||
阶段标记规则:
|
||||
- ALL: 训练和测试阶段都使用相同的参数
|
||||
- TRAIN: 只在训练阶段计算参数(如分位数、均值等),测试阶段使用训练阶段学到的参数
|
||||
- TEST: 只在测试阶段执行
|
||||
"""
|
||||
|
||||
# 子类必须定义适用阶段
|
||||
stage: PipelineStage = PipelineStage.ALL
|
||||
|
||||
def __init__(self, columns: Optional[List[str]] = None, **params):
|
||||
"""初始化处理器
|
||||
|
||||
Args:
|
||||
columns: 要处理的列,None表示所有数值列
|
||||
**params: 处理器特定参数
|
||||
"""
|
||||
self.columns = columns
|
||||
self.params = params
|
||||
self._is_fitted = False
|
||||
self._fitted_params: Dict[str, Any] = {}
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data: pl.DataFrame) -> "BaseProcessor":
|
||||
"""在训练数据上学习参数
|
||||
|
||||
此方法只在训练阶段调用一次。学习到的参数存储在 self._fitted_params 中。
|
||||
|
||||
Args:
|
||||
data: 训练数据
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""转换数据
|
||||
|
||||
在训练和测试阶段都会被调用。使用 fit() 阶段学习到的参数进行转换。
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
pass
|
||||
|
||||
def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""先fit再transform的便捷方法
|
||||
|
||||
Args:
|
||||
data: 训练数据
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
return self.fit(data).transform(data)
|
||||
|
||||
def get_fitted_params(self) -> Dict[str, Any]:
|
||||
"""获取学习到的参数(用于保存/加载)
|
||||
|
||||
Returns:
|
||||
学习到的参数字典
|
||||
"""
|
||||
return self._fitted_params.copy()
|
||||
|
||||
def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
|
||||
"""设置学习到的参数(用于从checkpoint恢复)
|
||||
|
||||
Args:
|
||||
params: 参数字典
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
self._fitted_params = params.copy()
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""机器学习模型基类
|
||||
|
||||
统一接口支持多种模型(LightGBM, CatBoost, XGBoost等)
|
||||
和多种任务类型(分类、回归、排序)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""初始化模型
|
||||
|
||||
Args:
|
||||
task_type: 任务类型 - "classification", "regression", "ranking"
|
||||
params: 模型特定参数
|
||||
name: 模型名称(用于日志和报告)
|
||||
"""
|
||||
self.task_type = task_type
|
||||
self.params = params or {}
|
||||
self.name = name or self.__class__.__name__
|
||||
self._model: Any = None
|
||||
self._is_fitted = False
|
||||
|
||||
@abstractmethod
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
X_val: Optional[pl.DataFrame] = None,
|
||||
y_val: Optional[pl.Series] = None,
|
||||
**fit_params,
|
||||
) -> "BaseModel":
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X: 特征数据
|
||||
y: 目标变量
|
||||
X_val: 验证集特征(可选)
|
||||
y_val: 验证集目标(可选)
|
||||
**fit_params: 额外的fit参数
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测
|
||||
|
||||
Args:
|
||||
X: 特征数据
|
||||
|
||||
Returns:
|
||||
预测结果数组
|
||||
- classification: 类别标签或概率
|
||||
- regression: 连续值
|
||||
- ranking: 排序分数
|
||||
"""
|
||||
pass
|
||||
|
||||
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测概率(仅分类任务)
|
||||
|
||||
Args:
|
||||
X: 特征数据
|
||||
|
||||
Returns:
|
||||
类别概率数组 [n_samples, n_classes]
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 非分类任务时抛出
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"predict_proba only available for classification tasks"
|
||||
)
|
||||
|
||||
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||
"""获取特征重要性(如果模型支持)
|
||||
|
||||
Returns:
|
||||
DataFrame[feature, importance] 或 None
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""保存模型到文件
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
"""
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "BaseModel":
|
||||
"""从文件加载模型
|
||||
|
||||
Args:
|
||||
path: 模型文件路径
|
||||
|
||||
Returns:
|
||||
加载的模型实例
|
||||
"""
|
||||
import pickle
|
||||
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
class BaseSplitter(ABC):
|
||||
"""数据划分策略基类
|
||||
|
||||
针对时间序列数据的特殊划分策略,防止未来泄露。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引
|
||||
|
||||
Args:
|
||||
data: 完整数据集
|
||||
date_col: 日期列名
|
||||
|
||||
Yields:
|
||||
(train_indices, test_indices) 元组
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围
|
||||
|
||||
Args:
|
||||
data: 完整数据集
|
||||
date_col: 日期列名
|
||||
|
||||
Returns:
|
||||
[(train_start, train_end, test_start, test_end), ...]
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseMetric(ABC):
|
||||
"""评估指标基类
|
||||
|
||||
所有评估指标必须继承此类。支持单次计算和累积计算两种模式。
|
||||
"""
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""初始化指标
|
||||
|
||||
Args:
|
||||
name: 指标名称
|
||||
"""
|
||||
self.name = name or self.__class__.__name__
|
||||
self._values: List[float] = []
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""计算指标值
|
||||
|
||||
Args:
|
||||
y_true: 真实值
|
||||
y_pred: 预测值
|
||||
|
||||
Returns:
|
||||
指标值
|
||||
"""
|
||||
pass
|
||||
|
||||
def update(self, y_true: np.ndarray, y_pred: np.ndarray) -> "BaseMetric":
|
||||
"""更新累积值
|
||||
|
||||
Args:
|
||||
y_true: 真实值
|
||||
y_pred: 预测值
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
self._values.append(self.compute(y_true, y_pred))
|
||||
return self
|
||||
|
||||
def get_mean(self) -> float:
|
||||
"""获取累积值的均值
|
||||
|
||||
Returns:
|
||||
均值
|
||||
"""
|
||||
if not self._values:
|
||||
return 0.0
|
||||
return float(np.mean(self._values))
|
||||
|
||||
def get_std(self) -> float:
|
||||
"""获取累积值的标准差
|
||||
|
||||
Returns:
|
||||
标准差
|
||||
"""
|
||||
if not self._values:
|
||||
return 0.0
|
||||
return float(np.std(self._values))
|
||||
|
||||
def reset(self) -> "BaseMetric":
|
||||
"""重置累积值
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
self._values = []
|
||||
return self
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PipelineStage",
|
||||
"TaskType",
|
||||
"BaseProcessor",
|
||||
"BaseModel",
|
||||
"BaseSplitter",
|
||||
"BaseMetric",
|
||||
]
|
||||
222
src/models/core/splitter.py
Normal file
222
src/models/core/splitter.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""时间序列数据划分策略
|
||||
|
||||
提供针对金融时间序列的特殊划分策略,防止未来泄露。
|
||||
"""
|
||||
|
||||
from typing import Iterator, List, Tuple
|
||||
import polars as pl
|
||||
|
||||
from src.models.core.base import BaseSplitter
|
||||
|
||||
|
||||
class TimeSeriesSplit(BaseSplitter):
|
||||
"""时间序列划分 - 确保训练数据在测试数据之前
|
||||
|
||||
按照时间顺序进行K折划分,每折的训练数据都在测试数据之前。
|
||||
通过 gap 参数防止训练集和测试集之间的数据泄露。
|
||||
|
||||
Args:
|
||||
n_splits: 划分折数
|
||||
gap: 训练集和测试集之间的间隔天数(防止泄露)
|
||||
min_train_size: 最小训练集大小(天数)
|
||||
"""
|
||||
|
||||
def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
|
||||
self.n_splits = n_splits
|
||||
self.gap = gap
|
||||
self.min_train_size = min_train_size
|
||||
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
test_size = (n_dates - self.min_train_size) // self.n_splits
|
||||
|
||||
for i in range(self.n_splits):
|
||||
train_end_idx = self.min_train_size + i * test_size
|
||||
test_start_idx = train_end_idx + self.gap
|
||||
test_end_idx = test_start_idx + test_size
|
||||
|
||||
if test_end_idx > n_dates:
|
||||
break
|
||||
|
||||
train_dates = dates[:train_end_idx]
|
||||
test_dates = dates[test_start_idx:test_end_idx]
|
||||
|
||||
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||
|
||||
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||
|
||||
yield train_idx, test_idx
|
||||
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
test_size = (n_dates - self.min_train_size) // self.n_splits
|
||||
|
||||
result = []
|
||||
for i in range(self.n_splits):
|
||||
train_end_idx = self.min_train_size + i * test_size
|
||||
test_start_idx = train_end_idx + self.gap
|
||||
test_end_idx = test_start_idx + test_size
|
||||
|
||||
if test_end_idx > n_dates:
|
||||
break
|
||||
|
||||
result.append(
|
||||
(
|
||||
str(dates[0]),
|
||||
str(dates[train_end_idx - 1]),
|
||||
str(dates[test_start_idx]),
|
||||
str(dates[test_end_idx - 1]),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class WalkForwardSplit(BaseSplitter):
|
||||
"""滚动前向验证 - 训练集逐步扩展
|
||||
|
||||
Args:
|
||||
train_window: 训练集窗口大小(天数)
|
||||
test_window: 测试集窗口大小(天数)
|
||||
gap: 训练集和测试集之间的间隔天数
|
||||
"""
|
||||
|
||||
def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
|
||||
self.train_window = train_window
|
||||
self.test_window = test_window
|
||||
self.gap = gap
|
||||
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
start_idx = self.train_window
|
||||
while start_idx + self.gap + self.test_window <= n_dates:
|
||||
train_start = start_idx - self.train_window
|
||||
train_end = start_idx
|
||||
test_start = start_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
|
||||
train_dates = dates[train_start:train_end]
|
||||
test_dates = dates[test_start:test_end]
|
||||
|
||||
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||
|
||||
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||
|
||||
yield train_idx, test_idx
|
||||
start_idx += self.test_window
|
||||
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
result = []
|
||||
start_idx = self.train_window
|
||||
while start_idx + self.gap + self.test_window <= n_dates:
|
||||
train_start = start_idx - self.train_window
|
||||
train_end = start_idx
|
||||
test_start = start_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
|
||||
result.append(
|
||||
(
|
||||
str(dates[train_start]),
|
||||
str(dates[train_end - 1]),
|
||||
str(dates[test_start]),
|
||||
str(dates[test_end - 1]),
|
||||
)
|
||||
)
|
||||
start_idx += self.test_window
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExpandingWindowSplit(BaseSplitter):
|
||||
"""扩展窗口划分 - 训练集不断扩大
|
||||
|
||||
Args:
|
||||
initial_train_size: 初始训练集大小(天数)
|
||||
test_window: 测试集窗口大小(天数)
|
||||
gap: 训练集和测试集之间的间隔天数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, initial_train_size: int = 252, test_window: int = 21, gap: int = 5
|
||||
):
|
||||
self.initial_train_size = initial_train_size
|
||||
self.test_window = test_window
|
||||
self.gap = gap
|
||||
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
train_end_idx = self.initial_train_size
|
||||
while train_end_idx + self.gap + self.test_window <= n_dates:
|
||||
train_dates = dates[:train_end_idx]
|
||||
test_start = train_end_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
test_dates = dates[test_start:test_end]
|
||||
|
||||
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||
|
||||
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||
|
||||
yield train_idx, test_idx
|
||||
train_end_idx += self.test_window
|
||||
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
result = []
|
||||
train_end_idx = self.initial_train_size
|
||||
while train_end_idx + self.gap + self.test_window <= n_dates:
|
||||
test_start = train_end_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
|
||||
result.append(
|
||||
(
|
||||
str(dates[0]),
|
||||
str(dates[train_end_idx - 1]),
|
||||
str(dates[test_start]),
|
||||
str(dates[test_end - 1]),
|
||||
)
|
||||
)
|
||||
train_end_idx += self.test_window
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TimeSeriesSplit",
|
||||
"WalkForwardSplit",
|
||||
"ExpandingWindowSplit",
|
||||
]
|
||||
11
src/models/models/__init__.py
Normal file
11
src/models/models/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""模型模块"""
|
||||
|
||||
from src.models.models.models import (
|
||||
LightGBMModel,
|
||||
CatBoostModel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LightGBMModel",
|
||||
"CatBoostModel",
|
||||
]
|
||||
210
src/models/models/models.py
Normal file
210
src/models/models/models.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""内置机器学习模型
|
||||
|
||||
提供 LightGBM、CatBoost 等模型的统一接口包装器。
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.models.core import BaseModel, TaskType
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
|
||||
@PluginRegistry.register_model("lightgbm")
|
||||
class LightGBMModel(BaseModel):
|
||||
"""LightGBM 模型包装器
|
||||
|
||||
支持分类、回归、排序三种任务类型。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(task_type, params, name)
|
||||
self._model = None
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
X_val: Optional[pl.DataFrame] = None,
|
||||
y_val: Optional[pl.Series] = None,
|
||||
**fit_params,
|
||||
) -> "LightGBMModel":
|
||||
"""训练模型"""
|
||||
try:
|
||||
import lightgbm as lgb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"lightgbm is required. Install with: uv pip install lightgbm"
|
||||
)
|
||||
|
||||
X_arr = X.to_numpy()
|
||||
y_arr = y.to_numpy()
|
||||
|
||||
train_data = lgb.Dataset(X_arr, label=y_arr)
|
||||
valid_sets = [train_data]
|
||||
valid_names = ["train"]
|
||||
|
||||
if X_val is not None and y_val is not None:
|
||||
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
|
||||
valid_sets.append(valid_data)
|
||||
valid_names.append("valid")
|
||||
|
||||
default_params = {
|
||||
"objective": self._get_objective(),
|
||||
"metric": self._get_metric(),
|
||||
"boosting_type": "gbdt",
|
||||
"num_leaves": 31,
|
||||
"learning_rate": 0.05,
|
||||
"feature_fraction": 0.9,
|
||||
"bagging_fraction": 0.8,
|
||||
"bagging_freq": 5,
|
||||
"verbose": -1,
|
||||
}
|
||||
default_params.update(self.params)
|
||||
|
||||
callbacks = []
|
||||
if len(valid_sets) > 1:
|
||||
callbacks.append(lgb.early_stopping(stopping_rounds=10, verbose=False))
|
||||
|
||||
self._model = lgb.train(
|
||||
default_params,
|
||||
train_data,
|
||||
num_boost_round=fit_params.get("num_boost_round", 100),
|
||||
valid_sets=valid_sets,
|
||||
valid_names=valid_names,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测"""
|
||||
if not self._is_fitted:
|
||||
raise RuntimeError("Model not fitted yet")
|
||||
return self._model.predict(X.to_numpy())
|
||||
|
||||
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测概率(仅分类任务)"""
|
||||
if self.task_type != "classification":
|
||||
raise ValueError("predict_proba only for classification")
|
||||
probs = self.predict(X)
|
||||
if len(probs.shape) == 1:
|
||||
return np.vstack([1 - probs, probs]).T
|
||||
return probs
|
||||
|
||||
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||
"""获取特征重要性"""
|
||||
if self._model is None:
|
||||
return None
|
||||
importance = self._model.feature_importance(importance_type="gain")
|
||||
feature_names = getattr(
|
||||
self._model,
|
||||
"feature_name",
|
||||
lambda: [f"feature_{i}" for i in range(len(importance))],
|
||||
)()
|
||||
return pl.DataFrame({"feature": feature_names, "importance": importance}).sort(
|
||||
"importance", descending=True
|
||||
)
|
||||
|
||||
def _get_objective(self) -> str:
|
||||
objectives = {
|
||||
"classification": "binary",
|
||||
"regression": "regression",
|
||||
"ranking": "lambdarank",
|
||||
}
|
||||
return objectives.get(self.task_type, "regression")
|
||||
|
||||
def _get_metric(self) -> str:
|
||||
metrics = {"classification": "auc", "regression": "rmse", "ranking": "ndcg"}
|
||||
return metrics.get(self.task_type, "rmse")
|
||||
|
||||
|
||||
@PluginRegistry.register_model("catboost")
|
||||
class CatBoostModel(BaseModel):
|
||||
"""CatBoost 模型包装器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(task_type, params, name)
|
||||
self._model = None
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
X_val: Optional[pl.DataFrame] = None,
|
||||
y_val: Optional[pl.Series] = None,
|
||||
**fit_params,
|
||||
) -> "CatBoostModel":
|
||||
"""训练模型"""
|
||||
try:
|
||||
from catboost import CatBoostClassifier, CatBoostRegressor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"catboost is required. Install with: uv pip install catboost"
|
||||
)
|
||||
|
||||
if self.task_type == "classification":
|
||||
model_class = CatBoostClassifier
|
||||
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
|
||||
elif self.task_type == "regression":
|
||||
model_class = CatBoostRegressor
|
||||
default_params = {"loss_function": "RMSE"}
|
||||
else:
|
||||
model_class = CatBoostRegressor
|
||||
default_params = {"loss_function": "QueryRMSE"}
|
||||
|
||||
default_params.update(self.params)
|
||||
default_params["verbose"] = False
|
||||
|
||||
self._model = model_class(**default_params)
|
||||
|
||||
eval_set = None
|
||||
if X_val is not None and y_val is not None:
|
||||
eval_set = (X_val.to_pandas(), y_val.to_pandas())
|
||||
|
||||
self._model.fit(
|
||||
X.to_pandas(),
|
||||
y.to_pandas(),
|
||||
eval_set=eval_set,
|
||||
early_stopping_rounds=fit_params.get("early_stopping_rounds", 10),
|
||||
verbose=False,
|
||||
)
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测"""
|
||||
if not self._is_fitted:
|
||||
raise RuntimeError("Model not fitted yet")
|
||||
return self._model.predict(X.to_pandas())
|
||||
|
||||
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测概率"""
|
||||
if self.task_type != "classification":
|
||||
raise ValueError("predict_proba only for classification")
|
||||
return self._model.predict_proba(X.to_pandas())
|
||||
|
||||
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||
"""获取特征重要性"""
|
||||
if self._model is None:
|
||||
return None
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"feature": self._model.feature_names_,
|
||||
"importance": self._model.feature_importances_,
|
||||
}
|
||||
).sort("importance", descending=True)
|
||||
|
||||
|
||||
__all__ = ["LightGBMModel", "CatBoostModel"]
|
||||
70
src/models/pipeline.py
Normal file
70
src/models/pipeline.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""数据处理流水线
|
||||
|
||||
管理多个处理器的顺序执行,支持阶段感知处理。
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
import polars as pl
|
||||
|
||||
from src.models.core import BaseProcessor, PipelineStage
|
||||
|
||||
|
||||
class ProcessingPipeline:
|
||||
"""数据处理流水线
|
||||
|
||||
按顺序执行多个处理器,自动处理阶段标记。
|
||||
关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。
|
||||
"""
|
||||
|
||||
def __init__(self, processors: List[BaseProcessor]):
|
||||
"""初始化流水线
|
||||
|
||||
Args:
|
||||
processors: 处理器列表(按执行顺序)
|
||||
"""
|
||||
self.processors = processors
|
||||
self._fitted_processors: Dict[int, BaseProcessor] = {}
|
||||
|
||||
def fit_transform(
|
||||
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
|
||||
) -> pl.DataFrame:
|
||||
"""在训练数据上fit所有处理器并transform"""
|
||||
result = data
|
||||
for i, processor in enumerate(self.processors):
|
||||
if processor.stage in [PipelineStage.ALL, stage]:
|
||||
result = processor.fit_transform(result)
|
||||
self._fitted_processors[i] = processor
|
||||
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
|
||||
processor.fit(result)
|
||||
self._fitted_processors[i] = processor
|
||||
return result
|
||||
|
||||
def transform(
|
||||
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
|
||||
) -> pl.DataFrame:
|
||||
"""在测试数据上应用已fit的处理器"""
|
||||
result = data
|
||||
for i, processor in enumerate(self.processors):
|
||||
if processor.stage in [PipelineStage.ALL, stage]:
|
||||
if i in self._fitted_processors:
|
||||
result = self._fitted_processors[i].transform(result)
|
||||
else:
|
||||
result = processor.transform(result)
|
||||
return result
|
||||
|
||||
def save_processors(self, path: str) -> None:
|
||||
"""保存所有已fit的处理器状态"""
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self._fitted_processors, f)
|
||||
|
||||
def load_processors(self, path: str) -> None:
|
||||
"""加载处理器状态"""
|
||||
import pickle
|
||||
|
||||
with open(path, "rb") as f:
|
||||
self._fitted_processors = pickle.load(f)
|
||||
|
||||
|
||||
__all__ = ["ProcessingPipeline"]
|
||||
21
src/models/processors/__init__.py
Normal file
21
src/models/processors/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""处理器模块"""
|
||||
|
||||
from src.models.processors.processors import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
MinMaxScaler,
|
||||
RankTransformer,
|
||||
Neutralizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
"Winsorizer",
|
||||
"StandardScaler",
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
]
|
||||
238
src/models/processors/processors.py
Normal file
238
src/models/processors/processors.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""内置数据处理器
|
||||
|
||||
提供常用的数据预处理和转换处理器。
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.models.core import BaseProcessor, PipelineStage
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
# 数值类型列表
|
||||
FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]
|
||||
|
||||
|
||||
def _get_numeric_columns(
|
||||
data: pl.DataFrame, columns: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""获取数值列"""
|
||||
if columns is not None:
|
||||
return columns
|
||||
return [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("dropna")
|
||||
class DropNAProcessor(BaseProcessor):
|
||||
"""缺失值删除处理器"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
cols = self.columns or data.columns
|
||||
return data.drop_nulls(subset=cols)
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("fillna")
|
||||
class FillNAProcessor(BaseProcessor):
|
||||
"""缺失值填充处理器(只在训练阶段计算填充值)"""
|
||||
|
||||
stage = PipelineStage.TRAIN
|
||||
|
||||
def __init__(self, columns: Optional[List[str]] = None, method: str = "median"):
|
||||
super().__init__(columns)
|
||||
if method not in ["median", "mean", "zero"]:
|
||||
raise ValueError(f"Unknown fill method: {method}")
|
||||
self.method = method
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
fill_values = {}
|
||||
|
||||
for col in cols:
|
||||
if self.method == "median":
|
||||
fill_values[col] = data[col].median()
|
||||
elif self.method == "mean":
|
||||
fill_values[col] = data[col].mean()
|
||||
elif self.method == "zero":
|
||||
fill_values[col] = 0.0
|
||||
|
||||
self._fitted_params = {"fill_values": fill_values, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, val in self._fitted_params.get("fill_values", {}).items():
|
||||
if col in result.columns:
|
||||
result = result.with_columns(pl.col(col).fill_null(val).alias(col))
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("winsorizer")
|
||||
class Winsorizer(BaseProcessor):
|
||||
"""缩尾处理器 - 防止极端值影响(只在训练阶段计算分位数)"""
|
||||
|
||||
stage = PipelineStage.TRAIN
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
columns: Optional[List[str]] = None,
|
||||
lower: float = 0.01,
|
||||
upper: float = 0.99,
|
||||
):
|
||||
super().__init__(columns)
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "Winsorizer":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
bounds = {}
|
||||
|
||||
for col in cols:
|
||||
bounds[col] = {
|
||||
"lower": data[col].quantile(self.lower),
|
||||
"upper": data[col].quantile(self.upper),
|
||||
}
|
||||
|
||||
self._fitted_params = {"bounds": bounds, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, bounds in self._fitted_params.get("bounds", {}).items():
|
||||
if col in result.columns:
|
||||
result = result.with_columns(
|
||||
pl.col(col).clip(bounds["lower"], bounds["upper"]).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("standard_scaler")
|
||||
class StandardScaler(BaseProcessor):
|
||||
"""标准化处理器 - Z-score标准化"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "StandardScaler":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
stats = {}
|
||||
|
||||
for col in cols:
|
||||
stats[col] = {"mean": data[col].mean(), "std": data[col].std()}
|
||||
|
||||
self._fitted_params = {"stats": stats, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, stats in self._fitted_params.get("stats", {}).items():
|
||||
if col in result.columns and stats["std"] is not None and stats["std"] > 0:
|
||||
result = result.with_columns(
|
||||
((pl.col(col) - stats["mean"]) / stats["std"]).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("minmax_scaler")
|
||||
class MinMaxScaler(BaseProcessor):
|
||||
"""归一化处理器 - 缩放到[0, 1]范围"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "MinMaxScaler":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
stats = {}
|
||||
|
||||
for col in cols:
|
||||
stats[col] = {"min": data[col].min(), "max": data[col].max()}
|
||||
|
||||
self._fitted_params = {"stats": stats, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, stats in self._fitted_params.get("stats", {}).items():
|
||||
if col in result.columns:
|
||||
range_val = stats["max"] - stats["min"]
|
||||
if range_val is not None and range_val > 0:
|
||||
result = result.with_columns(
|
||||
((pl.col(col) - stats["min"]) / range_val).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("rank_transformer")
|
||||
class RankTransformer(BaseProcessor):
|
||||
"""排名转换处理器 - 转换为截面排名"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "RankTransformer":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
cols = self.columns or _get_numeric_columns(data)
|
||||
|
||||
for col in cols:
|
||||
if col in result.columns:
|
||||
result = result.with_columns(
|
||||
pl.col(col).rank().over("trade_date").alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("neutralizer")
|
||||
class Neutralizer(BaseProcessor):
|
||||
"""中性化处理器 - 行业/市值中性化"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
columns: Optional[List[str]] = None,
|
||||
group_col: str = "industry",
|
||||
exclude_cols: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(columns)
|
||||
self.group_col = group_col
|
||||
self.exclude_cols = exclude_cols or []
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "Neutralizer":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
cols = self.columns or _get_numeric_columns(data)
|
||||
|
||||
for col in cols:
|
||||
if col in result.columns and col not in self.exclude_cols:
|
||||
result = result.with_columns(
|
||||
(
|
||||
pl.col(col)
|
||||
- pl.col(col).mean().over(["trade_date", self.group_col])
|
||||
).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
"Winsorizer",
|
||||
"StandardScaler",
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
]
|
||||
297
src/models/registry.py
Normal file
297
src/models/registry.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""插件注册中心
|
||||
|
||||
提供装饰器方式注册处理器、模型、划分策略等组件。
|
||||
实现真正的插件式架构 - 新功能只需注册即可使用。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... pass
|
||||
|
||||
>>> # 使用
|
||||
>>> scaler = PluginRegistry.get_processor("standard_scaler")()
|
||||
"""
|
||||
|
||||
from typing import Type, Dict, List, TypeVar, Optional
|
||||
from functools import wraps
|
||||
from weakref import WeakValueDictionary
|
||||
import contextlib
|
||||
|
||||
from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PluginRegistry:
|
||||
"""插件注册中心
|
||||
|
||||
管理所有组件的注册和获取。使用装饰器方式注册新组件。
|
||||
|
||||
Attributes:
|
||||
_processors: 已注册的处理器字典
|
||||
_models: 已注册的模型字典
|
||||
_splitters: 已注册的划分策略字典
|
||||
_metrics: 已注册的评估指标字典
|
||||
"""
|
||||
|
||||
_processors: Dict[str, Type[BaseProcessor]] = {}
|
||||
_models: Dict[str, Type[BaseModel]] = {}
|
||||
_splitters: Dict[str, Type[BaseSplitter]] = {}
|
||||
_metrics: Dict[str, Type[BaseMetric]] = {}
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def temp_registry(cls):
|
||||
"""临时注册上下文管理器
|
||||
|
||||
在上下文管理器内部注册的组件会在退出时自动清理,
|
||||
避免测试之间的状态污染。
|
||||
|
||||
示例:
|
||||
>>> with PluginRegistry.temp_registry():
|
||||
... @PluginRegistry.register_processor("temp_processor")
|
||||
... class TempProcessor(BaseProcessor):
|
||||
... pass
|
||||
... # 在此处可以使用 temp_processor
|
||||
... # 退出后自动清理
|
||||
"""
|
||||
original_state = {
|
||||
"_processors": cls._processors.copy(),
|
||||
"_models": cls._models.copy(),
|
||||
"_splitters": cls._splitters.copy(),
|
||||
"_metrics": cls._metrics.copy(),
|
||||
}
|
||||
try:
|
||||
yield cls
|
||||
finally:
|
||||
cls._processors = original_state["_processors"]
|
||||
cls._models = original_state["_models"]
|
||||
cls._splitters = original_state["_splitters"]
|
||||
cls._metrics = original_state["_metrics"]
|
||||
|
||||
@classmethod
|
||||
def register_processor(cls, name: Optional[str] = None):
|
||||
"""注册处理器装饰器
|
||||
|
||||
用于装饰器方式注册数据处理器。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... pass
|
||||
|
||||
>>> # 获取并使用
|
||||
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
|
||||
>>> scaler = scaler_class()
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
|
||||
key = name or processor_class.__name__
|
||||
cls._processors[key] = processor_class
|
||||
processor_class._registry_name = key
|
||||
return processor_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def register_model(cls, name: Optional[str] = None):
|
||||
"""注册模型装饰器
|
||||
|
||||
用于装饰器方式注册机器学习模型。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_model("lightgbm")
|
||||
... class LightGBMModel(BaseModel):
|
||||
... pass
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
|
||||
key = name or model_class.__name__
|
||||
cls._models[key] = model_class
|
||||
model_class._registry_name = key
|
||||
return model_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def register_splitter(cls, name: Optional[str] = None):
|
||||
"""注册划分策略装饰器
|
||||
|
||||
用于装饰器方式注册数据划分策略。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_splitter("time_series")
|
||||
... class TimeSeriesSplit(BaseSplitter):
|
||||
... pass
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
|
||||
key = name or splitter_class.__name__
|
||||
cls._splitters[key] = splitter_class
|
||||
splitter_class._registry_name = key
|
||||
return splitter_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def register_metric(cls, name: Optional[str] = None):
|
||||
"""注册评估指标装饰器
|
||||
|
||||
用于装饰器方式注册评估指标。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_metric("ic")
|
||||
... class ICMetric(BaseMetric):
|
||||
... pass
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(metric_class: Type[BaseMetric]) -> Type[BaseMetric]:
|
||||
key = name or metric_class.__name__
|
||||
cls._metrics[key] = metric_class
|
||||
metric_class._registry_name = key
|
||||
return metric_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_processor(cls, name: str) -> Type[BaseProcessor]:
|
||||
"""获取处理器类
|
||||
|
||||
Args:
|
||||
name: 处理器注册名称
|
||||
|
||||
Returns:
|
||||
处理器类
|
||||
|
||||
Raises:
|
||||
KeyError: 处理器不存在时抛出
|
||||
"""
|
||||
if name not in cls._processors:
|
||||
available = list(cls._processors.keys())
|
||||
raise KeyError(f"Processor '{name}' not found. Available: {available}")
|
||||
return cls._processors[name]
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, name: str) -> Type[BaseModel]:
|
||||
"""获取模型类
|
||||
|
||||
Args:
|
||||
name: 模型注册名称
|
||||
|
||||
Returns:
|
||||
模型类
|
||||
|
||||
Raises:
|
||||
KeyError: 模型不存在时抛出
|
||||
"""
|
||||
if name not in cls._models:
|
||||
available = list(cls._models.keys())
|
||||
raise KeyError(f"Model '{name}' not found. Available: {available}")
|
||||
return cls._models[name]
|
||||
|
||||
@classmethod
|
||||
def get_splitter(cls, name: str) -> Type[BaseSplitter]:
|
||||
"""获取划分策略类
|
||||
|
||||
Args:
|
||||
name: 划分策略注册名称
|
||||
|
||||
Returns:
|
||||
划分策略类
|
||||
|
||||
Raises:
|
||||
KeyError: 划分策略不存在时抛出
|
||||
"""
|
||||
if name not in cls._splitters:
|
||||
available = list(cls._splitters.keys())
|
||||
raise KeyError(f"Splitter '{name}' not found. Available: {available}")
|
||||
return cls._splitters[name]
|
||||
|
||||
@classmethod
|
||||
def get_metric(cls, name: str) -> Type[BaseMetric]:
|
||||
"""获取评估指标类
|
||||
|
||||
Args:
|
||||
name: 评估指标注册名称
|
||||
|
||||
Returns:
|
||||
评估指标类
|
||||
|
||||
Raises:
|
||||
KeyError: 评估指标不存在时抛出
|
||||
"""
|
||||
if name not in cls._metrics:
|
||||
available = list(cls._metrics.keys())
|
||||
raise KeyError(f"Metric '{name}' not found. Available: {available}")
|
||||
return cls._metrics[name]
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls) -> List[str]:
|
||||
"""列出所有可用处理器
|
||||
|
||||
Returns:
|
||||
处理器名称列表
|
||||
"""
|
||||
return list(cls._processors.keys())
|
||||
|
||||
@classmethod
|
||||
def list_models(cls) -> List[str]:
|
||||
"""列出所有可用模型
|
||||
|
||||
Returns:
|
||||
模型名称列表
|
||||
"""
|
||||
return list(cls._models.keys())
|
||||
|
||||
@classmethod
|
||||
def list_splitters(cls) -> List[str]:
|
||||
"""列出所有可用划分策略
|
||||
|
||||
Returns:
|
||||
划分策略名称列表
|
||||
"""
|
||||
return list(cls._splitters.keys())
|
||||
|
||||
@classmethod
|
||||
def list_metrics(cls) -> List[str]:
|
||||
"""列出所有可用评估指标
|
||||
|
||||
Returns:
|
||||
评估指标名称列表
|
||||
"""
|
||||
return list(cls._metrics.keys())
|
||||
|
||||
@classmethod
|
||||
def clear_all(cls) -> None:
|
||||
"""清除所有注册(主要用于测试)"""
|
||||
cls._processors.clear()
|
||||
cls._models.clear()
|
||||
cls._splitters.clear()
|
||||
cls._metrics.clear()
|
||||
|
||||
|
||||
__all__ = ["PluginRegistry"]
|
||||
478
tests/models/test_core.py
Normal file
478
tests/models/test_core.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""模型框架核心测试
|
||||
|
||||
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
|
||||
# 确保导入时注册所有组件
|
||||
from src.models import (
|
||||
PluginRegistry,
|
||||
PipelineStage,
|
||||
BaseProcessor,
|
||||
BaseModel,
|
||||
BaseSplitter,
|
||||
ProcessingPipeline,
|
||||
)
|
||||
from src.models.core import TaskType
|
||||
|
||||
|
||||
# ========== 测试核心抽象类 ==========
|
||||
|
||||
|
||||
class TestPipelineStage:
|
||||
"""测试阶段枚举"""
|
||||
|
||||
def test_stage_values(self):
|
||||
assert PipelineStage.ALL.name == "ALL"
|
||||
assert PipelineStage.TRAIN.name == "TRAIN"
|
||||
assert PipelineStage.TEST.name == "TEST"
|
||||
assert PipelineStage.VALIDATION.name == "VALIDATION"
|
||||
|
||||
|
||||
class TestBaseProcessor:
|
||||
"""测试处理器基类"""
|
||||
|
||||
def test_processor_initialization(self):
|
||||
"""测试处理器初始化"""
|
||||
|
||||
class DummyProcessor(BaseProcessor):
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "DummyProcessor":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
return data
|
||||
|
||||
processor = DummyProcessor(columns=["col1", "col2"])
|
||||
assert processor.columns == ["col1", "col2"]
|
||||
assert processor.stage == PipelineStage.ALL
|
||||
assert not processor._is_fitted
|
||||
|
||||
def test_processor_fit_transform(self):
|
||||
"""测试 fit_transform 方法"""
|
||||
|
||||
class AddOneProcessor(BaseProcessor):
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "AddOneProcessor":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data.clone()
|
||||
for col in self.columns or []:
|
||||
result = result.with_columns((pl.col(col) + 1).alias(col))
|
||||
return result
|
||||
|
||||
processor = AddOneProcessor(columns=["value"])
|
||||
df = pl.DataFrame({"value": [1, 2, 3]})
|
||||
|
||||
result = processor.fit_transform(df)
|
||||
|
||||
assert processor._is_fitted
|
||||
assert result["value"].to_list() == [2, 3, 4]
|
||||
|
||||
|
||||
class TestBaseModel:
|
||||
"""测试模型基类"""
|
||||
|
||||
def test_model_initialization(self):
|
||||
"""测试模型初始化"""
|
||||
|
||||
class DummyModel(BaseModel):
|
||||
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
return np.zeros(len(X))
|
||||
|
||||
model = DummyModel(
|
||||
task_type="regression", params={"lr": 0.01}, name="test_model"
|
||||
)
|
||||
|
||||
assert model.task_type == "regression"
|
||||
assert model.params == {"lr": 0.01}
|
||||
assert model.name == "test_model"
|
||||
assert not model._is_fitted
|
||||
|
||||
def test_predict_proba_not_implemented(self):
|
||||
"""测试未实现 predict_proba 时抛出异常"""
|
||||
|
||||
class DummyModel(BaseModel):
|
||||
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
return np.zeros(len(X))
|
||||
|
||||
model = DummyModel(task_type="regression")
|
||||
df = pl.DataFrame({"feature": [1, 2, 3]})
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
model.predict_proba(df)
|
||||
|
||||
|
||||
class TestBaseSplitter:
|
||||
"""测试划分策略基类"""
|
||||
|
||||
def test_splitter_interface(self):
|
||||
"""测试划分策略接口"""
|
||||
|
||||
class DummySplitter(BaseSplitter):
|
||||
def split(self, data, date_col="trade_date"):
|
||||
yield [0, 1], [2, 3]
|
||||
|
||||
def get_split_dates(self, data, date_col="trade_date"):
|
||||
return [("20200101", "20201231", "20210101", "20211231")]
|
||||
|
||||
splitter = DummySplitter()
|
||||
df = pl.DataFrame(
|
||||
{"trade_date": ["20200101", "20200601", "20210101", "20210601"]}
|
||||
)
|
||||
|
||||
splits = list(splitter.split(df))
|
||||
assert len(splits) == 1
|
||||
assert splits[0] == ([0, 1], [2, 3])
|
||||
|
||||
dates = splitter.get_split_dates(df)
|
||||
assert dates == [("20200101", "20201231", "20210101", "20211231")]
|
||||
|
||||
|
||||
# ========== 测试插件注册中心 ==========
|
||||
|
||||
|
||||
class TestPluginRegistry:
|
||||
"""测试插件注册中心"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清除注册"""
|
||||
PluginRegistry.clear_all()
|
||||
|
||||
def test_register_and_get_processor(self):
|
||||
"""测试注册和获取处理器"""
|
||||
|
||||
@PluginRegistry.register_processor("test_processor")
|
||||
class TestProcessor(BaseProcessor):
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data):
|
||||
return self
|
||||
|
||||
def transform(self, data):
|
||||
return data
|
||||
|
||||
processor_class = PluginRegistry.get_processor("test_processor")
|
||||
assert processor_class == TestProcessor
|
||||
assert "test_processor" in PluginRegistry.list_processors()
|
||||
|
||||
def test_register_and_get_model(self):
|
||||
"""测试注册和获取模型"""
|
||||
|
||||
@PluginRegistry.register_model("test_model")
|
||||
class TestModel(BaseModel):
|
||||
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
return np.zeros(len(X))
|
||||
|
||||
model_class = PluginRegistry.get_model("test_model")
|
||||
assert model_class == TestModel
|
||||
assert "test_model" in PluginRegistry.list_models()
|
||||
|
||||
def test_register_and_get_splitter(self):
|
||||
"""测试注册和获取划分策略"""
|
||||
|
||||
@PluginRegistry.register_splitter("test_splitter")
|
||||
class TestSplitter(BaseSplitter):
|
||||
def split(self, data, date_col="trade_date"):
|
||||
yield [], []
|
||||
|
||||
def get_split_dates(self, data, date_col="trade_date"):
|
||||
return []
|
||||
|
||||
splitter_class = PluginRegistry.get_splitter("test_splitter")
|
||||
assert splitter_class == TestSplitter
|
||||
assert "test_splitter" in PluginRegistry.list_splitters()
|
||||
|
||||
def test_get_nonexistent_processor(self):
|
||||
"""测试获取不存在的处理器时抛出异常"""
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
PluginRegistry.get_processor("nonexistent")
|
||||
assert "nonexistent" in str(exc_info.value)
|
||||
|
||||
def test_register_with_default_name(self):
|
||||
"""测试使用默认名称注册"""
|
||||
|
||||
@PluginRegistry.register_processor()
|
||||
class MyCustomProcessor(BaseProcessor):
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data):
|
||||
return self
|
||||
|
||||
def transform(self, data):
|
||||
return data
|
||||
|
||||
assert "MyCustomProcessor" in PluginRegistry.list_processors()
|
||||
|
||||
|
||||
# ========== 测试内置处理器 ==========
|
||||
|
||||
|
||||
class TestBuiltInProcessors:
|
||||
"""测试内置处理器"""
|
||||
|
||||
def test_dropna_processor(self):
|
||||
"""测试缺失值删除处理器"""
|
||||
from src.models.processors import DropNAProcessor
|
||||
|
||||
processor = DropNAProcessor(columns=["a", "b"])
|
||||
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
|
||||
|
||||
result = processor.fit_transform(df)
|
||||
|
||||
# 只有第一行没有缺失值
|
||||
assert len(result) == 1
|
||||
assert result["a"].to_list() == [1]
|
||||
assert result["b"].to_list() == [4]
|
||||
|
||||
def test_fillna_processor(self):
|
||||
"""测试缺失值填充处理器"""
|
||||
from src.models.processors import FillNAProcessor
|
||||
|
||||
processor = FillNAProcessor(columns=["a"], method="mean")
|
||||
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
|
||||
|
||||
result = processor.fit_transform(df)
|
||||
|
||||
# 均值 = (1+2+4)/3 = 2.333...
|
||||
assert result["a"][2] == pytest.approx(2.333, rel=0.01)
|
||||
|
||||
def test_standard_scaler(self):
|
||||
"""测试标准化处理器"""
|
||||
from src.models.processors import StandardScaler
|
||||
|
||||
processor = StandardScaler(columns=["value"])
|
||||
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
|
||||
|
||||
result = processor.fit_transform(df)
|
||||
|
||||
# Z-score 标准化后均值为0,标准差为1
|
||||
assert result["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||
assert result["value"].std() == pytest.approx(1.0, rel=0.01)
|
||||
|
||||
def test_winsorizer(self):
|
||||
"""测试缩尾处理器"""
|
||||
from src.models.processors import Winsorizer
|
||||
|
||||
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"value": list(range(100)) # 0-99
|
||||
}
|
||||
)
|
||||
|
||||
result = processor.fit_transform(df)
|
||||
|
||||
# 10%和90%分位数应该是10和89(Polars的quantile行为)
|
||||
assert result["value"].min() == 10
|
||||
assert result["value"].max() == 89
|
||||
|
||||
def test_rank_transformer(self):
|
||||
"""测试排名转换处理器"""
|
||||
from src.models.processors import RankTransformer
|
||||
|
||||
processor = RankTransformer(columns=["value"])
|
||||
df = pl.DataFrame(
|
||||
{"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]}
|
||||
)
|
||||
|
||||
result = processor.fit_transform(df)
|
||||
|
||||
# 排名应该是 1, 3, 2, 5, 4
|
||||
assert result["value"].to_list() == [1, 3, 2, 5, 4]
|
||||
|
||||
def test_neutralizer(self):
|
||||
"""测试中性化处理器"""
|
||||
from src.models.processors import Neutralizer
|
||||
|
||||
processor = Neutralizer(columns=["value"], group_col="industry")
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20200101", "20200101", "20200101", "20200101"],
|
||||
"industry": ["A", "A", "B", "B"],
|
||||
"value": [10, 20, 30, 50],
|
||||
}
|
||||
)
|
||||
|
||||
result = processor.fit_transform(df)
|
||||
|
||||
# 分组去均值后,每组的均值为0
|
||||
group_a = result.filter(pl.col("industry") == "A")
|
||||
group_b = result.filter(pl.col("industry") == "B")
|
||||
|
||||
assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||
assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||
|
||||
|
||||
# ========== 测试处理流水线 ==========
|
||||
|
||||
|
||||
class TestProcessingPipeline:
|
||||
"""测试处理流水线"""
|
||||
|
||||
def test_pipeline_fit_transform(self):
|
||||
"""测试流水线的 fit_transform"""
|
||||
from src.models.processors import StandardScaler
|
||||
|
||||
scaler1 = StandardScaler(columns=["a"])
|
||||
scaler2 = StandardScaler(columns=["b"])
|
||||
|
||||
pipeline = ProcessingPipeline([scaler1, scaler2])
|
||||
|
||||
df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]})
|
||||
|
||||
result = pipeline.fit_transform(df)
|
||||
|
||||
# 两个列都应该被标准化
|
||||
assert result["a"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||
assert result["b"].mean() == pytest.approx(0.0, abs=1e-10)
|
||||
|
||||
def test_pipeline_transform_uses_fitted_params(self):
|
||||
"""测试 transform 使用已 fit 的参数"""
|
||||
from src.models.processors import StandardScaler
|
||||
|
||||
scaler = StandardScaler(columns=["value"])
|
||||
pipeline = ProcessingPipeline([scaler])
|
||||
|
||||
# 训练数据
|
||||
train_df = pl.DataFrame(
|
||||
{
|
||||
"value": [1.0, 2.0, 3.0] # 均值=2,标准差=1
|
||||
}
|
||||
)
|
||||
|
||||
# 测试数据(不同的分布)
|
||||
test_df = pl.DataFrame(
|
||||
{
|
||||
"value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5
|
||||
}
|
||||
)
|
||||
|
||||
pipeline.fit_transform(train_df)
|
||||
result = pipeline.transform(test_df)
|
||||
|
||||
# 使用训练数据的均值=2和标准差=1进行标准化
|
||||
# 4 -> (4-2)/1 = 2
|
||||
assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10)
|
||||
|
||||
|
||||
# ========== 测试划分策略 ==========
|
||||
|
||||
|
||||
class TestSplitters:
|
||||
"""测试划分策略"""
|
||||
|
||||
def test_time_series_split(self):
|
||||
"""测试时间序列划分"""
|
||||
from src.models.core import TimeSeriesSplit
|
||||
|
||||
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
|
||||
|
||||
# 10天的数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": [f"202001{i:02d}" for i in range(1, 11)],
|
||||
"value": list(range(10)),
|
||||
}
|
||||
)
|
||||
|
||||
splits = list(splitter.split(df))
|
||||
|
||||
# 应该有两折
|
||||
assert len(splits) == 2
|
||||
|
||||
# 检查每折训练集在测试集之前
|
||||
for train_idx, test_idx in splits:
|
||||
assert max(train_idx) < min(test_idx)
|
||||
|
||||
def test_walk_forward_split(self):
|
||||
"""测试滚动前向划分"""
|
||||
from src.models.core import WalkForwardSplit
|
||||
|
||||
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
|
||||
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": [f"202001{i:02d}" for i in range(1, 13)],
|
||||
"value": list(range(12)),
|
||||
}
|
||||
)
|
||||
|
||||
splits = list(splitter.split(df))
|
||||
|
||||
# 检查训练集大小固定
|
||||
for train_idx, test_idx in splits:
|
||||
assert len(train_idx) == 5
|
||||
assert len(test_idx) == 2
|
||||
|
||||
def test_expanding_window_split(self):
|
||||
"""测试扩展窗口划分"""
|
||||
from src.models.core import ExpandingWindowSplit
|
||||
|
||||
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
|
||||
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": [f"202001{i:02d}" for i in range(1, 15)],
|
||||
"value": list(range(14)),
|
||||
}
|
||||
)
|
||||
|
||||
splits = list(splitter.split(df))
|
||||
|
||||
# 训练集应该逐渐增大
|
||||
train_sizes = [len(train_idx) for train_idx, _ in splits]
|
||||
assert train_sizes[0] == 3
|
||||
assert train_sizes[1] == 5 # 3 + 2
|
||||
assert train_sizes[2] == 7 # 5 + 2
|
||||
|
||||
|
||||
# ========== 测试内置模型(可选,需要安装依赖) ==========
|
||||
|
||||
|
||||
class TestModels:
|
||||
"""测试内置模型(标记为跳过如果依赖未安装)"""
|
||||
|
||||
@pytest.mark.skip(reason="需要安装 lightgbm")
|
||||
def test_lightgbm_model(self):
|
||||
"""测试 LightGBM 模型"""
|
||||
from src.models.models import LightGBMModel
|
||||
|
||||
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
|
||||
|
||||
X = pl.DataFrame(
|
||||
{
|
||||
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10,
|
||||
"feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10,
|
||||
}
|
||||
)
|
||||
y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10)
|
||||
|
||||
model.fit(X, y)
|
||||
predictions = model.predict(X)
|
||||
|
||||
assert len(predictions) == len(X)
|
||||
assert model._is_fitted
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user