From 9f95be56a08e42d775da10afbdc25b2ce769765e Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Mon, 23 Feb 2026 01:37:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(models):=20=E5=AE=9E=E7=8E=B0=E6=9C=BA?= =?UTF-8?q?=E5=99=A8=E5=AD=A6=E4=B9=A0=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=A1=86=E6=9E=B6=20-=20=E6=B7=BB=E5=8A=A0=E6=A0=B8=E5=BF=83?= =?UTF-8?q?=E6=8A=BD=E8=B1=A1=EF=BC=9AProcessor=E3=80=81Model=E3=80=81Spli?= =?UTF-8?q?tter=E3=80=81Metric=20=E5=9F=BA=E7=B1=BB=20-=20=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E9=98=B6=E6=AE=B5=E6=84=9F=E7=9F=A5=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=EF=BC=88TRAIN/TEST/ALL=EF=BC=89=EF=BC=8C=E9=98=B2=E6=AD=A2?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=B3=84=E9=9C=B2=20-=20=E5=86=85=E7=BD=AE?= =?UTF-8?q?=208=20=E4=B8=AA=E6=95=B0=E6=8D=AE=E5=A4=84=E7=90=86=E5=99=A8?= =?UTF-8?q?=E5=92=8C=203=20=E7=A7=8D=E6=97=B6=E5=BA=8F=E5=88=92=E5=88=86?= =?UTF-8?q?=E7=AD=96=E7=95=A5=20-=20=E6=94=AF=E6=8C=81=20LightGBM=E3=80=81?= =?UTF-8?q?CatBoost=20=E6=A8=A1=E5=9E=8B=20-=20PluginRegistry=20=E8=A3=85?= =?UTF-8?q?=E9=A5=B0=E5=99=A8=E6=B3=A8=E5=86=8C=EF=BC=8C=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=BC=8F=E6=9E=B6=E6=9E=84=20-=2022=20=E4=B8=AA=E5=8D=95?= =?UTF-8?q?=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 292 ++++- docs/factor_implementation_plan.md | 846 --------------- docs/hdf5_to_duckdb_migration.md | 13 +- docs/ml_framework_design.md | 1472 ++++++++++++++++++++++++++ docs/test_report_duckdb_migration.md | 2 + src/models/__init__.py | 86 ++ src/models/core/__init__.py | 30 + src/models/core/base.py | 351 ++++++ src/models/core/splitter.py | 222 ++++ src/models/models/__init__.py | 11 + src/models/models/models.py | 210 ++++ src/models/pipeline.py | 70 ++ src/models/processors/__init__.py | 21 + src/models/processors/processors.py | 238 +++++ src/models/registry.py | 297 ++++++ tests/models/test_core.py | 478 +++++++++ 16 files changed, 3774 insertions(+), 865 deletions(-) delete mode 100644 docs/factor_implementation_plan.md create mode 100644 docs/ml_framework_design.md create mode 100644 src/models/__init__.py create mode 100644 src/models/core/__init__.py create mode 100644 src/models/core/base.py create mode 100644 src/models/core/splitter.py create mode 100644 src/models/models/__init__.py create mode 100644 src/models/models/models.py create mode 100644 src/models/pipeline.py create mode 100644 src/models/processors/__init__.py create mode 100644 src/models/processors/processors.py create mode 100644 src/models/registry.py create mode 100644 tests/models/test_core.py diff --git a/README.md b/README.md index 3954294..74022d3 100644 --- a/README.md +++ b/README.md @@ -1,40 +1,300 @@ # ProStock -A股量化投资框架 +A股量化投资框架 - 从数据获取到模型训练的完整解决方案 + +## 功能特性 + +### 1. 数据层 (src/data/) +- **多源数据接入**: Tushare API 集成,支持日线、股票基础信息、交易日历 +- **DuckDB 存储**: 高性能嵌入式数据库,支持 SQL 查询下推 +- **智能同步**: 增量/全量同步策略,自动检测数据更新需求 +- **速率控制**: 令牌桶算法实现 API 限流 +- **并发优化**: ThreadPoolExecutor 多线程数据获取 + +### 2. 因子层 (src/factors/) +- **类型安全**: 严格的截面因子 vs 时序因子区分 +- **防泄露机制**: 框架层面防止未来数据和跨股票数据泄露 +- **因子组合**: 支持因子加减乘除和标量运算 +- **高性能计算**: Polars 向量化操作,零拷贝数据导出 +- **灵活扩展**: 基类抽象便于自定义因子 + +### 3. 模型层 (src/models/) +- **插件架构**: 装饰器注册机制,新模型即插即用 +- **阶段感知**: 训练/测试阶段区分,防止数据泄露 +- **多模型支持**: LightGBM、CatBoost 等模型统一接口 +- **数据处理**: 缺失值处理、缩尾、标准化、中性化等 +- **时序划分**: WalkForward、ExpandingWindow 等时间序列划分策略 + +## 项目结构 + +``` +ProStock/ +├── src/ +│ ├── config/ # 配置管理 +│ │ ├── settings.py # pydantic-settings 配置 +│ │ └── __init__.py +│ │ +│ ├── data/ # 数据获取与存储 +│ │ ├── api_wrappers/ # Tushare API 封装 +│ │ │ ├── api_daily.py # 日线数据接口 +│ │ │ ├── api_stock_basic.py # 股票基础信息 +│ │ │ └── api_trade_cal.py # 交易日历 +│ │ ├── client.py # Tushare 客户端(含限流) +│ │ ├── config.py # 数据模块配置 +│ │ ├── db_manager.py # DuckDB 表管理和同步 +│ │ ├── db_inspector.py # 数据库信息查看工具 +│ │ ├── rate_limiter.py # 令牌桶限流器 +│ │ ├── storage.py # DuckDB 存储核心 +│ │ ├── sync.py # 数据同步主逻辑 +│ │ └── __init__.py +│ │ +│ ├── factors/ # 因子计算框架 +│ │ ├── base.py # 因子基类(截面/时序) +│ │ ├── composite.py # 组合因子和标量运算 +│ │ ├── data_loader.py # DuckDB 数据加载器 +│ │ ├── data_spec.py # 数据规格定义 +│ │ ├── engine.py # 因子执行引擎 +│ │ └── __init__.py +│ │ +│ ├── models/ # 模型训练框架 +│ │ ├── core/ # 核心抽象 +│ │ │ ├── base.py # 处理器/模型/划分基类 +│ │ │ └── splitter.py # 时间序列划分策略 +│ │ ├── models/ # 模型实现 +│ │ │ └── models.py # LightGBM、CatBoost +│ │ ├── processors/ # 数据处理器 +│ │ │ └── processors.py # 标准化、缩尾、中性化等 +│ │ ├── pipeline.py # 处理流水线 +│ │ ├── registry.py # 插件注册中心 +│ │ └── __init__.py +│ │ +│ └── __init__.py +│ +├── docs/ # 文档 +│ ├── factor_framework_design.md # 因子框架设计 +│ ├── ml_framework_design.md # 模型框架设计 +│ ├── db_sync_guide.md # 数据同步指南 +│ └── ... +│ +├── data/ # 数据存储(DuckDB) +│ ├── prostock.db # 主数据库文件 +│ └── stock_basic.csv # 股票基础信息缓存 +│ +├── config/ # 配置文件 +│ └── .env.local # 环境变量(API Token等) +│ +└── tests/ # 测试文件 + ├── test_sync.py + └── factors/ +``` ## 快速开始 -### 安装依赖 +### 1. 安装依赖 -**⚠️ 本项目强制使用 uv 作为 Python 包管理器,禁止直接使用 `python` 或 `pip` 命令。** +**⚠️ 本项目强制使用 uv 作为 Python 包管理器** ```bash -# 使用 uv 安装(必须) +# 安装 uv (如果尚未安装) +pip install uv + +# 安装项目依赖 uv pip install -e . ``` -### 数据同步 +### 2. 配置环境变量 + +创建 `config/.env.local` 文件: ```bash -# 增量同步(自动从最新日期开始) +TUSHARE_TOKEN=your_tushare_token_here +DATA_PATH=data +RATE_LIMIT=100 +THREADS=10 +``` + +### 3. 数据同步 + +```bash +# 首次同步 - 全量同步(从20180101开始) +uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" + +# 日常同步 - 增量同步(自动从最新日期开始) uv run python -c "from src.data.sync import sync_all; sync_all()" -# 全量同步(从 20180101 开始) -uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" +# 预览同步(检查需要同步的数据量) +uv run python -c "from src.data.sync import preview_sync; preview_sync()" # 自定义线程数 uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" ``` +### 4. 查看数据库状态 + +```bash +uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()" +``` + +## 使用示例 + +### 因子计算 + +```python +from src.factors import FactorEngine, DataLoader, DataSpec +from src.factors.base import CrossSectionalFactor, TimeSeriesFactor +import polars as pl + +# 自定义截面因子:PE排名 +class PERankFactor(CrossSectionalFactor): + name = "pe_rank" + data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)] + + def compute(self, data) -> pl.Series: + cs = data.get_cross_section() + return cs["pe"].rank() + +# 自定义时序因子:20日移动平均 +class MA20Factor(TimeSeriesFactor): + name = "ma20" + data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)] + + def compute(self, data) -> pl.Series: + return data.get_column("close").rolling_mean(window_size=20) + +# 执行计算 +loader = DataLoader(data_dir="data") +engine = FactorEngine(loader) + +# 计算截面因子 +pe_rank = PERankFactor() +result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131") + +# 计算时序因子 +ma20 = MA20Factor() +result2 = engine.compute(ma20, stock_codes=["000001.SZ"], + start_date="20240101", end_date="20240131") + +# 因子组合 +combined = 0.5 * pe_rank + 0.3 * ma20 +``` + +### 模型训练 + +```python +from src.models import PluginRegistry, ProcessingPipeline +from src.models.core import PipelineStage +import polars as pl + +# 创建处理流水线 +pipeline = ProcessingPipeline([ + PluginRegistry.get_processor("dropna")(), + PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99), + PluginRegistry.get_processor("standard_scaler")(), +]) + +# 准备数据 +data = pl.read_csv("features.csv") # 包含特征和标签 + +# 划分训练/测试集 +from src.models.core import WalkForwardSplit +splitter = WalkForwardSplit(train_window=252, test_window=21) + +# 获取 LightGBM 模型 +ModelClass = PluginRegistry.get_model("lightgbm") +model = ModelClass(task_type="regression", params={"n_estimators": 100}) + +# 训练循环 +for train_idx, test_idx in splitter.split(data): + train_data = data[train_idx] + test_data = data[test_idx] + + # 数据处理 + X_train = pipeline.fit_transform(train_data.drop("target")) + X_test = pipeline.transform(test_data.drop("target")) + y_train = train_data["target"] + y_test = test_data["target"] + + # 训练模型 + model.fit(X_train, y_train) + predictions = model.predict(X_test) +``` + +## 核心设计 + +### 1. 数据防泄露机制 + +**截面因子 (CrossSectionalFactor)**: +- 防止日期泄露:每天只传入 `[T-lookback+1, T]` 数据 +- 允许股票间比较:传入当天所有股票数据 +- 典型应用:PE排名、市值分位数、当日收益率排名 + +**时序因子 (TimeSeriesFactor)**: +- 防止股票泄露:每只股票单独计算 +- 允许历史数据访问:传入完整时间序列 +- 典型应用:移动平均线、RSI、历史波动率 + +### 2. 插件注册机制 + +```python +from src.models.registry import PluginRegistry + +# 注册自定义处理器 +@PluginRegistry.register_processor("my_processor") +class MyProcessor(BaseProcessor): + stage = PipelineStage.TRAIN + + def fit(self, data): + # 学习参数 + return self + + def transform(self, data): + # 转换数据 + return data + +# 使用 +processor_class = PluginRegistry.get_processor("my_processor") +processor = processor_class() +``` + +### 3. 数据同步策略 + +**智能增量同步**: +```python +from src.data.db_manager import SyncManager + +manager = SyncManager() +result = manager.sync( + table_name="daily", + fetch_func=get_daily, + start_date="20240101", + end_date="20240131" +) +# 自动检测:表不存在→全量,表存在→增量 +``` + ## 文档 -- [数据同步模块](docs/data_sync.md) - 详细的数据同步使用说明 +- [因子框架设计](docs/factor_framework_design.md) - 因子计算架构详解 +- [模型框架设计](docs/ml_framework_design.md) - 模型训练架构详解 +- [数据同步指南](docs/db_sync_guide.md) - DuckDB 数据同步 API 说明 +- [代码审查报告](docs/code_review_factors_20260222.md) - 因子框架代码审查 -## 模块 +## 开发规范 -- `data/` - 数据获取 -- `factors/` - 因子生成 -- `models/` - 模型训练 -- `backtest/` - 回测分析 -- `utils/` - 工具函数 -- `scripts/` - 运行脚本 +- **Python 版本**: 3.10+ +- **代码风格**: Google 风格文档字符串 +- **类型提示**: 强制类型注解 +- **测试**: pytest 框架 +- **包管理**: uv (禁止直接使用 pip/python) + +## 技术栈 + +- **数据处理**: Polars, Pandas, NumPy +- **数据存储**: DuckDB (嵌入式 OLAP 数据库) +- **API 接口**: Tushare Pro +- **机器学习**: LightGBM, CatBoost, scikit-learn +- **配置管理**: pydantic-settings + +## 许可证 + +MIT License diff --git a/docs/factor_implementation_plan.md b/docs/factor_implementation_plan.md deleted file mode 100644 index b966a0c..0000000 --- a/docs/factor_implementation_plan.md +++ /dev/null @@ -1,846 +0,0 @@ -# ProStock 因子框架实现计划 - -## 目录结构 - -``` -src/factors/ -├── __init__.py # 导出主要类 -├── data_spec.py # Phase 1: 数据类型定义 -├── base.py # Phase 2: 因子基类 -├── composite.py # Phase 2: 组合因子 -├── data_loader.py # Phase 3: 数据加载 -├── engine.py # Phase 4: 执行引擎 -└── builtin/ # Phase 5: 内置因子库 - ├── __init__.py - ├── momentum.py # 截面动量因子 - ├── technical.py # 时序技术指标 - └── value.py # 截面估值因子 - -tests/factors/ # Phase 6-7: 测试 -├── __init__.py -├── test_data_spec.py # 数据类型测试 -├── test_base.py # 因子基类测试 -├── test_composite.py # 组合因子测试 -├── test_data_loader.py # 数据加载测试 -├── test_engine.py # 引擎测试 -├── test_builtin.py # 内置因子测试 -└── test_integration.py # 集成测试 -``` - ---- - -## Phase 1: 数据类型定义 (data_spec.py) - -### 1.1 DataSpec - 数据需求规格 - -**实现要求:** -```python -@dataclass(frozen=True) -class DataSpec: - """ - 数据需求规格说明 - - Args: - source: H5 文件名(如 "daily", "fundamental") - columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date" - lookback_days: 需要回看的天数(包含当日) - - 1 表示只需要当日数据 [T] - - 5 表示需要 [T-4, T] 共5天 - - 20 表示需要 [T-19, T] 共20天 - """ - source: str - columns: List[str] - lookback_days: int = 1 -``` - -**约束验证:** -- `lookback_days >= 1`(至少包含当日) -- `columns` 必须包含 `ts_code` 和 `trade_date` -- `source` 不能为空字符串 - -**测试需求:** -- [ ] 测试有效 DataSpec 创建 -- [ ] 测试 `lookback_days < 1` 时抛出 ValueError -- [ ] 测试缺少 `ts_code` 或 `trade_date` 时抛出 ValueError -- [ ] 测试空 `source` 时抛出 ValueError -- [ ] 测试 frozen 特性(创建后不可修改) - ---- - -### 1.2 FactorContext - 计算上下文 - -**实现要求:** -```python -@dataclass -class FactorContext: - """ - 因子计算上下文 - - 由 FactorEngine 自动注入,因子开发者可通过 data.context 访问 - - Attributes: - current_date: 当前计算日期 YYYYMMDD(截面因子使用) - current_stock: 当前计算股票代码(时序因子使用) - trade_dates: 交易日历列表(可选,用于对齐) - """ - current_date: Optional[str] = None - current_stock: Optional[str] = None - trade_dates: Optional[List[str]] = None -``` - -**测试需求:** -- [ ] 测试默认值创建 -- [ ] 测试完整参数创建 -- [ ] 测试 dataclass 自动生成的方法 - ---- - -### 1.3 FactorData - 数据容器 - -**实现要求:** -```python -class FactorData: - """ - 提供给因子的数据容器 - - 封装底层 Polars DataFrame,提供安全的数据访问接口 - """ - - def __init__(self, df: pl.DataFrame, context: FactorContext): - self._df = df - self._context = context - - def get_column(self, col: str) -> pl.Series: - """ - 获取指定列的数据 - - - 截面因子:获取当天所有股票的该列值 - - 时序因子:获取该股票时间序列的该列值 - - Args: - col: 列名 - - Returns: - Polars Series - - Raises: - KeyError: 列不存在 - """ - pass - - def filter_by_date(self, date: str) -> "FactorData": - """ - 按日期过滤数据,返回新的 FactorData - - 主要用于截面因子获取特定日期的数据 - - Args: - date: YYYYMMDD 格式的日期 - - Returns: - 过滤后的 FactorData - """ - pass - - def get_cross_section(self) -> pl.DataFrame: - """ - 获取当前日期的截面数据 - - 仅适用于截面因子,返回 current_date 当天的所有股票数据 - - Returns: - DataFrame 包含当前日期的所有股票 - - Raises: - ValueError: current_date 未设置(非截面因子场景) - """ - pass - - def to_polars(self) -> pl.DataFrame: - """获取底层的 Polars DataFrame(高级用法)""" - pass - - @property - def context(self) -> FactorContext: - """获取计算上下文""" - pass - - def __len__(self) -> int: - """返回数据行数""" - pass -``` - -**测试需求:** -- [ ] 测试 `get_column()` 返回正确 Series -- [ ] 测试 `get_column()` 列不存在时抛出 KeyError -- [ ] 测试 `filter_by_date()` 返回正确过滤结果 -- [ ] 测试 `filter_by_date()` 日期不存在时返回空 DataFrame -- [ ] 测试 `get_cross_section()` 返回 current_date 当天的数据 -- [ ] 测试 `get_cross_section()` current_date 为 None 时抛出 ValueError -- [ ] 测试 `to_polars()` 返回原始 DataFrame -- [ ] 测试 `context` 属性返回正确上下文 -- [ ] 测试 `__len__()` 返回正确行数 - ---- - -## Phase 2: 因子基类 (base.py, composite.py) - -### 2.1 BaseFactor - 抽象基类 - -**实现要求:** -```python -class BaseFactor(ABC): - """ - 因子基类 - 定义通用接口 - - 所有因子必须继承此类,并声明以下类属性: - - name: 因子唯一标识(snake_case) - - factor_type: "cross_sectional" 或 "time_series" - - data_specs: List[DataSpec] 数据需求列表 - - 可选声明: - - category: 因子分类(默认 "default") - - description: 因子描述 - """ - - # 必须声明的类属性 - name: str = "" - factor_type: str = "" # "cross_sectional" | "time_series" - data_specs: List[DataSpec] = field(default_factory=list) - - # 可选声明的类属性 - category: str = "default" - description: str = "" - - def __init_subclass__(cls, **kwargs): - """ - 子类创建时验证必须属性 - - 验证项: - 1. name 必须是非空字符串 - 2. factor_type 必须是 "cross_sectional" 或 "time_series" - 3. data_specs 必须是非空列表 - """ - pass - - def __init__(self, **params): - """ - 初始化因子参数 - - 子类可通过 __init__ 接收参数化配置,如 MA(period=20) - """ - self.params = params - self._validate_params() - - def _validate_params(self): - """ - 验证参数有效性 - - 子类可覆盖此方法进行自定义验证 - """ - pass - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """ - 核心计算逻辑 - 子类必须实现 - - Args: - data: 安全的数据容器,已根据因子类型裁剪 - - Returns: - 计算得到的因子值 Series - """ - pass - - # ========== 因子组合运算符 ========== - - def __add__(self, other: "BaseFactor") -> "CompositeFactor": - """因子相加:f1 + f2(要求同类型)""" - pass - - def __sub__(self, other: "BaseFactor") -> "CompositeFactor": - """因子相减:f1 - f2(要求同类型)""" - pass - - def __mul__(self, other: "BaseFactor") -> "CompositeFactor": - """因子相乘:f1 * f2(要求同类型)""" - pass - - def __truediv__(self, other: "BaseFactor") -> "CompositeFactor": - """因子相除:f1 / f2(要求同类型)""" - pass - - def __rmul__(self, scalar: float) -> "ScalarFactor": - """标量乘法:0.5 * f1""" - pass -``` - -**测试需求:** -- [ ] 测试有效子类创建通过验证 -- [ ] 测试缺少 `name` 时抛出 ValueError -- [ ] 测试 `name` 为空字符串时抛出 ValueError -- [ ] 测试缺少 `factor_type` 时抛出 ValueError -- [ ] 测试无效的 `factor_type`(非 cs/ts)时抛出 ValueError -- [ ] 测试缺少 `data_specs` 时抛出 ValueError -- [ ] 测试 `data_specs` 为空列表时抛出 ValueError -- [ ] 测试 `compute()` 抽象方法强制子类实现 -- [ ] 测试参数化初始化 `params` 正确存储 -- [ ] 测试 `_validate_params()` 被调用 - ---- - -### 2.2 CrossSectionalFactor - 日期截面因子 - -**实现要求:** -```python -class CrossSectionalFactor(BaseFactor): - """ - 日期截面因子基类 - - 计算逻辑:在每个交易日,对所有股票进行横向计算 - - 防泄露边界: - - ❌ 禁止访问未来日期的数据(日期泄露) - - ✅ 允许访问当前日期的所有股票数据 - - 数据传入: - - compute() 接收的是 [T-lookback+1, T] 的数据 - - 包含 lookback_days 的历史数据(用于时序计算后再截面) - """ - - factor_type: str = "cross_sectional" - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """ - 计算截面因子值 - - Args: - data: FactorData,包含 [T-lookback+1, T] 的截面数据 - 格式:DataFrame[ts_code, trade_date, col1, col2, ...] - - Returns: - pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量) - - 示例: - def compute(self, data): - # 获取当前日期的截面 - cs = data.get_cross_section() - # 计算市值排名 - return cs['market_cap'].rank() - """ - pass -``` - -**测试需求:** -- [ ] 测试 `factor_type` 自动设置为 "cross_sectional" -- [ ] 测试子类必须实现 `compute()` -- [ ] 测试 `compute()` 返回类型为 pl.Series - ---- - -### 2.3 TimeSeriesFactor - 时间序列因子 - -**实现要求:** -```python -class TimeSeriesFactor(BaseFactor): - """ - 时间序列因子基类(股票截面) - - 计算逻辑:对每只股票,在其时间序列上进行纵向计算 - - 防泄露边界: - - ❌ 禁止访问其他股票的数据(股票泄露) - - ✅ 允许访问该股票的完整历史数据 - - 数据传入: - - compute() 接收的是单只股票的完整时间序列 - - 包含该股票在 [start_date, end_date] 范围内的所有数据 - """ - - factor_type: str = "time_series" - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """ - 计算时间序列因子值 - - Args: - data: FactorData,包含单只股票的完整时间序列 - 格式:DataFrame[ts_code, trade_date, col1, col2, ...] - - Returns: - pl.Series: 该股票在各日期的因子值(长度 = 日期数量) - - 示例: - def compute(self, data): - series = data.get_column("close") - return series.rolling_mean(window_size=self.params['period']) - """ - pass -``` - -**测试需求:** -- [ ] 测试 `factor_type` 自动设置为 "time_series" -- [ ] 测试子类必须实现 `compute()` -- [ ] 测试 `compute()` 返回类型为 pl.Series - ---- - -### 2.4 CompositeFactor - 组合因子 (composite.py) - -**实现要求:** -```python -class CompositeFactor(BaseFactor): - """ - 组合因子 - 用于实现因子间的数学运算 - - 约束:左右因子必须是同类型(同为截面或同为时序) - """ - - def __init__(self, left: BaseFactor, right: BaseFactor, op: str): - """ - 创建组合因子 - - Args: - left: 左操作数因子 - right: 右操作数因子 - op: 运算符,支持 '+', '-', '*', '/' - - Raises: - ValueError: 左右因子类型不一致 - ValueError: 不支持的运算符 - """ - pass - - def _merge_data_specs(self) -> List[DataSpec]: - """ - 合并左右因子的数据需求 - - 策略: - 1. 相同 source 和 columns 的 DataSpec 合并 - 2. lookback_days 取最大值 - """ - pass - - def compute(self, data: FactorData) -> pl.Series: - """ - 执行组合运算 - - 流程: - 1. 分别计算 left 和 right 的值 - 2. 根据 op 执行运算 - 3. 返回结果 - """ - pass -``` - -**测试需求:** -- [ ] 测试同类型因子组合成功(cs + cs) -- [ ] 测试同类型因子组合成功(ts + ts) -- [ ] 测试不同类型因子组合抛出 ValueError(cs + ts) -- [ ] 测试无效运算符抛出 ValueError -- [ ] 测试 `_merge_data_specs()` 正确合并(相同 source) -- [ ] 测试 `_merge_data_specs()` 正确合并(不同 source) -- [ ] 测试 `_merge_data_specs()` lookback 取最大值 -- [ ] 测试 `compute()` 执行正确的数学运算 - ---- - -### 2.5 ScalarFactor - 标量运算因子 (composite.py) - -**实现要求:** -```python -class ScalarFactor(BaseFactor): - """ - 标量运算因子 - - 支持:scalar * factor, factor * scalar(通过 __rmul__) - """ - - def __init__(self, factor: BaseFactor, scalar: float, op: str): - """ - 创建标量运算因子 - - Args: - factor: 基础因子 - scalar: 标量值 - op: 运算符,支持 '*', '+' - """ - pass - - def compute(self, data: FactorData) -> pl.Series: - """执行标量运算""" - pass -``` - -**测试需求:** -- [ ] 测试标量乘法 `0.5 * factor` -- [ ] 测试标量乘法 `factor * 0.5` -- [ ] 测试标量加法(如支持) -- [ ] 测试继承基础因子的 data_specs -- [ ] 测试 `compute()` 返回正确缩放后的值 - ---- - -## Phase 3: 数据加载 (data_loader.py) - -### 3.1 DataLoader - 数据加载器 - -**实现要求:** -```python -class DataLoader: - """ - 数据加载器 - 负责从 HDF5 安全加载数据 - - 功能: - 1. 多文件聚合:合并多个 H5 文件的数据 - 2. 列选择:只加载需要的列 - 3. 原始数据缓存:避免重复读取 - """ - - def __init__(self, data_dir: str): - """ - 初始化 DataLoader - - Args: - data_dir: HDF5 文件所在目录 - """ - self.data_dir = Path(data_dir) - self._cache: Dict[str, pl.DataFrame] = {} - - def load( - self, - specs: List[DataSpec], - date_range: Optional[Tuple[str, str]] = None - ) -> pl.DataFrame: - """ - 加载并聚合多个 H5 文件的数据 - - 流程: - 1. 对每个 DataSpec: - a. 检查缓存,命中则直接使用 - b. 未命中则读取 HDF5(通过 pandas) - c. 转换为 Polars DataFrame - d. 按 date_range 过滤 - e. 存入缓存 - 2. 合并多个 DataFrame(按 trade_date 和 ts_code join) - - Args: - specs: 数据需求规格列表 - date_range: 日期范围限制 (start_date, end_date),可选 - - Returns: - 合并后的 Polars DataFrame - - Raises: - FileNotFoundError: H5 文件不存在 - KeyError: 列不存在于文件中 - """ - pass - - def clear_cache(self): - """清空缓存""" - pass - - def _read_h5(self, source: str) -> pl.DataFrame: - """ - 读取单个 H5 文件 - - 实现:使用 pandas.read_hdf(),然后 pl.from_pandas() - """ - pass -``` - -**测试需求:** -- [ ] 测试从单个 H5 文件加载数据 -- [ ] 测试从多个 H5 文件加载并合并 -- [ ] 测试列选择(只加载需要的列) -- [ ] 测试缓存机制(第二次加载更快) -- [ ] 测试 `clear_cache()` 清空缓存 -- [ ] 测试按 date_range 过滤 -- [ ] 测试文件不存在时抛出 FileNotFoundError -- [ ] 测试列不存在时抛出 KeyError - ---- - -## Phase 4: 执行引擎 (engine.py) - -### 4.1 FactorEngine - 因子执行引擎 - -**实现要求:** -```python -class FactorEngine: - """ - 因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略 - - 核心职责: - 1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据 - 2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列 - """ - - def __init__(self, data_loader: DataLoader): - """ - 初始化引擎 - - Args: - data_loader: 数据加载器实例 - """ - self.data_loader = data_loader - - def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame: - """ - 统一的计算入口 - - 根据 factor_type 分发到具体方法: - - "cross_sectional" -> _compute_cross_sectional() - - "time_series" -> _compute_time_series() - - Args: - factor: 要计算的因子 - **kwargs: 额外参数,根据因子类型不同: - - 截面因子: start_date, end_date - - 时序因子: stock_codes, start_date, end_date - - Returns: - DataFrame[trade_date, ts_code, factor_name] - """ - pass -``` - -**测试需求:** -- [ ] 测试 `compute()` 正确分发给截面计算 -- [ ] 测试 `compute()` 正确分发给时序计算 -- [ ] 测试无效 factor_type 时抛出 ValueError - ---- - -### 4.2 截面计算(防止日期泄露) - -**实现要求:** -```python -def _compute_cross_sectional( - self, - factor: CrossSectionalFactor, - start_date: str, - end_date: str -) -> pl.DataFrame: - """ - 执行日期截面计算 - - 防泄露策略: - - 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来) - - 允许股票间比较:传入当天所有股票的数据 - - 计算流程: - 1. 计算 max_lookback,确定数据起始日期 - 2. 一次性加载 [start-max_lookback+1, end] 的所有数据 - 3. 对每个日期 T in [start_date, end_date]: - a. 裁剪数据到 [T-lookback+1, T] - b. 创建 FactorData(current_date=T) - c. 调用 factor.compute() - d. 收集结果 - 4. 合并所有日期的结果 - - 返回 DataFrame 格式: - ┌────────────┬──────────┬──────────────┐ - │ trade_date │ ts_code │ factor_name │ - ├────────────┼──────────┼──────────────┤ - │ 20240101 │ 000001.SZ│ 0.5 │ - │ 20240101 │ 000002.SZ│ 0.3 │ - └────────────┴──────────┴──────────────┘ - """ - pass -``` - -**测试需求(防泄露验证):** -- [ ] 测试数据裁剪正确(传入 [T-lookback+1, T]) -- [ ] 测试不包含未来日期 T+1 的数据 -- [ ] 测试每个日期独立计算 -- [ ] 测试结果包含所有日期和所有股票 -- [ ] 测试结果 DataFrame 格式正确 -- [ ] 测试多个 DataSpec 时 lookback 取最大值 - ---- - -### 4.3 时序计算(防止股票泄露) - -**实现要求:** -```python -def _compute_time_series( - self, - factor: TimeSeriesFactor, - stock_codes: List[str], - start_date: str, - end_date: str -) -> pl.DataFrame: - """ - 执行时间序列计算 - - 防泄露策略: - - 防止股票泄露:每只股票单独计算,传入该股票的完整序列 - - 允许访问历史数据:时序计算需要历史数据 - - 计算流程: - 1. 计算 max_lookback,确定数据起始日期 - 2. 一次性加载 [start-max_lookback+1, end] 的所有数据 - 3. 对每只股票 S in stock_codes: - a. 过滤出 S 的数据(防止股票泄露) - b. 创建 FactorData(current_stock=S) - c. 调用 factor.compute()(向量化计算整个序列) - d. 收集结果 - 4. 合并所有股票的结果 - - 性能优势: - - 使用 Polars 的 rolling_mean 等向量化操作 - - 每只股票只计算一次,无重复计算 - - 返回 DataFrame 格式: - ┌────────────┬──────────┬──────────────┐ - │ trade_date │ ts_code │ factor_name │ - ├────────────┼──────────┼──────────────┤ - │ 20240101 │ 000001.SZ│ 10.5 │ - │ 20240102 │ 000001.SZ│ 10.6 │ - └────────────┴──────────┴──────────────┘ - """ - pass -``` - -**测试需求(防泄露验证):** -- [ ] 测试每只股票只看到自己的数据 -- [ ] 测试不包含其他股票的数据 -- [ ] 测试传入的是完整时间序列(向量化计算) -- [ ] 测试结果包含所有股票和所有日期 -- [ ] 测试结果 DataFrame 格式正确 -- [ ] 测试股票不在数据中时跳过(或填充 null) - ---- - -## Phase 5: 内置因子库 (builtin/) - -### 5.1 momentum.py - 截面动量因子 - -**实现因子:** - -1. **ReturnRankFactor** - 当日收益率排名 -```python -class ReturnRankFactor(CrossSectionalFactor): - """当日收益率排名因子""" - name = "return_rank" - data_specs = [DataSpec("daily", ["close"], lookback_days=2)] # 需要2天计算收益率 - - def compute(self, data): - # 获取当前日期截面 - cs = data.get_cross_section() - # 需要前1天和当天的收盘价,lookback=2 保证数据包含 [T-1, T] - # 这里假设 data 已经包含历史,实际计算需要 groupby 处理 - pass -``` - -**测试需求:** -- [ ] 测试收益率计算正确 -- [ ] 测试排名计算正确 -- [ ] 测试无数据时返回 null - -2. **MomentumFactor** - 过去 N 日涨幅排名 - ---- - -### 5.2 technical.py - 时序技术指标 - -**实现因子:** - -1. **MovingAverageFactor** - 移动平均线 -```python -class MovingAverageFactor(TimeSeriesFactor): - """移动平均线因子""" - name = "ma" - - def __init__(self, period: int = 20): - super().__init__(period=period) - self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)] - - def compute(self, data): - return data.get_column("close").rolling_mean(self.params["period"]) -``` - -**测试需求:** -- [ ] 测试 MA20 计算正确 -- [ ] 测试前19天返回 null(Polars 默认行为) -- [ ] 测试参数 period 生效 - -2. **RSIFactor** - RSI 指标 -3. **MACDFactor** - MACD 指标 - ---- - -### 5.3 value.py - 截面估值因子 - -**实现因子:** -1. **PERankFactor** - PE 行业分位数 -2. **PBFactor** - PB 排名 - ---- - -## Phase 6-7: 测试策略 - -### 测试金字塔 - -``` - /\ - / \ - / 集成\ tests/factors/test_integration.py - /────────\ - / 引擎 \ tests/factors/test_engine.py - /────────────\ - / 基类/组合因子 \ tests/factors/test_base.py, test_composite.py - /────────────────\ - / 数据加载/类型 \ tests/factors/test_data_loader.py, test_data_spec.py -/──────────────────────\ -``` - -### 测试数据准备 - -创建 `tests/fixtures/` 目录,包含: -- `sample_daily.h5`: 少量股票的日线数据(用于测试) -- `sample_fundamental.h5`: 基本面数据 - -### 关键测试场景 - -1. **防泄露测试(核心)** - - 截面因子:验证 compute() 中无法访问未来日期 - - 时序因子:验证 compute() 中无法访问其他股票 - -2. **边界测试** - - lookback_days = 1(最小值) - - 数据起始点(前 N 天为 null) - - 空数据/停牌处理 - -3. **性能测试(可选)** - - 大数据量下的内存占用 - - 缓存命中率 - ---- - -## 实现状态 - -| Phase | 状态 | 完成日期 | 测试覆盖 | -|-------|------|----------|----------| -| Phase 1: 数据类型定义 | ✅ 已完成 | 2026-02-21 | 27 tests passed | -| Phase 2: 因子基类 | ✅ 已完成 | 2026-02-21 | 49 tests passed | -| Phase 3: 数据加载 | ✅ 已完成 | 2026-02-21 | 11 tests passed | -| Phase 4: 执行引擎 | ✅ 已完成 | 2026-02-22 | 10 tests passed | -| Phase 5: 内置因子库 | 📝 待开发 | - | - | -| Phase 6-7: 测试文档 | ✅ 已完成 | 2026-02-22 | 76 tests total | - ---- - -## 实现顺序建议 - -1. **Week 1**: Phase 1-2(数据类型 + 基类) -2. **Week 2**: Phase 3-4(DataLoader + Engine)✅ **已完成** -3. **Week 3**: Phase 5(内置因子) -4. **Week 4**: Phase 6-7(测试 + 文档) - -每个 Phase 完成后运行对应测试,确保质量。 diff --git a/docs/hdf5_to_duckdb_migration.md b/docs/hdf5_to_duckdb_migration.md index 47c7c14..a03a96c 100644 --- a/docs/hdf5_to_duckdb_migration.md +++ b/docs/hdf5_to_duckdb_migration.md @@ -1,10 +1,17 @@ -# ProStock HDF5 到 DuckDB 迁移方案与计划 +# ProStock HDF5 到 DuckDB 迁移方案 -**文档版本**: v1.0 +**文档版本**: v1.1 **创建日期**: 2026-02-22 -**状态**: 待审批 +**完成日期**: 2026-02-22 +**状态**: ✅ 已完成 **影响范围**: data 模块、factors 模块、相关文档 +## 相关文档 + + [DuckDB 数据同步指南](./db_sync_guide.md) - 同步 API 使用说明 + [迁移测试报告](./test_report_duckdb_migration.md) - 测试验证结果 + + --- ## 目录 diff --git a/docs/ml_framework_design.md b/docs/ml_framework_design.md new file mode 100644 index 0000000..def2c50 --- /dev/null +++ b/docs/ml_framework_design.md @@ -0,0 +1,1472 @@ +# ProStock 模型训练框架设计文档 + +## 1. 设计目标与原则 + +### 1.1 核心目标 +- **组件化**:每个阶段(数据获取、处理、训练、评估)都是独立组件 +- **低耦合**:组件间通过标准接口交互,不依赖具体实现 +- **插件式**:新功能通过插件注册,无需修改核心代码 +- **阶段感知**:数据处理区分训练阶段和测试阶段,防止数据泄露 +- **多模型支持**:统一接口支持 LightGBM、CatBoost 等多种模型 +- **多任务支持**:分类、回归、排序三种任务类型 + +### 1.2 设计原则 + +| 原则 | 说明 | +|------|------| +| **单一职责** | 每个组件只做一件事,做好一件事 | +| **开闭原则** | 对扩展开放(插件),对修改封闭(核心) | +| **依赖倒置** | 依赖抽象接口,而非具体实现 | +| **显式优于隐式** | 阶段标记、处理逻辑必须显式声明 | +| **配置驱动** | 通过配置文件或代码配置定义流程,减少硬编码 | + +--- + +## 2. 整体架构 + +### 2.1 架构概览 + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ ML Pipeline Orchestrator │ +│ (流水线编排器 - 配置驱动执行) │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ┌───────────────────────────┼───────────────────────────┐ + ▼ ▼ ▼ +┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +│ Data Source │ │ Data Source │ │ Data Source │ +│ (因子数据) │ │ (行情数据) │ │ (标签数据) │ +└───────┬───────┘ └───────┬───────┘ └───────┬───────┘ + │ │ │ + └──────────────────────────┼──────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Feature Store (特征存储层) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ FactorLoader │ │ LabelLoader │ │ DataMerger │ │ CacheMgr │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Processing Pipeline (处理流水线) │ +│ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ +│ │ Processor │ -> │ Processor │ -> │ Processor │ -> │ ... │ │ +│ │ (阶段:ALL) │ │ (阶段:TRAIN)│ │ (阶段:TEST) │ │ │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ └──────────┘ │ +│ │ +│ 处理器类型: │ +│ - FeatureEncoder: 特征编码(类别编码、数值缩放等) │ +│ - FeatureSelector: 特征选择(相关性过滤、重要性筛选等) │ +│ - OutlierHandler: 异常值处理 │ +│ - MissingValueHandler: 缺失值处理 │ +│ - CustomTransformer: 自定义转换器 │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Train/Test Split (数据划分) │ +│ │ +│ 支持多种划分策略: │ +│ - TimeSeriesSplit: 时间序列划分(防止未来泄露) │ +│ - PurgedKFold: 清除重叠样本的K折交叉验证 │ +│ - EmbargoSplit: embargo 延迟验证 │ +│ - CustomSplit: 自定义划分策略 │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Model Training (模型训练层) │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────┐ │ +│ │ Model Registry │ │ +│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │ +│ │ │ LightGBM │ │CatBoost │ │ XGBoost │ │ Custom │ ... │ │ +│ │ │ Model │ │ Model │ │ Model │ │ Model │ │ │ +│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │ +│ └─────────────────────────────────────────────────────────────────┘ │ +│ │ +│ 任务类型: │ +│ - Classification: 分类任务(上涨/下跌预测) │ +│ - Regression: 回归任务(收益率预测) │ +│ - Ranking: 排序任务(股票排序/选股) │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Evaluation (评估层) │ +│ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌────────────┐ │ +│ │ Metric │ │ Metric │ │ Metric │ │ Analyzer │ │ +│ │ (IC/IR) │ │ (Sharpe) │ │ (Accuracy) │ │ (回测) │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ └────────────┘ │ +│ │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ ResultStore │ │ Report │ │ Visualizer │ │ +│ │ (模型存储) │ │ (报告生成) │ │ (可视化) │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +### 2.2 数据流向图 + +``` +因子DataFrame (Polars) + │ + ▼ +┌──────────────────────┐ +│ Feature Store │ 1. 加载并合并因子、标签、辅助数据 +│ - 列选择 │ 2. 支持按日期/股票过滤 +│ - 数据对齐 │ 3. 缓存机制避免重复加载 +└──────────┬───────────┘ + │ + ▼ +┌──────────────────────┐ +│ Processing Pipeline │ 顺序执行多个处理器 +│ │ 每个处理器标记适用阶段 (ALL/TRAIN/TEST) +│ for processor in pipeline: +│ if processor.stage in [current_stage, ALL]: +│ data = processor.transform(data) +└──────────┬───────────┘ + │ + ▼ +┌──────────────────────┐ +│ Data Splitter │ 时间序列感知的划分策略 +│ - X_train, y_train │ 防止未来泄露 +│ - X_test, y_test │ +└──────────┬───────────┘ + │ + ▼ +┌──────────────────────┐ +│ Model Training │ 统一接口,支持多种模型 +│ - fit(X_train) │ 任务类型: classification/regression/ranking +│ - predict(X_test) │ +└──────────┬───────────┘ + │ + ▼ +┌──────────────────────┐ +│ Evaluation │ 多维度评估 +│ - 预测指标 │ - IC/IR +│ - 回测指标 │ - 分组收益 +│ - 可视化 │ - 累计收益曲线 +└──────────────────────┘ +``` + +--- + +## 3. 核心组件设计 + +### 3.1 基础抽象类 + +#### 3.1.1 PipelineStage (流水线阶段枚举) + +```python +from enum import Enum, auto + +class PipelineStage(Enum): + """流水线阶段标记""" + ALL = auto() # 适用于所有阶段 + TRAIN = auto() # 仅训练阶段 + TEST = auto() # 仅测试阶段 + VALIDATION = auto() # 仅验证阶段 +``` + +#### 3.1.2 BaseProcessor (处理器基类) + +```python +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +import polars as pl + +class BaseProcessor(ABC): + """数据处理器基类 + + 所有数据处理器必须继承此类。 + 关键特性:通过 stage 属性控制处理器在哪些阶段生效。 + + 示例: + >>> class StandardScaler(BaseProcessor): + ... stage = PipelineStage.ALL # 训练和测试都使用 + ... + ... def fit(self, data: pl.DataFrame) -> None: + ... self.mean = data[self.columns].mean() + ... self.std = data[self.columns].std() + ... + ... def transform(self, data: pl.DataFrame) -> pl.DataFrame: + ... return (data - self.mean) / self.std + """ + + # 子类必须定义适用阶段 + stage: PipelineStage = PipelineStage.ALL + + def __init__(self, columns: Optional[list] = None, **params): + """初始化处理器 + + Args: + columns: 要处理的列,None表示所有数值列 + **params: 处理器特定参数 + """ + self.columns = columns + self.params = params + self._is_fitted = False + self._fitted_params: Dict[str, Any] = {} + + @abstractmethod + def fit(self, data: pl.DataFrame) -> "BaseProcessor": + """在训练数据上学习参数 + + 此方法只在训练阶段调用一次。 + 学习到的参数存储在 self._fitted_params 中。 + + Args: + data: 训练数据 + + Returns: + self (支持链式调用) + """ + pass + + @abstractmethod + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + """转换数据 + + 在训练和测试阶段都会被调用。 + 使用 fit() 阶段学习到的参数进行转换。 + + Args: + data: 输入数据 + + Returns: + 转换后的数据 + """ + pass + + def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame: + """先fit再transform的便捷方法""" + return self.fit(data).transform(data) + + def get_fitted_params(self) -> Dict[str, Any]: + """获取学习到的参数(用于保存/加载)""" + return self._fitted_params.copy() + + def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor": + """设置学习到的参数(用于从checkpoint恢复)""" + self._fitted_params = params.copy() + self._is_fitted = True + return self +``` + +#### 3.1.3 BaseModel (模型基类) + +```python +from abc import ABC, abstractmethod +from typing import Literal, Any, Dict +import polars as pl +import numpy as np + +TaskType = Literal["classification", "regression", "ranking"] + +class BaseModel(ABC): + """机器学习模型基类 + + 统一接口支持多种模型(LightGBM, CatBoost, XGBoost等) + 和多种任务类型(分类、回归、排序)。 + + 示例: + >>> model = LightGBMModel( + ... task_type="classification", + ... params={"n_estimators": 100} + ... ) + >>> model.fit(X_train, y_train) + >>> predictions = model.predict(X_test) + """ + + def __init__( + self, + task_type: TaskType, + params: Optional[Dict[str, Any]] = None, + name: Optional[str] = None + ): + """初始化模型 + + Args: + task_type: 任务类型 - "classification", "regression", "ranking" + params: 模型特定参数 + name: 模型名称(用于日志和报告) + """ + self.task_type = task_type + self.params = params or {} + self.name = name or self.__class__.__name__ + self._model: Any = None + self._is_fitted = False + + @abstractmethod + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + X_val: Optional[pl.DataFrame] = None, + y_val: Optional[pl.Series] = None, + **fit_params + ) -> "BaseModel": + """训练模型 + + Args: + X: 特征数据 + y: 目标变量 + X_val: 验证集特征(可选) + y_val: 验证集目标(可选) + **fit_params: 额外的fit参数 + + Returns: + self (支持链式调用) + """ + pass + + @abstractmethod + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测 + + Args: + X: 特征数据 + + Returns: + 预测结果数组 + - classification: 类别标签或概率 + - regression: 连续值 + - ranking: 排序分数 + """ + pass + + def predict_proba(self, X: pl.DataFrame) -> np.ndarray: + """预测概率(仅分类任务) + + Args: + X: 特征数据 + + Returns: + 类别概率数组 [n_samples, n_classes] + """ + raise NotImplementedError("predict_proba only available for classification tasks") + + def get_feature_importance(self) -> Optional[pl.DataFrame]: + """获取特征重要性(如果模型支持) + + Returns: + DataFrame[feature, importance] 或 None + """ + return None + + def save(self, path: str) -> None: + """保存模型到文件""" + import pickle + with open(path, 'wb') as f: + pickle.dump(self, f) + + @classmethod + def load(cls, path: str) -> "BaseModel": + """从文件加载模型""" + import pickle + with open(path, 'rb') as f: + return pickle.load(f) +``` + +#### 3.1.4 BaseSplitter (数据划分基类) + +```python +from abc import ABC, abstractmethod +from typing import Iterator, Tuple, List +import polars as pl + +class BaseSplitter(ABC): + """数据划分策略基类 + + 针对时间序列数据的特殊划分策略,防止未来泄露。 + + 示例: + >>> splitter = TimeSeriesSplit(n_splits=5, gap=5) + >>> for train_idx, test_idx in splitter.split(data): + ... X_train, X_test = X[train_idx], X[test_idx] + """ + + @abstractmethod + def split( + self, + data: pl.DataFrame, + date_col: str = "trade_date" + ) -> Iterator[Tuple[List[int], List[int]]]: + """生成训练/测试索引 + + Args: + data: 完整数据集 + date_col: 日期列名 + + Yields: + (train_indices, test_indices) 元组 + """ + pass + + @abstractmethod + def get_split_dates( + self, + data: pl.DataFrame, + date_col: str = "trade_date" + ) -> List[Tuple[str, str, str, str]]: + """获取划分日期范围 + + Returns: + [(train_start, train_end, test_start, test_end), ...] + """ + pass +``` + +--- + +### 3.2 核心组件 + +#### 3.2.1 FeatureStore (特征存储) + +```python +from typing import List, Optional, Dict +import polars as pl +from pathlib import Path + +class FeatureStore: + """特征存储管理器 + + 负责加载、合并、缓存因子数据。 + 支持从多个数据源(因子、标签、行情)加载并合并。 + """ + + def __init__(self, data_dir: str): + self.data_dir = Path(data_dir) + self._cache: Dict[str, pl.DataFrame] = {} + + def load_factors( + self, + factor_names: List[str], + start_date: Optional[str] = None, + end_date: Optional[str] = None, + stock_codes: Optional[List[str]] = None + ) -> pl.DataFrame: + """加载因子数据 + + Args: + factor_names: 因子名称列表 + start_date: 开始日期 YYYYMMDD + end_date: 结束日期 YYYYMMDD + stock_codes: 股票代码列表(可选) + + Returns: + DataFrame[trade_date, ts_code, factor1, factor2, ...] + """ + pass + + def load_labels( + self, + label_name: str, + forward_period: int = 5, + start_date: Optional[str] = None, + end_date: Optional[str] = None + ) -> pl.DataFrame: + """加载标签数据(未来收益) + + Args: + label_name: 标签名称(如 "return", "rank") + forward_period: 前瞻期(如5天后收益) + start_date: 开始日期 + end_date: 结束日期 + + Returns: + DataFrame[trade_date, ts_code, label] + """ + pass + + def build_dataset( + self, + factor_names: List[str], + label_config: Dict, + date_range: Tuple[str, str], + stock_codes: Optional[List[str]] = None, + additional_cols: Optional[List[str]] = None + ) -> pl.DataFrame: + """构建完整数据集 + + 合并因子、标签、辅助列,并对齐数据。 + + Args: + factor_names: 因子列表 + label_config: 标签配置 {"name": str, "forward_period": int} + date_range: (start_date, end_date) + stock_codes: 限定股票列表 + additional_cols: 额外列(如 industry, market_cap) + + Returns: + DataFrame[trade_date, ts_code, factor_cols..., label] + """ + pass +``` + +#### 3.2.2 ProcessingPipeline (处理流水线) + +```python +from typing import List +import polars as pl + +class ProcessingPipeline: + """数据处理流水线 + + 按顺序执行多个处理器,自动处理阶段标记。 + 关键特性:在测试阶段使用训练阶段学习到的参数。 + """ + + def __init__(self, processors: List[BaseProcessor]): + """初始化流水线 + + Args: + processors: 处理器列表(按执行顺序) + """ + self.processors = processors + self._fitted_processors: Dict[int, BaseProcessor] = {} + + def fit_transform( + self, + data: pl.DataFrame, + stage: PipelineStage = PipelineStage.TRAIN + ) -> pl.DataFrame: + """在训练数据上fit所有处理器并transform + + Args: + data: 训练数据 + stage: 当前阶段标记 + + Returns: + 处理后的数据 + """ + result = data + for i, processor in enumerate(self.processors): + # 检查处理器是否适用于当前阶段 + if processor.stage in [PipelineStage.ALL, stage]: + # fit并transform + result = processor.fit_transform(result) + self._fitted_processors[i] = processor + elif stage == PipelineStage.TRAIN: + # 即使不适用于TRAIN阶段,也要fit(为TEST阶段准备) + if processor.stage == PipelineStage.TEST: + processor.fit(result) + self._fitted_processors[i] = processor + return result + + def transform( + self, + data: pl.DataFrame, + stage: PipelineStage = PipelineStage.TEST + ) -> pl.DataFrame: + """在测试数据上应用已fit的处理器 + + 使用训练阶段学习到的参数,防止数据泄露。 + + Args: + data: 测试数据 + stage: 当前阶段标记 + + Returns: + 处理后的数据 + """ + result = data + for i, processor in enumerate(self.processors): + if processor.stage in [PipelineStage.ALL, stage]: + if i in self._fitted_processors: + # 使用已fit的处理器 + result = self._fitted_processors[i].transform(result) + else: + # 未fit的处理器(ALL阶段但train时没执行到) + result = processor.transform(result) + return result + + def save_processors(self, path: str) -> None: + """保存所有已fit的处理器状态""" + import pickle + with open(path, 'wb') as f: + pickle.dump(self._fitted_processors, f) + + def load_processors(self, path: str) -> None: + """加载处理器状态""" + import pickle + with open(path, 'rb') as f: + self._fitted_processors = pickle.load(f) +``` + +--- + +## 4. 插件系统 + +### 4.1 注册器模式 + +```python +from typing import Type, Dict, TypeVar +from functools import wraps + +T = TypeVar('T') + +class PluginRegistry: + """插件注册中心 + + 提供装饰器方式注册处理器、模型、划分策略等组件。 + 实现真正的插件式架构 - 新功能只需注册即可使用。 + """ + + _processors: Dict[str, Type[BaseProcessor]] = {} + _models: Dict[str, Type[BaseModel]] = {} + _splitters: Dict[str, Type[BaseSplitter]] = {} + _metrics: Dict[str, Type["BaseMetric"]] = {} + + @classmethod + def register_processor(cls, name: Optional[str] = None): + """注册处理器装饰器 + + 示例: + >>> @PluginRegistry.register_processor("standard_scaler") + ... class StandardScaler(BaseProcessor): + ... pass + + >>> # 使用 + >>> scaler = PluginRegistry.get_processor("standard_scaler")() + """ + def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]: + key = name or processor_class.__name__ + cls._processors[key] = processor_class + processor_class._registry_name = key + return processor_class + return decorator + + @classmethod + def register_model(cls, name: Optional[str] = None): + """注册模型装饰器""" + def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]: + key = name or model_class.__name__ + cls._models[key] = model_class + model_class._registry_name = key + return model_class + return decorator + + @classmethod + def register_splitter(cls, name: Optional[str] = None): + """注册划分策略装饰器""" + def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]: + key = name or splitter_class.__name__ + cls._splitters[key] = splitter_class + return splitter_class + return decorator + + @classmethod + def get_processor(cls, name: str) -> Type[BaseProcessor]: + """获取处理器类""" + if name not in cls._processors: + raise KeyError(f"Processor '{name}' not found. Available: {list(cls._processors.keys())}") + return cls._processors[name] + + @classmethod + def get_model(cls, name: str) -> Type[BaseModel]: + """获取模型类""" + if name not in cls._models: + raise KeyError(f"Model '{name}' not found. Available: {list(cls._models.keys())}") + return cls._models[name] + + @classmethod + def get_splitter(cls, name: str) -> Type[BaseSplitter]: + """获取划分策略类""" + if name not in cls._splitters: + raise KeyError(f"Splitter '{name}' not found. Available: {list(cls._splitters.keys())}") + return cls._splitters[name] + + @classmethod + def list_processors(cls) -> List[str]: + """列出所有可用处理器""" + return list(cls._processors.keys()) + + @classmethod + def list_models(cls) -> List[str]: + """列出所有可用模型""" + return list(cls._models.keys()) +``` + +### 4.2 内置插件 + +```python +# ========== 内置处理器 ========== + +@PluginRegistry.register_processor("standard_scaler") +class StandardScaler(BaseProcessor): + """标准缩放处理器 - Z-score标准化""" + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "StandardScaler": + cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] + self._fitted_params = { + "mean": {c: data[c].mean() for c in cols}, + "std": {c: data[c].std() for c in cols}, + "columns": cols + } + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col in self._fitted_params["columns"]: + mean = self._fitted_params["mean"][col] + std = self._fitted_params["std"][col] + if std > 0: + result = result.with_columns( + ((pl.col(col) - mean) / std).alias(col) + ) + return result + + +@PluginRegistry.register_processor("winsorizer") +class Winsorizer(BaseProcessor): + """缩尾处理器 - 防止极端值影响""" + stage = PipelineStage.TRAIN # 只在训练阶段计算分位数 + + def __init__(self, columns=None, lower=0.01, upper=0.99): + super().__init__(columns) + self.lower = lower + self.upper = upper + + def fit(self, data: pl.DataFrame) -> "Winsorizer": + cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] + self._fitted_params = { + "lower": {c: data[c].quantile(self.lower) for c in cols}, + "upper": {c: data[c].quantile(self.upper) for c in cols}, + "columns": cols + } + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col in self._fitted_params["columns"]: + lower = self._fitted_params["lower"][col] + upper = self._fitted_params["upper"][col] + result = result.with_columns( + pl.col(col).clip(lower, upper).alias(col) + ) + return result + + +@PluginRegistry.register_processor("neutralizer") +class Neutralizer(BaseProcessor): + """行业/市值中性化处理器""" + stage = PipelineStage.ALL + + def __init__(self, columns=None, group_col="industry", exclude_cols=None): + super().__init__(columns) + self.group_col = group_col + self.exclude_cols = exclude_cols or [] + + def fit(self, data: pl.DataFrame) -> "Neutralizer": + # 中性化通常在每个截面独立进行,不需要全局fit + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + # 按日期分组,对每个截面进行中性化 + result = data + for col in self.columns or []: + if col in self.exclude_cols: + continue + # 分组去均值 + result = result.with_columns( + (pl.col(col) - pl.col(col).mean().over(["trade_date", self.group_col])) + .alias(col) + ) + return result + + +@PluginRegistry.register_processor("dropna") +class DropNAProcessor(BaseProcessor): + """缺失值删除处理器""" + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "DropNAProcessor": + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + cols = self.columns or data.columns + return data.drop_nulls(subset=cols) + + +@PluginRegistry.register_processor("fillna") +class FillNAProcessor(BaseProcessor): + """缺失值填充处理器""" + stage = PipelineStage.TRAIN + + def __init__(self, columns=None, method="median"): + super().__init__(columns) + self.method = method + + def fit(self, data: pl.DataFrame) -> "FillNAProcessor": + cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] + fill_values = {} + for col in cols: + if self.method == "median": + fill_values[col] = data[col].median() + elif self.method == "mean": + fill_values[col] = data[col].mean() + elif self.method == "zero": + fill_values[col] = 0 + self._fitted_params = {"fill_values": fill_values, "columns": cols} + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col, val in self._fitted_params["fill_values"].items(): + result = result.with_columns(pl.col(col).fill_null(val).alias(col)) + return result + + +@PluginRegistry.register_processor("rank_transformer") +class RankTransformer(BaseProcessor): + """排名转换处理器 - 转换为截面排名""" + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "RankTransformer": + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col in self.columns or []: + # 按日期分组计算排名 + result = result.with_columns( + pl.col(col).rank().over("trade_date").alias(col) + ) + return result + + +# ========== 内置模型 ========== + +@PluginRegistry.register_model("lightgbm") +class LightGBMModel(BaseModel): + """LightGBM模型包装器""" + + def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None): + super().__init__(task_type, params, name) + self._model = None + + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + X_val: Optional[pl.DataFrame] = None, + y_val: Optional[pl.Series] = None, + **fit_params + ) -> "LightGBMModel": + import lightgbm as lgb + + # 转换数据格式 + X_arr = X.to_numpy() + y_arr = y.to_numpy() + + # 构建数据集 + train_data = lgb.Dataset(X_arr, label=y_arr) + valid_sets = [train_data] + + if X_val is not None and y_val is not None: + valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy()) + valid_sets.append(valid_data) + + # 设置默认参数 + default_params = { + "objective": self._get_objective(), + "metric": self._get_metric(), + "boosting_type": "gbdt", + "num_leaves": 31, + "learning_rate": 0.05, + "feature_fraction": 0.9, + "bagging_fraction": 0.8, + "bagging_freq": 5, + "verbose": -1 + } + default_params.update(self.params) + + # 训练 + self._model = lgb.train( + default_params, + train_data, + num_boost_round=fit_params.get("num_boost_round", 100), + valid_sets=valid_sets, + callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=False)] if len(valid_sets) > 1 else [] + ) + self._is_fitted = True + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + if not self._is_fitted: + raise RuntimeError("Model not fitted yet") + return self._model.predict(X.to_numpy()) + + def predict_proba(self, X: pl.DataFrame) -> np.ndarray: + if self.task_type != "classification": + raise ValueError("predict_proba only for classification") + probs = self.predict(X) + if len(probs.shape) == 1: + return np.vstack([1 - probs, probs]).T + return probs + + def get_feature_importance(self) -> Optional[pl.DataFrame]: + if self._model is None: + return None + importance = self._model.feature_importance(importance_type="gain") + return pl.DataFrame({ + "feature": self._model.feature_name(), + "importance": importance + }).sort("importance", descending=True) + + def _get_objective(self) -> str: + if self.task_type == "classification": + return "binary" + elif self.task_type == "regression": + return "regression" + elif self.task_type == "ranking": + return "lambdarank" + return "regression" + + def _get_metric(self) -> str: + if self.task_type == "classification": + return "auc" + elif self.task_type == "regression": + return "rmse" + elif self.task_type == "ranking": + return "ndcg" + return "rmse" + + +@PluginRegistry.register_model("catboost") +class CatBoostModel(BaseModel): + """CatBoost模型包装器""" + + def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None): + super().__init__(task_type, params, name) + self._model = None + + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + X_val: Optional[pl.DataFrame] = None, + y_val: Optional[pl.Series] = None, + **fit_params + ) -> "CatBoostModel": + from catboost import CatBoostClassifier, CatBoostRegressor + + # 选择模型类型 + if self.task_type == "classification": + model_class = CatBoostClassifier + default_params = {"loss_function": "Logloss", "eval_metric": "AUC"} + elif self.task_type == "regression": + model_class = CatBoostRegressor + default_params = {"loss_function": "RMSE"} + else: # ranking + model_class = CatBoostRegressor + default_params = {"loss_function": "QueryRMSE"} + + default_params.update(self.params) + default_params["verbose"] = False + + self._model = model_class(**default_params) + + # 准备验证集 + eval_set = None + if X_val is not None and y_val is not None: + eval_set = (X_val.to_pandas(), y_val.to_pandas()) + + # 训练 + self._model.fit( + X.to_pandas(), + y.to_pandas(), + eval_set=eval_set, + early_stopping_rounds=10, + verbose=False + ) + self._is_fitted = True + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + if not self._is_fitted: + raise RuntimeError("Model not fitted yet") + return self._model.predict(X.to_pandas()) + + def predict_proba(self, X: pl.DataFrame) -> np.ndarray: + if self.task_type != "classification": + raise ValueError("predict_proba only for classification") + return self._model.predict_proba(X.to_pandas()) + + def get_feature_importance(self) -> Optional[pl.DataFrame]: + if self._model is None: + return None + return pl.DataFrame({ + "feature": self._model.feature_names_, + "importance": self._model.feature_importances_ + }).sort("importance", descending=True) + + +# ========== 内置划分策略 ========== + +@PluginRegistry.register_splitter("time_series") +class TimeSeriesSplit(BaseSplitter): + """时间序列划分 - 确保训练数据在测试数据之前""" + + def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252): + self.n_splits = n_splits + self.gap = gap + self.min_train_size = min_train_size + + def split(self, data: pl.DataFrame, date_col: str = "trade_date"): + dates = data[date_col].unique().sort() + n_dates = len(dates) + + # 计算每个split的测试集大小 + test_size = (n_dates - self.min_train_size) // self.n_splits + + for i in range(self.n_splits): + # 训练集结束位置 + train_end_idx = self.min_train_size + i * test_size + # 测试集开始位置(留gap防止泄露) + test_start_idx = train_end_idx + self.gap + test_end_idx = test_start_idx + test_size + + if test_end_idx > n_dates: + break + + train_dates = dates[:train_end_idx] + test_dates = dates[test_start_idx:test_end_idx] + + train_mask = data[date_col].is_in(train_dates) + test_mask = data[date_col].is_in(test_dates) + + train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list() + test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list() + + yield train_idx, test_idx + + def get_split_dates(self, data: pl.DataFrame, date_col: str = "trade_date"): + dates = data[date_col].unique().sort() + n_dates = len(dates) + test_size = (n_dates - self.min_train_size) // self.n_splits + + result = [] + for i in range(self.n_splits): + train_end_idx = self.min_train_size + i * test_size + test_start_idx = train_end_idx + self.gap + test_end_idx = test_start_idx + test_size + + if test_end_idx > n_dates: + break + + result.append(( + dates[0], + dates[train_end_idx - 1], + dates[test_start_idx], + dates[test_end_idx - 1] + )) + return result + + +@PluginRegistry.register_splitter("walk_forward") +class WalkForwardSplit(BaseSplitter): + """滚动前向验证 - 训练集逐步扩展""" + + def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5): + self.train_window = train_window + self.test_window = test_window + self.gap = gap + + def split(self, data: pl.DataFrame, date_col: str = "trade_date"): + dates = data[date_col].unique().sort() + n_dates = len(dates) + + start_idx = self.train_window + while start_idx + self.gap + self.test_window <= n_dates: + train_start = start_idx - self.train_window + train_end = start_idx + test_start = start_idx + self.gap + test_end = test_start + self.test_window + + train_dates = dates[train_start:train_end] + test_dates = dates[test_start:test_end] + + train_mask = data[date_col].is_in(train_dates) + test_mask = data[date_col].is_in(test_dates) + + train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list() + test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list() + + yield train_idx, test_idx + start_idx += self.test_window +``` + +--- + +## 5. 使用示例 + +### 5.1 基础用法 + +```python +from src.models import ( + FeatureStore, ProcessingPipeline, PluginRegistry, + PipelineStage, MLPipeline +) + +# 1. 创建数据存储 +store = FeatureStore(data_dir="data") + +# 2. 构建数据集 +dataset = store.build_dataset( + factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"], + label_config={"name": "forward_return", "forward_period": 5}, + date_range=("20200101", "20241231") +) + +# 3. 创建处理流水线 +processors = [ + # 删除缺失值 + PluginRegistry.get_processor("dropna")(), + + # 异常值处理(只在训练阶段计算分位数) + PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99), + + # 中性化(行业和市值中性化) + PluginRegistry.get_processor("neutralizer")(group_col="industry"), + + # 标准化(训练和测试都使用) + PluginRegistry.get_processor("standard_scaler")(), +] +pipeline = ProcessingPipeline(processors) + +# 4. 创建划分策略 +splitter = PluginRegistry.get_splitter("time_series")( + n_splits=5, + gap=5, + min_train_size=252 +) + +# 5. 创建模型 +model = PluginRegistry.get_model("lightgbm")( + task_type="regression", + params={"n_estimators": 200, "learning_rate": 0.03} +) + +# 6. 运行完整流程 +ml_pipeline = MLPipeline( + feature_store=store, + processing_pipeline=pipeline, + splitter=splitter, + model=model +) + +results = ml_pipeline.run( + factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"], + label_config={"name": "forward_return", "forward_period": 5}, + date_range=("20200101", "20241231") +) + +# 7. 查看结果 +print(results.metrics) # 各折的评估指标 +print(results.feature_importance) # 特征重要性 +print(results.predictions) # 预测结果 +``` + +### 5.2 配置驱动用法(推荐) + +```python +# config.yaml +experiment: + name: "momentum_factor_regression" + +data: + factor_names: ["momentum_5", "momentum_20", "momentum_60", "volatility_20"] + label: + name: "forward_return" + forward_period: 5 + date_range: ["20200101", "20241231"] + +processing: + - name: "dropna" + params: {} + stage: "all" + + - name: "winsorizer" + params: + lower: 0.01 + upper: 0.99 + stage: "train" # 只在训练阶段计算分位数 + + - name: "neutralizer" + params: + group_col: "industry" + stage: "all" + + - name: "standard_scaler" + params: {} + stage: "all" + +splitting: + strategy: "time_series" + params: + n_splits: 5 + gap: 5 + min_train_size: 252 + +model: + name: "lightgbm" + task_type: "regression" + params: + n_estimators: 200 + learning_rate: 0.03 + max_depth: 6 + +evaluation: + metrics: ["ic", "rank_ic", "mse", "mae"] + output_dir: "results/momentum_experiment" +``` + +```python +# 代码中使用配置 +from src.models import MLPipeline + +pipeline = MLPipeline.from_config("config.yaml") +results = pipeline.run() + +# 保存结果 +results.save("results/momentum_experiment") +``` + +### 5.3 自定义插件 + +```python +# 1. 创建自定义处理器 +@PluginRegistry.register_processor("my_transformer") +class MyTransformer(BaseProcessor): + """自定义转换器示例""" + stage = PipelineStage.ALL + + def __init__(self, columns=None, multiplier=2.0): + super().__init__(columns) + self.multiplier = multiplier + + def fit(self, data: pl.DataFrame) -> "MyTransformer": + # 学习参数(如有需要) + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col in self.columns or []: + result = result.with_columns( + (pl.col(col) * self.multiplier).alias(col) + ) + return result + + +# 2. 创建自定义模型 +@PluginRegistry.register_model("my_model") +class MyModel(BaseModel): + """自定义模型示例""" + + def fit(self, X, y, X_val=None, y_val=None, **kwargs): + # 实现训练逻辑 + self._model = ... + return self + + def predict(self, X): + # 实现预测逻辑 + return self._model.predict(X) + + +# 3. 在配置中使用 +# config.yaml +processing: + - name: "my_transformer" + params: + multiplier: 3.0 + stage: "all" + +model: + name: "my_model" + task_type: "regression" +``` + +--- + +## 6. 目录结构 + +``` +src/ +├── models/ # 模型训练框架 +│ ├── __init__.py # 导出主要类 +│ ├── core/ # 核心抽象和基类 +│ │ ├── __init__.py +│ │ ├── processor.py # BaseProcessor, PipelineStage +│ │ ├── model.py # BaseModel, TaskType +│ │ ├── splitter.py # BaseSplitter +│ │ ├── metric.py # BaseMetric +│ │ └── pipeline.py # MLPipeline (编排器) +│ │ +│ ├── registry.py # PluginRegistry 插件注册中心 +│ │ +│ ├── data/ # 数据相关 +│ │ ├── __init__.py +│ │ ├── feature_store.py # FeatureStore 特征存储 +│ │ ├── label_generator.py # LabelGenerator 标签生成 +│ │ └── dataset.py # Dataset 数据集包装 +│ │ +│ ├── processors/ # 内置处理器 +│ │ ├── __init__.py # 自动注册所有处理器 +│ │ ├── scaler.py # StandardScaler +│ │ ├── winsorizer.py # Winsorizer +│ │ ├── neutralizer.py # Neutralizer +│ │ ├── imputer.py # FillNAProcessor +│ │ ├── selector.py # FeatureSelector +│ │ └── custom.py # 其他处理器 +│ │ +│ ├── models/ # 内置模型 +│ │ ├── __init__.py # 自动注册所有模型 +│ │ ├── lightgbm_model.py # LightGBMModel +│ │ ├── catboost_model.py # CatBoostModel +│ │ └── sklearn_model.py # SklearnModel (LR, RF等) +│ │ +│ ├── splitters/ # 划分策略 +│ │ ├── __init__.py +│ │ ├── time_series.py # TimeSeriesSplit +│ │ ├── walk_forward.py # WalkForwardSplit +│ │ └── purged.py # PurgedKFold +│ │ +│ ├── metrics/ # 评估指标 +│ │ ├── __init__.py +│ │ ├── ic.py # IC, RankIC +│ │ ├── returns.py # 收益指标 +│ │ └── classification.py # 分类指标 +│ │ +│ ├── evaluation/ # 评估和报告 +│ │ ├── __init__.py +│ │ ├── evaluator.py # ModelEvaluator +│ │ ├── report.py # ReportGenerator +│ │ └── visualizer.py # ResultVisualizer +│ │ +│ └── config/ # 配置解析 +│ ├── __init__.py +│ └── parser.py # ConfigParser +│ +├── factors/ # 已有因子框架 +│ └── ... +│ +tests/ +├── models/ # 模型框架测试 +│ ├── __init__.py +│ ├── test_processors.py # 处理器测试 +│ ├── test_models.py # 模型测试 +│ ├── test_pipeline.py # 流水线集成测试 +│ └── test_registry.py # 注册器测试 +│ +└── factors/ # 已有因子测试 + └── ... + +configs/ # 配置文件目录 +├── momentum_regression.yaml +├── value_classification.yaml +└├── ranking_lambdamart.yaml + +experiments/ # 实验结果目录 +└── {experiment_name}/ + ├── config.yaml # 实验配置 + ├── model.pkl # 保存的模型 + ├── processors.pkl # 保存的处理器状态 + ├── predictions.parquet # 预测结果 + ├── metrics.json # 评估指标 + ├── feature_importance.csv # 特征重要性 + └── report.html # 可视化报告 +``` + +--- + +## 7. 开发计划 + +### Phase 1: 核心基础设施 (Week 1-2) +- [ ] 设计并实现 `BaseProcessor`, `BaseModel`, `BaseSplitter` 抽象类 +- [ ] 实现 `PluginRegistry` 注册中心 +- [ ] 实现 `PipelineStage` 阶段管理 +- [ ] 编写基础单元测试 + +### Phase 2: 数据层 (Week 2-3) +- [ ] 实现 `FeatureStore` 特征存储 +- [ ] 实现 `LabelGenerator` 标签生成器 +- [ ] 实现 `Dataset` 数据集包装 +- [ ] 集成现有因子框架输出 + +### Phase 3: 处理器 (Week 3-4) +- [ ] 实现 `StandardScaler` 标准化处理器 +- [ ] 实现 `Winsorizer` 缩尾处理器 +- [ ] 实现 `Neutralizer` 中性化处理器 +- [ ] 实现 `FillNAProcessor` 缺失值处理器 +- [ ] 实现 `DropNAProcessor` 缺失值删除处理器 +- [ ] 实现 `FeatureSelector` 特征选择器 +- [ ] 实现 `ProcessingPipeline` 流水线 + +### Phase 4: 模型层 (Week 4-5) +- [ ] 实现 `LightGBMModel` LightGBM包装 +- [ ] 实现 `CatBoostModel` CatBoost包装 +- [ ] 实现 `SklearnModel` sklearn模型支持 +- [ ] 支持 classification/regression/ranking 三种任务 + +### Phase 5: 划分策略 (Week 5) +- [ ] 实现 `TimeSeriesSplit` 时间序列划分 +- [ ] 实现 `WalkForwardSplit` 滚动前向验证 +- [ ] 实现 `PurgedKFold` 清除重叠样本 + +### Phase 6: 评估层 (Week 5-6) +- [ ] 实现 IC/RankIC 指标 +- [ ] 实现收益分析指标 +- [ ] 实现分类指标 +- [ ] 实现 `ModelEvaluator` 评估器 +- [ ] 实现 `ReportGenerator` 报告生成 + +### Phase 7: 配置和编排 (Week 6) +- [ ] 实现配置解析器 +- [ ] 实现 `MLPipeline` 编排器 +- [ ] 支持配置驱动执行 + +### Phase 8: 集成测试和文档 (Week 7) +- [ ] 编写完整集成测试 +- [ ] 编写使用文档 +- [ ] 编写示例代码 +- [ ] 性能基准测试 + +--- + +## 8. 关键设计决策 + +| 决策点 | 选择 | 理由 | +|--------|------|------| +| **数据处理阶段标记** | `PipelineStage` 枚举 | 显式、类型安全、易于扩展 | +| **插件注册方式** | 装饰器模式 | Pythonic、简洁、自动发现 | +| **数据格式** | Polars DataFrame | 与因子框架一致、高性能 | +| **模型接口** | `fit/predict` 统一接口 | 行业标准、易于替换模型 | +| **配置格式** | YAML | 人类可读、支持复杂结构 | +| **处理器状态保存** | pickle | 简单、Python原生、支持大部分对象 | +| **特征存储** | 从因子框架直接读取 | 避免数据冗余、保持一致性 | + +--- + +## 9. 防数据泄露检查清单 + +- [x] 处理器明确标记适用阶段 (`stage` 属性) +- [x] `TRAIN` 阶段处理器只在训练数据上 `fit` +- [x] `TEST` 阶段使用训练阶段学习到的参数 +- [x] 划分策略支持时间序列感知 (`TimeSeriesSplit`, `WalkForwardSplit`) +- [x] 划分时支持 `gap` 参数防止相邻样本泄露 +- [x] 特征存储从已计算的因子加载(不访问未来数据) +- [x] 标签生成使用预定义的前瞻期(明确的future data) + +--- + +*文档版本: v1.0* +*最后更新: 2026-02-23* +*设计状态: 草案 - 待评审* diff --git a/docs/test_report_duckdb_migration.md b/docs/test_report_duckdb_migration.md index 54832ba..7085d61 100644 --- a/docs/test_report_duckdb_migration.md +++ b/docs/test_report_duckdb_migration.md @@ -1,6 +1,8 @@ # ProStock HDF5 到 DuckDB 迁移测试报告 **报告生成时间**: 2026-02-22 +**完成时间**: 2026-02-22 +**状态**: ✅ 已完成 **迁移文档**: [hdf5_to_duckdb_migration.md](./hdf5_to_duckdb_migration.md) **测试数据范围**: 2024年1月-3月(3个月) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..789dcac --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,86 @@ +"""ProStock 模型训练框架 + +组件化、低耦合、插件式的机器学习训练框架。 + +示例: + >>> from src.models import ( + ... PluginRegistry, ProcessingPipeline, + ... PipelineStage, BaseProcessor + ... ) + + >>> # 获取注册的处理器 + >>> scaler_class = PluginRegistry.get_processor("standard_scaler") + >>> scaler = scaler_class() + + >>> # 创建处理流水线 + >>> pipeline = ProcessingPipeline([ + ... PluginRegistry.get_processor("dropna")(), + ... PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99), + ... PluginRegistry.get_processor("standard_scaler")(), + ... ]) +""" + +# 导入核心抽象类和划分策略 +from src.models.core import ( + PipelineStage, + TaskType, + BaseProcessor, + BaseModel, + BaseSplitter, + BaseMetric, + TimeSeriesSplit, + WalkForwardSplit, + ExpandingWindowSplit, +) + +# 导入注册中心 +from src.models.registry import PluginRegistry + +# 导入处理流水线 +from src.models.pipeline import ProcessingPipeline + +# 导入并注册内置处理器 +from src.models.processors.processors import ( + DropNAProcessor, + FillNAProcessor, + Winsorizer, + StandardScaler, + MinMaxScaler, + RankTransformer, + Neutralizer, +) + +# 导入并注册内置模型 +from src.models.models.models import ( + LightGBMModel, + CatBoostModel, +) + +__all__ = [ + # 核心抽象 + "PipelineStage", + "TaskType", + "BaseProcessor", + "BaseModel", + "BaseSplitter", + "BaseMetric", + # 划分策略 + "TimeSeriesSplit", + "WalkForwardSplit", + "ExpandingWindowSplit", + # 注册中心 + "PluginRegistry", + # 处理流水线 + "ProcessingPipeline", + # 处理器 + "DropNAProcessor", + "FillNAProcessor", + "Winsorizer", + "StandardScaler", + "MinMaxScaler", + "RankTransformer", + "Neutralizer", + # 模型 + "LightGBMModel", + "CatBoostModel", +] diff --git a/src/models/core/__init__.py b/src/models/core/__init__.py new file mode 100644 index 0000000..7369ced --- /dev/null +++ b/src/models/core/__init__.py @@ -0,0 +1,30 @@ +"""核心模块导出""" + +from src.models.core.base import ( + PipelineStage, + TaskType, + BaseProcessor, + BaseModel, + BaseSplitter, + BaseMetric, +) + +from src.models.core.splitter import ( + TimeSeriesSplit, + WalkForwardSplit, + ExpandingWindowSplit, +) + +__all__ = [ + # 基础抽象 + "PipelineStage", + "TaskType", + "BaseProcessor", + "BaseModel", + "BaseSplitter", + "BaseMetric", + # 划分策略 + "TimeSeriesSplit", + "WalkForwardSplit", + "ExpandingWindowSplit", +] diff --git a/src/models/core/base.py b/src/models/core/base.py new file mode 100644 index 0000000..b083add --- /dev/null +++ b/src/models/core/base.py @@ -0,0 +1,351 @@ +"""模型训练框架核心抽象类 + +提供处理器、模型、划分策略和评估指标的基类定义。 +""" + +from abc import ABC, abstractmethod +from enum import Enum, auto +from typing import Any, Dict, Iterator, List, Optional, Tuple, Literal +import polars as pl +import numpy as np + +# 任务类型 +TaskType = Literal["classification", "regression", "ranking"] + + +class PipelineStage(Enum): + """流水线阶段标记 + + 用于标记处理器在哪些阶段生效,防止数据泄露。 + + Attributes: + ALL: 适用于所有阶段(训练、测试、验证) + TRAIN: 仅训练阶段 + TEST: 仅测试阶段 + VALIDATION: 仅验证阶段 + """ + + ALL = auto() + TRAIN = auto() + TEST = auto() + VALIDATION = auto() + + +class BaseProcessor(ABC): + """数据处理器基类 + + 所有数据处理器必须继承此类。关键特性是通过 stage 属性控制处理器在哪些阶段生效。 + + 阶段标记规则: + - ALL: 训练和测试阶段都使用相同的参数 + - TRAIN: 只在训练阶段计算参数(如分位数、均值等),测试阶段使用训练阶段学到的参数 + - TEST: 只在测试阶段执行 + """ + + # 子类必须定义适用阶段 + stage: PipelineStage = PipelineStage.ALL + + def __init__(self, columns: Optional[List[str]] = None, **params): + """初始化处理器 + + Args: + columns: 要处理的列,None表示所有数值列 + **params: 处理器特定参数 + """ + self.columns = columns + self.params = params + self._is_fitted = False + self._fitted_params: Dict[str, Any] = {} + + @abstractmethod + def fit(self, data: pl.DataFrame) -> "BaseProcessor": + """在训练数据上学习参数 + + 此方法只在训练阶段调用一次。学习到的参数存储在 self._fitted_params 中。 + + Args: + data: 训练数据 + + Returns: + self (支持链式调用) + """ + pass + + @abstractmethod + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + """转换数据 + + 在训练和测试阶段都会被调用。使用 fit() 阶段学习到的参数进行转换。 + + Args: + data: 输入数据 + + Returns: + 转换后的数据 + """ + pass + + def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame: + """先fit再transform的便捷方法 + + Args: + data: 训练数据 + + Returns: + 转换后的数据 + """ + return self.fit(data).transform(data) + + def get_fitted_params(self) -> Dict[str, Any]: + """获取学习到的参数(用于保存/加载) + + Returns: + 学习到的参数字典 + """ + return self._fitted_params.copy() + + def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor": + """设置学习到的参数(用于从checkpoint恢复) + + Args: + params: 参数字典 + + Returns: + self (支持链式调用) + """ + self._fitted_params = params.copy() + self._is_fitted = True + return self + + +class BaseModel(ABC): + """机器学习模型基类 + + 统一接口支持多种模型(LightGBM, CatBoost, XGBoost等) + 和多种任务类型(分类、回归、排序)。 + """ + + def __init__( + self, + task_type: TaskType, + params: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + ): + """初始化模型 + + Args: + task_type: 任务类型 - "classification", "regression", "ranking" + params: 模型特定参数 + name: 模型名称(用于日志和报告) + """ + self.task_type = task_type + self.params = params or {} + self.name = name or self.__class__.__name__ + self._model: Any = None + self._is_fitted = False + + @abstractmethod + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + X_val: Optional[pl.DataFrame] = None, + y_val: Optional[pl.Series] = None, + **fit_params, + ) -> "BaseModel": + """训练模型 + + Args: + X: 特征数据 + y: 目标变量 + X_val: 验证集特征(可选) + y_val: 验证集目标(可选) + **fit_params: 额外的fit参数 + + Returns: + self (支持链式调用) + """ + pass + + @abstractmethod + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测 + + Args: + X: 特征数据 + + Returns: + 预测结果数组 + - classification: 类别标签或概率 + - regression: 连续值 + - ranking: 排序分数 + """ + pass + + def predict_proba(self, X: pl.DataFrame) -> np.ndarray: + """预测概率(仅分类任务) + + Args: + X: 特征数据 + + Returns: + 类别概率数组 [n_samples, n_classes] + + Raises: + NotImplementedError: 非分类任务时抛出 + """ + raise NotImplementedError( + "predict_proba only available for classification tasks" + ) + + def get_feature_importance(self) -> Optional[pl.DataFrame]: + """获取特征重要性(如果模型支持) + + Returns: + DataFrame[feature, importance] 或 None + """ + return None + + def save(self, path: str) -> None: + """保存模型到文件 + + Args: + path: 保存路径 + """ + import pickle + + with open(path, "wb") as f: + pickle.dump(self, f) + + @classmethod + def load(cls, path: str) -> "BaseModel": + """从文件加载模型 + + Args: + path: 模型文件路径 + + Returns: + 加载的模型实例 + """ + import pickle + + with open(path, "rb") as f: + return pickle.load(f) + + +class BaseSplitter(ABC): + """数据划分策略基类 + + 针对时间序列数据的特殊划分策略,防止未来泄露。 + """ + + @abstractmethod + def split( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> Iterator[Tuple[List[int], List[int]]]: + """生成训练/测试索引 + + Args: + data: 完整数据集 + date_col: 日期列名 + + Yields: + (train_indices, test_indices) 元组 + """ + pass + + @abstractmethod + def get_split_dates( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> List[Tuple[str, str, str, str]]: + """获取划分日期范围 + + Args: + data: 完整数据集 + date_col: 日期列名 + + Returns: + [(train_start, train_end, test_start, test_end), ...] + """ + pass + + +class BaseMetric(ABC): + """评估指标基类 + + 所有评估指标必须继承此类。支持单次计算和累积计算两种模式。 + """ + + def __init__(self, name: Optional[str] = None): + """初始化指标 + + Args: + name: 指标名称 + """ + self.name = name or self.__class__.__name__ + self._values: List[float] = [] + + @abstractmethod + def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: + """计算指标值 + + Args: + y_true: 真实值 + y_pred: 预测值 + + Returns: + 指标值 + """ + pass + + def update(self, y_true: np.ndarray, y_pred: np.ndarray) -> "BaseMetric": + """更新累积值 + + Args: + y_true: 真实值 + y_pred: 预测值 + + Returns: + self (支持链式调用) + """ + self._values.append(self.compute(y_true, y_pred)) + return self + + def get_mean(self) -> float: + """获取累积值的均值 + + Returns: + 均值 + """ + if not self._values: + return 0.0 + return float(np.mean(self._values)) + + def get_std(self) -> float: + """获取累积值的标准差 + + Returns: + 标准差 + """ + if not self._values: + return 0.0 + return float(np.std(self._values)) + + def reset(self) -> "BaseMetric": + """重置累积值 + + Returns: + self (支持链式调用) + """ + self._values = [] + return self + + +__all__ = [ + "PipelineStage", + "TaskType", + "BaseProcessor", + "BaseModel", + "BaseSplitter", + "BaseMetric", +] diff --git a/src/models/core/splitter.py b/src/models/core/splitter.py new file mode 100644 index 0000000..d2734a6 --- /dev/null +++ b/src/models/core/splitter.py @@ -0,0 +1,222 @@ +"""时间序列数据划分策略 + +提供针对金融时间序列的特殊划分策略,防止未来泄露。 +""" + +from typing import Iterator, List, Tuple +import polars as pl + +from src.models.core.base import BaseSplitter + + +class TimeSeriesSplit(BaseSplitter): + """时间序列划分 - 确保训练数据在测试数据之前 + + 按照时间顺序进行K折划分,每折的训练数据都在测试数据之前。 + 通过 gap 参数防止训练集和测试集之间的数据泄露。 + + Args: + n_splits: 划分折数 + gap: 训练集和测试集之间的间隔天数(防止泄露) + min_train_size: 最小训练集大小(天数) + """ + + def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252): + self.n_splits = n_splits + self.gap = gap + self.min_train_size = min_train_size + + def split( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> Iterator[Tuple[List[int], List[int]]]: + """生成训练/测试索引""" + dates = data[date_col].unique().sort() + n_dates = len(dates) + + test_size = (n_dates - self.min_train_size) // self.n_splits + + for i in range(self.n_splits): + train_end_idx = self.min_train_size + i * test_size + test_start_idx = train_end_idx + self.gap + test_end_idx = test_start_idx + test_size + + if test_end_idx > n_dates: + break + + train_dates = dates[:train_end_idx] + test_dates = dates[test_start_idx:test_end_idx] + + train_mask = data[date_col].is_in(train_dates.to_list()) + test_mask = data[date_col].is_in(test_dates.to_list()) + + train_idx = data.with_row_index().filter(train_mask)["index"].to_list() + test_idx = data.with_row_index().filter(test_mask)["index"].to_list() + + yield train_idx, test_idx + + def get_split_dates( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> List[Tuple[str, str, str, str]]: + """获取划分日期范围""" + dates = data[date_col].unique().sort() + n_dates = len(dates) + test_size = (n_dates - self.min_train_size) // self.n_splits + + result = [] + for i in range(self.n_splits): + train_end_idx = self.min_train_size + i * test_size + test_start_idx = train_end_idx + self.gap + test_end_idx = test_start_idx + test_size + + if test_end_idx > n_dates: + break + + result.append( + ( + str(dates[0]), + str(dates[train_end_idx - 1]), + str(dates[test_start_idx]), + str(dates[test_end_idx - 1]), + ) + ) + return result + + +class WalkForwardSplit(BaseSplitter): + """滚动前向验证 - 训练集逐步扩展 + + Args: + train_window: 训练集窗口大小(天数) + test_window: 测试集窗口大小(天数) + gap: 训练集和测试集之间的间隔天数 + """ + + def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5): + self.train_window = train_window + self.test_window = test_window + self.gap = gap + + def split( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> Iterator[Tuple[List[int], List[int]]]: + """生成训练/测试索引""" + dates = data[date_col].unique().sort() + n_dates = len(dates) + + start_idx = self.train_window + while start_idx + self.gap + self.test_window <= n_dates: + train_start = start_idx - self.train_window + train_end = start_idx + test_start = start_idx + self.gap + test_end = test_start + self.test_window + + train_dates = dates[train_start:train_end] + test_dates = dates[test_start:test_end] + + train_mask = data[date_col].is_in(train_dates.to_list()) + test_mask = data[date_col].is_in(test_dates.to_list()) + + train_idx = data.with_row_index().filter(train_mask)["index"].to_list() + test_idx = data.with_row_index().filter(test_mask)["index"].to_list() + + yield train_idx, test_idx + start_idx += self.test_window + + def get_split_dates( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> List[Tuple[str, str, str, str]]: + """获取划分日期范围""" + dates = data[date_col].unique().sort() + n_dates = len(dates) + + result = [] + start_idx = self.train_window + while start_idx + self.gap + self.test_window <= n_dates: + train_start = start_idx - self.train_window + train_end = start_idx + test_start = start_idx + self.gap + test_end = test_start + self.test_window + + result.append( + ( + str(dates[train_start]), + str(dates[train_end - 1]), + str(dates[test_start]), + str(dates[test_end - 1]), + ) + ) + start_idx += self.test_window + + return result + + +class ExpandingWindowSplit(BaseSplitter): + """扩展窗口划分 - 训练集不断扩大 + + Args: + initial_train_size: 初始训练集大小(天数) + test_window: 测试集窗口大小(天数) + gap: 训练集和测试集之间的间隔天数 + """ + + def __init__( + self, initial_train_size: int = 252, test_window: int = 21, gap: int = 5 + ): + self.initial_train_size = initial_train_size + self.test_window = test_window + self.gap = gap + + def split( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> Iterator[Tuple[List[int], List[int]]]: + """生成训练/测试索引""" + dates = data[date_col].unique().sort() + n_dates = len(dates) + + train_end_idx = self.initial_train_size + while train_end_idx + self.gap + self.test_window <= n_dates: + train_dates = dates[:train_end_idx] + test_start = train_end_idx + self.gap + test_end = test_start + self.test_window + test_dates = dates[test_start:test_end] + + train_mask = data[date_col].is_in(train_dates.to_list()) + test_mask = data[date_col].is_in(test_dates.to_list()) + + train_idx = data.with_row_index().filter(train_mask)["index"].to_list() + test_idx = data.with_row_index().filter(test_mask)["index"].to_list() + + yield train_idx, test_idx + train_end_idx += self.test_window + + def get_split_dates( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> List[Tuple[str, str, str, str]]: + """获取划分日期范围""" + dates = data[date_col].unique().sort() + n_dates = len(dates) + + result = [] + train_end_idx = self.initial_train_size + while train_end_idx + self.gap + self.test_window <= n_dates: + test_start = train_end_idx + self.gap + test_end = test_start + self.test_window + + result.append( + ( + str(dates[0]), + str(dates[train_end_idx - 1]), + str(dates[test_start]), + str(dates[test_end - 1]), + ) + ) + train_end_idx += self.test_window + + return result + + +__all__ = [ + "TimeSeriesSplit", + "WalkForwardSplit", + "ExpandingWindowSplit", +] diff --git a/src/models/models/__init__.py b/src/models/models/__init__.py new file mode 100644 index 0000000..9618f58 --- /dev/null +++ b/src/models/models/__init__.py @@ -0,0 +1,11 @@ +"""模型模块""" + +from src.models.models.models import ( + LightGBMModel, + CatBoostModel, +) + +__all__ = [ + "LightGBMModel", + "CatBoostModel", +] diff --git a/src/models/models/models.py b/src/models/models/models.py new file mode 100644 index 0000000..e50e179 --- /dev/null +++ b/src/models/models/models.py @@ -0,0 +1,210 @@ +"""内置机器学习模型 + +提供 LightGBM、CatBoost 等模型的统一接口包装器。 +""" + +from typing import Optional, Dict, Any +import polars as pl +import numpy as np + +from src.models.core import BaseModel, TaskType +from src.models.registry import PluginRegistry + + +@PluginRegistry.register_model("lightgbm") +class LightGBMModel(BaseModel): + """LightGBM 模型包装器 + + 支持分类、回归、排序三种任务类型。 + """ + + def __init__( + self, + task_type: TaskType, + params: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + ): + super().__init__(task_type, params, name) + self._model = None + + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + X_val: Optional[pl.DataFrame] = None, + y_val: Optional[pl.Series] = None, + **fit_params, + ) -> "LightGBMModel": + """训练模型""" + try: + import lightgbm as lgb + except ImportError: + raise ImportError( + "lightgbm is required. Install with: uv pip install lightgbm" + ) + + X_arr = X.to_numpy() + y_arr = y.to_numpy() + + train_data = lgb.Dataset(X_arr, label=y_arr) + valid_sets = [train_data] + valid_names = ["train"] + + if X_val is not None and y_val is not None: + valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy()) + valid_sets.append(valid_data) + valid_names.append("valid") + + default_params = { + "objective": self._get_objective(), + "metric": self._get_metric(), + "boosting_type": "gbdt", + "num_leaves": 31, + "learning_rate": 0.05, + "feature_fraction": 0.9, + "bagging_fraction": 0.8, + "bagging_freq": 5, + "verbose": -1, + } + default_params.update(self.params) + + callbacks = [] + if len(valid_sets) > 1: + callbacks.append(lgb.early_stopping(stopping_rounds=10, verbose=False)) + + self._model = lgb.train( + default_params, + train_data, + num_boost_round=fit_params.get("num_boost_round", 100), + valid_sets=valid_sets, + valid_names=valid_names, + callbacks=callbacks, + ) + self._is_fitted = True + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测""" + if not self._is_fitted: + raise RuntimeError("Model not fitted yet") + return self._model.predict(X.to_numpy()) + + def predict_proba(self, X: pl.DataFrame) -> np.ndarray: + """预测概率(仅分类任务)""" + if self.task_type != "classification": + raise ValueError("predict_proba only for classification") + probs = self.predict(X) + if len(probs.shape) == 1: + return np.vstack([1 - probs, probs]).T + return probs + + def get_feature_importance(self) -> Optional[pl.DataFrame]: + """获取特征重要性""" + if self._model is None: + return None + importance = self._model.feature_importance(importance_type="gain") + feature_names = getattr( + self._model, + "feature_name", + lambda: [f"feature_{i}" for i in range(len(importance))], + )() + return pl.DataFrame({"feature": feature_names, "importance": importance}).sort( + "importance", descending=True + ) + + def _get_objective(self) -> str: + objectives = { + "classification": "binary", + "regression": "regression", + "ranking": "lambdarank", + } + return objectives.get(self.task_type, "regression") + + def _get_metric(self) -> str: + metrics = {"classification": "auc", "regression": "rmse", "ranking": "ndcg"} + return metrics.get(self.task_type, "rmse") + + +@PluginRegistry.register_model("catboost") +class CatBoostModel(BaseModel): + """CatBoost 模型包装器""" + + def __init__( + self, + task_type: TaskType, + params: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + ): + super().__init__(task_type, params, name) + self._model = None + + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + X_val: Optional[pl.DataFrame] = None, + y_val: Optional[pl.Series] = None, + **fit_params, + ) -> "CatBoostModel": + """训练模型""" + try: + from catboost import CatBoostClassifier, CatBoostRegressor + except ImportError: + raise ImportError( + "catboost is required. Install with: uv pip install catboost" + ) + + if self.task_type == "classification": + model_class = CatBoostClassifier + default_params = {"loss_function": "Logloss", "eval_metric": "AUC"} + elif self.task_type == "regression": + model_class = CatBoostRegressor + default_params = {"loss_function": "RMSE"} + else: + model_class = CatBoostRegressor + default_params = {"loss_function": "QueryRMSE"} + + default_params.update(self.params) + default_params["verbose"] = False + + self._model = model_class(**default_params) + + eval_set = None + if X_val is not None and y_val is not None: + eval_set = (X_val.to_pandas(), y_val.to_pandas()) + + self._model.fit( + X.to_pandas(), + y.to_pandas(), + eval_set=eval_set, + early_stopping_rounds=fit_params.get("early_stopping_rounds", 10), + verbose=False, + ) + self._is_fitted = True + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测""" + if not self._is_fitted: + raise RuntimeError("Model not fitted yet") + return self._model.predict(X.to_pandas()) + + def predict_proba(self, X: pl.DataFrame) -> np.ndarray: + """预测概率""" + if self.task_type != "classification": + raise ValueError("predict_proba only for classification") + return self._model.predict_proba(X.to_pandas()) + + def get_feature_importance(self) -> Optional[pl.DataFrame]: + """获取特征重要性""" + if self._model is None: + return None + return pl.DataFrame( + { + "feature": self._model.feature_names_, + "importance": self._model.feature_importances_, + } + ).sort("importance", descending=True) + + +__all__ = ["LightGBMModel", "CatBoostModel"] diff --git a/src/models/pipeline.py b/src/models/pipeline.py new file mode 100644 index 0000000..09be2ee --- /dev/null +++ b/src/models/pipeline.py @@ -0,0 +1,70 @@ +"""数据处理流水线 + +管理多个处理器的顺序执行,支持阶段感知处理。 +""" + +from typing import List, Dict +import polars as pl + +from src.models.core import BaseProcessor, PipelineStage + + +class ProcessingPipeline: + """数据处理流水线 + + 按顺序执行多个处理器,自动处理阶段标记。 + 关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。 + """ + + def __init__(self, processors: List[BaseProcessor]): + """初始化流水线 + + Args: + processors: 处理器列表(按执行顺序) + """ + self.processors = processors + self._fitted_processors: Dict[int, BaseProcessor] = {} + + def fit_transform( + self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN + ) -> pl.DataFrame: + """在训练数据上fit所有处理器并transform""" + result = data + for i, processor in enumerate(self.processors): + if processor.stage in [PipelineStage.ALL, stage]: + result = processor.fit_transform(result) + self._fitted_processors[i] = processor + elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST: + processor.fit(result) + self._fitted_processors[i] = processor + return result + + def transform( + self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST + ) -> pl.DataFrame: + """在测试数据上应用已fit的处理器""" + result = data + for i, processor in enumerate(self.processors): + if processor.stage in [PipelineStage.ALL, stage]: + if i in self._fitted_processors: + result = self._fitted_processors[i].transform(result) + else: + result = processor.transform(result) + return result + + def save_processors(self, path: str) -> None: + """保存所有已fit的处理器状态""" + import pickle + + with open(path, "wb") as f: + pickle.dump(self._fitted_processors, f) + + def load_processors(self, path: str) -> None: + """加载处理器状态""" + import pickle + + with open(path, "rb") as f: + self._fitted_processors = pickle.load(f) + + +__all__ = ["ProcessingPipeline"] diff --git a/src/models/processors/__init__.py b/src/models/processors/__init__.py new file mode 100644 index 0000000..f68eb14 --- /dev/null +++ b/src/models/processors/__init__.py @@ -0,0 +1,21 @@ +"""处理器模块""" + +from src.models.processors.processors import ( + DropNAProcessor, + FillNAProcessor, + Winsorizer, + StandardScaler, + MinMaxScaler, + RankTransformer, + Neutralizer, +) + +__all__ = [ + "DropNAProcessor", + "FillNAProcessor", + "Winsorizer", + "StandardScaler", + "MinMaxScaler", + "RankTransformer", + "Neutralizer", +] diff --git a/src/models/processors/processors.py b/src/models/processors/processors.py new file mode 100644 index 0000000..c38a6a3 --- /dev/null +++ b/src/models/processors/processors.py @@ -0,0 +1,238 @@ +"""内置数据处理器 + +提供常用的数据预处理和转换处理器。 +""" + +from typing import List, Optional, Dict, Any +import polars as pl +import numpy as np + +from src.models.core import BaseProcessor, PipelineStage +from src.models.registry import PluginRegistry + +# 数值类型列表 +FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64] + + +def _get_numeric_columns( + data: pl.DataFrame, columns: Optional[List[str]] = None +) -> List[str]: + """获取数值列""" + if columns is not None: + return columns + return [c for c in data.columns if data[c].dtype in FLOAT_TYPES] + + +@PluginRegistry.register_processor("dropna") +class DropNAProcessor(BaseProcessor): + """缺失值删除处理器""" + + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "DropNAProcessor": + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + cols = self.columns or data.columns + return data.drop_nulls(subset=cols) + + +@PluginRegistry.register_processor("fillna") +class FillNAProcessor(BaseProcessor): + """缺失值填充处理器(只在训练阶段计算填充值)""" + + stage = PipelineStage.TRAIN + + def __init__(self, columns: Optional[List[str]] = None, method: str = "median"): + super().__init__(columns) + if method not in ["median", "mean", "zero"]: + raise ValueError(f"Unknown fill method: {method}") + self.method = method + + def fit(self, data: pl.DataFrame) -> "FillNAProcessor": + cols = _get_numeric_columns(data, self.columns) + fill_values = {} + + for col in cols: + if self.method == "median": + fill_values[col] = data[col].median() + elif self.method == "mean": + fill_values[col] = data[col].mean() + elif self.method == "zero": + fill_values[col] = 0.0 + + self._fitted_params = {"fill_values": fill_values, "columns": cols} + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col, val in self._fitted_params.get("fill_values", {}).items(): + if col in result.columns: + result = result.with_columns(pl.col(col).fill_null(val).alias(col)) + return result + + +@PluginRegistry.register_processor("winsorizer") +class Winsorizer(BaseProcessor): + """缩尾处理器 - 防止极端值影响(只在训练阶段计算分位数)""" + + stage = PipelineStage.TRAIN + + def __init__( + self, + columns: Optional[List[str]] = None, + lower: float = 0.01, + upper: float = 0.99, + ): + super().__init__(columns) + self.lower = lower + self.upper = upper + + def fit(self, data: pl.DataFrame) -> "Winsorizer": + cols = _get_numeric_columns(data, self.columns) + bounds = {} + + for col in cols: + bounds[col] = { + "lower": data[col].quantile(self.lower), + "upper": data[col].quantile(self.upper), + } + + self._fitted_params = {"bounds": bounds, "columns": cols} + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col, bounds in self._fitted_params.get("bounds", {}).items(): + if col in result.columns: + result = result.with_columns( + pl.col(col).clip(bounds["lower"], bounds["upper"]).alias(col) + ) + return result + + +@PluginRegistry.register_processor("standard_scaler") +class StandardScaler(BaseProcessor): + """标准化处理器 - Z-score标准化""" + + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "StandardScaler": + cols = _get_numeric_columns(data, self.columns) + stats = {} + + for col in cols: + stats[col] = {"mean": data[col].mean(), "std": data[col].std()} + + self._fitted_params = {"stats": stats, "columns": cols} + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col, stats in self._fitted_params.get("stats", {}).items(): + if col in result.columns and stats["std"] is not None and stats["std"] > 0: + result = result.with_columns( + ((pl.col(col) - stats["mean"]) / stats["std"]).alias(col) + ) + return result + + +@PluginRegistry.register_processor("minmax_scaler") +class MinMaxScaler(BaseProcessor): + """归一化处理器 - 缩放到[0, 1]范围""" + + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "MinMaxScaler": + cols = _get_numeric_columns(data, self.columns) + stats = {} + + for col in cols: + stats[col] = {"min": data[col].min(), "max": data[col].max()} + + self._fitted_params = {"stats": stats, "columns": cols} + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + for col, stats in self._fitted_params.get("stats", {}).items(): + if col in result.columns: + range_val = stats["max"] - stats["min"] + if range_val is not None and range_val > 0: + result = result.with_columns( + ((pl.col(col) - stats["min"]) / range_val).alias(col) + ) + return result + + +@PluginRegistry.register_processor("rank_transformer") +class RankTransformer(BaseProcessor): + """排名转换处理器 - 转换为截面排名""" + + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "RankTransformer": + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + cols = self.columns or _get_numeric_columns(data) + + for col in cols: + if col in result.columns: + result = result.with_columns( + pl.col(col).rank().over("trade_date").alias(col) + ) + return result + + +@PluginRegistry.register_processor("neutralizer") +class Neutralizer(BaseProcessor): + """中性化处理器 - 行业/市值中性化""" + + stage = PipelineStage.ALL + + def __init__( + self, + columns: Optional[List[str]] = None, + group_col: str = "industry", + exclude_cols: Optional[List[str]] = None, + ): + super().__init__(columns) + self.group_col = group_col + self.exclude_cols = exclude_cols or [] + + def fit(self, data: pl.DataFrame) -> "Neutralizer": + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data + cols = self.columns or _get_numeric_columns(data) + + for col in cols: + if col in result.columns and col not in self.exclude_cols: + result = result.with_columns( + ( + pl.col(col) + - pl.col(col).mean().over(["trade_date", self.group_col]) + ).alias(col) + ) + return result + + +__all__ = [ + "DropNAProcessor", + "FillNAProcessor", + "Winsorizer", + "StandardScaler", + "MinMaxScaler", + "RankTransformer", + "Neutralizer", +] diff --git a/src/models/registry.py b/src/models/registry.py new file mode 100644 index 0000000..b65767e --- /dev/null +++ b/src/models/registry.py @@ -0,0 +1,297 @@ +"""插件注册中心 + +提供装饰器方式注册处理器、模型、划分策略等组件。 +实现真正的插件式架构 - 新功能只需注册即可使用。 + +示例: + >>> @PluginRegistry.register_processor("standard_scaler") + ... class StandardScaler(BaseProcessor): + ... pass + + >>> # 使用 + >>> scaler = PluginRegistry.get_processor("standard_scaler")() +""" + +from typing import Type, Dict, List, TypeVar, Optional +from functools import wraps +from weakref import WeakValueDictionary +import contextlib + +from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric + +T = TypeVar("T") + + +class PluginRegistry: + """插件注册中心 + + 管理所有组件的注册和获取。使用装饰器方式注册新组件。 + + Attributes: + _processors: 已注册的处理器字典 + _models: 已注册的模型字典 + _splitters: 已注册的划分策略字典 + _metrics: 已注册的评估指标字典 + """ + + _processors: Dict[str, Type[BaseProcessor]] = {} + _models: Dict[str, Type[BaseModel]] = {} + _splitters: Dict[str, Type[BaseSplitter]] = {} + _metrics: Dict[str, Type[BaseMetric]] = {} + + @classmethod + @contextlib.contextmanager + def temp_registry(cls): + """临时注册上下文管理器 + + 在上下文管理器内部注册的组件会在退出时自动清理, + 避免测试之间的状态污染。 + + 示例: + >>> with PluginRegistry.temp_registry(): + ... @PluginRegistry.register_processor("temp_processor") + ... class TempProcessor(BaseProcessor): + ... pass + ... # 在此处可以使用 temp_processor + ... # 退出后自动清理 + """ + original_state = { + "_processors": cls._processors.copy(), + "_models": cls._models.copy(), + "_splitters": cls._splitters.copy(), + "_metrics": cls._metrics.copy(), + } + try: + yield cls + finally: + cls._processors = original_state["_processors"] + cls._models = original_state["_models"] + cls._splitters = original_state["_splitters"] + cls._metrics = original_state["_metrics"] + + @classmethod + def register_processor(cls, name: Optional[str] = None): + """注册处理器装饰器 + + 用于装饰器方式注册数据处理器。 + + 示例: + >>> @PluginRegistry.register_processor("standard_scaler") + ... class StandardScaler(BaseProcessor): + ... pass + + >>> # 获取并使用 + >>> scaler_class = PluginRegistry.get_processor("standard_scaler") + >>> scaler = scaler_class() + + Args: + name: 注册名称,默认为类名 + + Returns: + 装饰器函数 + """ + + def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]: + key = name or processor_class.__name__ + cls._processors[key] = processor_class + processor_class._registry_name = key + return processor_class + + return decorator + + @classmethod + def register_model(cls, name: Optional[str] = None): + """注册模型装饰器 + + 用于装饰器方式注册机器学习模型。 + + 示例: + >>> @PluginRegistry.register_model("lightgbm") + ... class LightGBMModel(BaseModel): + ... pass + + Args: + name: 注册名称,默认为类名 + + Returns: + 装饰器函数 + """ + + def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]: + key = name or model_class.__name__ + cls._models[key] = model_class + model_class._registry_name = key + return model_class + + return decorator + + @classmethod + def register_splitter(cls, name: Optional[str] = None): + """注册划分策略装饰器 + + 用于装饰器方式注册数据划分策略。 + + 示例: + >>> @PluginRegistry.register_splitter("time_series") + ... class TimeSeriesSplit(BaseSplitter): + ... pass + + Args: + name: 注册名称,默认为类名 + + Returns: + 装饰器函数 + """ + + def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]: + key = name or splitter_class.__name__ + cls._splitters[key] = splitter_class + splitter_class._registry_name = key + return splitter_class + + return decorator + + @classmethod + def register_metric(cls, name: Optional[str] = None): + """注册评估指标装饰器 + + 用于装饰器方式注册评估指标。 + + 示例: + >>> @PluginRegistry.register_metric("ic") + ... class ICMetric(BaseMetric): + ... pass + + Args: + name: 注册名称,默认为类名 + + Returns: + 装饰器函数 + """ + + def decorator(metric_class: Type[BaseMetric]) -> Type[BaseMetric]: + key = name or metric_class.__name__ + cls._metrics[key] = metric_class + metric_class._registry_name = key + return metric_class + + return decorator + + @classmethod + def get_processor(cls, name: str) -> Type[BaseProcessor]: + """获取处理器类 + + Args: + name: 处理器注册名称 + + Returns: + 处理器类 + + Raises: + KeyError: 处理器不存在时抛出 + """ + if name not in cls._processors: + available = list(cls._processors.keys()) + raise KeyError(f"Processor '{name}' not found. Available: {available}") + return cls._processors[name] + + @classmethod + def get_model(cls, name: str) -> Type[BaseModel]: + """获取模型类 + + Args: + name: 模型注册名称 + + Returns: + 模型类 + + Raises: + KeyError: 模型不存在时抛出 + """ + if name not in cls._models: + available = list(cls._models.keys()) + raise KeyError(f"Model '{name}' not found. Available: {available}") + return cls._models[name] + + @classmethod + def get_splitter(cls, name: str) -> Type[BaseSplitter]: + """获取划分策略类 + + Args: + name: 划分策略注册名称 + + Returns: + 划分策略类 + + Raises: + KeyError: 划分策略不存在时抛出 + """ + if name not in cls._splitters: + available = list(cls._splitters.keys()) + raise KeyError(f"Splitter '{name}' not found. Available: {available}") + return cls._splitters[name] + + @classmethod + def get_metric(cls, name: str) -> Type[BaseMetric]: + """获取评估指标类 + + Args: + name: 评估指标注册名称 + + Returns: + 评估指标类 + + Raises: + KeyError: 评估指标不存在时抛出 + """ + if name not in cls._metrics: + available = list(cls._metrics.keys()) + raise KeyError(f"Metric '{name}' not found. Available: {available}") + return cls._metrics[name] + + @classmethod + def list_processors(cls) -> List[str]: + """列出所有可用处理器 + + Returns: + 处理器名称列表 + """ + return list(cls._processors.keys()) + + @classmethod + def list_models(cls) -> List[str]: + """列出所有可用模型 + + Returns: + 模型名称列表 + """ + return list(cls._models.keys()) + + @classmethod + def list_splitters(cls) -> List[str]: + """列出所有可用划分策略 + + Returns: + 划分策略名称列表 + """ + return list(cls._splitters.keys()) + + @classmethod + def list_metrics(cls) -> List[str]: + """列出所有可用评估指标 + + Returns: + 评估指标名称列表 + """ + return list(cls._metrics.keys()) + + @classmethod + def clear_all(cls) -> None: + """清除所有注册(主要用于测试)""" + cls._processors.clear() + cls._models.clear() + cls._splitters.clear() + cls._metrics.clear() + + +__all__ = ["PluginRegistry"] diff --git a/tests/models/test_core.py b/tests/models/test_core.py new file mode 100644 index 0000000..23442b4 --- /dev/null +++ b/tests/models/test_core.py @@ -0,0 +1,478 @@ +"""模型框架核心测试 + +测试核心抽象类、插件注册中心、处理器、模型和划分策略。 +""" + +import pytest +import polars as pl +import numpy as np +from typing import List, Optional + +# 确保导入时注册所有组件 +from src.models import ( + PluginRegistry, + PipelineStage, + BaseProcessor, + BaseModel, + BaseSplitter, + ProcessingPipeline, +) +from src.models.core import TaskType + + +# ========== 测试核心抽象类 ========== + + +class TestPipelineStage: + """测试阶段枚举""" + + def test_stage_values(self): + assert PipelineStage.ALL.name == "ALL" + assert PipelineStage.TRAIN.name == "TRAIN" + assert PipelineStage.TEST.name == "TEST" + assert PipelineStage.VALIDATION.name == "VALIDATION" + + +class TestBaseProcessor: + """测试处理器基类""" + + def test_processor_initialization(self): + """测试处理器初始化""" + + class DummyProcessor(BaseProcessor): + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "DummyProcessor": + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + return data + + processor = DummyProcessor(columns=["col1", "col2"]) + assert processor.columns == ["col1", "col2"] + assert processor.stage == PipelineStage.ALL + assert not processor._is_fitted + + def test_processor_fit_transform(self): + """测试 fit_transform 方法""" + + class AddOneProcessor(BaseProcessor): + stage = PipelineStage.ALL + + def fit(self, data: pl.DataFrame) -> "AddOneProcessor": + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + result = data.clone() + for col in self.columns or []: + result = result.with_columns((pl.col(col) + 1).alias(col)) + return result + + processor = AddOneProcessor(columns=["value"]) + df = pl.DataFrame({"value": [1, 2, 3]}) + + result = processor.fit_transform(df) + + assert processor._is_fitted + assert result["value"].to_list() == [2, 3, 4] + + +class TestBaseModel: + """测试模型基类""" + + def test_model_initialization(self): + """测试模型初始化""" + + class DummyModel(BaseModel): + def fit(self, X, y, X_val=None, y_val=None, **kwargs): + self._is_fitted = True + return self + + def predict(self, X): + return np.zeros(len(X)) + + model = DummyModel( + task_type="regression", params={"lr": 0.01}, name="test_model" + ) + + assert model.task_type == "regression" + assert model.params == {"lr": 0.01} + assert model.name == "test_model" + assert not model._is_fitted + + def test_predict_proba_not_implemented(self): + """测试未实现 predict_proba 时抛出异常""" + + class DummyModel(BaseModel): + def fit(self, X, y, X_val=None, y_val=None, **kwargs): + return self + + def predict(self, X): + return np.zeros(len(X)) + + model = DummyModel(task_type="regression") + df = pl.DataFrame({"feature": [1, 2, 3]}) + + with pytest.raises(NotImplementedError): + model.predict_proba(df) + + +class TestBaseSplitter: + """测试划分策略基类""" + + def test_splitter_interface(self): + """测试划分策略接口""" + + class DummySplitter(BaseSplitter): + def split(self, data, date_col="trade_date"): + yield [0, 1], [2, 3] + + def get_split_dates(self, data, date_col="trade_date"): + return [("20200101", "20201231", "20210101", "20211231")] + + splitter = DummySplitter() + df = pl.DataFrame( + {"trade_date": ["20200101", "20200601", "20210101", "20210601"]} + ) + + splits = list(splitter.split(df)) + assert len(splits) == 1 + assert splits[0] == ([0, 1], [2, 3]) + + dates = splitter.get_split_dates(df) + assert dates == [("20200101", "20201231", "20210101", "20211231")] + + +# ========== 测试插件注册中心 ========== + + +class TestPluginRegistry: + """测试插件注册中心""" + + def setup_method(self): + """每个测试前清除注册""" + PluginRegistry.clear_all() + + def test_register_and_get_processor(self): + """测试注册和获取处理器""" + + @PluginRegistry.register_processor("test_processor") + class TestProcessor(BaseProcessor): + stage = PipelineStage.ALL + + def fit(self, data): + return self + + def transform(self, data): + return data + + processor_class = PluginRegistry.get_processor("test_processor") + assert processor_class == TestProcessor + assert "test_processor" in PluginRegistry.list_processors() + + def test_register_and_get_model(self): + """测试注册和获取模型""" + + @PluginRegistry.register_model("test_model") + class TestModel(BaseModel): + def fit(self, X, y, X_val=None, y_val=None, **kwargs): + return self + + def predict(self, X): + return np.zeros(len(X)) + + model_class = PluginRegistry.get_model("test_model") + assert model_class == TestModel + assert "test_model" in PluginRegistry.list_models() + + def test_register_and_get_splitter(self): + """测试注册和获取划分策略""" + + @PluginRegistry.register_splitter("test_splitter") + class TestSplitter(BaseSplitter): + def split(self, data, date_col="trade_date"): + yield [], [] + + def get_split_dates(self, data, date_col="trade_date"): + return [] + + splitter_class = PluginRegistry.get_splitter("test_splitter") + assert splitter_class == TestSplitter + assert "test_splitter" in PluginRegistry.list_splitters() + + def test_get_nonexistent_processor(self): + """测试获取不存在的处理器时抛出异常""" + with pytest.raises(KeyError) as exc_info: + PluginRegistry.get_processor("nonexistent") + assert "nonexistent" in str(exc_info.value) + + def test_register_with_default_name(self): + """测试使用默认名称注册""" + + @PluginRegistry.register_processor() + class MyCustomProcessor(BaseProcessor): + stage = PipelineStage.ALL + + def fit(self, data): + return self + + def transform(self, data): + return data + + assert "MyCustomProcessor" in PluginRegistry.list_processors() + + +# ========== 测试内置处理器 ========== + + +class TestBuiltInProcessors: + """测试内置处理器""" + + def test_dropna_processor(self): + """测试缺失值删除处理器""" + from src.models.processors import DropNAProcessor + + processor = DropNAProcessor(columns=["a", "b"]) + df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]}) + + result = processor.fit_transform(df) + + # 只有第一行没有缺失值 + assert len(result) == 1 + assert result["a"].to_list() == [1] + assert result["b"].to_list() == [4] + + def test_fillna_processor(self): + """测试缺失值填充处理器""" + from src.models.processors import FillNAProcessor + + processor = FillNAProcessor(columns=["a"], method="mean") + df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]}) + + result = processor.fit_transform(df) + + # 均值 = (1+2+4)/3 = 2.333... + assert result["a"][2] == pytest.approx(2.333, rel=0.01) + + def test_standard_scaler(self): + """测试标准化处理器""" + from src.models.processors import StandardScaler + + processor = StandardScaler(columns=["value"]) + df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]}) + + result = processor.fit_transform(df) + + # Z-score 标准化后均值为0,标准差为1 + assert result["value"].mean() == pytest.approx(0.0, abs=1e-10) + assert result["value"].std() == pytest.approx(1.0, rel=0.01) + + def test_winsorizer(self): + """测试缩尾处理器""" + from src.models.processors import Winsorizer + + processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9) + df = pl.DataFrame( + { + "value": list(range(100)) # 0-99 + } + ) + + result = processor.fit_transform(df) + + # 10%和90%分位数应该是10和89(Polars的quantile行为) + assert result["value"].min() == 10 + assert result["value"].max() == 89 + + def test_rank_transformer(self): + """测试排名转换处理器""" + from src.models.processors import RankTransformer + + processor = RankTransformer(columns=["value"]) + df = pl.DataFrame( + {"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]} + ) + + result = processor.fit_transform(df) + + # 排名应该是 1, 3, 2, 5, 4 + assert result["value"].to_list() == [1, 3, 2, 5, 4] + + def test_neutralizer(self): + """测试中性化处理器""" + from src.models.processors import Neutralizer + + processor = Neutralizer(columns=["value"], group_col="industry") + df = pl.DataFrame( + { + "trade_date": ["20200101", "20200101", "20200101", "20200101"], + "industry": ["A", "A", "B", "B"], + "value": [10, 20, 30, 50], + } + ) + + result = processor.fit_transform(df) + + # 分组去均值后,每组的均值为0 + group_a = result.filter(pl.col("industry") == "A") + group_b = result.filter(pl.col("industry") == "B") + + assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10) + assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10) + + +# ========== 测试处理流水线 ========== + + +class TestProcessingPipeline: + """测试处理流水线""" + + def test_pipeline_fit_transform(self): + """测试流水线的 fit_transform""" + from src.models.processors import StandardScaler + + scaler1 = StandardScaler(columns=["a"]) + scaler2 = StandardScaler(columns=["b"]) + + pipeline = ProcessingPipeline([scaler1, scaler2]) + + df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]}) + + result = pipeline.fit_transform(df) + + # 两个列都应该被标准化 + assert result["a"].mean() == pytest.approx(0.0, abs=1e-10) + assert result["b"].mean() == pytest.approx(0.0, abs=1e-10) + + def test_pipeline_transform_uses_fitted_params(self): + """测试 transform 使用已 fit 的参数""" + from src.models.processors import StandardScaler + + scaler = StandardScaler(columns=["value"]) + pipeline = ProcessingPipeline([scaler]) + + # 训练数据 + train_df = pl.DataFrame( + { + "value": [1.0, 2.0, 3.0] # 均值=2,标准差=1 + } + ) + + # 测试数据(不同的分布) + test_df = pl.DataFrame( + { + "value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5 + } + ) + + pipeline.fit_transform(train_df) + result = pipeline.transform(test_df) + + # 使用训练数据的均值=2和标准差=1进行标准化 + # 4 -> (4-2)/1 = 2 + assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10) + + +# ========== 测试划分策略 ========== + + +class TestSplitters: + """测试划分策略""" + + def test_time_series_split(self): + """测试时间序列划分""" + from src.models.core import TimeSeriesSplit + + splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3) + + # 10天的数据 + df = pl.DataFrame( + { + "trade_date": [f"202001{i:02d}" for i in range(1, 11)], + "value": list(range(10)), + } + ) + + splits = list(splitter.split(df)) + + # 应该有两折 + assert len(splits) == 2 + + # 检查每折训练集在测试集之前 + for train_idx, test_idx in splits: + assert max(train_idx) < min(test_idx) + + def test_walk_forward_split(self): + """测试滚动前向划分""" + from src.models.core import WalkForwardSplit + + splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1) + + df = pl.DataFrame( + { + "trade_date": [f"202001{i:02d}" for i in range(1, 13)], + "value": list(range(12)), + } + ) + + splits = list(splitter.split(df)) + + # 检查训练集大小固定 + for train_idx, test_idx in splits: + assert len(train_idx) == 5 + assert len(test_idx) == 2 + + def test_expanding_window_split(self): + """测试扩展窗口划分""" + from src.models.core import ExpandingWindowSplit + + splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1) + + df = pl.DataFrame( + { + "trade_date": [f"202001{i:02d}" for i in range(1, 15)], + "value": list(range(14)), + } + ) + + splits = list(splitter.split(df)) + + # 训练集应该逐渐增大 + train_sizes = [len(train_idx) for train_idx, _ in splits] + assert train_sizes[0] == 3 + assert train_sizes[1] == 5 # 3 + 2 + assert train_sizes[2] == 7 # 5 + 2 + + +# ========== 测试内置模型(可选,需要安装依赖) ========== + + +class TestModels: + """测试内置模型(标记为跳过如果依赖未安装)""" + + @pytest.mark.skip(reason="需要安装 lightgbm") + def test_lightgbm_model(self): + """测试 LightGBM 模型""" + from src.models.models import LightGBMModel + + model = LightGBMModel(task_type="regression", params={"n_estimators": 10}) + + X = pl.DataFrame( + { + "feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10, + "feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10, + } + ) + y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10) + + model.fit(X, y) + predictions = model.predict(X) + + assert len(predictions) == len(X) + assert model._is_fitted + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])