197 lines
5.8 KiB
Python
197 lines
5.8 KiB
Python
"""
|
||
因子算子基础框架 - 简化版本
|
||
提供股票截面和日期截面两个基础函数
|
||
"""
|
||
|
||
import polars as pl
|
||
from typing import Callable, Any, Optional, Union
|
||
import logging
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def apply_stockwise(
|
||
df: pl.DataFrame,
|
||
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
|
||
*args,
|
||
**kwargs
|
||
) -> pl.DataFrame:
|
||
"""
|
||
在股票截面上应用算子函数
|
||
|
||
Args:
|
||
df: 输入的polars DataFrame,必须包含ts_code和trade_date列
|
||
operator_func: 算子函数,接收单个股票的数据和参数,返回处理后的DataFrame
|
||
*args, **kwargs: 传递给算子函数的额外参数
|
||
|
||
Returns:
|
||
处理后的完整DataFrame
|
||
"""
|
||
# 验证必需列
|
||
required_cols = ['ts_code', 'trade_date']
|
||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||
if missing_cols:
|
||
raise ValueError(f"缺少必需列: {missing_cols}")
|
||
|
||
# 获取股票列表
|
||
stock_list = df['ts_code'].unique().to_list()
|
||
results = []
|
||
|
||
# 按股票分组处理
|
||
for ts_code in stock_list:
|
||
try:
|
||
# 获取单个股票的数据并按日期排序
|
||
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||
|
||
# 应用算子函数
|
||
result_df = operator_func(stock_df, *args, **kwargs)
|
||
results.append(result_df)
|
||
|
||
except Exception as e:
|
||
logger.error(f"股票 {ts_code} 处理失败: {e}")
|
||
# 失败时返回原始数据
|
||
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||
results.append(stock_df)
|
||
|
||
# 合并结果并排序
|
||
if results:
|
||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||
else:
|
||
return df
|
||
|
||
|
||
def apply_datewise(
|
||
df: pl.DataFrame,
|
||
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
|
||
*args,
|
||
**kwargs
|
||
) -> pl.DataFrame:
|
||
"""
|
||
在日期截面上应用算子函数
|
||
|
||
Args:
|
||
df: 输入的polars DataFrame,必须包含ts_code和trade_date列
|
||
operator_func: 算子函数,接收单个日期的数据和参数,返回处理后的DataFrame
|
||
*args, **kwargs: 传递给算子函数的额外参数
|
||
|
||
Returns:
|
||
处理后的完整DataFrame
|
||
"""
|
||
# 验证必需列
|
||
required_cols = ['ts_code', 'trade_date']
|
||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||
if missing_cols:
|
||
raise ValueError(f"缺少必需列: {missing_cols}")
|
||
|
||
# 获取日期列表
|
||
date_list = df['trade_date'].unique().to_list()
|
||
results = []
|
||
|
||
# 按日期分组处理
|
||
for trade_date in date_list:
|
||
try:
|
||
# 获取单个日期的数据
|
||
date_df = df.filter(pl.col('trade_date') == trade_date)
|
||
|
||
# 应用算子函数
|
||
result_df = operator_func(date_df, *args, **kwargs)
|
||
results.append(result_df)
|
||
|
||
except Exception as e:
|
||
logger.error(f"日期 {trade_date} 处理失败: {e}")
|
||
# 失败时返回原始数据
|
||
date_df = df.filter(pl.col('trade_date') == trade_date)
|
||
results.append(date_df)
|
||
|
||
# 合并结果并排序
|
||
if results:
|
||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||
else:
|
||
return df
|
||
|
||
|
||
# 常用算子函数示例
|
||
def rolling_mean_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
|
||
"""
|
||
滚动均值算子 - 股票截面
|
||
|
||
Args:
|
||
df: 单个股票的数据
|
||
column: 要计算均值的列
|
||
window: 窗口大小
|
||
output_col: 输出列名,默认为f'{column}_mean_{window}'
|
||
|
||
Returns:
|
||
添加均值列的DataFrame
|
||
"""
|
||
if output_col is None:
|
||
output_col = f'{column}_mean_{window}'
|
||
|
||
return df.with_columns(
|
||
pl.col(column).rolling_mean(window_size=window).alias(output_col)
|
||
)
|
||
|
||
|
||
def rolling_std_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
|
||
"""
|
||
滚动标准差算子 - 股票截面
|
||
|
||
Args:
|
||
df: 单个股票的数据
|
||
column: 要计算标准差的列
|
||
window: 窗口大小
|
||
output_col: 输出列名,默认为f'{column}_std_{window}'
|
||
|
||
Returns:
|
||
添加标准差列的DataFrame
|
||
"""
|
||
if output_col is None:
|
||
output_col = f'{column}_std_{window}'
|
||
|
||
return df.with_columns(
|
||
pl.col(column).rolling_std(window_size=window).alias(output_col)
|
||
)
|
||
|
||
|
||
def rank_operator(df: pl.DataFrame, column: str, ascending: bool = True, output_col: str = None) -> pl.DataFrame:
|
||
"""
|
||
排名算子 - 日期截面
|
||
|
||
Args:
|
||
df: 单个日期的数据
|
||
column: 要排名的列
|
||
ascending: 是否升序
|
||
output_col: 输出列名,默认为f'{column}_rank'
|
||
|
||
Returns:
|
||
添加排名列的DataFrame
|
||
"""
|
||
if output_col is None:
|
||
output_col = f'{column}_rank'
|
||
|
||
return df.with_columns(
|
||
pl.col(column).rank(method='dense', descending=not ascending).alias(output_col)
|
||
)
|
||
|
||
|
||
def pct_change_operator(df: pl.DataFrame, column: str, periods: int = 1, output_col: str = None) -> pl.DataFrame:
|
||
"""
|
||
百分比变化算子 - 股票截面
|
||
|
||
Args:
|
||
df: 单个股票的数据
|
||
column: 要计算变化的列
|
||
periods: 期数
|
||
output_col: 输出列名,默认为f'{column}_pct_change_{periods}'
|
||
|
||
Returns:
|
||
添加变化率列的DataFrame
|
||
"""
|
||
if output_col is None:
|
||
output_col = f'{column}_pct_change_{periods}'
|
||
|
||
return df.with_columns(
|
||
((pl.col(column) / pl.col(column).shift(periods)) - 1).alias(output_col)
|
||
)
|