Files
ProStock/docs/factor_implementation_plan.md
liaozhaorun 0a16129548 feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
2026-02-22 14:41:32 +08:00

847 lines
24 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ProStock 因子框架实现计划
## 目录结构
```
src/factors/
├── __init__.py # 导出主要类
├── data_spec.py # Phase 1: 数据类型定义
├── base.py # Phase 2: 因子基类
├── composite.py # Phase 2: 组合因子
├── data_loader.py # Phase 3: 数据加载
├── engine.py # Phase 4: 执行引擎
└── builtin/ # Phase 5: 内置因子库
├── __init__.py
├── momentum.py # 截面动量因子
├── technical.py # 时序技术指标
└── value.py # 截面估值因子
tests/factors/ # Phase 6-7: 测试
├── __init__.py
├── test_data_spec.py # 数据类型测试
├── test_base.py # 因子基类测试
├── test_composite.py # 组合因子测试
├── test_data_loader.py # 数据加载测试
├── test_engine.py # 引擎测试
├── test_builtin.py # 内置因子测试
└── test_integration.py # 集成测试
```
---
## Phase 1: 数据类型定义 (data_spec.py)
### 1.1 DataSpec - 数据需求规格
**实现要求:**
```python
@dataclass(frozen=True)
class DataSpec:
"""
数据需求规格说明
Args:
source: H5 文件名(如 "daily", "fundamental"
columns: 需要的列名列表,必须包含 "ts_code""trade_date"
lookback_days: 需要回看的天数(包含当日)
- 1 表示只需要当日数据 [T]
- 5 表示需要 [T-4, T] 共5天
- 20 表示需要 [T-19, T] 共20天
"""
source: str
columns: List[str]
lookback_days: int = 1
```
**约束验证:**
- `lookback_days >= 1`(至少包含当日)
- `columns` 必须包含 `ts_code``trade_date`
- `source` 不能为空字符串
**测试需求:**
- [ ] 测试有效 DataSpec 创建
- [ ] 测试 `lookback_days < 1` 时抛出 ValueError
- [ ] 测试缺少 `ts_code``trade_date` 时抛出 ValueError
- [ ] 测试空 `source` 时抛出 ValueError
- [ ] 测试 frozen 特性(创建后不可修改)
---
### 1.2 FactorContext - 计算上下文
**实现要求:**
```python
@dataclass
class FactorContext:
"""
因子计算上下文
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问
Attributes:
current_date: 当前计算日期 YYYYMMDD截面因子使用
current_stock: 当前计算股票代码(时序因子使用)
trade_dates: 交易日历列表(可选,用于对齐)
"""
current_date: Optional[str] = None
current_stock: Optional[str] = None
trade_dates: Optional[List[str]] = None
```
**测试需求:**
- [ ] 测试默认值创建
- [ ] 测试完整参数创建
- [ ] 测试 dataclass 自动生成的方法
---
### 1.3 FactorData - 数据容器
**实现要求:**
```python
class FactorData:
"""
提供给因子的数据容器
封装底层 Polars DataFrame提供安全的数据访问接口
"""
def __init__(self, df: pl.DataFrame, context: FactorContext):
self._df = df
self._context = context
def get_column(self, col: str) -> pl.Series:
"""
获取指定列的数据
- 截面因子:获取当天所有股票的该列值
- 时序因子:获取该股票时间序列的该列值
Args:
col: 列名
Returns:
Polars Series
Raises:
KeyError: 列不存在
"""
pass
def filter_by_date(self, date: str) -> "FactorData":
"""
按日期过滤数据,返回新的 FactorData
主要用于截面因子获取特定日期的数据
Args:
date: YYYYMMDD 格式的日期
Returns:
过滤后的 FactorData
"""
pass
def get_cross_section(self) -> pl.DataFrame:
"""
获取当前日期的截面数据
仅适用于截面因子,返回 current_date 当天的所有股票数据
Returns:
DataFrame 包含当前日期的所有股票
Raises:
ValueError: current_date 未设置(非截面因子场景)
"""
pass
def to_polars(self) -> pl.DataFrame:
"""获取底层的 Polars DataFrame高级用法"""
pass
@property
def context(self) -> FactorContext:
"""获取计算上下文"""
pass
def __len__(self) -> int:
"""返回数据行数"""
pass
```
**测试需求:**
- [ ] 测试 `get_column()` 返回正确 Series
- [ ] 测试 `get_column()` 列不存在时抛出 KeyError
- [ ] 测试 `filter_by_date()` 返回正确过滤结果
- [ ] 测试 `filter_by_date()` 日期不存在时返回空 DataFrame
- [ ] 测试 `get_cross_section()` 返回 current_date 当天的数据
- [ ] 测试 `get_cross_section()` current_date 为 None 时抛出 ValueError
- [ ] 测试 `to_polars()` 返回原始 DataFrame
- [ ] 测试 `context` 属性返回正确上下文
- [ ] 测试 `__len__()` 返回正确行数
---
## Phase 2: 因子基类 (base.py, composite.py)
### 2.1 BaseFactor - 抽象基类
**实现要求:**
```python
class BaseFactor(ABC):
"""
因子基类 - 定义通用接口
所有因子必须继承此类,并声明以下类属性:
- name: 因子唯一标识snake_case
- factor_type: "cross_sectional""time_series"
- data_specs: List[DataSpec] 数据需求列表
可选声明:
- category: 因子分类(默认 "default"
- description: 因子描述
"""
# 必须声明的类属性
name: str = ""
factor_type: str = "" # "cross_sectional" | "time_series"
data_specs: List[DataSpec] = field(default_factory=list)
# 可选声明的类属性
category: str = "default"
description: str = ""
def __init_subclass__(cls, **kwargs):
"""
子类创建时验证必须属性
验证项:
1. name 必须是非空字符串
2. factor_type 必须是 "cross_sectional""time_series"
3. data_specs 必须是非空列表
"""
pass
def __init__(self, **params):
"""
初始化因子参数
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
"""
self.params = params
self._validate_params()
def _validate_params(self):
"""
验证参数有效性
子类可覆盖此方法进行自定义验证
"""
pass
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
核心计算逻辑 - 子类必须实现
Args:
data: 安全的数据容器,已根据因子类型裁剪
Returns:
计算得到的因子值 Series
"""
pass
# ========== 因子组合运算符 ==========
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相加f1 + f2要求同类型"""
pass
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相减f1 - f2要求同类型"""
pass
def __mul__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相乘f1 * f2要求同类型"""
pass
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相除f1 / f2要求同类型"""
pass
def __rmul__(self, scalar: float) -> "ScalarFactor":
"""标量乘法0.5 * f1"""
pass
```
**测试需求:**
- [ ] 测试有效子类创建通过验证
- [ ] 测试缺少 `name` 时抛出 ValueError
- [ ] 测试 `name` 为空字符串时抛出 ValueError
- [ ] 测试缺少 `factor_type` 时抛出 ValueError
- [ ] 测试无效的 `factor_type`(非 cs/ts时抛出 ValueError
- [ ] 测试缺少 `data_specs` 时抛出 ValueError
- [ ] 测试 `data_specs` 为空列表时抛出 ValueError
- [ ] 测试 `compute()` 抽象方法强制子类实现
- [ ] 测试参数化初始化 `params` 正确存储
- [ ] 测试 `_validate_params()` 被调用
---
### 2.2 CrossSectionalFactor - 日期截面因子
**实现要求:**
```python
class CrossSectionalFactor(BaseFactor):
"""
日期截面因子基类
计算逻辑:在每个交易日,对所有股票进行横向计算
防泄露边界:
- ❌ 禁止访问未来日期的数据(日期泄露)
- ✅ 允许访问当前日期的所有股票数据
数据传入:
- compute() 接收的是 [T-lookback+1, T] 的数据
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
"""
factor_type: str = "cross_sectional"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算截面因子值
Args:
data: FactorData包含 [T-lookback+1, T] 的截面数据
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
示例:
def compute(self, data):
# 获取当前日期的截面
cs = data.get_cross_section()
# 计算市值排名
return cs['market_cap'].rank()
"""
pass
```
**测试需求:**
- [ ] 测试 `factor_type` 自动设置为 "cross_sectional"
- [ ] 测试子类必须实现 `compute()`
- [ ] 测试 `compute()` 返回类型为 pl.Series
---
### 2.3 TimeSeriesFactor - 时间序列因子
**实现要求:**
```python
class TimeSeriesFactor(BaseFactor):
"""
时间序列因子基类(股票截面)
计算逻辑:对每只股票,在其时间序列上进行纵向计算
防泄露边界:
- ❌ 禁止访问其他股票的数据(股票泄露)
- ✅ 允许访问该股票的完整历史数据
数据传入:
- compute() 接收的是单只股票的完整时间序列
- 包含该股票在 [start_date, end_date] 范围内的所有数据
"""
factor_type: str = "time_series"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算时间序列因子值
Args:
data: FactorData包含单只股票的完整时间序列
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
示例:
def compute(self, data):
series = data.get_column("close")
return series.rolling_mean(window_size=self.params['period'])
"""
pass
```
**测试需求:**
- [ ] 测试 `factor_type` 自动设置为 "time_series"
- [ ] 测试子类必须实现 `compute()`
- [ ] 测试 `compute()` 返回类型为 pl.Series
---
### 2.4 CompositeFactor - 组合因子 (composite.py)
**实现要求:**
```python
class CompositeFactor(BaseFactor):
"""
组合因子 - 用于实现因子间的数学运算
约束:左右因子必须是同类型(同为截面或同为时序)
"""
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
"""
创建组合因子
Args:
left: 左操作数因子
right: 右操作数因子
op: 运算符,支持 '+', '-', '*', '/'
Raises:
ValueError: 左右因子类型不一致
ValueError: 不支持的运算符
"""
pass
def _merge_data_specs(self) -> List[DataSpec]:
"""
合并左右因子的数据需求
策略:
1. 相同 source 和 columns 的 DataSpec 合并
2. lookback_days 取最大值
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""
执行组合运算
流程:
1. 分别计算 left 和 right 的值
2. 根据 op 执行运算
3. 返回结果
"""
pass
```
**测试需求:**
- [ ] 测试同类型因子组合成功cs + cs
- [ ] 测试同类型因子组合成功ts + ts
- [ ] 测试不同类型因子组合抛出 ValueErrorcs + ts
- [ ] 测试无效运算符抛出 ValueError
- [ ] 测试 `_merge_data_specs()` 正确合并(相同 source
- [ ] 测试 `_merge_data_specs()` 正确合并(不同 source
- [ ] 测试 `_merge_data_specs()` lookback 取最大值
- [ ] 测试 `compute()` 执行正确的数学运算
---
### 2.5 ScalarFactor - 标量运算因子 (composite.py)
**实现要求:**
```python
class ScalarFactor(BaseFactor):
"""
标量运算因子
支持scalar * factor, factor * scalar通过 __rmul__
"""
def __init__(self, factor: BaseFactor, scalar: float, op: str):
"""
创建标量运算因子
Args:
factor: 基础因子
scalar: 标量值
op: 运算符,支持 '*', '+'
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""执行标量运算"""
pass
```
**测试需求:**
- [ ] 测试标量乘法 `0.5 * factor`
- [ ] 测试标量乘法 `factor * 0.5`
- [ ] 测试标量加法(如支持)
- [ ] 测试继承基础因子的 data_specs
- [ ] 测试 `compute()` 返回正确缩放后的值
---
## Phase 3: 数据加载 (data_loader.py)
### 3.1 DataLoader - 数据加载器
**实现要求:**
```python
class DataLoader:
"""
数据加载器 - 负责从 HDF5 安全加载数据
功能:
1. 多文件聚合:合并多个 H5 文件的数据
2. 列选择:只加载需要的列
3. 原始数据缓存:避免重复读取
"""
def __init__(self, data_dir: str):
"""
初始化 DataLoader
Args:
data_dir: HDF5 文件所在目录
"""
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
def load(
self,
specs: List[DataSpec],
date_range: Optional[Tuple[str, str]] = None
) -> pl.DataFrame:
"""
加载并聚合多个 H5 文件的数据
流程:
1. 对每个 DataSpec
a. 检查缓存,命中则直接使用
b. 未命中则读取 HDF5通过 pandas
c. 转换为 Polars DataFrame
d. 按 date_range 过滤
e. 存入缓存
2. 合并多个 DataFrame按 trade_date 和 ts_code join
Args:
specs: 数据需求规格列表
date_range: 日期范围限制 (start_date, end_date),可选
Returns:
合并后的 Polars DataFrame
Raises:
FileNotFoundError: H5 文件不存在
KeyError: 列不存在于文件中
"""
pass
def clear_cache(self):
"""清空缓存"""
pass
def _read_h5(self, source: str) -> pl.DataFrame:
"""
读取单个 H5 文件
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
"""
pass
```
**测试需求:**
- [ ] 测试从单个 H5 文件加载数据
- [ ] 测试从多个 H5 文件加载并合并
- [ ] 测试列选择(只加载需要的列)
- [ ] 测试缓存机制(第二次加载更快)
- [ ] 测试 `clear_cache()` 清空缓存
- [ ] 测试按 date_range 过滤
- [ ] 测试文件不存在时抛出 FileNotFoundError
- [ ] 测试列不存在时抛出 KeyError
---
## Phase 4: 执行引擎 (engine.py)
### 4.1 FactorEngine - 因子执行引擎
**实现要求:**
```python
class FactorEngine:
"""
因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
核心职责:
1. CrossSectionalFactor防止日期泄露每天传入 [T-lookback+1, T] 数据
2. TimeSeriesFactor防止股票泄露每只股票传入完整序列
"""
def __init__(self, data_loader: DataLoader):
"""
初始化引擎
Args:
data_loader: 数据加载器实例
"""
self.data_loader = data_loader
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
"""
统一的计算入口
根据 factor_type 分发到具体方法:
- "cross_sectional" -> _compute_cross_sectional()
- "time_series" -> _compute_time_series()
Args:
factor: 要计算的因子
**kwargs: 额外参数,根据因子类型不同:
- 截面因子: start_date, end_date
- 时序因子: stock_codes, start_date, end_date
Returns:
DataFrame[trade_date, ts_code, factor_name]
"""
pass
```
**测试需求:**
- [ ] 测试 `compute()` 正确分发给截面计算
- [ ] 测试 `compute()` 正确分发给时序计算
- [ ] 测试无效 factor_type 时抛出 ValueError
---
### 4.2 截面计算(防止日期泄露)
**实现要求:**
```python
def _compute_cross_sectional(
self,
factor: CrossSectionalFactor,
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行日期截面计算
防泄露策略:
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
- 允许股票间比较:传入当天所有股票的数据
计算流程:
1. 计算 max_lookback确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每个日期 T in [start_date, end_date]
a. 裁剪数据到 [T-lookback+1, T]
b. 创建 FactorDatacurrent_date=T
c. 调用 factor.compute()
d. 收集结果
4. 合并所有日期的结果
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 0.5 │
│ 20240101 │ 000002.SZ│ 0.3 │
└────────────┴──────────┴──────────────┘
"""
pass
```
**测试需求(防泄露验证):**
- [ ] 测试数据裁剪正确(传入 [T-lookback+1, T]
- [ ] 测试不包含未来日期 T+1 的数据
- [ ] 测试每个日期独立计算
- [ ] 测试结果包含所有日期和所有股票
- [ ] 测试结果 DataFrame 格式正确
- [ ] 测试多个 DataSpec 时 lookback 取最大值
---
### 4.3 时序计算(防止股票泄露)
**实现要求:**
```python
def _compute_time_series(
self,
factor: TimeSeriesFactor,
stock_codes: List[str],
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行时间序列计算
防泄露策略:
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
- 允许访问历史数据:时序计算需要历史数据
计算流程:
1. 计算 max_lookback确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每只股票 S in stock_codes
a. 过滤出 S 的数据(防止股票泄露)
b. 创建 FactorDatacurrent_stock=S
c. 调用 factor.compute()(向量化计算整个序列)
d. 收集结果
4. 合并所有股票的结果
性能优势:
- 使用 Polars 的 rolling_mean 等向量化操作
- 每只股票只计算一次,无重复计算
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 10.5 │
│ 20240102 │ 000001.SZ│ 10.6 │
└────────────┴──────────┴──────────────┘
"""
pass
```
**测试需求(防泄露验证):**
- [ ] 测试每只股票只看到自己的数据
- [ ] 测试不包含其他股票的数据
- [ ] 测试传入的是完整时间序列(向量化计算)
- [ ] 测试结果包含所有股票和所有日期
- [ ] 测试结果 DataFrame 格式正确
- [ ] 测试股票不在数据中时跳过(或填充 null
---
## Phase 5: 内置因子库 (builtin/)
### 5.1 momentum.py - 截面动量因子
**实现因子:**
1. **ReturnRankFactor** - 当日收益率排名
```python
class ReturnRankFactor(CrossSectionalFactor):
"""当日收益率排名因子"""
name = "return_rank"
data_specs = [DataSpec("daily", ["close"], lookback_days=2)] # 需要2天计算收益率
def compute(self, data):
# 获取当前日期截面
cs = data.get_cross_section()
# 需要前1天和当天的收盘价lookback=2 保证数据包含 [T-1, T]
# 这里假设 data 已经包含历史,实际计算需要 groupby 处理
pass
```
**测试需求:**
- [ ] 测试收益率计算正确
- [ ] 测试排名计算正确
- [ ] 测试无数据时返回 null
2. **MomentumFactor** - 过去 N 日涨幅排名
---
### 5.2 technical.py - 时序技术指标
**实现因子:**
1. **MovingAverageFactor** - 移动平均线
```python
class MovingAverageFactor(TimeSeriesFactor):
"""移动平均线因子"""
name = "ma"
def __init__(self, period: int = 20):
super().__init__(period=period)
self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
def compute(self, data):
return data.get_column("close").rolling_mean(self.params["period"])
```
**测试需求:**
- [ ] 测试 MA20 计算正确
- [ ] 测试前19天返回 nullPolars 默认行为)
- [ ] 测试参数 period 生效
2. **RSIFactor** - RSI 指标
3. **MACDFactor** - MACD 指标
---
### 5.3 value.py - 截面估值因子
**实现因子:**
1. **PERankFactor** - PE 行业分位数
2. **PBFactor** - PB 排名
---
## Phase 6-7: 测试策略
### 测试金字塔
```
/\
/ \
/ 集成\ tests/factors/test_integration.py
/────────\
/ 引擎 \ tests/factors/test_engine.py
/────────────\
/ 基类/组合因子 \ tests/factors/test_base.py, test_composite.py
/────────────────\
/ 数据加载/类型 \ tests/factors/test_data_loader.py, test_data_spec.py
/──────────────────────\
```
### 测试数据准备
创建 `tests/fixtures/` 目录,包含:
- `sample_daily.h5`: 少量股票的日线数据(用于测试)
- `sample_fundamental.h5`: 基本面数据
### 关键测试场景
1. **防泄露测试(核心)**
- 截面因子:验证 compute() 中无法访问未来日期
- 时序因子:验证 compute() 中无法访问其他股票
2. **边界测试**
- lookback_days = 1最小值
- 数据起始点(前 N 天为 null
- 空数据/停牌处理
3. **性能测试(可选)**
- 大数据量下的内存占用
- 缓存命中率
---
## 实现状态
| Phase | 状态 | 完成日期 | 测试覆盖 |
|-------|------|----------|----------|
| Phase 1: 数据类型定义 | ✅ 已完成 | 2026-02-21 | 27 tests passed |
| Phase 2: 因子基类 | ✅ 已完成 | 2026-02-21 | 49 tests passed |
| Phase 3: 数据加载 | ✅ 已完成 | 2026-02-21 | 11 tests passed |
| Phase 4: 执行引擎 | ✅ 已完成 | 2026-02-22 | 10 tests passed |
| Phase 5: 内置因子库 | 📝 待开发 | - | - |
| Phase 6-7: 测试文档 | ✅ 已完成 | 2026-02-22 | 76 tests total |
---
## 实现顺序建议
1. **Week 1**: Phase 1-2数据类型 + 基类)
2. **Week 2**: Phase 3-4DataLoader + Engine**已完成**
3. **Week 3**: Phase 5内置因子
4. **Week 4**: Phase 6-7测试 + 文档)
每个 Phase 完成后运行对应测试,确保质量。