368 lines
13 KiB
Python
368 lines
13 KiB
Python
|
|
"""执行引擎 - Phase 4 因子执行引擎
|
|||
|
|
|
|||
|
|
本模块负责因子计算的核心逻辑:
|
|||
|
|
- FactorEngine: 因子执行引擎,根据因子类型采用不同的计算和防泄露策略
|
|||
|
|
|
|||
|
|
防泄露策略:
|
|||
|
|
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
|||
|
|
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import List, Optional
|
|||
|
|
|
|||
|
|
import polars as pl
|
|||
|
|
|
|||
|
|
from src.factors.data_loader import DataLoader
|
|||
|
|
from src.factors.data_spec import FactorContext, FactorData
|
|||
|
|
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FactorEngine:
|
|||
|
|
"""因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略
|
|||
|
|
|
|||
|
|
核心职责:
|
|||
|
|
1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据
|
|||
|
|
2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列
|
|||
|
|
|
|||
|
|
示例:
|
|||
|
|
>>> loader = DataLoader(data_dir="data")
|
|||
|
|
>>> engine = FactorEngine(loader)
|
|||
|
|
>>> result = engine.compute(factor, start_date="20240101", end_date="20240131")
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, data_loader: DataLoader):
|
|||
|
|
"""初始化引擎
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_loader: 数据加载器实例
|
|||
|
|
"""
|
|||
|
|
self.data_loader = data_loader
|
|||
|
|
|
|||
|
|
def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame:
|
|||
|
|
"""统一的计算入口
|
|||
|
|
|
|||
|
|
根据 factor_type 分发到具体方法:
|
|||
|
|
- "cross_sectional" -> _compute_cross_sectional()
|
|||
|
|
- "time_series" -> _compute_time_series()
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
factor: 要计算的因子
|
|||
|
|
**kwargs: 额外参数,根据因子类型不同:
|
|||
|
|
- 截面因子: start_date, end_date
|
|||
|
|
- 时序因子: stock_codes, start_date, end_date
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
DataFrame[trade_date, ts_code, factor_name]
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
ValueError: 无效的 factor_type 或缺少必需参数
|
|||
|
|
"""
|
|||
|
|
if factor.factor_type == "cross_sectional":
|
|||
|
|
if "start_date" not in kwargs or "end_date" not in kwargs:
|
|||
|
|
raise ValueError(
|
|||
|
|
"cross_sectional factor requires 'start_date' and 'end_date' parameters"
|
|||
|
|
)
|
|||
|
|
return self._compute_cross_sectional(
|
|||
|
|
factor, kwargs["start_date"], kwargs["end_date"]
|
|||
|
|
)
|
|||
|
|
elif factor.factor_type == "time_series":
|
|||
|
|
missing = []
|
|||
|
|
if "stock_codes" not in kwargs:
|
|||
|
|
missing.append("stock_codes")
|
|||
|
|
if "start_date" not in kwargs:
|
|||
|
|
missing.append("start_date")
|
|||
|
|
if "end_date" not in kwargs:
|
|||
|
|
missing.append("end_date")
|
|||
|
|
if missing:
|
|||
|
|
raise ValueError(
|
|||
|
|
f"time_series factor requires parameters: {', '.join(missing)}"
|
|||
|
|
)
|
|||
|
|
return self._compute_time_series(
|
|||
|
|
factor,
|
|||
|
|
kwargs["stock_codes"],
|
|||
|
|
kwargs["start_date"],
|
|||
|
|
kwargs["end_date"],
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"Unknown factor type: {factor.factor_type}")
|
|||
|
|
|
|||
|
|
def _compute_cross_sectional(
|
|||
|
|
self,
|
|||
|
|
factor: CrossSectionalFactor,
|
|||
|
|
start_date: str,
|
|||
|
|
end_date: str,
|
|||
|
|
) -> pl.DataFrame:
|
|||
|
|
"""执行日期截面计算
|
|||
|
|
|
|||
|
|
防泄露策略:
|
|||
|
|
- 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来)
|
|||
|
|
- 允许股票间比较:传入当天所有股票的数据
|
|||
|
|
|
|||
|
|
计算流程:
|
|||
|
|
1. 计算 max_lookback,确定数据起始日期
|
|||
|
|
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
|||
|
|
3. 对每个日期 T in [start_date, end_date]:
|
|||
|
|
a. 裁剪数据到 [T-lookback+1, T]
|
|||
|
|
b. 创建 FactorData(current_date=T)
|
|||
|
|
c. 调用 factor.compute()
|
|||
|
|
d. 收集结果
|
|||
|
|
4. 合并所有日期的结果
|
|||
|
|
|
|||
|
|
返回 DataFrame 格式:
|
|||
|
|
┌────────────┬──────────┬──────────────┐
|
|||
|
|
│ trade_date │ ts_code │ factor_name │
|
|||
|
|
├────────────┼──────────┼──────────────┤
|
|||
|
|
│ 20240101 │ 000001.SZ│ 0.5 │
|
|||
|
|
│ 20240101 │ 000002.SZ│ 0.3 │
|
|||
|
|
└────────────┴──────────┴──────────────┘
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
factor: 截面因子
|
|||
|
|
start_date: 开始日期 YYYYMMDD
|
|||
|
|
end_date: 结束日期 YYYYMMDD
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
包含因子值的 DataFrame
|
|||
|
|
"""
|
|||
|
|
# 计算最大 lookback
|
|||
|
|
max_lookback = max(spec.lookback_days for spec in factor.data_specs)
|
|||
|
|
|
|||
|
|
# 确定数据起始日期(向前扩展 lookback)
|
|||
|
|
data_start = self._get_trading_date_offset(start_date, -max_lookback + 1)
|
|||
|
|
|
|||
|
|
# 一次性加载所有数据
|
|||
|
|
raw_data = self.data_loader.load(
|
|||
|
|
factor.data_specs, date_range=(data_start, end_date)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
|
|||
|
|
# 获取日期范围
|
|||
|
|
date_range = self._get_date_range(start_date, end_date, raw_data)
|
|||
|
|
|
|||
|
|
# 按日期遍历:每天计算一次
|
|||
|
|
for current_date in date_range:
|
|||
|
|
# 裁剪数据:只保留 current_date 及之前的数据(防止日期泄露)
|
|||
|
|
# 但保留所有股票的数据(允许股票间比较)
|
|||
|
|
day_data = raw_data.filter(pl.col("trade_date") <= current_date)
|
|||
|
|
|
|||
|
|
# 如果 lookback > 0,进一步裁剪到 lookback 窗口
|
|||
|
|
if max_lookback > 0:
|
|||
|
|
lookback_start = self._get_trading_date_offset(
|
|||
|
|
current_date, -max_lookback + 1
|
|||
|
|
)
|
|||
|
|
day_data = day_data.filter(pl.col("trade_date") >= lookback_start)
|
|||
|
|
|
|||
|
|
# 如果没有数据,跳过
|
|||
|
|
if len(day_data) == 0:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 创建 FactorData(包含当天及历史数据,无未来数据)
|
|||
|
|
context = FactorContext(
|
|||
|
|
current_date=current_date,
|
|||
|
|
trade_dates=date_range,
|
|||
|
|
)
|
|||
|
|
factor_data = FactorData(day_data, context)
|
|||
|
|
|
|||
|
|
# 计算因子值
|
|||
|
|
factor_values = factor.compute(factor_data)
|
|||
|
|
|
|||
|
|
# 获取当前日期的股票列表
|
|||
|
|
today_stocks = day_data.filter(pl.col("trade_date") == current_date)[
|
|||
|
|
"ts_code"
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 确保 factor_values 长度与股票列表一致
|
|||
|
|
if len(factor_values) != len(today_stocks):
|
|||
|
|
# 如果长度不一致,可能是 factor.compute 返回了错误的长度
|
|||
|
|
# 尝试从 factor_data 重新提取
|
|||
|
|
cs_data = factor_data.get_cross_section()
|
|||
|
|
if len(cs_data) > 0:
|
|||
|
|
today_stocks = cs_data["ts_code"]
|
|||
|
|
# 如果 factor_values 仍然不匹配,用 null 填充
|
|||
|
|
if len(factor_values) != len(today_stocks):
|
|||
|
|
factor_values = pl.Series([None] * len(today_stocks))
|
|||
|
|
|
|||
|
|
# 收集结果
|
|||
|
|
if len(today_stocks) > 0:
|
|||
|
|
results.append(
|
|||
|
|
pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"trade_date": [current_date] * len(today_stocks),
|
|||
|
|
"ts_code": today_stocks,
|
|||
|
|
factor.name: factor_values,
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 合并所有日期的结果
|
|||
|
|
if results:
|
|||
|
|
return pl.concat(results)
|
|||
|
|
else:
|
|||
|
|
# 返回空 DataFrame
|
|||
|
|
return pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"trade_date": [],
|
|||
|
|
"ts_code": [],
|
|||
|
|
factor.name: [],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _compute_time_series(
|
|||
|
|
self,
|
|||
|
|
factor: TimeSeriesFactor,
|
|||
|
|
stock_codes: List[str],
|
|||
|
|
start_date: str,
|
|||
|
|
end_date: str,
|
|||
|
|
) -> pl.DataFrame:
|
|||
|
|
"""执行时间序列计算
|
|||
|
|
|
|||
|
|
防泄露策略:
|
|||
|
|
- 防止股票泄露:每只股票单独计算,传入该股票的完整序列
|
|||
|
|
- 允许访问历史数据:时序计算需要历史数据,这是正常的
|
|||
|
|
|
|||
|
|
计算流程:
|
|||
|
|
1. 计算 max_lookback,确定数据起始日期
|
|||
|
|
2. 一次性加载 [start-max_lookback+1, end] 的所有数据
|
|||
|
|
3. 对每只股票 S in stock_codes:
|
|||
|
|
a. 过滤出 S 的数据(防止股票泄露)
|
|||
|
|
b. 创建 FactorData(current_stock=S)
|
|||
|
|
c. 调用 factor.compute()(向量化计算整个序列)
|
|||
|
|
d. 收集结果
|
|||
|
|
4. 合并所有股票的结果
|
|||
|
|
|
|||
|
|
性能优势:
|
|||
|
|
- 使用 Polars 的 rolling_mean 等向量化操作
|
|||
|
|
- 每只股票只计算一次,无重复计算
|
|||
|
|
|
|||
|
|
返回 DataFrame 格式:
|
|||
|
|
┌────────────┬──────────┬──────────────┐
|
|||
|
|
│ trade_date │ ts_code │ factor_name │
|
|||
|
|
├────────────┼──────────┼──────────────┤
|
|||
|
|
│ 20240101 │ 000001.SZ│ 10.5 │
|
|||
|
|
│ 20240102 │ 000001.SZ│ 10.6 │
|
|||
|
|
└────────────┴──────────┴──────────────┘
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
factor: 时序因子
|
|||
|
|
stock_codes: 股票代码列表
|
|||
|
|
start_date: 开始日期 YYYYMMDD
|
|||
|
|
end_date: 结束日期 YYYYMMDD
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
包含因子值的 DataFrame
|
|||
|
|
"""
|
|||
|
|
# 计算最大 lookback
|
|||
|
|
max_lookback = max(spec.lookback_days for spec in factor.data_specs)
|
|||
|
|
|
|||
|
|
# 确定数据起始日期(向前扩展 lookback)
|
|||
|
|
data_start = self._get_trading_date_offset(start_date, -max_lookback + 1)
|
|||
|
|
|
|||
|
|
# 加载所有数据
|
|||
|
|
all_data = self.data_loader.load(
|
|||
|
|
factor.data_specs, date_range=(data_start, end_date)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = []
|
|||
|
|
|
|||
|
|
# 获取所有交易日
|
|||
|
|
all_dates = all_data["trade_date"].unique().sort() if len(all_data) > 0 else []
|
|||
|
|
|
|||
|
|
# 按股票遍历:每只股票一次性计算
|
|||
|
|
for stock_code in stock_codes:
|
|||
|
|
# 过滤出该股票的数据(防止股票泄露)
|
|||
|
|
stock_data = all_data.filter(pl.col("ts_code") == stock_code)
|
|||
|
|
|
|||
|
|
if len(stock_data) == 0:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 创建 FactorData(该股票的完整序列)
|
|||
|
|
context = FactorContext(
|
|||
|
|
current_stock=stock_code,
|
|||
|
|
trade_dates=list(all_dates),
|
|||
|
|
)
|
|||
|
|
factor_data = FactorData(stock_data, context)
|
|||
|
|
|
|||
|
|
# 一次性计算整个时间序列(向量化,高效)
|
|||
|
|
factor_values = factor.compute(factor_data)
|
|||
|
|
|
|||
|
|
# 获取该股票的日期列表
|
|||
|
|
stock_dates = stock_data["trade_date"]
|
|||
|
|
|
|||
|
|
# 确保 factor_values 长度与日期列表一致
|
|||
|
|
if len(factor_values) != len(stock_dates):
|
|||
|
|
# 如果长度不一致,用 null 填充
|
|||
|
|
factor_values = pl.Series([None] * len(stock_dates))
|
|||
|
|
|
|||
|
|
# 收集结果
|
|||
|
|
results.append(
|
|||
|
|
pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"trade_date": stock_dates,
|
|||
|
|
"ts_code": [stock_code] * len(stock_dates),
|
|||
|
|
factor.name: factor_values,
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 合并所有股票的结果
|
|||
|
|
if results:
|
|||
|
|
return pl.concat(results)
|
|||
|
|
else:
|
|||
|
|
# 返回空 DataFrame
|
|||
|
|
return pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"trade_date": [],
|
|||
|
|
"ts_code": [],
|
|||
|
|
factor.name: [],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _get_trading_date_offset(self, date: str, offset: int) -> str:
|
|||
|
|
"""获取相对于给定日期的交易日偏移
|
|||
|
|
|
|||
|
|
简单实现:假设每天都有交易,直接计算日期偏移
|
|||
|
|
实际项目中可能需要使用交易日历
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
date: 基准日期 YYYYMMDD
|
|||
|
|
offset: 偏移天数(正数向后,负数向前)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
偏移后的日期 YYYYMMDD
|
|||
|
|
"""
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
|
|||
|
|
dt = datetime.strptime(date, "%Y%m%d")
|
|||
|
|
new_dt = dt + timedelta(days=offset)
|
|||
|
|
return new_dt.strftime("%Y%m%d")
|
|||
|
|
|
|||
|
|
def _get_date_range(
|
|||
|
|
self, start_date: str, end_date: str, data: pl.DataFrame
|
|||
|
|
) -> List[str]:
|
|||
|
|
"""获取日期范围内的所有交易日
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
start_date: 开始日期 YYYYMMDD
|
|||
|
|
end_date: 结束日期 YYYYMMDD
|
|||
|
|
data: 包含 trade_date 列的 DataFrame
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
日期列表
|
|||
|
|
"""
|
|||
|
|
if len(data) == 0:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 从数据中获取实际存在的日期
|
|||
|
|
dates = (
|
|||
|
|
data.filter(
|
|||
|
|
(pl.col("trade_date") >= start_date)
|
|||
|
|
& (pl.col("trade_date") <= end_date)
|
|||
|
|
)["trade_date"]
|
|||
|
|
.unique()
|
|||
|
|
.sort()
|
|||
|
|
.to_list()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return dates
|