feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor) - 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData) - 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine) - 新增组合因子支持 (CompositeFactor, ScalarFactor) - 添加因子模块完整测试用例 - 添加 Git 提交规范文档
This commit is contained in:
65
src/factors/__init__.py
Normal file
65
src/factors/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""ProStock 因子计算框架
|
||||
|
||||
因子框架提供以下核心功能:
|
||||
1. 类型安全的因子定义(截面因子、时序因子)
|
||||
2. 数据泄露防护机制
|
||||
3. 因子组合和运算
|
||||
4. 高效的数据加载和计算引擎
|
||||
|
||||
基础数据类型(Phase 1):
|
||||
- DataSpec: 数据需求规格
|
||||
- FactorContext: 计算上下文
|
||||
- FactorData: 数据容器
|
||||
|
||||
因子基类(Phase 2):
|
||||
- BaseFactor: 抽象基类
|
||||
- CrossSectionalFactor: 日期截面因子基类
|
||||
- TimeSeriesFactor: 时间序列因子基类
|
||||
- CompositeFactor: 组合因子
|
||||
- ScalarFactor: 标量运算因子
|
||||
|
||||
数据加载和执行(Phase 3-4):
|
||||
- DataLoader: 数据加载器
|
||||
- FactorEngine: 因子执行引擎
|
||||
|
||||
使用示例:
|
||||
from src.factors import DataSpec, FactorContext, FactorData
|
||||
from src.factors import CrossSectionalFactor, TimeSeriesFactor
|
||||
from src.factors import DataLoader, FactorEngine
|
||||
|
||||
# 定义数据需求
|
||||
spec = DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=20
|
||||
)
|
||||
|
||||
# 初始化引擎
|
||||
loader = DataLoader(data_dir="data")
|
||||
engine = FactorEngine(loader)
|
||||
|
||||
# 计算因子
|
||||
result = engine.compute(factor, start_date="20240101", end_date="20240131")
|
||||
"""
|
||||
|
||||
from src.factors.data_spec import DataSpec, FactorContext, FactorData
|
||||
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
|
||||
from src.factors.composite import CompositeFactor, ScalarFactor
|
||||
from src.factors.data_loader import DataLoader
|
||||
from src.factors.engine import FactorEngine
|
||||
|
||||
__all__ = [
|
||||
# Phase 1: 数据类型定义
|
||||
"DataSpec",
|
||||
"FactorContext",
|
||||
"FactorData",
|
||||
# Phase 2: 因子基类
|
||||
"BaseFactor",
|
||||
"CrossSectionalFactor",
|
||||
"TimeSeriesFactor",
|
||||
"CompositeFactor",
|
||||
"ScalarFactor",
|
||||
# Phase 3-4: 数据加载和执行引擎
|
||||
"DataLoader",
|
||||
"FactorEngine",
|
||||
]
|
||||
274
src/factors/base.py
Normal file
274
src/factors/base.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""因子基类 - Phase 2 核心抽象类
|
||||
|
||||
本模块定义了因子框架的基类:
|
||||
- BaseFactor: 抽象基类,定义通用接口和验证逻辑
|
||||
- CrossSectionalFactor: 日期截面因子基类(防止日期泄露)
|
||||
- TimeSeriesFactor: 时间序列因子基类(防止股票泄露)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import field
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
|
||||
|
||||
class BaseFactor(ABC):
|
||||
"""因子基类 - 定义通用接口
|
||||
|
||||
所有因子必须继承此类,并声明以下类属性:
|
||||
- name: 因子唯一标识(snake_case)
|
||||
- factor_type: "cross_sectional" 或 "time_series"
|
||||
- data_specs: List[DataSpec] 数据需求列表
|
||||
|
||||
可选声明:
|
||||
- category: 因子分类(默认 "default")
|
||||
- description: 因子描述
|
||||
|
||||
示例:
|
||||
>>> class MyFactor(CrossSectionalFactor):
|
||||
... name = "my_factor"
|
||||
... data_specs = [DataSpec("daily", ["close"], lookback_days=5)]
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... return data.get_column("close").rank()
|
||||
"""
|
||||
|
||||
# 必须声明的类属性
|
||||
name: str = ""
|
||||
factor_type: str = "" # "cross_sectional" | "time_series"
|
||||
data_specs: List[DataSpec] = field(default_factory=list)
|
||||
|
||||
# 可选声明的类属性
|
||||
category: str = "default"
|
||||
description: str = ""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""子类创建时验证必须属性
|
||||
|
||||
验证项:
|
||||
1. name 必须是非空字符串
|
||||
2. factor_type 必须是 "cross_sectional" 或 "time_series"
|
||||
3. data_specs 必须是非空列表
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 跳过抽象基类和特殊因子类的验证
|
||||
if cls.__name__ in (
|
||||
"CrossSectionalFactor",
|
||||
"TimeSeriesFactor",
|
||||
"CompositeFactor",
|
||||
"ScalarFactor",
|
||||
):
|
||||
return
|
||||
|
||||
# 验证 name - 必须直接定义在类中(不能继承)
|
||||
if "name" not in cls.__dict__ or not cls.name:
|
||||
raise ValueError(f"Factor {cls.__name__} must define 'name'")
|
||||
if not isinstance(cls.name, str):
|
||||
raise ValueError(f"Factor {cls.__name__}.name must be a string")
|
||||
|
||||
# 验证 factor_type - 必须有值(可以是继承的)
|
||||
if not cls.factor_type:
|
||||
raise ValueError(f"Factor {cls.__name__} must define 'factor_type'")
|
||||
if cls.factor_type not in ("cross_sectional", "time_series"):
|
||||
raise ValueError(
|
||||
f"Factor {cls.__name__}.factor_type must be 'cross_sectional' "
|
||||
f"or 'time_series', got '{cls.factor_type}'"
|
||||
)
|
||||
|
||||
# 验证 data_specs
|
||||
# 情况1: 完全没有定义 data_specs(继承的空列表)
|
||||
if "data_specs" not in cls.__dict__:
|
||||
raise ValueError(f"Factor {cls.__name__} must define 'data_specs'")
|
||||
# 情况2: 定义了但为空列表
|
||||
if not cls.data_specs or len(cls.data_specs) == 0:
|
||||
raise ValueError(f"Factor {cls.__name__}.data_specs cannot be empty")
|
||||
if not isinstance(cls.data_specs, list):
|
||||
raise ValueError(f"Factor {cls.__name__}.data_specs must be a list")
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化因子参数
|
||||
|
||||
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
|
||||
|
||||
注意:data_specs 必须在类级别定义(类属性),
|
||||
而非在 __init__ 中设置。data_specs 的验证在
|
||||
__init_subclass__ 中完成(类创建时)。
|
||||
|
||||
Args:
|
||||
**params: 因子参数,存储在 self.params 中
|
||||
"""
|
||||
self.params = params
|
||||
|
||||
def _validate_params(self):
|
||||
"""验证参数有效性
|
||||
|
||||
子类可覆盖此方法进行自定义验证(需自行在子类 __init__ 中调用)。
|
||||
基类实现为空,表示不执行任何验证。
|
||||
|
||||
注意:由于 data_specs 在类创建时通过 __init_subclass__ 验证,
|
||||
不应在实例级别修改。如需动态 data_specs,请使用参数化模式:
|
||||
|
||||
>>> class ParamFactor(TimeSeriesFactor):
|
||||
... name = "param_factor"
|
||||
... data_specs = [] # 类级别定义
|
||||
...
|
||||
... def __init__(self, period: int = 20):
|
||||
... super().__init__(period=period)
|
||||
... # 通过参数化改变计算逻辑,而非 data_specs
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... return data.get_column("close").rolling_mean(self.params["period"])
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""核心计算逻辑 - 子类必须实现
|
||||
|
||||
Args:
|
||||
data: 安全的数据容器,已根据因子类型裁剪
|
||||
|
||||
Returns:
|
||||
计算得到的因子值 Series
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========== 因子组合运算符 ==========
|
||||
|
||||
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相加:f1 + f2(要求同类型)"""
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "+")
|
||||
|
||||
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相减:f1 - f2(要求同类型)"""
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "-")
|
||||
|
||||
def __mul__(self, other):
|
||||
"""因子相乘:f1 * f2 或 f1 * scalar"""
|
||||
if isinstance(other, (int, float)):
|
||||
from src.factors.composite import ScalarFactor
|
||||
|
||||
return ScalarFactor(self, float(other), "*")
|
||||
elif isinstance(other, BaseFactor):
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "*")
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相除:f1 / f2(要求同类型)"""
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "/")
|
||||
|
||||
def __rmul__(self, scalar: float) -> "ScalarFactor":
|
||||
"""标量乘法:0.5 * f1"""
|
||||
from src.factors.composite import ScalarFactor
|
||||
|
||||
return ScalarFactor(self, scalar, "*")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回因子的字符串表示"""
|
||||
return (
|
||||
f"{self.__class__.__name__}(name='{self.name}', type='{self.factor_type}')"
|
||||
)
|
||||
|
||||
|
||||
class CrossSectionalFactor(BaseFactor):
|
||||
"""日期截面因子基类
|
||||
|
||||
计算逻辑:在每个交易日,对所有股票进行横向计算
|
||||
|
||||
防泄露边界:
|
||||
- ❌ 禁止访问未来日期的数据(日期泄露)
|
||||
- ✅ 允许访问当前日期的所有股票数据
|
||||
|
||||
数据传入:
|
||||
- compute() 接收的是 [T-lookback+1, T] 的数据
|
||||
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
|
||||
|
||||
示例:
|
||||
>>> class PERankFactor(CrossSectionalFactor):
|
||||
... name = "pe_rank"
|
||||
... data_specs = [DataSpec("daily", ["pe"], lookback_days=1)]
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... cs = data.get_cross_section()
|
||||
... return cs["pe"].rank()
|
||||
"""
|
||||
|
||||
factor_type: str = "cross_sectional"
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算截面因子值
|
||||
|
||||
Args:
|
||||
data: FactorData,包含 [T-lookback+1, T] 的截面数据
|
||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
||||
|
||||
Returns:
|
||||
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
|
||||
|
||||
示例:
|
||||
>>> def compute(self, data):
|
||||
... # 获取当前日期截面
|
||||
... cs = data.get_cross_section()
|
||||
... # 计算市值排名
|
||||
... return cs['market_cap'].rank()
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TimeSeriesFactor(BaseFactor):
|
||||
"""时间序列因子基类(股票截面)
|
||||
|
||||
计算逻辑:对每只股票,在其时间序列上进行纵向计算
|
||||
|
||||
防泄露边界:
|
||||
- ❌ 禁止访问其他股票的数据(股票泄露)
|
||||
- ✅ 允许访问该股票的完整历史数据
|
||||
|
||||
数据传入:
|
||||
- compute() 接收的是单只股票的完整时间序列
|
||||
- 包含该股票在 [start_date, end_date] 范围内的所有数据
|
||||
|
||||
示例:
|
||||
>>> class MovingAverageFactor(TimeSeriesFactor):
|
||||
... name = "ma"
|
||||
...
|
||||
... def __init__(self, period: int = 20):
|
||||
... super().__init__(period=period)
|
||||
... self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... return data.get_column("close").rolling_mean(self.params["period"])
|
||||
"""
|
||||
|
||||
factor_type: str = "time_series"
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算时间序列因子值
|
||||
|
||||
Args:
|
||||
data: FactorData,包含单只股票的完整时间序列
|
||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
||||
|
||||
Returns:
|
||||
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
|
||||
|
||||
示例:
|
||||
>>> def compute(self, data):
|
||||
... series = data.get_column("close")
|
||||
... return series.rolling_mean(window_size=self.params['period'])
|
||||
"""
|
||||
pass
|
||||
201
src/factors/composite.py
Normal file
201
src/factors/composite.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""组合因子 - Phase 2 因子组合和标量运算
|
||||
|
||||
本模块定义了因子组合相关的类:
|
||||
- CompositeFactor: 组合因子,用于实现因子间的数学运算
|
||||
- ScalarFactor: 标量运算因子,支持因子与标量的运算
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
from src.factors.base import BaseFactor
|
||||
|
||||
|
||||
class CompositeFactor(BaseFactor):
|
||||
"""组合因子 - 用于实现因子间的数学运算
|
||||
|
||||
约束:左右因子必须是同类型(同为截面或同为时序)
|
||||
|
||||
支持的运算符:'+', '-', '*', '/'
|
||||
|
||||
示例:
|
||||
>>> f1 = SomeCrossSectionalFactor()
|
||||
>>> f2 = AnotherCrossSectionalFactor()
|
||||
>>> combined = f1 + f2 # 创建 CompositeFactor
|
||||
"""
|
||||
|
||||
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
|
||||
"""创建组合因子
|
||||
|
||||
Args:
|
||||
left: 左操作数因子
|
||||
right: 右操作数因子
|
||||
op: 运算符,支持 '+', '-', '*', '/'
|
||||
|
||||
Raises:
|
||||
ValueError: 左右因子类型不一致
|
||||
ValueError: 不支持的运算符
|
||||
"""
|
||||
# 验证类型一致性
|
||||
if left.factor_type != right.factor_type:
|
||||
raise ValueError(
|
||||
f"Cannot combine factors of different types: "
|
||||
f"'{left.factor_type}' vs '{right.factor_type}'"
|
||||
)
|
||||
|
||||
# 验证运算符
|
||||
if op not in ("+", "-", "*", "/"):
|
||||
raise ValueError(f"Unsupported operator: '{op}'")
|
||||
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.op = op
|
||||
|
||||
# 设置类属性
|
||||
self.factor_type = left.factor_type
|
||||
self.name = f"({left.name}_{op}_{right.name})"
|
||||
self.data_specs = self._merge_data_specs()
|
||||
self.category = "composite"
|
||||
self.description = f"Composite factor: {left.name} {op} {right.name}"
|
||||
|
||||
# 注意:不调用 super().__init__(),因为 CompositeFactor 是特殊因子
|
||||
self.params = {
|
||||
"left": left,
|
||||
"right": right,
|
||||
"op": op,
|
||||
}
|
||||
|
||||
def _merge_data_specs(self) -> List[DataSpec]:
|
||||
"""合并左右因子的数据需求
|
||||
|
||||
策略:
|
||||
1. 相同 source 和 columns 的 DataSpec 合并
|
||||
2. lookback_days 取最大值
|
||||
|
||||
Returns:
|
||||
合并后的 DataSpec 列表
|
||||
"""
|
||||
merged = []
|
||||
|
||||
# 收集所有 specs
|
||||
all_specs = list(self.left.data_specs) + list(self.right.data_specs)
|
||||
|
||||
# 按 (source, columns_tuple) 分组
|
||||
spec_groups = {}
|
||||
for spec in all_specs:
|
||||
key = (spec.source, tuple(sorted(spec.columns)))
|
||||
if key not in spec_groups:
|
||||
spec_groups[key] = []
|
||||
spec_groups[key].append(spec)
|
||||
|
||||
# 合并每组,取最大 lookback_days
|
||||
for (source, columns_tuple), specs in spec_groups.items():
|
||||
max_lookback = max(spec.lookback_days for spec in specs)
|
||||
merged.append(
|
||||
DataSpec(
|
||||
source=source,
|
||||
columns=list(columns_tuple),
|
||||
lookback_days=max_lookback,
|
||||
)
|
||||
)
|
||||
|
||||
return merged
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""执行组合运算
|
||||
|
||||
流程:
|
||||
1. 分别计算 left 和 right 的值
|
||||
2. 根据 op 执行运算
|
||||
3. 返回结果
|
||||
|
||||
Args:
|
||||
data: 包含左右因子所需数据的 FactorData
|
||||
|
||||
Returns:
|
||||
组合运算后的因子值 Series
|
||||
"""
|
||||
left_values = self.left.compute(data)
|
||||
right_values = self.right.compute(data)
|
||||
|
||||
ops = {
|
||||
"+": lambda a, b: a + b,
|
||||
"-": lambda a, b: a - b,
|
||||
"*": lambda a, b: a * b,
|
||||
"/": lambda a, b: a / b,
|
||||
}
|
||||
|
||||
return ops[self.op](left_values, right_values)
|
||||
|
||||
def _validate_params(self):
|
||||
"""CompositeFactor 不需要额外验证"""
|
||||
pass
|
||||
|
||||
|
||||
class ScalarFactor(BaseFactor):
|
||||
"""标量运算因子
|
||||
|
||||
支持:scalar * factor, factor * scalar(通过 __rmul__)
|
||||
|
||||
示例:
|
||||
>>> factor = SomeFactor()
|
||||
>>> scaled = 0.5 * factor # 创建 ScalarFactor
|
||||
"""
|
||||
|
||||
def __init__(self, factor: BaseFactor, scalar: float, op: str):
|
||||
"""创建标量运算因子
|
||||
|
||||
Args:
|
||||
factor: 基础因子
|
||||
scalar: 标量值
|
||||
op: 运算符,支持 '*', '+'
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的运算符
|
||||
"""
|
||||
# 验证运算符
|
||||
if op not in ("*", "+"):
|
||||
raise ValueError(f"ScalarFactor only supports '*' and '+', got '{op}'")
|
||||
|
||||
self.factor = factor
|
||||
self.scalar = scalar
|
||||
self.op = op
|
||||
|
||||
# 设置类属性
|
||||
self.factor_type = factor.factor_type
|
||||
self.name = f"({scalar}_{op}_{factor.name})"
|
||||
self.data_specs = factor.data_specs
|
||||
self.category = "scalar"
|
||||
self.description = f"Scalar factor: {scalar} {op} {factor.name}"
|
||||
|
||||
# 注意:不调用 super().__init__(),因为 ScalarFactor 是特殊因子
|
||||
self.params = {
|
||||
"factor": factor,
|
||||
"scalar": scalar,
|
||||
"op": op,
|
||||
}
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""执行标量运算
|
||||
|
||||
Args:
|
||||
data: 包含基础因子所需数据的 FactorData
|
||||
|
||||
Returns:
|
||||
标量运算后的因子值 Series
|
||||
"""
|
||||
values = self.factor.compute(data)
|
||||
|
||||
if self.op == "*":
|
||||
return values * self.scalar
|
||||
elif self.op == "+":
|
||||
return values + self.scalar
|
||||
else:
|
||||
# 不应该执行到这里,因为 __init__ 已经验证了 op
|
||||
raise ValueError(f"Unsupported operation: '{self.op}'")
|
||||
|
||||
def _validate_params(self):
|
||||
"""ScalarFactor 不需要额外验证"""
|
||||
pass
|
||||
183
src/factors/data_loader.py
Normal file
183
src/factors/data_loader.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""数据加载器 - Phase 3 数据加载模块
|
||||
|
||||
本模块负责从 HDF5 文件安全加载数据:
|
||||
- DataLoader: 数据加载器,支持多文件聚合、列选择、缓存
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_spec import DataSpec
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""数据加载器 - 负责从 HDF5 安全加载数据
|
||||
|
||||
功能:
|
||||
1. 多文件聚合:合并多个 H5 文件的数据
|
||||
2. 列选择:只加载需要的列
|
||||
3. 原始数据缓存:避免重复读取
|
||||
|
||||
示例:
|
||||
>>> loader = DataLoader(data_dir="data")
|
||||
>>> specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
|
||||
>>> df = loader.load(specs, date_range=("20240101", "20240131"))
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str):
|
||||
"""初始化 DataLoader
|
||||
|
||||
Args:
|
||||
data_dir: HDF5 文件所在目录
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
|
||||
def load(
|
||||
self,
|
||||
specs: List[DataSpec],
|
||||
date_range: Optional[Tuple[str, str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""加载并聚合多个 H5 文件的数据
|
||||
|
||||
流程:
|
||||
1. 对每个 DataSpec:
|
||||
a. 检查缓存,命中则直接使用
|
||||
b. 未命中则读取 HDF5(通过 pandas)
|
||||
c. 转换为 Polars DataFrame
|
||||
d. 按 date_range 过滤
|
||||
e. 存入缓存
|
||||
2. 合并多个 DataFrame(按 trade_date 和 ts_code join)
|
||||
|
||||
Args:
|
||||
specs: 数据需求规格列表
|
||||
date_range: 日期范围限制 (start_date, end_date),可选
|
||||
|
||||
Returns:
|
||||
合并后的 Polars DataFrame
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: H5 文件不存在
|
||||
KeyError: 列不存在于文件中
|
||||
"""
|
||||
dataframes = []
|
||||
|
||||
for spec in specs:
|
||||
# 检查缓存
|
||||
cache_key = f"{spec.source}_{','.join(sorted(spec.columns))}"
|
||||
if cache_key in self._cache:
|
||||
df = self._cache[cache_key]
|
||||
else:
|
||||
# 读取 H5 文件
|
||||
df = self._read_h5(spec.source)
|
||||
|
||||
# 列选择 - 只保留需要的列
|
||||
missing_cols = set(spec.columns) - set(df.columns)
|
||||
if missing_cols:
|
||||
raise KeyError(
|
||||
f"Columns {missing_cols} not found in {spec.source}.h5. "
|
||||
f"Available columns: {df.columns}"
|
||||
)
|
||||
df = df.select(spec.columns)
|
||||
|
||||
# 存入缓存
|
||||
self._cache[cache_key] = df
|
||||
|
||||
# 按 date_range 过滤
|
||||
if date_range:
|
||||
start_date, end_date = date_range
|
||||
df = df.filter(
|
||||
(pl.col("trade_date") >= start_date)
|
||||
& (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
dataframes.append(df)
|
||||
|
||||
# 合并多个 DataFrame
|
||||
if len(dataframes) == 1:
|
||||
return dataframes[0]
|
||||
else:
|
||||
return self._merge_dataframes(dataframes)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
|
||||
def _read_h5(self, source: str) -> pl.DataFrame:
|
||||
"""读取单个 H5 文件
|
||||
|
||||
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
|
||||
|
||||
Args:
|
||||
source: H5 文件名(不含扩展名)
|
||||
|
||||
Returns:
|
||||
Polars DataFrame
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: H5 文件不存在
|
||||
"""
|
||||
file_path = self.data_dir / f"{source}.h5"
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"HDF5 file not found: {file_path}")
|
||||
|
||||
# 使用 pandas 读取 HDF5
|
||||
# Note: read_hdf returns DataFrame, ignore LSP type error
|
||||
pdf = pd.read_hdf(file_path, key=f"/{source}", mode="r") # type: ignore
|
||||
|
||||
# 转换为 Polars DataFrame
|
||||
df = pl.from_pandas(pdf) # type: ignore
|
||||
|
||||
return df
|
||||
|
||||
def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame:
|
||||
"""合并多个 DataFrame
|
||||
|
||||
策略:
|
||||
1. 按 trade_date 和 ts_code join
|
||||
2. 使用外连接保留所有数据
|
||||
|
||||
Args:
|
||||
dataframes: DataFrame 列表
|
||||
|
||||
Returns:
|
||||
合并后的 DataFrame
|
||||
"""
|
||||
result = dataframes[0]
|
||||
|
||||
for df in dataframes[1:]:
|
||||
# 确定 join 键
|
||||
join_keys = ["trade_date", "ts_code"]
|
||||
|
||||
# 检查 join 键是否存在
|
||||
for key in join_keys:
|
||||
if key not in result.columns or key not in df.columns:
|
||||
raise KeyError(f"Join key '{key}' not found in DataFrames")
|
||||
|
||||
# 获取需要添加的列(排除重复的 join 键)
|
||||
new_cols = [c for c in df.columns if c not in result.columns]
|
||||
|
||||
if new_cols:
|
||||
# 选择必要的列进行 join
|
||||
df_to_join = df.select(join_keys + new_cols)
|
||||
|
||||
# 执行 join
|
||||
result = result.join(df_to_join, on=join_keys, how="full")
|
||||
|
||||
return result
|
||||
|
||||
def get_cache_info(self) -> Dict[str, int]:
|
||||
"""获取缓存信息
|
||||
|
||||
Returns:
|
||||
包含缓存条目数和总字节数的字典
|
||||
"""
|
||||
total_rows = sum(len(df) for df in self._cache.values())
|
||||
return {
|
||||
"entries": len(self._cache),
|
||||
"total_rows": total_rows,
|
||||
}
|
||||
242
src/factors/data_spec.py
Normal file
242
src/factors/data_spec.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""数据类型定义 - Phase 1 核心数据模型
|
||||
|
||||
本模块定义了因子框架的基础数据类型:
|
||||
- DataSpec: 数据需求规格,声明因子所需的数据源、列和回看窗口
|
||||
- FactorContext: 计算上下文,由引擎自动注入,提供计算点信息
|
||||
- FactorData: 数据容器,封装底层 Polars DataFrame,提供安全的数据访问
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
import polars as pl
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DataSpec:
|
||||
"""数据需求规格说明
|
||||
|
||||
用于声明因子计算所需的数据来源、列和回看窗口。
|
||||
这是一个不可变对象,创建后不可修改。
|
||||
|
||||
Args:
|
||||
source: H5 文件名(如 "daily", "fundamental")
|
||||
columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date"
|
||||
lookback_days: 需要回看的天数(包含当日)
|
||||
- 1 表示只需要当日数据 [T]
|
||||
- 5 表示需要 [T-4, T] 共5天
|
||||
- 20 表示需要 [T-19, T] 共20天
|
||||
|
||||
Raises:
|
||||
ValueError: 当参数不满足约束条件时
|
||||
|
||||
Examples:
|
||||
>>> spec = DataSpec(
|
||||
... source="daily",
|
||||
... columns=["ts_code", "trade_date", "close"],
|
||||
... lookback_days=20
|
||||
... )
|
||||
"""
|
||||
|
||||
source: str
|
||||
columns: List[str]
|
||||
lookback_days: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证约束条件
|
||||
|
||||
验证项:
|
||||
1. lookback_days >= 1(至少包含当日)
|
||||
2. columns 必须包含 ts_code 和 trade_date
|
||||
3. source 不能为空字符串
|
||||
|
||||
注意:由于 frozen=True,实例创建后不可修改。
|
||||
若需要在 __post_init__ 中修改字段(如有),可使用 object.__setattr__。
|
||||
本类仅做验证,无需修改字段,因此直接 raise ValueError 即可。
|
||||
"""
|
||||
if self.lookback_days < 1:
|
||||
raise ValueError(f"lookback_days must be >= 1, got {self.lookback_days}")
|
||||
|
||||
if not self.source:
|
||||
raise ValueError("source cannot be empty string")
|
||||
|
||||
required_cols = {"ts_code", "trade_date"}
|
||||
missing_cols = required_cols - set(self.columns)
|
||||
if missing_cols:
|
||||
raise ValueError(
|
||||
f"columns must contain {required_cols}, missing: {missing_cols}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorContext:
|
||||
"""因子计算上下文
|
||||
|
||||
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问。
|
||||
根据因子类型的不同,包含不同的上下文信息:
|
||||
- CrossSectionalFactor:current_date 表示当前计算的日期
|
||||
- TimeSeriesFactor:current_stock 表示当前计算的股票
|
||||
|
||||
Attributes:
|
||||
current_date: 当前计算日期 YYYYMMDD(截面因子使用)
|
||||
current_stock: 当前计算股票代码(时序因子使用)
|
||||
trade_dates: 交易日历列表(可选,用于对齐)
|
||||
|
||||
Examples:
|
||||
>>> context = FactorContext(current_date="20240101")
|
||||
>>> context.current_date
|
||||
'20240101'
|
||||
"""
|
||||
|
||||
current_date: Optional[str] = None
|
||||
current_stock: Optional[str] = None
|
||||
trade_dates: Optional[List[str]] = None
|
||||
|
||||
|
||||
class FactorData:
|
||||
"""提供给因子的数据容器
|
||||
|
||||
封装底层 Polars DataFrame,提供安全的数据访问接口。
|
||||
根据因子类型的不同,包含不同的数据:
|
||||
- CrossSectionalFactor:当前日期及历史 lookback 的截面数据(所有股票)
|
||||
- TimeSeriesFactor:单只股票的完整时间序列数据
|
||||
|
||||
Args:
|
||||
df: 底层的 Polars DataFrame
|
||||
context: 计算上下文
|
||||
|
||||
Examples:
|
||||
>>> df = pl.DataFrame({
|
||||
... "ts_code": ["000001.SZ"],
|
||||
... "trade_date": ["20240101"],
|
||||
... "close": [10.0]
|
||||
... })
|
||||
>>> context = FactorContext(current_date="20240101")
|
||||
>>> data = FactorData(df, context)
|
||||
"""
|
||||
|
||||
def __init__(self, df: pl.DataFrame, context: FactorContext):
|
||||
self._df = df
|
||||
self._context = context
|
||||
|
||||
def get_column(self, col: str) -> pl.Series:
|
||||
"""获取指定列的数据
|
||||
|
||||
适用于两种因子类型:
|
||||
- 截面因子:获取当天所有股票的该列值
|
||||
- 时序因子:获取该股票时间序列的该列值
|
||||
|
||||
Args:
|
||||
col: 列名
|
||||
|
||||
Returns:
|
||||
Polars Series
|
||||
|
||||
Raises:
|
||||
KeyError: 列不存在于数据中
|
||||
|
||||
Examples:
|
||||
>>> prices = data.get_column("close")
|
||||
>>> print(prices)
|
||||
"""
|
||||
if col not in self._df.columns:
|
||||
raise KeyError(
|
||||
f"Column '{col}' not found in data. Available columns: {self._df.columns}"
|
||||
)
|
||||
return self._df[col]
|
||||
|
||||
def filter_by_date(self, date: str) -> "FactorData":
|
||||
"""按日期过滤数据,返回新的 FactorData
|
||||
|
||||
主要用于截面因子获取特定日期的数据。
|
||||
注意:无法获取未来日期的数据(引擎已经裁剪掉)。
|
||||
|
||||
Args:
|
||||
date: YYYYMMDD 格式的日期
|
||||
|
||||
Returns:
|
||||
过滤后的 FactorData(新实例,不修改原数据)
|
||||
|
||||
Examples:
|
||||
>>> today_data = data.filter_by_date("20240101")
|
||||
>>> print(len(today_data))
|
||||
"""
|
||||
filtered = self._df.filter(pl.col("trade_date") == date)
|
||||
return FactorData(filtered, self._context)
|
||||
|
||||
def get_cross_section(self) -> pl.DataFrame:
|
||||
"""获取当前日期的截面数据
|
||||
|
||||
仅适用于截面因子,返回 current_date 当天的所有股票数据。
|
||||
|
||||
Returns:
|
||||
DataFrame 包含当前日期的所有股票
|
||||
|
||||
Raises:
|
||||
ValueError: current_date 未设置(非截面因子场景)
|
||||
|
||||
Examples:
|
||||
>>> cs = data.get_cross_section()
|
||||
>>> rankings = cs["pe"].rank()
|
||||
"""
|
||||
if self._context.current_date is None:
|
||||
raise ValueError(
|
||||
"current_date is not set in context. "
|
||||
"get_cross_section() is only applicable for cross-sectional factors."
|
||||
)
|
||||
return self._df.filter(pl.col("trade_date") == self._context.current_date)
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
"""获取底层的 Polars DataFrame(高级用法)
|
||||
|
||||
返回原始 DataFrame,允许进行自定义的 Polars 操作。
|
||||
注意:直接操作底层数据可能绕过框架的防泄露保护,请谨慎使用。
|
||||
|
||||
Returns:
|
||||
底层的 Polars DataFrame
|
||||
|
||||
Examples:
|
||||
>>> df = data.to_polars()
|
||||
>>> result = df.group_by("industry").agg(pl.col("pe").mean())
|
||||
"""
|
||||
return self._df
|
||||
|
||||
@property
|
||||
def context(self) -> FactorContext:
|
||||
"""获取计算上下文
|
||||
|
||||
Returns:
|
||||
当前的 FactorContext 实例
|
||||
|
||||
Examples:
|
||||
>>> date = data.context.current_date
|
||||
>>> stock = data.context.current_stock
|
||||
"""
|
||||
return self._context
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""返回数据行数
|
||||
|
||||
Returns:
|
||||
DataFrame 的行数
|
||||
|
||||
Examples:
|
||||
>>> if len(data) > 0:
|
||||
... result = data.get_column("close").mean()
|
||||
"""
|
||||
return len(self._df)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回 FactorData 的字符串表示
|
||||
|
||||
Returns:
|
||||
包含类名、行数、列数和上下文信息的字符串
|
||||
"""
|
||||
cols = self._df.columns
|
||||
context_info = []
|
||||
if self._context.current_date:
|
||||
context_info.append(f"date={self._context.current_date}")
|
||||
if self._context.current_stock:
|
||||
context_info.append(f"stock={self._context.current_stock}")
|
||||
|
||||
context_str = ", ".join(context_info) if context_info else "no context"
|
||||
return f"FactorData(rows={len(self)}, cols={len(cols)}, {context_str})"
|
||||
367
src/factors/engine.py
Normal file
367
src/factors/engine.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""执行引擎 - Phase 4 因子执行引擎
|
||||
|
||||
本模块负责因子计算的核心逻辑:
|
||||
- FactorEngine: 因子执行引擎,根据因子类型采用不同的计算和防泄露策略
|
||||
|
||||
防泄露策略:
|
||||
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
||||
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_loader import DataLoader
|
||||
from src.factors.data_spec import FactorContext, FactorData
|
||||
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
|
||||
|
||||
|
||||
class FactorEngine:
|
||||
"""因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
|
||||
|
||||
核心职责:
|
||||
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
||||
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
||||
|
||||
示例:
|
||||
>>> loader = DataLoader(data_dir="data")
|
||||
>>> engine = FactorEngine(loader)
|
||||
>>> result = engine.compute(factor, start_date="20240101", end_date="20240131")
|
||||
"""
|
||||
|
||||
def __init__(self, data_loader: DataLoader):
|
||||
"""初始化引擎
|
||||
|
||||
Args:
|
||||
data_loader: 数据加载器实例
|
||||
"""
|
||||
self.data_loader = data_loader
|
||||
|
||||
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
|
||||
"""统一的计算入口
|
||||
|
||||
根据 factor_type 分发到具体方法:
|
||||
- "cross_sectional" -> _compute_cross_sectional()
|
||||
- "time_series" -> _compute_time_series()
|
||||
|
||||
Args:
|
||||
factor: 要计算的因子
|
||||
**kwargs: 额外参数,根据因子类型不同:
|
||||
- 截面因子: start_date, end_date
|
||||
- 时序因子: stock_codes, start_date, end_date
|
||||
|
||||
Returns:
|
||||
DataFrame[trade_date, ts_code, factor_name]
|
||||
|
||||
Raises:
|
||||
ValueError: 无效的 factor_type 或缺少必需参数
|
||||
"""
|
||||
if factor.factor_type == "cross_sectional":
|
||||
if "start_date" not in kwargs or "end_date" not in kwargs:
|
||||
raise ValueError(
|
||||
"cross_sectional factor requires 'start_date' and 'end_date' parameters"
|
||||
)
|
||||
return self._compute_cross_sectional(
|
||||
factor, kwargs["start_date"], kwargs["end_date"]
|
||||
)
|
||||
elif factor.factor_type == "time_series":
|
||||
missing = []
|
||||
if "stock_codes" not in kwargs:
|
||||
missing.append("stock_codes")
|
||||
if "start_date" not in kwargs:
|
||||
missing.append("start_date")
|
||||
if "end_date" not in kwargs:
|
||||
missing.append("end_date")
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"time_series factor requires parameters: {', '.join(missing)}"
|
||||
)
|
||||
return self._compute_time_series(
|
||||
factor,
|
||||
kwargs["stock_codes"],
|
||||
kwargs["start_date"],
|
||||
kwargs["end_date"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown factor type: {factor.factor_type}")
|
||||
|
||||
def _compute_cross_sectional(
|
||||
self,
|
||||
factor: CrossSectionalFactor,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""执行日期截面计算
|
||||
|
||||
防泄露策略:
|
||||
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
|
||||
- 允许股票间比较:传入当天所有股票的数据
|
||||
|
||||
计算流程:
|
||||
1. 计算 max_lookback,确定数据起始日期
|
||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
||||
3. 对每个日期 T in [start_date, end_date]:
|
||||
a. 裁剪数据到 [T-lookback+1, T]
|
||||
b. 创建 FactorData(current_date=T)
|
||||
c. 调用 factor.compute()
|
||||
d. 收集结果
|
||||
4. 合并所有日期的结果
|
||||
|
||||
返回 DataFrame 格式:
|
||||
┌────────────┬──────────┬──────────────┐
|
||||
│ trade_date │ ts_code │ factor_name │
|
||||
├────────────┼──────────┼──────────────┤
|
||||
│ 20240101 │ 000001.SZ│ 0.5 │
|
||||
│ 20240101 │ 000002.SZ│ 0.3 │
|
||||
└────────────┴──────────┴──────────────┘
|
||||
|
||||
Args:
|
||||
factor: 截面因子
|
||||
start_date: 开始日期 YYYYMMDD
|
||||
end_date: 结束日期 YYYYMMDD
|
||||
|
||||
Returns:
|
||||
包含因子值的 DataFrame
|
||||
"""
|
||||
# 计算最大 lookback
|
||||
max_lookback = max(spec.lookback_days for spec in factor.data_specs)
|
||||
|
||||
# 确定数据起始日期(向前扩展 lookback)
|
||||
data_start = self._get_trading_date_offset(start_date, -max_lookback + 1)
|
||||
|
||||
# 一次性加载所有数据
|
||||
raw_data = self.data_loader.load(
|
||||
factor.data_specs, date_range=(data_start, end_date)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
# 获取日期范围
|
||||
date_range = self._get_date_range(start_date, end_date, raw_data)
|
||||
|
||||
# 按日期遍历:每天计算一次
|
||||
for current_date in date_range:
|
||||
# 裁剪数据:只保留 current_date 及之前的数据(防止日期泄露)
|
||||
# 但保留所有股票的数据(允许股票间比较)
|
||||
day_data = raw_data.filter(pl.col("trade_date") <= current_date)
|
||||
|
||||
# 如果 lookback > 0,进一步裁剪到 lookback 窗口
|
||||
if max_lookback > 0:
|
||||
lookback_start = self._get_trading_date_offset(
|
||||
current_date, -max_lookback + 1
|
||||
)
|
||||
day_data = day_data.filter(pl.col("trade_date") >= lookback_start)
|
||||
|
||||
# 如果没有数据,跳过
|
||||
if len(day_data) == 0:
|
||||
continue
|
||||
|
||||
# 创建 FactorData(包含当天及历史数据,无未来数据)
|
||||
context = FactorContext(
|
||||
current_date=current_date,
|
||||
trade_dates=date_range,
|
||||
)
|
||||
factor_data = FactorData(day_data, context)
|
||||
|
||||
# 计算因子值
|
||||
factor_values = factor.compute(factor_data)
|
||||
|
||||
# 获取当前日期的股票列表
|
||||
today_stocks = day_data.filter(pl.col("trade_date") == current_date)[
|
||||
"ts_code"
|
||||
]
|
||||
|
||||
# 确保 factor_values 长度与股票列表一致
|
||||
if len(factor_values) != len(today_stocks):
|
||||
# 如果长度不一致,可能是 factor.compute 返回了错误的长度
|
||||
# 尝试从 factor_data 重新提取
|
||||
cs_data = factor_data.get_cross_section()
|
||||
if len(cs_data) > 0:
|
||||
today_stocks = cs_data["ts_code"]
|
||||
# 如果 factor_values 仍然不匹配,用 null 填充
|
||||
if len(factor_values) != len(today_stocks):
|
||||
factor_values = pl.Series([None] * len(today_stocks))
|
||||
|
||||
# 收集结果
|
||||
if len(today_stocks) > 0:
|
||||
results.append(
|
||||
pl.DataFrame(
|
||||
{
|
||||
"trade_date": [current_date] * len(today_stocks),
|
||||
"ts_code": today_stocks,
|
||||
factor.name: factor_values,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# 合并所有日期的结果
|
||||
if results:
|
||||
return pl.concat(results)
|
||||
else:
|
||||
# 返回空 DataFrame
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"trade_date": [],
|
||||
"ts_code": [],
|
||||
factor.name: [],
|
||||
}
|
||||
)
|
||||
|
||||
def _compute_time_series(
|
||||
self,
|
||||
factor: TimeSeriesFactor,
|
||||
stock_codes: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""执行时间序列计算
|
||||
|
||||
防泄露策略:
|
||||
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
|
||||
- 允许访问历史数据:时序计算需要历史数据,这是正常的
|
||||
|
||||
计算流程:
|
||||
1. 计算 max_lookback,确定数据起始日期
|
||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
||||
3. 对每只股票 S in stock_codes:
|
||||
a. 过滤出 S 的数据(防止股票泄露)
|
||||
b. 创建 FactorData(current_stock=S)
|
||||
c. 调用 factor.compute()(向量化计算整个序列)
|
||||
d. 收集结果
|
||||
4. 合并所有股票的结果
|
||||
|
||||
性能优势:
|
||||
- 使用 Polars 的 rolling_mean 等向量化操作
|
||||
- 每只股票只计算一次,无重复计算
|
||||
|
||||
返回 DataFrame 格式:
|
||||
┌────────────┬──────────┬──────────────┐
|
||||
│ trade_date │ ts_code │ factor_name │
|
||||
├────────────┼──────────┼──────────────┤
|
||||
│ 20240101 │ 000001.SZ│ 10.5 │
|
||||
│ 20240102 │ 000001.SZ│ 10.6 │
|
||||
└────────────┴──────────┴──────────────┘
|
||||
|
||||
Args:
|
||||
factor: 时序因子
|
||||
stock_codes: 股票代码列表
|
||||
start_date: 开始日期 YYYYMMDD
|
||||
end_date: 结束日期 YYYYMMDD
|
||||
|
||||
Returns:
|
||||
包含因子值的 DataFrame
|
||||
"""
|
||||
# 计算最大 lookback
|
||||
max_lookback = max(spec.lookback_days for spec in factor.data_specs)
|
||||
|
||||
# 确定数据起始日期(向前扩展 lookback)
|
||||
data_start = self._get_trading_date_offset(start_date, -max_lookback + 1)
|
||||
|
||||
# 加载所有数据
|
||||
all_data = self.data_loader.load(
|
||||
factor.data_specs, date_range=(data_start, end_date)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
# 获取所有交易日
|
||||
all_dates = all_data["trade_date"].unique().sort() if len(all_data) > 0 else []
|
||||
|
||||
# 按股票遍历:每只股票一次性计算
|
||||
for stock_code in stock_codes:
|
||||
# 过滤出该股票的数据(防止股票泄露)
|
||||
stock_data = all_data.filter(pl.col("ts_code") == stock_code)
|
||||
|
||||
if len(stock_data) == 0:
|
||||
continue
|
||||
|
||||
# 创建 FactorData(该股票的完整序列)
|
||||
context = FactorContext(
|
||||
current_stock=stock_code,
|
||||
trade_dates=list(all_dates),
|
||||
)
|
||||
factor_data = FactorData(stock_data, context)
|
||||
|
||||
# 一次性计算整个时间序列(向量化,高效)
|
||||
factor_values = factor.compute(factor_data)
|
||||
|
||||
# 获取该股票的日期列表
|
||||
stock_dates = stock_data["trade_date"]
|
||||
|
||||
# 确保 factor_values 长度与日期列表一致
|
||||
if len(factor_values) != len(stock_dates):
|
||||
# 如果长度不一致,用 null 填充
|
||||
factor_values = pl.Series([None] * len(stock_dates))
|
||||
|
||||
# 收集结果
|
||||
results.append(
|
||||
pl.DataFrame(
|
||||
{
|
||||
"trade_date": stock_dates,
|
||||
"ts_code": [stock_code] * len(stock_dates),
|
||||
factor.name: factor_values,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# 合并所有股票的结果
|
||||
if results:
|
||||
return pl.concat(results)
|
||||
else:
|
||||
# 返回空 DataFrame
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"trade_date": [],
|
||||
"ts_code": [],
|
||||
factor.name: [],
|
||||
}
|
||||
)
|
||||
|
||||
def _get_trading_date_offset(self, date: str, offset: int) -> str:
|
||||
"""获取相对于给定日期的交易日偏移
|
||||
|
||||
简单实现:假设每天都有交易,直接计算日期偏移
|
||||
实际项目中可能需要使用交易日历
|
||||
|
||||
Args:
|
||||
date: 基准日期 YYYYMMDD
|
||||
offset: 偏移天数(正数向后,负数向前)
|
||||
|
||||
Returns:
|
||||
偏移后的日期 YYYYMMDD
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
dt = datetime.strptime(date, "%Y%m%d")
|
||||
new_dt = dt + timedelta(days=offset)
|
||||
return new_dt.strftime("%Y%m%d")
|
||||
|
||||
def _get_date_range(
|
||||
self, start_date: str, end_date: str, data: pl.DataFrame
|
||||
) -> List[str]:
|
||||
"""获取日期范围内的所有交易日
|
||||
|
||||
Args:
|
||||
start_date: 开始日期 YYYYMMDD
|
||||
end_date: 结束日期 YYYYMMDD
|
||||
data: 包含 trade_date 列的 DataFrame
|
||||
|
||||
Returns:
|
||||
日期列表
|
||||
"""
|
||||
if len(data) == 0:
|
||||
return []
|
||||
|
||||
# 从数据中获取实际存在的日期
|
||||
dates = (
|
||||
data.filter(
|
||||
(pl.col("trade_date") >= start_date)
|
||||
& (pl.col("trade_date") <= end_date)
|
||||
)["trade_date"]
|
||||
.unique()
|
||||
.sort()
|
||||
.to_list()
|
||||
)
|
||||
|
||||
return dates
|
||||
Reference in New Issue
Block a user