"""执行引擎 - Phase 4 因子执行引擎 本模块负责因子计算的核心逻辑: - FactorEngine: 因子执行引擎,根据因子类型采用不同的计算和防泄露策略 防泄露策略: 1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据 2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列 """ from typing import List, Optional import polars as pl from src.factors.data_loader import DataLoader from src.factors.data_spec import FactorContext, FactorData from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor class FactorEngine: """因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略 核心职责: 1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据 2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列 示例: >>> loader = DataLoader(data_dir="data") >>> engine = FactorEngine(loader) >>> result = engine.compute(factor, start_date="20240101", end_date="20240131") """ def __init__(self, data_loader: DataLoader): """初始化引擎 Args: data_loader: 数据加载器实例 """ self.data_loader = data_loader def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame: """统一的计算入口 根据 factor_type 分发到具体方法: - "cross_sectional" -> _compute_cross_sectional() - "time_series" -> _compute_time_series() Args: factor: 要计算的因子 **kwargs: 额外参数,根据因子类型不同: - 截面因子: start_date, end_date - 时序因子: stock_codes, start_date, end_date Returns: DataFrame[trade_date, ts_code, factor_name] Raises: ValueError: 无效的 factor_type 或缺少必需参数 """ if factor.factor_type == "cross_sectional": if "start_date" not in kwargs or "end_date" not in kwargs: raise ValueError( "cross_sectional factor requires 'start_date' and 'end_date' parameters" ) return self._compute_cross_sectional( factor, kwargs["start_date"], kwargs["end_date"] ) elif factor.factor_type == "time_series": missing = [] if "stock_codes" not in kwargs: missing.append("stock_codes") if "start_date" not in kwargs: missing.append("start_date") if "end_date" not in kwargs: missing.append("end_date") if missing: raise ValueError( f"time_series factor requires parameters: {', '.join(missing)}" ) return self._compute_time_series( factor, kwargs["stock_codes"], kwargs["start_date"], kwargs["end_date"], ) else: raise ValueError(f"Unknown factor type: {factor.factor_type}") def _compute_cross_sectional( self, factor: CrossSectionalFactor, start_date: str, end_date: str, ) -> pl.DataFrame: """执行日期截面计算 防泄露策略: - 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来) - 允许股票间比较:传入当天所有股票的数据 计算流程: 1. 计算 max_lookback,确定数据起始日期 2. 一次性加载 [start-max_lookback+1, end] 的所有数据 3. 对每个日期 T in [start_date, end_date]: a. 裁剪数据到 [T-lookback+1, T] b. 创建 FactorData(current_date=T) c. 调用 factor.compute() d. 收集结果 4. 合并所有日期的结果 返回 DataFrame 格式: ┌────────────┬──────────┬──────────────┐ │ trade_date │ ts_code │ factor_name │ ├────────────┼──────────┼──────────────┤ │ 20240101 │ 000001.SZ│ 0.5 │ │ 20240101 │ 000002.SZ│ 0.3 │ └────────────┴──────────┴──────────────┘ Args: factor: 截面因子 start_date: 开始日期 YYYYMMDD end_date: 结束日期 YYYYMMDD Returns: 包含因子值的 DataFrame """ # 计算最大 lookback max_lookback = max(spec.lookback_days for spec in factor.data_specs) # 确定数据起始日期(向前扩展 lookback) data_start = self._get_trading_date_offset(start_date, -max_lookback + 1) # 一次性加载所有数据 raw_data = self.data_loader.load( factor.data_specs, date_range=(data_start, end_date) ) results = [] # 获取日期范围 date_range = self._get_date_range(start_date, end_date, raw_data) # 按日期遍历:每天计算一次 for current_date in date_range: # 裁剪数据:只保留 current_date 及之前的数据(防止日期泄露) # 但保留所有股票的数据(允许股票间比较) day_data = raw_data.filter(pl.col("trade_date") <= current_date) # 如果 lookback > 0,进一步裁剪到 lookback 窗口 if max_lookback > 0: lookback_start = self._get_trading_date_offset( current_date, -max_lookback + 1 ) day_data = day_data.filter(pl.col("trade_date") >= lookback_start) # 如果没有数据,跳过 if len(day_data) == 0: continue # 创建 FactorData(包含当天及历史数据,无未来数据) context = FactorContext( current_date=current_date, trade_dates=date_range, ) factor_data = FactorData(day_data, context) # 计算因子值 factor_values = factor.compute(factor_data) # 获取当前日期的股票列表 today_stocks = day_data.filter(pl.col("trade_date") == current_date)[ "ts_code" ] # 确保 factor_values 长度与股票列表一致 if len(factor_values) != len(today_stocks): # 如果长度不一致,可能是 factor.compute 返回了错误的长度 # 尝试从 factor_data 重新提取 cs_data = factor_data.get_cross_section() if len(cs_data) > 0: today_stocks = cs_data["ts_code"] # 如果 factor_values 仍然不匹配,用 null 填充 if len(factor_values) != len(today_stocks): factor_values = pl.Series([None] * len(today_stocks)) # 收集结果 if len(today_stocks) > 0: results.append( pl.DataFrame( { "trade_date": [current_date] * len(today_stocks), "ts_code": today_stocks, factor.name: factor_values, } ) ) # 合并所有日期的结果 if results: return pl.concat(results) else: # 返回空 DataFrame return pl.DataFrame( { "trade_date": [], "ts_code": [], factor.name: [], } ) def _compute_time_series( self, factor: TimeSeriesFactor, stock_codes: List[str], start_date: str, end_date: str, ) -> pl.DataFrame: """执行时间序列计算 防泄露策略: - 防止股票泄露:每只股票单独计算,传入该股票的完整序列 - 允许访问历史数据:时序计算需要历史数据,这是正常的 计算流程: 1. 计算 max_lookback,确定数据起始日期 2. 一次性加载 [start-max_lookback+1, end] 的所有数据 3. 对每只股票 S in stock_codes: a. 过滤出 S 的数据(防止股票泄露) b. 创建 FactorData(current_stock=S) c. 调用 factor.compute()(向量化计算整个序列) d. 收集结果 4. 合并所有股票的结果 性能优势: - 使用 Polars 的 rolling_mean 等向量化操作 - 每只股票只计算一次,无重复计算 返回 DataFrame 格式: ┌────────────┬──────────┬──────────────┐ │ trade_date │ ts_code │ factor_name │ ├────────────┼──────────┼──────────────┤ │ 20240101 │ 000001.SZ│ 10.5 │ │ 20240102 │ 000001.SZ│ 10.6 │ └────────────┴──────────┴──────────────┘ Args: factor: 时序因子 stock_codes: 股票代码列表 start_date: 开始日期 YYYYMMDD end_date: 结束日期 YYYYMMDD Returns: 包含因子值的 DataFrame """ # 计算最大 lookback max_lookback = max(spec.lookback_days for spec in factor.data_specs) # 确定数据起始日期(向前扩展 lookback) data_start = self._get_trading_date_offset(start_date, -max_lookback + 1) # 加载所有数据 all_data = self.data_loader.load( factor.data_specs, date_range=(data_start, end_date) ) results = [] # 获取所有交易日 all_dates = all_data["trade_date"].unique().sort() if len(all_data) > 0 else [] # 按股票遍历:每只股票一次性计算 for stock_code in stock_codes: # 过滤出该股票的数据(防止股票泄露) stock_data = all_data.filter(pl.col("ts_code") == stock_code) if len(stock_data) == 0: continue # 创建 FactorData(该股票的完整序列) context = FactorContext( current_stock=stock_code, trade_dates=list(all_dates), ) factor_data = FactorData(stock_data, context) # 一次性计算整个时间序列(向量化,高效) factor_values = factor.compute(factor_data) # 获取该股票的日期列表 stock_dates = stock_data["trade_date"] # 确保 factor_values 长度与日期列表一致 if len(factor_values) != len(stock_dates): # 如果长度不一致,用 null 填充 factor_values = pl.Series([None] * len(stock_dates)) # 收集结果 results.append( pl.DataFrame( { "trade_date": stock_dates, "ts_code": [stock_code] * len(stock_dates), factor.name: factor_values, } ) ) # 合并所有股票的结果 if results: return pl.concat(results) else: # 返回空 DataFrame return pl.DataFrame( { "trade_date": [], "ts_code": [], factor.name: [], } ) def _get_trading_date_offset(self, date: str, offset: int) -> str: """获取相对于给定日期的交易日偏移 简单实现:假设每天都有交易,直接计算日期偏移 实际项目中可能需要使用交易日历 Args: date: 基准日期 YYYYMMDD offset: 偏移天数(正数向后,负数向前) Returns: 偏移后的日期 YYYYMMDD """ from datetime import datetime, timedelta dt = datetime.strptime(date, "%Y%m%d") new_dt = dt + timedelta(days=offset) return new_dt.strftime("%Y%m%d") def _get_date_range( self, start_date: str, end_date: str, data: pl.DataFrame ) -> List[str]: """获取日期范围内的所有交易日 Args: start_date: 开始日期 YYYYMMDD end_date: 结束日期 YYYYMMDD data: 包含 trade_date 列的 DataFrame Returns: 日期列表 """ if len(data) == 0: return [] # 从数据中获取实际存在的日期 dates = ( data.filter( (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) )["trade_date"] .unique() .sort() .to_list() ) return dates