""" 因子算子基础框架 - 简化版本 提供股票截面和日期截面两个基础函数 """ import polars as pl from typing import Callable, Any, Optional, Union import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def apply_stockwise( df: pl.DataFrame, operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame], *args, **kwargs ) -> pl.DataFrame: """ 在股票截面上应用算子函数 Args: df: 输入的polars DataFrame,必须包含ts_code和trade_date列 operator_func: 算子函数,接收单个股票的数据和参数,返回处理后的DataFrame *args, **kwargs: 传递给算子函数的额外参数 Returns: 处理后的完整DataFrame """ # 验证必需列 required_cols = ['ts_code', 'trade_date'] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: raise ValueError(f"缺少必需列: {missing_cols}") # 获取股票列表 stock_list = df['ts_code'].unique().to_list() results = [] # 按股票分组处理 for ts_code in stock_list: try: # 获取单个股票的数据并按日期排序 stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date') # 应用算子函数 result_df = operator_func(stock_df, *args, **kwargs) results.append(result_df) except Exception as e: logger.error(f"股票 {ts_code} 处理失败: {e}") # 失败时返回原始数据 stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date') results.append(stock_df) # 合并结果并排序 if results: return pl.concat(results).sort(['ts_code', 'trade_date']) else: return df def apply_datewise( df: pl.DataFrame, operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame], *args, **kwargs ) -> pl.DataFrame: """ 在日期截面上应用算子函数 Args: df: 输入的polars DataFrame,必须包含ts_code和trade_date列 operator_func: 算子函数,接收单个日期的数据和参数,返回处理后的DataFrame *args, **kwargs: 传递给算子函数的额外参数 Returns: 处理后的完整DataFrame """ # 验证必需列 required_cols = ['ts_code', 'trade_date'] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: raise ValueError(f"缺少必需列: {missing_cols}") # 获取日期列表 date_list = df['trade_date'].unique().to_list() results = [] # 按日期分组处理 for trade_date in date_list: try: # 获取单个日期的数据 date_df = df.filter(pl.col('trade_date') == trade_date) # 应用算子函数 result_df = operator_func(date_df, *args, **kwargs) results.append(result_df) except Exception as e: logger.error(f"日期 {trade_date} 处理失败: {e}") # 失败时返回原始数据 date_df = df.filter(pl.col('trade_date') == trade_date) results.append(date_df) # 合并结果并排序 if results: return pl.concat(results).sort(['ts_code', 'trade_date']) else: return df # 常用算子函数示例 def rolling_mean_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame: """ 滚动均值算子 - 股票截面 Args: df: 单个股票的数据 column: 要计算均值的列 window: 窗口大小 output_col: 输出列名,默认为f'{column}_mean_{window}' Returns: 添加均值列的DataFrame """ if output_col is None: output_col = f'{column}_mean_{window}' return df.with_columns( pl.col(column).rolling_mean(window_size=window).alias(output_col) ) def rolling_std_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame: """ 滚动标准差算子 - 股票截面 Args: df: 单个股票的数据 column: 要计算标准差的列 window: 窗口大小 output_col: 输出列名,默认为f'{column}_std_{window}' Returns: 添加标准差列的DataFrame """ if output_col is None: output_col = f'{column}_std_{window}' return df.with_columns( pl.col(column).rolling_std(window_size=window).alias(output_col) ) def rank_operator(df: pl.DataFrame, column: str, ascending: bool = True, output_col: str = None) -> pl.DataFrame: """ 排名算子 - 日期截面 Args: df: 单个日期的数据 column: 要排名的列 ascending: 是否升序 output_col: 输出列名,默认为f'{column}_rank' Returns: 添加排名列的DataFrame """ if output_col is None: output_col = f'{column}_rank' return df.with_columns( pl.col(column).rank(method='dense', descending=not ascending).alias(output_col) ) def pct_change_operator(df: pl.DataFrame, column: str, periods: int = 1, output_col: str = None) -> pl.DataFrame: """ 百分比变化算子 - 股票截面 Args: df: 单个股票的数据 column: 要计算变化的列 periods: 期数 output_col: 输出列名,默认为f'{column}_pct_change_{periods}' Returns: 添加变化率列的DataFrame """ if output_col is None: output_col = f'{column}_pct_change_{periods}' return df.with_columns( ((pl.col(column) / pl.col(column).shift(periods)) - 1).alias(output_col) )