Files
NewStock/main/factor/operator_base.py

197 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
因子算子基础框架 - 简化版本
提供股票截面和日期截面两个基础函数
"""
import polars as pl
from typing import Callable, Any, Optional, Union
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def apply_stockwise(
df: pl.DataFrame,
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
*args,
**kwargs
) -> pl.DataFrame:
"""
在股票截面上应用算子函数
Args:
df: 输入的polars DataFrame必须包含ts_code和trade_date列
operator_func: 算子函数接收单个股票的数据和参数返回处理后的DataFrame
*args, **kwargs: 传递给算子函数的额外参数
Returns:
处理后的完整DataFrame
"""
# 验证必需列
required_cols = ['ts_code', 'trade_date']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"缺少必需列: {missing_cols}")
# 获取股票列表
stock_list = df['ts_code'].unique().to_list()
results = []
# 按股票分组处理
for ts_code in stock_list:
try:
# 获取单个股票的数据并按日期排序
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
# 应用算子函数
result_df = operator_func(stock_df, *args, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"股票 {ts_code} 处理失败: {e}")
# 失败时返回原始数据
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
results.append(stock_df)
# 合并结果并排序
if results:
return pl.concat(results).sort(['ts_code', 'trade_date'])
else:
return df
def apply_datewise(
df: pl.DataFrame,
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
*args,
**kwargs
) -> pl.DataFrame:
"""
在日期截面上应用算子函数
Args:
df: 输入的polars DataFrame必须包含ts_code和trade_date列
operator_func: 算子函数接收单个日期的数据和参数返回处理后的DataFrame
*args, **kwargs: 传递给算子函数的额外参数
Returns:
处理后的完整DataFrame
"""
# 验证必需列
required_cols = ['ts_code', 'trade_date']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"缺少必需列: {missing_cols}")
# 获取日期列表
date_list = df['trade_date'].unique().to_list()
results = []
# 按日期分组处理
for trade_date in date_list:
try:
# 获取单个日期的数据
date_df = df.filter(pl.col('trade_date') == trade_date)
# 应用算子函数
result_df = operator_func(date_df, *args, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"日期 {trade_date} 处理失败: {e}")
# 失败时返回原始数据
date_df = df.filter(pl.col('trade_date') == trade_date)
results.append(date_df)
# 合并结果并排序
if results:
return pl.concat(results).sort(['ts_code', 'trade_date'])
else:
return df
# 常用算子函数示例
def rolling_mean_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
"""
滚动均值算子 - 股票截面
Args:
df: 单个股票的数据
column: 要计算均值的列
window: 窗口大小
output_col: 输出列名默认为f'{column}_mean_{window}'
Returns:
添加均值列的DataFrame
"""
if output_col is None:
output_col = f'{column}_mean_{window}'
return df.with_columns(
pl.col(column).rolling_mean(window_size=window).alias(output_col)
)
def rolling_std_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
"""
滚动标准差算子 - 股票截面
Args:
df: 单个股票的数据
column: 要计算标准差的列
window: 窗口大小
output_col: 输出列名默认为f'{column}_std_{window}'
Returns:
添加标准差列的DataFrame
"""
if output_col is None:
output_col = f'{column}_std_{window}'
return df.with_columns(
pl.col(column).rolling_std(window_size=window).alias(output_col)
)
def rank_operator(df: pl.DataFrame, column: str, ascending: bool = True, output_col: str = None) -> pl.DataFrame:
"""
排名算子 - 日期截面
Args:
df: 单个日期的数据
column: 要排名的列
ascending: 是否升序
output_col: 输出列名默认为f'{column}_rank'
Returns:
添加排名列的DataFrame
"""
if output_col is None:
output_col = f'{column}_rank'
return df.with_columns(
pl.col(column).rank(method='dense', descending=not ascending).alias(output_col)
)
def pct_change_operator(df: pl.DataFrame, column: str, periods: int = 1, output_col: str = None) -> pl.DataFrame:
"""
百分比变化算子 - 股票截面
Args:
df: 单个股票的数据
column: 要计算变化的列
periods: 期数
output_col: 输出列名默认为f'{column}_pct_change_{periods}'
Returns:
添加变化率列的DataFrame
"""
if output_col is None:
output_col = f'{column}_pct_change_{periods}'
return df.with_columns(
((pl.col(column) / pl.col(column).shift(periods)) - 1).alias(output_col)
)