""" 因子算子框架 - 使用Polars实现统一的因子计算 避免数据泄露,支持切面计算 """ import polars as pl import numpy as np from typing import Dict, List, Callable, Optional, Union, Any from abc import ABC, abstractmethod from dataclasses import dataclass import logging # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class OperatorConfig: """算子配置""" name: str description: str required_columns: List[str] output_columns: List[str] parameters: Dict[str, Any] class DataSlice: """数据切面基类""" def __init__(self, df: pl.DataFrame): self.df = df self.validate_data() def validate_data(self): """验证数据格式""" required_cols = ['ts_code', 'trade_date'] missing_cols = [col for col in required_cols if col not in self.df.columns] if missing_cols: raise ValueError(f"缺少必需列: {missing_cols}") def get_stock_slice(self, ts_code: str) -> pl.DataFrame: """获取单个股票的数据切面""" return self.df.filter(pl.col('ts_code') == ts_code).sort('trade_date') def get_date_slice(self, trade_date: str) -> pl.DataFrame: """获取单个日期的数据切面""" return self.df.filter(pl.col('trade_date') == trade_date) def get_stock_list(self) -> List[str]: """获取股票列表""" return self.df['ts_code'].unique().to_list() def get_date_list(self) -> List[str]: """获取日期列表""" return self.df['trade_date'].unique().to_list() class BaseOperator(ABC): """算子基类""" def __init__(self, config: OperatorConfig): self.config = config self.name = config.name self.required_columns = config.required_columns self.output_columns = config.output_columns def validate_input(self, df: pl.DataFrame) -> bool: """验证输入数据""" missing_cols = [col for col in self.required_columns if col not in df.columns] if missing_cols: logger.warning(f"算子 {self.name} 缺少必需列: {missing_cols}") return False return True @abstractmethod def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame: """应用算子""" pass def __call__(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame: """调用算子""" if not self.validate_input(df): # 返回原始数据,添加NaN列 for col in self.output_columns: df = df.with_columns(pl.lit(None).alias(col)) return df try: return self.apply(df, **kwargs) except Exception as e: logger.error(f"算子 {self.name} 应用失败: {e}") # 返回原始数据,添加NaN列 for col in self.output_columns: df = df.with_columns(pl.lit(None).alias(col)) return df class StockWiseOperator(BaseOperator): """股票切面算子 - 按股票分组计算""" def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame: """按股票分组应用算子""" stock_list = df['ts_code'].unique().to_list() results = [] for ts_code in stock_list: stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date') try: result_df = self.apply_stock(stock_df, **kwargs) results.append(result_df) except Exception as e: logger.error(f"股票 {ts_code} 算子应用失败: {e}") # 为失败的股票添加NaN列 for col in self.output_columns: stock_df = stock_df.with_columns(pl.lit(None).alias(col)) results.append(stock_df) return pl.concat(results).sort(['ts_code', 'trade_date']) @abstractmethod def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """应用到单个股票数据""" pass class DateWiseOperator(BaseOperator): """日期切面算子 - 按日期分组计算""" def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame: """按日期分组应用算子""" date_list = df['trade_date'].unique().to_list() results = [] for trade_date in date_list: date_df = df.filter(pl.col('trade_date') == trade_date) try: result_df = self.apply_date(date_df, **kwargs) results.append(result_df) except Exception as e: logger.error(f"日期 {trade_date} 算子应用失败: {e}") # 为失败的日期添加NaN列 for col in self.output_columns: date_df = date_df.with_columns(pl.lit(None).alias(col)) results.append(date_df) return pl.concat(results).sort(['ts_code', 'trade_date']) @abstractmethod def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """应用到单个日期数据""" pass class RollingOperator(StockWiseOperator): """滚动窗口算子基类""" def __init__(self, config: OperatorConfig, window: int, min_periods: Optional[int] = None): super().__init__(config) self.window = window self.min_periods = min_periods or max(1, window // 2) def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """应用滚动窗口计算""" return self.apply_rolling(stock_df, **kwargs) @abstractmethod def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """滚动窗口计算逻辑""" pass # 基础算子实现 class ReturnOperator(RollingOperator): """收益率算子""" def __init__(self, periods: int = 1): config = OperatorConfig( name=f"return_{periods}", description=f"{periods}期收益率", required_columns=['close'], output_columns=[f'return_{periods}'], parameters={'periods': periods} ) super().__init__(config, window=periods + 1) self.periods = periods def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """计算收益率""" return stock_df.with_columns( (pl.col('close') / pl.col('close').shift(self.periods) - 1).alias(f'return_{self.periods}') ) class VolatilityOperator(RollingOperator): """波动率算子""" def __init__(self, window: int = 20): config = OperatorConfig( name=f"volatility_{window}", description=f"{window}日波动率", required_columns=['pct_chg'], output_columns=[f'volatility_{window}'], parameters={'window': window} ) super().__init__(config, window=window) def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """计算波动率""" return stock_df.with_columns( pl.col('pct_chg').rolling_std(window=self.window).alias(f'volatility_{self.window}') ) class MeanOperator(RollingOperator): """均值算子""" def __init__(self, column: str, window: int): config = OperatorConfig( name=f"mean_{column}_{window}", description=f"{column}的{window}日均值", required_columns=[column], output_columns=[f'mean_{column}_{window}'], parameters={'column': column, 'window': window} ) super().__init__(config, window=window) self.column = column def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame: """计算均值""" return stock_df.with_columns( pl.col(self.column).rolling_mean(window=self.window).alias(f'mean_{self.column}_{self.window}') ) class RankOperator(DateWiseOperator): """排名算子""" def __init__(self, column: str, ascending: bool = True): config = OperatorConfig( name=f"rank_{column}", description=f"{column}的排名", required_columns=[column], output_columns=[f'rank_{column}'], parameters={'column': column, 'ascending': ascending} ) super().__init__(config) self.column = column self.ascending = ascending