feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor) - 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData) - 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine) - 新增组合因子支持 (CompositeFactor, ScalarFactor) - 添加因子模块完整测试用例 - 添加 Git 提交规范文档
This commit is contained in:
367
src/factors/engine.py
Normal file
367
src/factors/engine.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""执行引擎 - Phase 4 因子执行引擎
|
||||
|
||||
本模块负责因子计算的核心逻辑:
|
||||
- FactorEngine: 因子执行引擎,根据因子类型采用不同的计算和防泄露策略
|
||||
|
||||
防泄露策略:
|
||||
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
||||
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_loader import DataLoader
|
||||
from src.factors.data_spec import FactorContext, FactorData
|
||||
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
|
||||
|
||||
|
||||
class FactorEngine:
|
||||
"""因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
|
||||
|
||||
核心职责:
|
||||
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
||||
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
||||
|
||||
示例:
|
||||
>>> loader = DataLoader(data_dir="data")
|
||||
>>> engine = FactorEngine(loader)
|
||||
>>> result = engine.compute(factor, start_date="20240101", end_date="20240131")
|
||||
"""
|
||||
|
||||
def __init__(self, data_loader: DataLoader):
|
||||
"""初始化引擎
|
||||
|
||||
Args:
|
||||
data_loader: 数据加载器实例
|
||||
"""
|
||||
self.data_loader = data_loader
|
||||
|
||||
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
|
||||
"""统一的计算入口
|
||||
|
||||
根据 factor_type 分发到具体方法:
|
||||
- "cross_sectional" -> _compute_cross_sectional()
|
||||
- "time_series" -> _compute_time_series()
|
||||
|
||||
Args:
|
||||
factor: 要计算的因子
|
||||
**kwargs: 额外参数,根据因子类型不同:
|
||||
- 截面因子: start_date, end_date
|
||||
- 时序因子: stock_codes, start_date, end_date
|
||||
|
||||
Returns:
|
||||
DataFrame[trade_date, ts_code, factor_name]
|
||||
|
||||
Raises:
|
||||
ValueError: 无效的 factor_type 或缺少必需参数
|
||||
"""
|
||||
if factor.factor_type == "cross_sectional":
|
||||
if "start_date" not in kwargs or "end_date" not in kwargs:
|
||||
raise ValueError(
|
||||
"cross_sectional factor requires 'start_date' and 'end_date' parameters"
|
||||
)
|
||||
return self._compute_cross_sectional(
|
||||
factor, kwargs["start_date"], kwargs["end_date"]
|
||||
)
|
||||
elif factor.factor_type == "time_series":
|
||||
missing = []
|
||||
if "stock_codes" not in kwargs:
|
||||
missing.append("stock_codes")
|
||||
if "start_date" not in kwargs:
|
||||
missing.append("start_date")
|
||||
if "end_date" not in kwargs:
|
||||
missing.append("end_date")
|
||||
if missing:
|
||||
raise ValueError(
|
||||
f"time_series factor requires parameters: {', '.join(missing)}"
|
||||
)
|
||||
return self._compute_time_series(
|
||||
factor,
|
||||
kwargs["stock_codes"],
|
||||
kwargs["start_date"],
|
||||
kwargs["end_date"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown factor type: {factor.factor_type}")
|
||||
|
||||
def _compute_cross_sectional(
|
||||
self,
|
||||
factor: CrossSectionalFactor,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""执行日期截面计算
|
||||
|
||||
防泄露策略:
|
||||
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
|
||||
- 允许股票间比较:传入当天所有股票的数据
|
||||
|
||||
计算流程:
|
||||
1. 计算 max_lookback,确定数据起始日期
|
||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
||||
3. 对每个日期 T in [start_date, end_date]:
|
||||
a. 裁剪数据到 [T-lookback+1, T]
|
||||
b. 创建 FactorData(current_date=T)
|
||||
c. 调用 factor.compute()
|
||||
d. 收集结果
|
||||
4. 合并所有日期的结果
|
||||
|
||||
返回 DataFrame 格式:
|
||||
┌────────────┬──────────┬──────────────┐
|
||||
│ trade_date │ ts_code │ factor_name │
|
||||
├────────────┼──────────┼──────────────┤
|
||||
│ 20240101 │ 000001.SZ│ 0.5 │
|
||||
│ 20240101 │ 000002.SZ│ 0.3 │
|
||||
└────────────┴──────────┴──────────────┘
|
||||
|
||||
Args:
|
||||
factor: 截面因子
|
||||
start_date: 开始日期 YYYYMMDD
|
||||
end_date: 结束日期 YYYYMMDD
|
||||
|
||||
Returns:
|
||||
包含因子值的 DataFrame
|
||||
"""
|
||||
# 计算最大 lookback
|
||||
max_lookback = max(spec.lookback_days for spec in factor.data_specs)
|
||||
|
||||
# 确定数据起始日期(向前扩展 lookback)
|
||||
data_start = self._get_trading_date_offset(start_date, -max_lookback + 1)
|
||||
|
||||
# 一次性加载所有数据
|
||||
raw_data = self.data_loader.load(
|
||||
factor.data_specs, date_range=(data_start, end_date)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
# 获取日期范围
|
||||
date_range = self._get_date_range(start_date, end_date, raw_data)
|
||||
|
||||
# 按日期遍历:每天计算一次
|
||||
for current_date in date_range:
|
||||
# 裁剪数据:只保留 current_date 及之前的数据(防止日期泄露)
|
||||
# 但保留所有股票的数据(允许股票间比较)
|
||||
day_data = raw_data.filter(pl.col("trade_date") <= current_date)
|
||||
|
||||
# 如果 lookback > 0,进一步裁剪到 lookback 窗口
|
||||
if max_lookback > 0:
|
||||
lookback_start = self._get_trading_date_offset(
|
||||
current_date, -max_lookback + 1
|
||||
)
|
||||
day_data = day_data.filter(pl.col("trade_date") >= lookback_start)
|
||||
|
||||
# 如果没有数据,跳过
|
||||
if len(day_data) == 0:
|
||||
continue
|
||||
|
||||
# 创建 FactorData(包含当天及历史数据,无未来数据)
|
||||
context = FactorContext(
|
||||
current_date=current_date,
|
||||
trade_dates=date_range,
|
||||
)
|
||||
factor_data = FactorData(day_data, context)
|
||||
|
||||
# 计算因子值
|
||||
factor_values = factor.compute(factor_data)
|
||||
|
||||
# 获取当前日期的股票列表
|
||||
today_stocks = day_data.filter(pl.col("trade_date") == current_date)[
|
||||
"ts_code"
|
||||
]
|
||||
|
||||
# 确保 factor_values 长度与股票列表一致
|
||||
if len(factor_values) != len(today_stocks):
|
||||
# 如果长度不一致,可能是 factor.compute 返回了错误的长度
|
||||
# 尝试从 factor_data 重新提取
|
||||
cs_data = factor_data.get_cross_section()
|
||||
if len(cs_data) > 0:
|
||||
today_stocks = cs_data["ts_code"]
|
||||
# 如果 factor_values 仍然不匹配,用 null 填充
|
||||
if len(factor_values) != len(today_stocks):
|
||||
factor_values = pl.Series([None] * len(today_stocks))
|
||||
|
||||
# 收集结果
|
||||
if len(today_stocks) > 0:
|
||||
results.append(
|
||||
pl.DataFrame(
|
||||
{
|
||||
"trade_date": [current_date] * len(today_stocks),
|
||||
"ts_code": today_stocks,
|
||||
factor.name: factor_values,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# 合并所有日期的结果
|
||||
if results:
|
||||
return pl.concat(results)
|
||||
else:
|
||||
# 返回空 DataFrame
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"trade_date": [],
|
||||
"ts_code": [],
|
||||
factor.name: [],
|
||||
}
|
||||
)
|
||||
|
||||
def _compute_time_series(
|
||||
self,
|
||||
factor: TimeSeriesFactor,
|
||||
stock_codes: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""执行时间序列计算
|
||||
|
||||
防泄露策略:
|
||||
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
|
||||
- 允许访问历史数据:时序计算需要历史数据,这是正常的
|
||||
|
||||
计算流程:
|
||||
1. 计算 max_lookback,确定数据起始日期
|
||||
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
||||
3. 对每只股票 S in stock_codes:
|
||||
a. 过滤出 S 的数据(防止股票泄露)
|
||||
b. 创建 FactorData(current_stock=S)
|
||||
c. 调用 factor.compute()(向量化计算整个序列)
|
||||
d. 收集结果
|
||||
4. 合并所有股票的结果
|
||||
|
||||
性能优势:
|
||||
- 使用 Polars 的 rolling_mean 等向量化操作
|
||||
- 每只股票只计算一次,无重复计算
|
||||
|
||||
返回 DataFrame 格式:
|
||||
┌────────────┬──────────┬──────────────┐
|
||||
│ trade_date │ ts_code │ factor_name │
|
||||
├────────────┼──────────┼──────────────┤
|
||||
│ 20240101 │ 000001.SZ│ 10.5 │
|
||||
│ 20240102 │ 000001.SZ│ 10.6 │
|
||||
└────────────┴──────────┴──────────────┘
|
||||
|
||||
Args:
|
||||
factor: 时序因子
|
||||
stock_codes: 股票代码列表
|
||||
start_date: 开始日期 YYYYMMDD
|
||||
end_date: 结束日期 YYYYMMDD
|
||||
|
||||
Returns:
|
||||
包含因子值的 DataFrame
|
||||
"""
|
||||
# 计算最大 lookback
|
||||
max_lookback = max(spec.lookback_days for spec in factor.data_specs)
|
||||
|
||||
# 确定数据起始日期(向前扩展 lookback)
|
||||
data_start = self._get_trading_date_offset(start_date, -max_lookback + 1)
|
||||
|
||||
# 加载所有数据
|
||||
all_data = self.data_loader.load(
|
||||
factor.data_specs, date_range=(data_start, end_date)
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
# 获取所有交易日
|
||||
all_dates = all_data["trade_date"].unique().sort() if len(all_data) > 0 else []
|
||||
|
||||
# 按股票遍历:每只股票一次性计算
|
||||
for stock_code in stock_codes:
|
||||
# 过滤出该股票的数据(防止股票泄露)
|
||||
stock_data = all_data.filter(pl.col("ts_code") == stock_code)
|
||||
|
||||
if len(stock_data) == 0:
|
||||
continue
|
||||
|
||||
# 创建 FactorData(该股票的完整序列)
|
||||
context = FactorContext(
|
||||
current_stock=stock_code,
|
||||
trade_dates=list(all_dates),
|
||||
)
|
||||
factor_data = FactorData(stock_data, context)
|
||||
|
||||
# 一次性计算整个时间序列(向量化,高效)
|
||||
factor_values = factor.compute(factor_data)
|
||||
|
||||
# 获取该股票的日期列表
|
||||
stock_dates = stock_data["trade_date"]
|
||||
|
||||
# 确保 factor_values 长度与日期列表一致
|
||||
if len(factor_values) != len(stock_dates):
|
||||
# 如果长度不一致,用 null 填充
|
||||
factor_values = pl.Series([None] * len(stock_dates))
|
||||
|
||||
# 收集结果
|
||||
results.append(
|
||||
pl.DataFrame(
|
||||
{
|
||||
"trade_date": stock_dates,
|
||||
"ts_code": [stock_code] * len(stock_dates),
|
||||
factor.name: factor_values,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# 合并所有股票的结果
|
||||
if results:
|
||||
return pl.concat(results)
|
||||
else:
|
||||
# 返回空 DataFrame
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"trade_date": [],
|
||||
"ts_code": [],
|
||||
factor.name: [],
|
||||
}
|
||||
)
|
||||
|
||||
def _get_trading_date_offset(self, date: str, offset: int) -> str:
|
||||
"""获取相对于给定日期的交易日偏移
|
||||
|
||||
简单实现:假设每天都有交易,直接计算日期偏移
|
||||
实际项目中可能需要使用交易日历
|
||||
|
||||
Args:
|
||||
date: 基准日期 YYYYMMDD
|
||||
offset: 偏移天数(正数向后,负数向前)
|
||||
|
||||
Returns:
|
||||
偏移后的日期 YYYYMMDD
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
dt = datetime.strptime(date, "%Y%m%d")
|
||||
new_dt = dt + timedelta(days=offset)
|
||||
return new_dt.strftime("%Y%m%d")
|
||||
|
||||
def _get_date_range(
|
||||
self, start_date: str, end_date: str, data: pl.DataFrame
|
||||
) -> List[str]:
|
||||
"""获取日期范围内的所有交易日
|
||||
|
||||
Args:
|
||||
start_date: 开始日期 YYYYMMDD
|
||||
end_date: 结束日期 YYYYMMDD
|
||||
data: 包含 trade_date 列的 DataFrame
|
||||
|
||||
Returns:
|
||||
日期列表
|
||||
"""
|
||||
if len(data) == 0:
|
||||
return []
|
||||
|
||||
# 从数据中获取实际存在的日期
|
||||
dates = (
|
||||
data.filter(
|
||||
(pl.col("trade_date") >= start_date)
|
||||
& (pl.col("trade_date") <= end_date)
|
||||
)["trade_date"]
|
||||
.unique()
|
||||
.sort()
|
||||
.to_list()
|
||||
)
|
||||
|
||||
return dates
|
||||
Reference in New Issue
Block a user