- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor) - 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData) - 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine) - 新增组合因子支持 (CompositeFactor, ScalarFactor) - 添加因子模块完整测试用例 - 添加 Git 提交规范文档
24 KiB
24 KiB
ProStock 因子框架实现计划
目录结构
src/factors/
├── __init__.py # 导出主要类
├── data_spec.py # Phase 1: 数据类型定义
├── base.py # Phase 2: 因子基类
├── composite.py # Phase 2: 组合因子
├── data_loader.py # Phase 3: 数据加载
├── engine.py # Phase 4: 执行引擎
└── builtin/ # Phase 5: 内置因子库
├── __init__.py
├── momentum.py # 截面动量因子
├── technical.py # 时序技术指标
└── value.py # 截面估值因子
tests/factors/ # Phase 6-7: 测试
├── __init__.py
├── test_data_spec.py # 数据类型测试
├── test_base.py # 因子基类测试
├── test_composite.py # 组合因子测试
├── test_data_loader.py # 数据加载测试
├── test_engine.py # 引擎测试
├── test_builtin.py # 内置因子测试
└── test_integration.py # 集成测试
Phase 1: 数据类型定义 (data_spec.py)
1.1 DataSpec - 数据需求规格
实现要求:
@dataclass(frozen=True)
class DataSpec:
"""
数据需求规格说明
Args:
source: H5 文件名(如 "daily", "fundamental")
columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date"
lookback_days: 需要回看的天数(包含当日)
- 1 表示只需要当日数据 [T]
- 5 表示需要 [T-4, T] 共5天
- 20 表示需要 [T-19, T] 共20天
"""
source: str
columns: List[str]
lookback_days: int = 1
约束验证:
lookback_days >= 1(至少包含当日)columns必须包含ts_code和trade_datesource不能为空字符串
测试需求:
- 测试有效 DataSpec 创建
- 测试
lookback_days < 1时抛出 ValueError - 测试缺少
ts_code或trade_date时抛出 ValueError - 测试空
source时抛出 ValueError - 测试 frozen 特性(创建后不可修改)
1.2 FactorContext - 计算上下文
实现要求:
@dataclass
class FactorContext:
"""
因子计算上下文
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问
Attributes:
current_date: 当前计算日期 YYYYMMDD(截面因子使用)
current_stock: 当前计算股票代码(时序因子使用)
trade_dates: 交易日历列表(可选,用于对齐)
"""
current_date: Optional[str] = None
current_stock: Optional[str] = None
trade_dates: Optional[List[str]] = None
测试需求:
- 测试默认值创建
- 测试完整参数创建
- 测试 dataclass 自动生成的方法
1.3 FactorData - 数据容器
实现要求:
class FactorData:
"""
提供给因子的数据容器
封装底层 Polars DataFrame,提供安全的数据访问接口
"""
def __init__(self, df: pl.DataFrame, context: FactorContext):
self._df = df
self._context = context
def get_column(self, col: str) -> pl.Series:
"""
获取指定列的数据
- 截面因子:获取当天所有股票的该列值
- 时序因子:获取该股票时间序列的该列值
Args:
col: 列名
Returns:
Polars Series
Raises:
KeyError: 列不存在
"""
pass
def filter_by_date(self, date: str) -> "FactorData":
"""
按日期过滤数据,返回新的 FactorData
主要用于截面因子获取特定日期的数据
Args:
date: YYYYMMDD 格式的日期
Returns:
过滤后的 FactorData
"""
pass
def get_cross_section(self) -> pl.DataFrame:
"""
获取当前日期的截面数据
仅适用于截面因子,返回 current_date 当天的所有股票数据
Returns:
DataFrame 包含当前日期的所有股票
Raises:
ValueError: current_date 未设置(非截面因子场景)
"""
pass
def to_polars(self) -> pl.DataFrame:
"""获取底层的 Polars DataFrame(高级用法)"""
pass
@property
def context(self) -> FactorContext:
"""获取计算上下文"""
pass
def __len__(self) -> int:
"""返回数据行数"""
pass
测试需求:
- 测试
get_column()返回正确 Series - 测试
get_column()列不存在时抛出 KeyError - 测试
filter_by_date()返回正确过滤结果 - 测试
filter_by_date()日期不存在时返回空 DataFrame - 测试
get_cross_section()返回 current_date 当天的数据 - 测试
get_cross_section()current_date 为 None 时抛出 ValueError - 测试
to_polars()返回原始 DataFrame - 测试
context属性返回正确上下文 - 测试
__len__()返回正确行数
Phase 2: 因子基类 (base.py, composite.py)
2.1 BaseFactor - 抽象基类
实现要求:
class BaseFactor(ABC):
"""
因子基类 - 定义通用接口
所有因子必须继承此类,并声明以下类属性:
- name: 因子唯一标识(snake_case)
- factor_type: "cross_sectional" 或 "time_series"
- data_specs: List[DataSpec] 数据需求列表
可选声明:
- category: 因子分类(默认 "default")
- description: 因子描述
"""
# 必须声明的类属性
name: str = ""
factor_type: str = "" # "cross_sectional" | "time_series"
data_specs: List[DataSpec] = field(default_factory=list)
# 可选声明的类属性
category: str = "default"
description: str = ""
def __init_subclass__(cls, **kwargs):
"""
子类创建时验证必须属性
验证项:
1. name 必须是非空字符串
2. factor_type 必须是 "cross_sectional" 或 "time_series"
3. data_specs 必须是非空列表
"""
pass
def __init__(self, **params):
"""
初始化因子参数
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
"""
self.params = params
self._validate_params()
def _validate_params(self):
"""
验证参数有效性
子类可覆盖此方法进行自定义验证
"""
pass
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
核心计算逻辑 - 子类必须实现
Args:
data: 安全的数据容器,已根据因子类型裁剪
Returns:
计算得到的因子值 Series
"""
pass
# ========== 因子组合运算符 ==========
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相加:f1 + f2(要求同类型)"""
pass
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相减:f1 - f2(要求同类型)"""
pass
def __mul__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相乘:f1 * f2(要求同类型)"""
pass
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相除:f1 / f2(要求同类型)"""
pass
def __rmul__(self, scalar: float) -> "ScalarFactor":
"""标量乘法:0.5 * f1"""
pass
测试需求:
- 测试有效子类创建通过验证
- 测试缺少
name时抛出 ValueError - 测试
name为空字符串时抛出 ValueError - 测试缺少
factor_type时抛出 ValueError - 测试无效的
factor_type(非 cs/ts)时抛出 ValueError - 测试缺少
data_specs时抛出 ValueError - 测试
data_specs为空列表时抛出 ValueError - 测试
compute()抽象方法强制子类实现 - 测试参数化初始化
params正确存储 - 测试
_validate_params()被调用
2.2 CrossSectionalFactor - 日期截面因子
实现要求:
class CrossSectionalFactor(BaseFactor):
"""
日期截面因子基类
计算逻辑:在每个交易日,对所有股票进行横向计算
防泄露边界:
- ❌ 禁止访问未来日期的数据(日期泄露)
- ✅ 允许访问当前日期的所有股票数据
数据传入:
- compute() 接收的是 [T-lookback+1, T] 的数据
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
"""
factor_type: str = "cross_sectional"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算截面因子值
Args:
data: FactorData,包含 [T-lookback+1, T] 的截面数据
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
示例:
def compute(self, data):
# 获取当前日期的截面
cs = data.get_cross_section()
# 计算市值排名
return cs['market_cap'].rank()
"""
pass
测试需求:
- 测试
factor_type自动设置为 "cross_sectional" - 测试子类必须实现
compute() - 测试
compute()返回类型为 pl.Series
2.3 TimeSeriesFactor - 时间序列因子
实现要求:
class TimeSeriesFactor(BaseFactor):
"""
时间序列因子基类(股票截面)
计算逻辑:对每只股票,在其时间序列上进行纵向计算
防泄露边界:
- ❌ 禁止访问其他股票的数据(股票泄露)
- ✅ 允许访问该股票的完整历史数据
数据传入:
- compute() 接收的是单只股票的完整时间序列
- 包含该股票在 [start_date, end_date] 范围内的所有数据
"""
factor_type: str = "time_series"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""
计算时间序列因子值
Args:
data: FactorData,包含单只股票的完整时间序列
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
示例:
def compute(self, data):
series = data.get_column("close")
return series.rolling_mean(window_size=self.params['period'])
"""
pass
测试需求:
- 测试
factor_type自动设置为 "time_series" - 测试子类必须实现
compute() - 测试
compute()返回类型为 pl.Series
2.4 CompositeFactor - 组合因子 (composite.py)
实现要求:
class CompositeFactor(BaseFactor):
"""
组合因子 - 用于实现因子间的数学运算
约束:左右因子必须是同类型(同为截面或同为时序)
"""
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
"""
创建组合因子
Args:
left: 左操作数因子
right: 右操作数因子
op: 运算符,支持 '+', '-', '*', '/'
Raises:
ValueError: 左右因子类型不一致
ValueError: 不支持的运算符
"""
pass
def _merge_data_specs(self) -> List[DataSpec]:
"""
合并左右因子的数据需求
策略:
1. 相同 source 和 columns 的 DataSpec 合并
2. lookback_days 取最大值
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""
执行组合运算
流程:
1. 分别计算 left 和 right 的值
2. 根据 op 执行运算
3. 返回结果
"""
pass
测试需求:
- 测试同类型因子组合成功(cs + cs)
- 测试同类型因子组合成功(ts + ts)
- 测试不同类型因子组合抛出 ValueError(cs + ts)
- 测试无效运算符抛出 ValueError
- 测试
_merge_data_specs()正确合并(相同 source) - 测试
_merge_data_specs()正确合并(不同 source) - 测试
_merge_data_specs()lookback 取最大值 - 测试
compute()执行正确的数学运算
2.5 ScalarFactor - 标量运算因子 (composite.py)
实现要求:
class ScalarFactor(BaseFactor):
"""
标量运算因子
支持:scalar * factor, factor * scalar(通过 __rmul__)
"""
def __init__(self, factor: BaseFactor, scalar: float, op: str):
"""
创建标量运算因子
Args:
factor: 基础因子
scalar: 标量值
op: 运算符,支持 '*', '+'
"""
pass
def compute(self, data: FactorData) -> pl.Series:
"""执行标量运算"""
pass
测试需求:
- 测试标量乘法
0.5 * factor - 测试标量乘法
factor * 0.5 - 测试标量加法(如支持)
- 测试继承基础因子的 data_specs
- 测试
compute()返回正确缩放后的值
Phase 3: 数据加载 (data_loader.py)
3.1 DataLoader - 数据加载器
实现要求:
class DataLoader:
"""
数据加载器 - 负责从 HDF5 安全加载数据
功能:
1. 多文件聚合:合并多个 H5 文件的数据
2. 列选择:只加载需要的列
3. 原始数据缓存:避免重复读取
"""
def __init__(self, data_dir: str):
"""
初始化 DataLoader
Args:
data_dir: HDF5 文件所在目录
"""
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
def load(
self,
specs: List[DataSpec],
date_range: Optional[Tuple[str, str]] = None
) -> pl.DataFrame:
"""
加载并聚合多个 H5 文件的数据
流程:
1. 对每个 DataSpec:
a. 检查缓存,命中则直接使用
b. 未命中则读取 HDF5(通过 pandas)
c. 转换为 Polars DataFrame
d. 按 date_range 过滤
e. 存入缓存
2. 合并多个 DataFrame(按 trade_date 和 ts_code join)
Args:
specs: 数据需求规格列表
date_range: 日期范围限制 (start_date, end_date),可选
Returns:
合并后的 Polars DataFrame
Raises:
FileNotFoundError: H5 文件不存在
KeyError: 列不存在于文件中
"""
pass
def clear_cache(self):
"""清空缓存"""
pass
def _read_h5(self, source: str) -> pl.DataFrame:
"""
读取单个 H5 文件
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
"""
pass
测试需求:
- 测试从单个 H5 文件加载数据
- 测试从多个 H5 文件加载并合并
- 测试列选择(只加载需要的列)
- 测试缓存机制(第二次加载更快)
- 测试
clear_cache()清空缓存 - 测试按 date_range 过滤
- 测试文件不存在时抛出 FileNotFoundError
- 测试列不存在时抛出 KeyError
Phase 4: 执行引擎 (engine.py)
4.1 FactorEngine - 因子执行引擎
实现要求:
class FactorEngine:
"""
因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
核心职责:
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
"""
def __init__(self, data_loader: DataLoader):
"""
初始化引擎
Args:
data_loader: 数据加载器实例
"""
self.data_loader = data_loader
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
"""
统一的计算入口
根据 factor_type 分发到具体方法:
- "cross_sectional" -> _compute_cross_sectional()
- "time_series" -> _compute_time_series()
Args:
factor: 要计算的因子
**kwargs: 额外参数,根据因子类型不同:
- 截面因子: start_date, end_date
- 时序因子: stock_codes, start_date, end_date
Returns:
DataFrame[trade_date, ts_code, factor_name]
"""
pass
测试需求:
- 测试
compute()正确分发给截面计算 - 测试
compute()正确分发给时序计算 - 测试无效 factor_type 时抛出 ValueError
4.2 截面计算(防止日期泄露)
实现要求:
def _compute_cross_sectional(
self,
factor: CrossSectionalFactor,
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行日期截面计算
防泄露策略:
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
- 允许股票间比较:传入当天所有股票的数据
计算流程:
1. 计算 max_lookback,确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每个日期 T in [start_date, end_date]:
a. 裁剪数据到 [T-lookback+1, T]
b. 创建 FactorData(current_date=T)
c. 调用 factor.compute()
d. 收集结果
4. 合并所有日期的结果
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 0.5 │
│ 20240101 │ 000002.SZ│ 0.3 │
└────────────┴──────────┴──────────────┘
"""
pass
测试需求(防泄露验证):
- 测试数据裁剪正确(传入 [T-lookback+1, T])
- 测试不包含未来日期 T+1 的数据
- 测试每个日期独立计算
- 测试结果包含所有日期和所有股票
- 测试结果 DataFrame 格式正确
- 测试多个 DataSpec 时 lookback 取最大值
4.3 时序计算(防止股票泄露)
实现要求:
def _compute_time_series(
self,
factor: TimeSeriesFactor,
stock_codes: List[str],
start_date: str,
end_date: str
) -> pl.DataFrame:
"""
执行时间序列计算
防泄露策略:
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
- 允许访问历史数据:时序计算需要历史数据
计算流程:
1. 计算 max_lookback,确定数据起始日期
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
3. 对每只股票 S in stock_codes:
a. 过滤出 S 的数据(防止股票泄露)
b. 创建 FactorData(current_stock=S)
c. 调用 factor.compute()(向量化计算整个序列)
d. 收集结果
4. 合并所有股票的结果
性能优势:
- 使用 Polars 的 rolling_mean 等向量化操作
- 每只股票只计算一次,无重复计算
返回 DataFrame 格式:
┌────────────┬──────────┬──────────────┐
│ trade_date │ ts_code │ factor_name │
├────────────┼──────────┼──────────────┤
│ 20240101 │ 000001.SZ│ 10.5 │
│ 20240102 │ 000001.SZ│ 10.6 │
└────────────┴──────────┴──────────────┘
"""
pass
测试需求(防泄露验证):
- 测试每只股票只看到自己的数据
- 测试不包含其他股票的数据
- 测试传入的是完整时间序列(向量化计算)
- 测试结果包含所有股票和所有日期
- 测试结果 DataFrame 格式正确
- 测试股票不在数据中时跳过(或填充 null)
Phase 5: 内置因子库 (builtin/)
5.1 momentum.py - 截面动量因子
实现因子:
- ReturnRankFactor - 当日收益率排名
class ReturnRankFactor(CrossSectionalFactor):
"""当日收益率排名因子"""
name = "return_rank"
data_specs = [DataSpec("daily", ["close"], lookback_days=2)] # 需要2天计算收益率
def compute(self, data):
# 获取当前日期截面
cs = data.get_cross_section()
# 需要前1天和当天的收盘价,lookback=2 保证数据包含 [T-1, T]
# 这里假设 data 已经包含历史,实际计算需要 groupby 处理
pass
测试需求:
- 测试收益率计算正确
- 测试排名计算正确
- 测试无数据时返回 null
- MomentumFactor - 过去 N 日涨幅排名
5.2 technical.py - 时序技术指标
实现因子:
- MovingAverageFactor - 移动平均线
class MovingAverageFactor(TimeSeriesFactor):
"""移动平均线因子"""
name = "ma"
def __init__(self, period: int = 20):
super().__init__(period=period)
self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
def compute(self, data):
return data.get_column("close").rolling_mean(self.params["period"])
测试需求:
- 测试 MA20 计算正确
- 测试前19天返回 null(Polars 默认行为)
- 测试参数 period 生效
- RSIFactor - RSI 指标
- MACDFactor - MACD 指标
5.3 value.py - 截面估值因子
实现因子:
- PERankFactor - PE 行业分位数
- PBFactor - PB 排名
Phase 6-7: 测试策略
测试金字塔
/\
/ \
/ 集成\ tests/factors/test_integration.py
/────────\
/ 引擎 \ tests/factors/test_engine.py
/────────────\
/ 基类/组合因子 \ tests/factors/test_base.py, test_composite.py
/────────────────\
/ 数据加载/类型 \ tests/factors/test_data_loader.py, test_data_spec.py
/──────────────────────\
测试数据准备
创建 tests/fixtures/ 目录,包含:
sample_daily.h5: 少量股票的日线数据(用于测试)sample_fundamental.h5: 基本面数据
关键测试场景
-
防泄露测试(核心)
- 截面因子:验证 compute() 中无法访问未来日期
- 时序因子:验证 compute() 中无法访问其他股票
-
边界测试
- lookback_days = 1(最小值)
- 数据起始点(前 N 天为 null)
- 空数据/停牌处理
-
性能测试(可选)
- 大数据量下的内存占用
- 缓存命中率
实现状态
| Phase | 状态 | 完成日期 | 测试覆盖 |
|---|---|---|---|
| Phase 1: 数据类型定义 | ✅ 已完成 | 2026-02-21 | 27 tests passed |
| Phase 2: 因子基类 | ✅ 已完成 | 2026-02-21 | 49 tests passed |
| Phase 3: 数据加载 | ✅ 已完成 | 2026-02-21 | 11 tests passed |
| Phase 4: 执行引擎 | ✅ 已完成 | 2026-02-22 | 10 tests passed |
| Phase 5: 内置因子库 | 📝 待开发 | - | - |
| Phase 6-7: 测试文档 | ✅ 已完成 | 2026-02-22 | 76 tests total |
实现顺序建议
- Week 1: Phase 1-2(数据类型 + 基类)
- Week 2: Phase 3-4(DataLoader + Engine)✅ 已完成
- Week 3: Phase 5(内置因子)
- Week 4: Phase 6-7(测试 + 文档)
每个 Phase 完成后运行对应测试,确保质量。