""" 因子算子框架 - Polars 实现 支持:截面滚动 → 拼回长表 → 按列名合并 返回形式可选:完整 DataFrame(默认)或单列 Series """ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Literal import polars as pl @dataclass class OperatorConfig: """算子配置""" name: str description: str required_columns: List[str] output_columns: List[str] parameters: dict class BaseOperator(ABC): """算子基类""" def __init__(self, config: OperatorConfig): self.config = config self.name = config.name self.required_columns = config.required_columns self.output_columns = config.output_columns # ---------- 子类必须实现 ---------- @abstractmethod def get_factor_name(self) -> str: """返回因子列名(用于合并)""" pass @abstractmethod def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series: """ 真正的截面计算逻辑。 参数:按 ts_code 或 trade_date 分组后的子表 返回:与 group_df 行数一一对应的因子 Series(含正确索引) """ pass # ---------- 公共接口 ---------- def apply(self, df: pl.DataFrame, return_type: Literal['df', 'series'] = 'df', **kwargs) -> pl.DataFrame | pl.Series: """入口:截面滚动 → 拼回长表 → 合并/返回""" if not self.validate_input(df): raise ValueError(f"缺少必需列:{self.required_columns}") long_table = self._sectional_roll(df, **kwargs) # ① 滚动 merged = self._merge_factor(df, long_table) # ② 合并 return merged if return_type == 'df' else merged[self.get_factor_name()] # ---------- 内部流程 ---------- def validate_input(self, df: pl.DataFrame) -> bool: return all(col in df.columns for col in self.required_columns) @abstractmethod def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame: """ 截面滚动模板:group → calc_factor → 拼回长表 返回:含【trade_date, ts_code, factor】的长表 """ pass def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame: """按 [ts_code, trade_date] 左联,原地追加因子列""" factor_name = self.get_factor_name() return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]), on=['ts_code', 'trade_date'], how='left') # -------------------- 股票截面:按 ts_code 分组 -------------------- class StockWiseOperator(BaseOperator): """股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子""" def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame: factor_name = self.get_factor_name() # 确保排序(时间顺序对 shift 等操作至关重要) df_sorted = df.sort(['ts_code', 'trade_date']) # 使用 map_groups:对每个 ts_code 分组,传入完整子 DataFrame result = ( df_sorted .group_by('ts_code', maintain_order=True) .map_groups( lambda group_df: group_df.with_columns( self.calc_factor(group_df, **kwargs) ) ) .select(['ts_code', 'trade_date', factor_name]) ) return result # -------------------- 日期截面:按 trade_date 分组 -------------------- class DateWiseOperator(BaseOperator): """日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子""" def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame: factor_name = self.get_factor_name() df_sorted = df.sort(['trade_date', 'ts_code']) result = ( df_sorted .group_by('trade_date', maintain_order=True) .map_groups( lambda group_df: group_df.with_columns( self.calc_factor(group_df, **kwargs) ) ) .select(['ts_code', 'trade_date', factor_name]) ) return result