Files
ProStock/docs/factor_implementation_plan.md

847 lines
24 KiB
Markdown
Raw Normal View History

# ProStock 因子框架实现计划
## 目录结构
```
src/factors/
├── __init__.py # 导出主要类
├── data_spec.py # Phase 1: 数据类型定义
├── base.py # Phase 2: 因子基类
├── composite.py # Phase 2: 组合因子
├── data_loader.py # Phase 3: 数据加载
├── engine.py # Phase 4: 执行引擎
└── builtin/ # Phase 5: 内置因子库
├── __init__.py
├── momentum.py # 截面动量因子
├── technical.py # 时序技术指标
└── value.py # 截面估值因子
tests/factors/ # Phase 6-7: 测试
├── __init__.py
├── test_data_spec.py # 数据类型测试
├── test_base.py # 因子基类测试
├── test_composite.py # 组合因子测试
├── test_data_loader.py # 数据加载测试
├── test_engine.py # 引擎测试
├── test_builtin.py # 内置因子测试
└── test_integration.py # 集成测试
```
---
## Phase 1: 数据类型定义 (data_spec.py)
### 1.1 DataSpec - 数据需求规格
**实现要求:**
```python
@dataclass(frozen=True)
class DataSpec:
"""
数据需求规格说明
Args:
source: H5 文件名(如 "daily", "fundamental"
columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date"
lookback_days: 需要回看的天数(包含当日)
- 1 表示只需要当日数据 [T]
- 5 表示需要 [T-4, T] 共5天
- 20 表示需要 [T-19, T] 共20天
"""
source: str
columns: List[str]
lookback_days: int = 1
```
**约束验证:**
- `lookback_days >= 1`(至少包含当日)
- `columns` 必须包含 `ts_code``trade_date`
- `source` 不能为空字符串
**测试需求:**
- [ ] 测试有效 DataSpec 创建
- [ ] 测试 `lookback_days < 1` 时抛出 ValueError
- [ ] 测试缺少 `ts_code``trade_date` 时抛出 ValueError
- [ ] 测试空 `source` 时抛出 ValueError
- [ ] 测试 frozen 特性(创建后不可修改)
---
### 1.2 FactorContext - 计算上下文
**实现要求:**
```python
@dataclass
class FactorContext:
"""
因子计算上下文
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问
Attributes:
current_date: 当前计算日期 YYYYMMDD截面因子使用
current_stock: 当前计算股票代码(时序因子使用)
trade_dates: 交易日历列表(可选,用于对齐)
"""
current_date: Optional[str] = None
current_stock: Optional[str] = None
trade_dates: Optional[List[str]] = None
```
**测试需求:**
- [ ] 测试默认值创建
- [ ] 测试完整参数创建
- [ ] 测试 dataclass 自动生成的方法
---
### 1.3 FactorData - 数据容器
**实现要求:**
```python
class FactorData:
"""
提供给因子的数据容器
封装底层 Polars DataFrame提供安全的数据访问接口
"""
def __init__(self, df: pl.DataFrame, context: FactorContext):
self._df = df
self._context = context
def get_column(self, col: str) -> pl.Series:
"""
获取指定列的数据
- 截面因子:获取当天所有股票的该列值
- 时序因子:获取该股票时间序列的该列值
Args:
col: 列名
Returns:
Polars Series
Raises:
KeyError: 列不存在
"""
pass
def filter_by_date(self, date: str) -> "FactorData":
"""
按日期过滤数据,返回新的 FactorData
主要用于截面因子获取特定日期的数据
Args:
date: YYYYMMDD 格式的日期
Returns:
过滤后的 FactorData
"""
pass
def get_cross_section(self) -> pl.DataFrame:
"""
获取当前日期的截面数据
仅适用于截面因子,返回 current_date 当天的所有股票数据
Returns:
DataFrame 包含当前日期的所有股票
Raises:
ValueError: current_date 未设置(非截面因子场景)
"""
pass
def to_polars(self) -> pl.DataFrame:
"""获取底层的 Polars DataFrame高级用法"""
pass
@property
def context(self) -> FactorContext:
"""获取计算上下文"""
pass
def __len__(self) -> int:
"""返回数据行数"""
pass
```
**测试需求:**
- [ ] 测试 `get_column()` 返回正确 Series
- [ ] 测试 `get_column()` 列不存在时抛出 KeyError
- [ ] 测试 `filter_by_date()` 返回正确过滤结果
- [ ] 测试 `filter_by_date()` 日期不存在时返回空 DataFrame
- [ ] 测试 `get_cross_section()` 返回 current_date 当天的数据
- [ ] 测试 `get_cross_section()` current_date 为 None 时抛出 ValueError
- [ ] 测试 `to_polars()` 返回原始 DataFrame
- [ ] 测试 `context` 属性返回正确上下文
- [ ] 测试 `__len__()` 返回正确行数
---
## Phase 2: 因子基类 (base.py, composite.py)
### 2.1 BaseFactor - 抽象基类
**实现要求:**
```python
class BaseFactor(ABC):
"""
因子基类 - 定义通用接口
所有因子必须继承此类,并声明以下类属性:
- name: 因子唯一标识snake_case
- factor_type: "cross_sectional" 或 "time_series"
- data_specs: List[DataSpec] 数据需求列表
可选声明:
- category: 因子分类(默认 "default"
- description: 因子描述
"""
# 必须声明的类属性
name: str = ""
factor_type: str = "" # "cross_sectional" | "time_series"
data_specs: List[DataSpec] = field(default_factory=list)
# 可选声明的类属性
category: str = "default"
description: str = ""
def __init_subclass__(cls, **kwargs):
"""
子类创建时验证必须属性
验证项:
1. name 必须是非空字符串
2. factor_type 必须是 "cross_sectional" 或 "time_series"
3. data_specs 必须是非空列表
"""
pass
def __init__(self, **params):
"""
初始化因子参数
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
"""
self.params = params
self._validate_params()
def _validate_params(self):
"""
验证参数有效性
子类可覆盖此方法进行自定义验证
"""
pass
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
核心计算逻辑 - 子类必须实现
Args:
data: 安全的数据容器,已根据因子类型裁剪
Returns:
计算得到的因子值 Series
"""
pass
# ========== 因子组合运算符 ==========
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相加f1 + f2要求同类型"""
pass
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相减f1 - f2要求同类型"""
pass
def __mul__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相乘f1 * f2要求同类型"""
pass
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相除f1 / f2要求同类型"""
pass
def __rmul__(self, scalar: float) -> "ScalarFactor":
"""标量乘法0.5 * f1"""
pass
```
**测试需求:**
- [ ] 测试有效子类创建通过验证
- [ ] 测试缺少 `name` 时抛出 ValueError
- [ ] 测试 `name` 为空字符串时抛出 ValueError
- [ ] 测试缺少 `factor_type` 时抛出 ValueError
- [ ] 测试无效的 `factor_type`(非 cs/ts时抛出 ValueError
- [ ] 测试缺少 `data_specs` 时抛出 ValueError
- [ ] 测试 `data_specs` 为空列表时抛出 ValueError
- [ ] 测试 `compute()` 抽象方法强制子类实现
- [ ] 测试参数化初始化 `params` 正确存储
- [ ] 测试 `_validate_params()` 被调用
---
### 2.2 CrossSectionalFactor - 日期截面因子
**实现要求:**
```python
class CrossSectionalFactor(BaseFactor):
"""
日期截面因子基类
计算逻辑:在每个交易日,对所有股票进行横向计算
防泄露边界:
- ❌ 禁止访问未来日期的数据(日期泄露)
- ✅ 允许访问当前日期的所有股票数据
数据传入:
- compute() 接收的是 [T-lookback+1, T] 的数据
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
"""
factor_type: str = "cross_sectional"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算截面因子值
Args:
data: FactorData包含 [T-lookback+1, T] 的截面数据
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
示例:
def compute(self, data):
# 获取当前日期的截面
cs = data.get_cross_section()
# 计算市值排名
return cs['market_cap'].rank()
"""
pass
```
**测试需求:**
- [ ] 测试 `factor_type` 自动设置为 "cross_sectional"
- [ ] 测试子类必须实现 `compute()`
- [ ] 测试 `compute()` 返回类型为 pl.Series
---
### 2.3 TimeSeriesFactor - 时间序列因子
**实现要求:**
```python
class TimeSeriesFactor(BaseFactor):
"""
时间序列因子基类(股票截面)
计算逻辑:对每只股票,在其时间序列上进行纵向计算
防泄露边界:
- ❌ 禁止访问其他股票的数据(股票泄露)
- ✅ 允许访问该股票的完整历史数据
数据传入:
- compute() 接收的是单只股票的完整时间序列
- 包含该股票在 [start_date, end_date] 范围内的所有数据
"""
factor_type: str = "time_series"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算时间序列因子值
Args:
data: FactorData包含单只股票的完整时间序列
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
示例:
def compute(self, data):
series = data.get_column("close")
return series.rolling_mean(window_size=self.params['period'])
"""
pass
```
**测试需求:**
- [ ] 测试 `factor_type` 自动设置为 "time_series"
- [ ] 测试子类必须实现 `compute()`
- [ ] 测试 `compute()` 返回类型为 pl.Series
---
### 2.4 CompositeFactor - 组合因子 (composite.py)
**实现要求:**
```python
class CompositeFactor(BaseFactor):
"""
组合因子 - 用于实现因子间的数学运算
约束:左右因子必须是同类型(同为截面或同为时序)
"""
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
"""
创建组合因子
Args:
left: 左操作数因子
right: 右操作数因子
op: 运算符,支持 '+', '-', '*', '/'
Raises:
ValueError: 左右因子类型不一致
ValueError: 不支持的运算符
"""
pass
def _merge_data_specs(self) -> List[DataSpec]:
"""
合并左右因子的数据需求
策略:
1. 相同 source 和 columns 的 DataSpec 合并
2. lookback_days 取最大值
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""
执行组合运算
流程:
1. 分别计算 left 和 right 的值
2. 根据 op 执行运算
3. 返回结果
"""
pass
```
**测试需求:**
- [ ] 测试同类型因子组合成功cs + cs
- [ ] 测试同类型因子组合成功ts + ts
- [ ] 测试不同类型因子组合抛出 ValueErrorcs + ts
- [ ] 测试无效运算符抛出 ValueError
- [ ] 测试 `_merge_data_specs()` 正确合并(相同 source
- [ ] 测试 `_merge_data_specs()` 正确合并(不同 source
- [ ] 测试 `_merge_data_specs()` lookback 取最大值
- [ ] 测试 `compute()` 执行正确的数学运算
---
### 2.5 ScalarFactor - 标量运算因子 (composite.py)
**实现要求:**
```python
class ScalarFactor(BaseFactor):
"""
标量运算因子
支持scalar * factor, factor * scalar通过 __rmul__
"""
def __init__(self, factor: BaseFactor, scalar: float, op: str):
"""
创建标量运算因子
Args:
factor: 基础因子
scalar: 标量值
op: 运算符,支持 '*', '+'
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""执行标量运算"""
pass
```
**测试需求:**
- [ ] 测试标量乘法 `0.5 * factor`
- [ ] 测试标量乘法 `factor * 0.5`
- [ ] 测试标量加法(如支持)
- [ ] 测试继承基础因子的 data_specs
- [ ] 测试 `compute()` 返回正确缩放后的值
---
## Phase 3: 数据加载 (data_loader.py)
### 3.1 DataLoader - 数据加载器
**实现要求:**
```python
class DataLoader:
"""
数据加载器 - 负责从 HDF5 安全加载数据
功能:
1. 多文件聚合:合并多个 H5 文件的数据
2. 列选择:只加载需要的列
3. 原始数据缓存:避免重复读取
"""
def __init__(self, data_dir: str):
"""
初始化 DataLoader
Args:
data_dir: HDF5 文件所在目录
"""
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
def load(
self,
specs: List[DataSpec],
date_range: Optional[Tuple[str, str]] = None
) -> pl.DataFrame:
"""
加载并聚合多个 H5 文件的数据
流程:
1. 对每个 DataSpec
a. 检查缓存,命中则直接使用
b. 未命中则读取 HDF5通过 pandas
c. 转换为 Polars DataFrame
d. 按 date_range 过滤
e. 存入缓存
2. 合并多个 DataFrame按 trade_date 和 ts_code join
Args:
specs: 数据需求规格列表
date_range: 日期范围限制 (start_date, end_date),可选
Returns:
合并后的 Polars DataFrame
Raises:
FileNotFoundError: H5 文件不存在
KeyError: 列不存在于文件中
"""
pass
def clear_cache(self):
"""清空缓存"""
pass
def _read_h5(self, source: str) -> pl.DataFrame:
"""
读取单个 H5 文件
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
"""
pass
```
**测试需求:**
- [ ] 测试从单个 H5 文件加载数据
- [ ] 测试从多个 H5 文件加载并合并
- [ ] 测试列选择(只加载需要的列)
- [ ] 测试缓存机制(第二次加载更快)
- [ ] 测试 `clear_cache()` 清空缓存
- [ ] 测试按 date_range 过滤
- [ ] 测试文件不存在时抛出 FileNotFoundError
- [ ] 测试列不存在时抛出 KeyError
---
## Phase 4: 执行引擎 (engine.py)
### 4.1 FactorEngine - 因子执行引擎
**实现要求:**
```python
class FactorEngine:
"""
因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
核心职责:
1. CrossSectionalFactor防止日期泄露每天传入 [T-lookback+1, T] 数据
2. TimeSeriesFactor防止股票泄露每只股票传入完整序列
"""
def __init__(self, data_loader: DataLoader):
"""
初始化引擎
Args:
data_loader: 数据加载器实例
"""
self.data_loader = data_loader
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
"""
统一的计算入口
根据 factor_type 分发到具体方法:
- "cross_sectional" -> _compute_cross_sectional()
- "time_series" -> _compute_time_series()
Args:
factor: 要计算的因子
**kwargs: 额外参数,根据因子类型不同:
- 截面因子: start_date, end_date
- 时序因子: stock_codes, start_date, end_date
Returns:
DataFrame[trade_date, ts_code, factor_name]
"""
pass
```
**测试需求:**
- [ ] 测试 `compute()` 正确分发给截面计算
- [ ] 测试 `compute()` 正确分发给时序计算
- [ ] 测试无效 factor_type 时抛出 ValueError
---
### 4.2 截面计算(防止日期泄露)
**实现要求:**
```python
def _compute_cross_sectional(
self,
factor: CrossSectionalFactor,
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行日期截面计算
防泄露策略:
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
- 允许股票间比较:传入当天所有股票的数据
计算流程:
1. 计算 max_lookback确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每个日期 T in [start_date, end_date]
a. 裁剪数据到 [T-lookback+1, T]
b. 创建 FactorDatacurrent_date=T
c. 调用 factor.compute()
d. 收集结果
4. 合并所有日期的结果
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 0.5 │
│ 20240101 │ 000002.SZ│ 0.3 │
└────────────┴──────────┴──────────────┘
"""
pass
```
**测试需求(防泄露验证):**
- [ ] 测试数据裁剪正确(传入 [T-lookback+1, T]
- [ ] 测试不包含未来日期 T+1 的数据
- [ ] 测试每个日期独立计算
- [ ] 测试结果包含所有日期和所有股票
- [ ] 测试结果 DataFrame 格式正确
- [ ] 测试多个 DataSpec 时 lookback 取最大值
---
### 4.3 时序计算(防止股票泄露)
**实现要求:**
```python
def _compute_time_series(
self,
factor: TimeSeriesFactor,
stock_codes: List[str],
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行时间序列计算
防泄露策略:
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
- 允许访问历史数据:时序计算需要历史数据
计算流程:
1. 计算 max_lookback确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每只股票 S in stock_codes
a. 过滤出 S 的数据(防止股票泄露)
b. 创建 FactorDatacurrent_stock=S
c. 调用 factor.compute()(向量化计算整个序列)
d. 收集结果
4. 合并所有股票的结果
性能优势:
- 使用 Polars 的 rolling_mean 等向量化操作
- 每只股票只计算一次,无重复计算
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 10.5 │
│ 20240102 │ 000001.SZ│ 10.6 │
└────────────┴──────────┴──────────────┘
"""
pass
```
**测试需求(防泄露验证):**
- [ ] 测试每只股票只看到自己的数据
- [ ] 测试不包含其他股票的数据
- [ ] 测试传入的是完整时间序列(向量化计算)
- [ ] 测试结果包含所有股票和所有日期
- [ ] 测试结果 DataFrame 格式正确
- [ ] 测试股票不在数据中时跳过(或填充 null
---
## Phase 5: 内置因子库 (builtin/)
### 5.1 momentum.py - 截面动量因子
**实现因子:**
1. **ReturnRankFactor** - 当日收益率排名
```python
class ReturnRankFactor(CrossSectionalFactor):
"""当日收益率排名因子"""
name = "return_rank"
data_specs = [DataSpec("daily", ["close"], lookback_days=2)] # 需要2天计算收益率
def compute(self, data):
# 获取当前日期截面
cs = data.get_cross_section()
# 需要前1天和当天的收盘价lookback=2 保证数据包含 [T-1, T]
# 这里假设 data 已经包含历史,实际计算需要 groupby 处理
pass
```
**测试需求:**
- [ ] 测试收益率计算正确
- [ ] 测试排名计算正确
- [ ] 测试无数据时返回 null
2. **MomentumFactor** - 过去 N 日涨幅排名
---
### 5.2 technical.py - 时序技术指标
**实现因子:**
1. **MovingAverageFactor** - 移动平均线
```python
class MovingAverageFactor(TimeSeriesFactor):
"""移动平均线因子"""
name = "ma"
def __init__(self, period: int = 20):
super().__init__(period=period)
self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
def compute(self, data):
return data.get_column("close").rolling_mean(self.params["period"])
```
**测试需求:**
- [ ] 测试 MA20 计算正确
- [ ] 测试前19天返回 nullPolars 默认行为)
- [ ] 测试参数 period 生效
2. **RSIFactor** - RSI 指标
3. **MACDFactor** - MACD 指标
---
### 5.3 value.py - 截面估值因子
**实现因子:**
1. **PERankFactor** - PE 行业分位数
2. **PBFactor** - PB 排名
---
## Phase 6-7: 测试策略
### 测试金字塔
```
/\
/ \
/ 集成\ tests/factors/test_integration.py
/────────\
/ 引擎 \ tests/factors/test_engine.py
/────────────\
/ 基类/组合因子 \ tests/factors/test_base.py, test_composite.py
/────────────────\
/ 数据加载/类型 \ tests/factors/test_data_loader.py, test_data_spec.py
/──────────────────────\
```
### 测试数据准备
创建 `tests/fixtures/` 目录,包含:
- `sample_daily.h5`: 少量股票的日线数据(用于测试)
- `sample_fundamental.h5`: 基本面数据
### 关键测试场景
1. **防泄露测试(核心)**
- 截面因子:验证 compute() 中无法访问未来日期
- 时序因子:验证 compute() 中无法访问其他股票
2. **边界测试**
- lookback_days = 1最小值
- 数据起始点(前 N 天为 null
- 空数据/停牌处理
3. **性能测试(可选)**
- 大数据量下的内存占用
- 缓存命中率
---
## 实现状态
| Phase | 状态 | 完成日期 | 测试覆盖 |
|-------|------|----------|----------|
| Phase 1: 数据类型定义 | ✅ 已完成 | 2026-02-21 | 27 tests passed |
| Phase 2: 因子基类 | ✅ 已完成 | 2026-02-21 | 49 tests passed |
| Phase 3: 数据加载 | ✅ 已完成 | 2026-02-21 | 11 tests passed |
| Phase 4: 执行引擎 | ✅ 已完成 | 2026-02-22 | 10 tests passed |
| Phase 5: 内置因子库 | 📝 待开发 | - | - |
| Phase 6-7: 测试文档 | ✅ 已完成 | 2026-02-22 | 76 tests total |
---
## 实现顺序建议
1. **Week 1**: Phase 1-2数据类型 + 基类)
2. **Week 2**: Phase 3-4DataLoader + Engine**已完成**
3. **Week 3**: Phase 5内置因子
4. **Week 4**: Phase 6-7测试 + 文档)
每个 Phase 完成后运行对应测试,确保质量。