Files
ProStock/docs/factor_implementation_plan.md
liaozhaorun 0a16129548 feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
2026-02-22 14:41:32 +08:00

24 KiB
Raw Blame History

ProStock 因子框架实现计划

目录结构

src/factors/
├── __init__.py              # 导出主要类
├── data_spec.py             # Phase 1: 数据类型定义
├── base.py                  # Phase 2: 因子基类
├── composite.py             # Phase 2: 组合因子
├── data_loader.py           # Phase 3: 数据加载
├── engine.py                # Phase 4: 执行引擎
└── builtin/                 # Phase 5: 内置因子库
    ├── __init__.py
    ├── momentum.py          # 截面动量因子
    ├── technical.py         # 时序技术指标
    └── value.py             # 截面估值因子

tests/factors/               # Phase 6-7: 测试
├── __init__.py
├── test_data_spec.py        # 数据类型测试
├── test_base.py             # 因子基类测试
├── test_composite.py        # 组合因子测试
├── test_data_loader.py      # 数据加载测试
├── test_engine.py           # 引擎测试
├── test_builtin.py          # 内置因子测试
└── test_integration.py      # 集成测试

Phase 1: 数据类型定义 (data_spec.py)

1.1 DataSpec - 数据需求规格

实现要求:

@dataclass(frozen=True)
class DataSpec:
    """
    数据需求规格说明
    
    Args:
        source: H5 文件名(如 "daily", "fundamental"
        columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date"
        lookback_days: 需要回看的天数(包含当日)
            - 1 表示只需要当日数据 [T]
            - 5 表示需要 [T-4, T] 共5天
            - 20 表示需要 [T-19, T] 共20天
    """
    source: str
    columns: List[str]
    lookback_days: int = 1

约束验证:

  • lookback_days >= 1(至少包含当日)
  • columns 必须包含 ts_codetrade_date
  • source 不能为空字符串

测试需求:

  • 测试有效 DataSpec 创建
  • 测试 lookback_days < 1 时抛出 ValueError
  • 测试缺少 ts_codetrade_date 时抛出 ValueError
  • 测试空 source 时抛出 ValueError
  • 测试 frozen 特性(创建后不可修改)

1.2 FactorContext - 计算上下文

实现要求:

@dataclass
class FactorContext:
    """
    因子计算上下文
    
    由 FactorEngine 自动注入,因子开发者可通过 data.context 访问
    
    Attributes:
        current_date: 当前计算日期 YYYYMMDD截面因子使用
        current_stock: 当前计算股票代码(时序因子使用)
        trade_dates: 交易日历列表(可选,用于对齐)
    """
    current_date: Optional[str] = None
    current_stock: Optional[str] = None
    trade_dates: Optional[List[str]] = None

测试需求:

  • 测试默认值创建
  • 测试完整参数创建
  • 测试 dataclass 自动生成的方法

1.3 FactorData - 数据容器

实现要求:

class FactorData:
    """
    提供给因子的数据容器
    
    封装底层 Polars DataFrame提供安全的数据访问接口
    """
    
    def __init__(self, df: pl.DataFrame, context: FactorContext):
        self._df = df
        self._context = context
    
    def get_column(self, col: str) -> pl.Series:
        """
        获取指定列的数据
        
        - 截面因子:获取当天所有股票的该列值
        - 时序因子:获取该股票时间序列的该列值
        
        Args:
            col: 列名
            
        Returns:
            Polars Series
            
        Raises:
            KeyError: 列不存在
        """
        pass
    
    def filter_by_date(self, date: str) -> "FactorData":
        """
        按日期过滤数据,返回新的 FactorData
        
        主要用于截面因子获取特定日期的数据
        
        Args:
            date: YYYYMMDD 格式的日期
            
        Returns:
            过滤后的 FactorData
        """
        pass
    
    def get_cross_section(self) -> pl.DataFrame:
        """
        获取当前日期的截面数据
        
        仅适用于截面因子,返回 current_date 当天的所有股票数据
        
        Returns:
            DataFrame 包含当前日期的所有股票
            
        Raises:
            ValueError: current_date 未设置(非截面因子场景)
        """
        pass
    
    def to_polars(self) -> pl.DataFrame:
        """获取底层的 Polars DataFrame高级用法"""
        pass
    
    @property
    def context(self) -> FactorContext:
        """获取计算上下文"""
        pass
    
    def __len__(self) -> int:
        """返回数据行数"""
        pass

测试需求:

  • 测试 get_column() 返回正确 Series
  • 测试 get_column() 列不存在时抛出 KeyError
  • 测试 filter_by_date() 返回正确过滤结果
  • 测试 filter_by_date() 日期不存在时返回空 DataFrame
  • 测试 get_cross_section() 返回 current_date 当天的数据
  • 测试 get_cross_section() current_date 为 None 时抛出 ValueError
  • 测试 to_polars() 返回原始 DataFrame
  • 测试 context 属性返回正确上下文
  • 测试 __len__() 返回正确行数

Phase 2: 因子基类 (base.py, composite.py)

2.1 BaseFactor - 抽象基类

实现要求:

class BaseFactor(ABC):
    """
    因子基类 - 定义通用接口
    
    所有因子必须继承此类,并声明以下类属性:
    - name: 因子唯一标识snake_case
    - factor_type: "cross_sectional" 或 "time_series"
    - data_specs: List[DataSpec] 数据需求列表
    
    可选声明:
    - category: 因子分类(默认 "default"
    - description: 因子描述
    """
    
    # 必须声明的类属性
    name: str = ""
    factor_type: str = ""  # "cross_sectional" | "time_series"
    data_specs: List[DataSpec] = field(default_factory=list)
    
    # 可选声明的类属性
    category: str = "default"
    description: str = ""
    
    def __init_subclass__(cls, **kwargs):
        """
        子类创建时验证必须属性
        
        验证项:
        1. name 必须是非空字符串
        2. factor_type 必须是 "cross_sectional" 或 "time_series"
        3. data_specs 必须是非空列表
        """
        pass
    
    def __init__(self, **params):
        """
        初始化因子参数
        
        子类可通过 __init__ 接收参数化配置,如 MA(period=20)
        """
        self.params = params
        self._validate_params()
    
    def _validate_params(self):
        """
        验证参数有效性
        
        子类可覆盖此方法进行自定义验证
        """
        pass
    
    @abstractmethod
    def compute(self, data: FactorData) -> pl.Series:
        """
        核心计算逻辑 - 子类必须实现
        
        Args:
            data: 安全的数据容器,已根据因子类型裁剪
            
        Returns:
            计算得到的因子值 Series
        """
        pass
    
    # ========== 因子组合运算符 ==========
    
    def __add__(self, other: "BaseFactor") -> "CompositeFactor":
        """因子相加f1 + f2要求同类型"""
        pass
    
    def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
        """因子相减f1 - f2要求同类型"""
        pass
    
    def __mul__(self, other: "BaseFactor") -> "CompositeFactor":
        """因子相乘f1 * f2要求同类型"""
        pass
    
    def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
        """因子相除f1 / f2要求同类型"""
        pass
    
    def __rmul__(self, scalar: float) -> "ScalarFactor":
        """标量乘法0.5 * f1"""
        pass

测试需求:

  • 测试有效子类创建通过验证
  • 测试缺少 name 时抛出 ValueError
  • 测试 name 为空字符串时抛出 ValueError
  • 测试缺少 factor_type 时抛出 ValueError
  • 测试无效的 factor_type(非 cs/ts时抛出 ValueError
  • 测试缺少 data_specs 时抛出 ValueError
  • 测试 data_specs 为空列表时抛出 ValueError
  • 测试 compute() 抽象方法强制子类实现
  • 测试参数化初始化 params 正确存储
  • 测试 _validate_params() 被调用

2.2 CrossSectionalFactor - 日期截面因子

实现要求:

class CrossSectionalFactor(BaseFactor):
    """
    日期截面因子基类
    
    计算逻辑:在每个交易日,对所有股票进行横向计算
    
    防泄露边界:
    - ❌ 禁止访问未来日期的数据(日期泄露)
    - ✅ 允许访问当前日期的所有股票数据
    
    数据传入:
    - compute() 接收的是 [T-lookback+1, T] 的数据
    - 包含 lookback_days 的历史数据(用于时序计算后再截面)
    """
    
    factor_type: str = "cross_sectional"
    
    @abstractmethod
    def compute(self, data: FactorData) -> pl.Series:
        """
        计算截面因子值
        
        Args:
            data: FactorData包含 [T-lookback+1, T] 的截面数据
                  格式DataFrame[ts_code, trade_date, col1, col2, ...]
                  
        Returns:
            pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
            
        示例:
            def compute(self, data):
                # 获取当前日期的截面
                cs = data.get_cross_section()
                # 计算市值排名
                return cs['market_cap'].rank()
        """
        pass

测试需求:

  • 测试 factor_type 自动设置为 "cross_sectional"
  • 测试子类必须实现 compute()
  • 测试 compute() 返回类型为 pl.Series

2.3 TimeSeriesFactor - 时间序列因子

实现要求:

class TimeSeriesFactor(BaseFactor):
    """
    时间序列因子基类(股票截面)
    
    计算逻辑:对每只股票,在其时间序列上进行纵向计算
    
    防泄露边界:
    - ❌ 禁止访问其他股票的数据(股票泄露)
    - ✅ 允许访问该股票的完整历史数据
    
    数据传入:
    - compute() 接收的是单只股票的完整时间序列
    - 包含该股票在 [start_date, end_date] 范围内的所有数据
    """
    
    factor_type: str = "time_series"
    
    @abstractmethod
    def compute(self, data: FactorData) -> pl.Series:
        """
        计算时间序列因子值
        
        Args:
            data: FactorData包含单只股票的完整时间序列
                  格式DataFrame[ts_code, trade_date, col1, col2, ...]
                  
        Returns:
            pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
            
        示例:
            def compute(self, data):
                series = data.get_column("close")
                return series.rolling_mean(window_size=self.params['period'])
        """
        pass

测试需求:

  • 测试 factor_type 自动设置为 "time_series"
  • 测试子类必须实现 compute()
  • 测试 compute() 返回类型为 pl.Series

2.4 CompositeFactor - 组合因子 (composite.py)

实现要求:

class CompositeFactor(BaseFactor):
    """
    组合因子 - 用于实现因子间的数学运算
    
    约束:左右因子必须是同类型(同为截面或同为时序)
    """
    
    def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
        """
        创建组合因子
        
        Args:
            left: 左操作数因子
            right: 右操作数因子
            op: 运算符,支持 '+', '-', '*', '/'
            
        Raises:
            ValueError: 左右因子类型不一致
            ValueError: 不支持的运算符
        """
        pass
    
    def _merge_data_specs(self) -> List[DataSpec]:
        """
        合并左右因子的数据需求
        
        策略:
        1. 相同 source 和 columns 的 DataSpec 合并
        2. lookback_days 取最大值
        """
        pass
    
    def compute(self, data: FactorData) -> pl.Series:
        """
        执行组合运算
        
        流程:
        1. 分别计算 left 和 right 的值
        2. 根据 op 执行运算
        3. 返回结果
        """
        pass

测试需求:

  • 测试同类型因子组合成功cs + cs
  • 测试同类型因子组合成功ts + ts
  • 测试不同类型因子组合抛出 ValueErrorcs + ts
  • 测试无效运算符抛出 ValueError
  • 测试 _merge_data_specs() 正确合并(相同 source
  • 测试 _merge_data_specs() 正确合并(不同 source
  • 测试 _merge_data_specs() lookback 取最大值
  • 测试 compute() 执行正确的数学运算

2.5 ScalarFactor - 标量运算因子 (composite.py)

实现要求:

class ScalarFactor(BaseFactor):
    """
    标量运算因子
    
    支持scalar * factor, factor * scalar通过 __rmul__
    """
    
    def __init__(self, factor: BaseFactor, scalar: float, op: str):
        """
        创建标量运算因子
        
        Args:
            factor: 基础因子
            scalar: 标量值
            op: 运算符,支持 '*', '+'
        """
        pass
    
    def compute(self, data: FactorData) -> pl.Series:
        """执行标量运算"""
        pass

测试需求:

  • 测试标量乘法 0.5 * factor
  • 测试标量乘法 factor * 0.5
  • 测试标量加法(如支持)
  • 测试继承基础因子的 data_specs
  • 测试 compute() 返回正确缩放后的值

Phase 3: 数据加载 (data_loader.py)

3.1 DataLoader - 数据加载器

实现要求:

class DataLoader:
    """
    数据加载器 - 负责从 HDF5 安全加载数据
    
    功能:
    1. 多文件聚合:合并多个 H5 文件的数据
    2. 列选择:只加载需要的列
    3. 原始数据缓存:避免重复读取
    """
    
    def __init__(self, data_dir: str):
        """
        初始化 DataLoader
        
        Args:
            data_dir: HDF5 文件所在目录
        """
        self.data_dir = Path(data_dir)
        self._cache: Dict[str, pl.DataFrame] = {}
    
    def load(
        self, 
        specs: List[DataSpec],
        date_range: Optional[Tuple[str, str]] = None
    ) -> pl.DataFrame:
        """
        加载并聚合多个 H5 文件的数据
        
        流程:
        1. 对每个 DataSpec
           a. 检查缓存,命中则直接使用
           b. 未命中则读取 HDF5通过 pandas
           c. 转换为 Polars DataFrame
           d. 按 date_range 过滤
           e. 存入缓存
        2. 合并多个 DataFrame按 trade_date 和 ts_code join
        
        Args:
            specs: 数据需求规格列表
            date_range: 日期范围限制 (start_date, end_date),可选
            
        Returns:
            合并后的 Polars DataFrame
            
        Raises:
            FileNotFoundError: H5 文件不存在
            KeyError: 列不存在于文件中
        """
        pass
    
    def clear_cache(self):
        """清空缓存"""
        pass
    
    def _read_h5(self, source: str) -> pl.DataFrame:
        """
        读取单个 H5 文件
        
        实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
        """
        pass

测试需求:

  • 测试从单个 H5 文件加载数据
  • 测试从多个 H5 文件加载并合并
  • 测试列选择(只加载需要的列)
  • 测试缓存机制(第二次加载更快)
  • 测试 clear_cache() 清空缓存
  • 测试按 date_range 过滤
  • 测试文件不存在时抛出 FileNotFoundError
  • 测试列不存在时抛出 KeyError

Phase 4: 执行引擎 (engine.py)

4.1 FactorEngine - 因子执行引擎

实现要求:

class FactorEngine:
    """
    因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
    
    核心职责:
    1. CrossSectionalFactor防止日期泄露每天传入 [T-lookback+1, T] 数据
    2. TimeSeriesFactor防止股票泄露每只股票传入完整序列
    """
    
    def __init__(self, data_loader: DataLoader):
        """
        初始化引擎
        
        Args:
            data_loader: 数据加载器实例
        """
        self.data_loader = data_loader
    
    def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
        """
        统一的计算入口
        
        根据 factor_type 分发到具体方法:
        - "cross_sectional" -> _compute_cross_sectional()
        - "time_series" -> _compute_time_series()
        
        Args:
            factor: 要计算的因子
            **kwargs: 额外参数,根据因子类型不同:
                - 截面因子: start_date, end_date
                - 时序因子: stock_codes, start_date, end_date
                
        Returns:
            DataFrame[trade_date, ts_code, factor_name]
        """
        pass

测试需求:

  • 测试 compute() 正确分发给截面计算
  • 测试 compute() 正确分发给时序计算
  • 测试无效 factor_type 时抛出 ValueError

4.2 截面计算(防止日期泄露)

实现要求:

def _compute_cross_sectional(
    self,
    factor: CrossSectionalFactor,
    start_date: str,
    end_date: str
) -> pl.DataFrame:
    """
    执行日期截面计算
    
    防泄露策略:
    - 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
    - 允许股票间比较:传入当天所有股票的数据
    
    计算流程:
    1. 计算 max_lookback确定数据起始日期
    2. 一次性加载 [start-max_lookback+1, end] 的所有数据
    3. 对每个日期 T in [start_date, end_date]
       a. 裁剪数据到 [T-lookback+1, T]
       b. 创建 FactorDatacurrent_date=T
       c. 调用 factor.compute()
       d. 收集结果
    4. 合并所有日期的结果
    
    返回 DataFrame 格式:
    ┌────────────┬──────────┬──────────────┐
    │ trade_date │ ts_code  │ factor_name  │
    ├────────────┼──────────┼──────────────┤
    │ 20240101   │ 000001.SZ│ 0.5          │
    │ 20240101   │ 000002.SZ│ 0.3          │
    └────────────┴──────────┴──────────────┘
    """
    pass

测试需求(防泄露验证):

  • 测试数据裁剪正确(传入 [T-lookback+1, T]
  • 测试不包含未来日期 T+1 的数据
  • 测试每个日期独立计算
  • 测试结果包含所有日期和所有股票
  • 测试结果 DataFrame 格式正确
  • 测试多个 DataSpec 时 lookback 取最大值

4.3 时序计算(防止股票泄露)

实现要求:

def _compute_time_series(
    self,
    factor: TimeSeriesFactor,
    stock_codes: List[str],
    start_date: str,
    end_date: str
) -> pl.DataFrame:
    """
    执行时间序列计算
    
    防泄露策略:
    - 防止股票泄露:每只股票单独计算,传入该股票的完整序列
    - 允许访问历史数据:时序计算需要历史数据
    
    计算流程:
    1. 计算 max_lookback确定数据起始日期
    2. 一次性加载 [start-max_lookback+1, end] 的所有数据
    3. 对每只股票 S in stock_codes
       a. 过滤出 S 的数据(防止股票泄露)
       b. 创建 FactorDatacurrent_stock=S
       c. 调用 factor.compute()(向量化计算整个序列)
       d. 收集结果
    4. 合并所有股票的结果
    
    性能优势:
    - 使用 Polars 的 rolling_mean 等向量化操作
    - 每只股票只计算一次,无重复计算
    
    返回 DataFrame 格式:
    ┌────────────┬──────────┬──────────────┐
    │ trade_date │ ts_code  │ factor_name  │
    ├────────────┼──────────┼──────────────┤
    │ 20240101   │ 000001.SZ│ 10.5         │
    │ 20240102   │ 000001.SZ│ 10.6         │
    └────────────┴──────────┴──────────────┘
    """
    pass

测试需求(防泄露验证):

  • 测试每只股票只看到自己的数据
  • 测试不包含其他股票的数据
  • 测试传入的是完整时间序列(向量化计算)
  • 测试结果包含所有股票和所有日期
  • 测试结果 DataFrame 格式正确
  • 测试股票不在数据中时跳过(或填充 null

Phase 5: 内置因子库 (builtin/)

5.1 momentum.py - 截面动量因子

实现因子:

  1. ReturnRankFactor - 当日收益率排名
class ReturnRankFactor(CrossSectionalFactor):
    """当日收益率排名因子"""
    name = "return_rank"
    data_specs = [DataSpec("daily", ["close"], lookback_days=2)]  # 需要2天计算收益率
    
    def compute(self, data):
        # 获取当前日期截面
        cs = data.get_cross_section()
        # 需要前1天和当天的收盘价lookback=2 保证数据包含 [T-1, T]
        # 这里假设 data 已经包含历史,实际计算需要 groupby 处理
        pass

测试需求:

  • 测试收益率计算正确
  • 测试排名计算正确
  • 测试无数据时返回 null
  1. MomentumFactor - 过去 N 日涨幅排名

5.2 technical.py - 时序技术指标

实现因子:

  1. MovingAverageFactor - 移动平均线
class MovingAverageFactor(TimeSeriesFactor):
    """移动平均线因子"""
    name = "ma"
    
    def __init__(self, period: int = 20):
        super().__init__(period=period)
        self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
    
    def compute(self, data):
        return data.get_column("close").rolling_mean(self.params["period"])

测试需求:

  • 测试 MA20 计算正确
  • 测试前19天返回 nullPolars 默认行为)
  • 测试参数 period 生效
  1. RSIFactor - RSI 指标
  2. MACDFactor - MACD 指标

5.3 value.py - 截面估值因子

实现因子:

  1. PERankFactor - PE 行业分位数
  2. PBFactor - PB 排名

Phase 6-7: 测试策略

测试金字塔

         /\
        /  \
       / 集成\     tests/factors/test_integration.py
      /────────\
     /   引擎    \   tests/factors/test_engine.py
    /────────────\
   /  基类/组合因子 \  tests/factors/test_base.py, test_composite.py
  /────────────────\
 /    数据加载/类型    \ tests/factors/test_data_loader.py, test_data_spec.py
/──────────────────────\

测试数据准备

创建 tests/fixtures/ 目录,包含:

  • sample_daily.h5: 少量股票的日线数据(用于测试)
  • sample_fundamental.h5: 基本面数据

关键测试场景

  1. 防泄露测试(核心)

    • 截面因子:验证 compute() 中无法访问未来日期
    • 时序因子:验证 compute() 中无法访问其他股票
  2. 边界测试

    • lookback_days = 1最小值
    • 数据起始点(前 N 天为 null
    • 空数据/停牌处理
  3. 性能测试(可选)

    • 大数据量下的内存占用
    • 缓存命中率

实现状态

Phase 状态 完成日期 测试覆盖
Phase 1: 数据类型定义 已完成 2026-02-21 27 tests passed
Phase 2: 因子基类 已完成 2026-02-21 49 tests passed
Phase 3: 数据加载 已完成 2026-02-21 11 tests passed
Phase 4: 执行引擎 已完成 2026-02-22 10 tests passed
Phase 5: 内置因子库 📝 待开发 - -
Phase 6-7: 测试文档 已完成 2026-02-22 76 tests total

实现顺序建议

  1. Week 1: Phase 1-2数据类型 + 基类)
  2. Week 2: Phase 3-4DataLoader + Engine 已完成
  3. Week 3: Phase 5内置因子
  4. Week 4: Phase 6-7测试 + 文档)

每个 Phase 完成后运行对应测试,确保质量。