Files
NewStock/main/factor/operator_base.py

197 lines
5.8 KiB
Python
Raw Normal View History

2025-10-13 21:42:35 +08:00
"""
因子算子基础框架 - 简化版本
提供股票截面和日期截面两个基础函数
"""
import polars as pl
from typing import Callable, Any, Optional, Union
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def apply_stockwise(
df: pl.DataFrame,
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
*args,
**kwargs
) -> pl.DataFrame:
"""
在股票截面上应用算子函数
Args:
df: 输入的polars DataFrame必须包含ts_code和trade_date列
operator_func: 算子函数接收单个股票的数据和参数返回处理后的DataFrame
*args, **kwargs: 传递给算子函数的额外参数
Returns:
处理后的完整DataFrame
"""
# 验证必需列
required_cols = ['ts_code', 'trade_date']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"缺少必需列: {missing_cols}")
# 获取股票列表
stock_list = df['ts_code'].unique().to_list()
results = []
# 按股票分组处理
for ts_code in stock_list:
try:
# 获取单个股票的数据并按日期排序
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
# 应用算子函数
result_df = operator_func(stock_df, *args, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"股票 {ts_code} 处理失败: {e}")
# 失败时返回原始数据
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
results.append(stock_df)
# 合并结果并排序
if results:
return pl.concat(results).sort(['ts_code', 'trade_date'])
else:
return df
def apply_datewise(
df: pl.DataFrame,
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
*args,
**kwargs
) -> pl.DataFrame:
"""
在日期截面上应用算子函数
Args:
df: 输入的polars DataFrame必须包含ts_code和trade_date列
operator_func: 算子函数接收单个日期的数据和参数返回处理后的DataFrame
*args, **kwargs: 传递给算子函数的额外参数
Returns:
处理后的完整DataFrame
"""
# 验证必需列
required_cols = ['ts_code', 'trade_date']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"缺少必需列: {missing_cols}")
# 获取日期列表
date_list = df['trade_date'].unique().to_list()
results = []
# 按日期分组处理
for trade_date in date_list:
try:
# 获取单个日期的数据
date_df = df.filter(pl.col('trade_date') == trade_date)
# 应用算子函数
result_df = operator_func(date_df, *args, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"日期 {trade_date} 处理失败: {e}")
# 失败时返回原始数据
date_df = df.filter(pl.col('trade_date') == trade_date)
results.append(date_df)
# 合并结果并排序
if results:
return pl.concat(results).sort(['ts_code', 'trade_date'])
else:
return df
# 常用算子函数示例
def rolling_mean_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
"""
滚动均值算子 - 股票截面
Args:
df: 单个股票的数据
column: 要计算均值的列
window: 窗口大小
output_col: 输出列名默认为f'{column}_mean_{window}'
Returns:
添加均值列的DataFrame
"""
if output_col is None:
output_col = f'{column}_mean_{window}'
return df.with_columns(
pl.col(column).rolling_mean(window_size=window).alias(output_col)
)
def rolling_std_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
"""
滚动标准差算子 - 股票截面
Args:
df: 单个股票的数据
column: 要计算标准差的列
window: 窗口大小
output_col: 输出列名默认为f'{column}_std_{window}'
Returns:
添加标准差列的DataFrame
"""
if output_col is None:
output_col = f'{column}_std_{window}'
return df.with_columns(
pl.col(column).rolling_std(window_size=window).alias(output_col)
)
def rank_operator(df: pl.DataFrame, column: str, ascending: bool = True, output_col: str = None) -> pl.DataFrame:
"""
排名算子 - 日期截面
Args:
df: 单个日期的数据
column: 要排名的列
ascending: 是否升序
output_col: 输出列名默认为f'{column}_rank'
Returns:
添加排名列的DataFrame
"""
if output_col is None:
output_col = f'{column}_rank'
return df.with_columns(
pl.col(column).rank(method='dense', descending=not ascending).alias(output_col)
)
def pct_change_operator(df: pl.DataFrame, column: str, periods: int = 1, output_col: str = None) -> pl.DataFrame:
"""
百分比变化算子 - 股票截面
Args:
df: 单个股票的数据
column: 要计算变化的列
periods: 期数
output_col: 输出列名默认为f'{column}_pct_change_{periods}'
Returns:
添加变化率列的DataFrame
"""
if output_col is None:
output_col = f'{column}_pct_change_{periods}'
return df.with_columns(
((pl.col(column) / pl.col(column).shift(periods)) - 1).alias(output_col)
)