factor优化,改为polars
This commit is contained in:
196
main/factor/operator_base.py
Normal file
196
main/factor/operator_base.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
因子算子基础框架 - 简化版本
|
||||
提供股票截面和日期截面两个基础函数
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
from typing import Callable, Any, Optional, Union
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_stockwise(
|
||||
df: pl.DataFrame,
|
||||
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
|
||||
*args,
|
||||
**kwargs
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
在股票截面上应用算子函数
|
||||
|
||||
Args:
|
||||
df: 输入的polars DataFrame,必须包含ts_code和trade_date列
|
||||
operator_func: 算子函数,接收单个股票的数据和参数,返回处理后的DataFrame
|
||||
*args, **kwargs: 传递给算子函数的额外参数
|
||||
|
||||
Returns:
|
||||
处理后的完整DataFrame
|
||||
"""
|
||||
# 验证必需列
|
||||
required_cols = ['ts_code', 'trade_date']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"缺少必需列: {missing_cols}")
|
||||
|
||||
# 获取股票列表
|
||||
stock_list = df['ts_code'].unique().to_list()
|
||||
results = []
|
||||
|
||||
# 按股票分组处理
|
||||
for ts_code in stock_list:
|
||||
try:
|
||||
# 获取单个股票的数据并按日期排序
|
||||
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||||
|
||||
# 应用算子函数
|
||||
result_df = operator_func(stock_df, *args, **kwargs)
|
||||
results.append(result_df)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"股票 {ts_code} 处理失败: {e}")
|
||||
# 失败时返回原始数据
|
||||
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||||
results.append(stock_df)
|
||||
|
||||
# 合并结果并排序
|
||||
if results:
|
||||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||||
else:
|
||||
return df
|
||||
|
||||
|
||||
def apply_datewise(
|
||||
df: pl.DataFrame,
|
||||
operator_func: Callable[[pl.DataFrame, Any], pl.DataFrame],
|
||||
*args,
|
||||
**kwargs
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
在日期截面上应用算子函数
|
||||
|
||||
Args:
|
||||
df: 输入的polars DataFrame,必须包含ts_code和trade_date列
|
||||
operator_func: 算子函数,接收单个日期的数据和参数,返回处理后的DataFrame
|
||||
*args, **kwargs: 传递给算子函数的额外参数
|
||||
|
||||
Returns:
|
||||
处理后的完整DataFrame
|
||||
"""
|
||||
# 验证必需列
|
||||
required_cols = ['ts_code', 'trade_date']
|
||||
missing_cols = [col for col in required_cols if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"缺少必需列: {missing_cols}")
|
||||
|
||||
# 获取日期列表
|
||||
date_list = df['trade_date'].unique().to_list()
|
||||
results = []
|
||||
|
||||
# 按日期分组处理
|
||||
for trade_date in date_list:
|
||||
try:
|
||||
# 获取单个日期的数据
|
||||
date_df = df.filter(pl.col('trade_date') == trade_date)
|
||||
|
||||
# 应用算子函数
|
||||
result_df = operator_func(date_df, *args, **kwargs)
|
||||
results.append(result_df)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"日期 {trade_date} 处理失败: {e}")
|
||||
# 失败时返回原始数据
|
||||
date_df = df.filter(pl.col('trade_date') == trade_date)
|
||||
results.append(date_df)
|
||||
|
||||
# 合并结果并排序
|
||||
if results:
|
||||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||||
else:
|
||||
return df
|
||||
|
||||
|
||||
# 常用算子函数示例
|
||||
def rolling_mean_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
|
||||
"""
|
||||
滚动均值算子 - 股票截面
|
||||
|
||||
Args:
|
||||
df: 单个股票的数据
|
||||
column: 要计算均值的列
|
||||
window: 窗口大小
|
||||
output_col: 输出列名,默认为f'{column}_mean_{window}'
|
||||
|
||||
Returns:
|
||||
添加均值列的DataFrame
|
||||
"""
|
||||
if output_col is None:
|
||||
output_col = f'{column}_mean_{window}'
|
||||
|
||||
return df.with_columns(
|
||||
pl.col(column).rolling_mean(window_size=window).alias(output_col)
|
||||
)
|
||||
|
||||
|
||||
def rolling_std_operator(df: pl.DataFrame, column: str, window: int, output_col: str = None) -> pl.DataFrame:
|
||||
"""
|
||||
滚动标准差算子 - 股票截面
|
||||
|
||||
Args:
|
||||
df: 单个股票的数据
|
||||
column: 要计算标准差的列
|
||||
window: 窗口大小
|
||||
output_col: 输出列名,默认为f'{column}_std_{window}'
|
||||
|
||||
Returns:
|
||||
添加标准差列的DataFrame
|
||||
"""
|
||||
if output_col is None:
|
||||
output_col = f'{column}_std_{window}'
|
||||
|
||||
return df.with_columns(
|
||||
pl.col(column).rolling_std(window_size=window).alias(output_col)
|
||||
)
|
||||
|
||||
|
||||
def rank_operator(df: pl.DataFrame, column: str, ascending: bool = True, output_col: str = None) -> pl.DataFrame:
|
||||
"""
|
||||
排名算子 - 日期截面
|
||||
|
||||
Args:
|
||||
df: 单个日期的数据
|
||||
column: 要排名的列
|
||||
ascending: 是否升序
|
||||
output_col: 输出列名,默认为f'{column}_rank'
|
||||
|
||||
Returns:
|
||||
添加排名列的DataFrame
|
||||
"""
|
||||
if output_col is None:
|
||||
output_col = f'{column}_rank'
|
||||
|
||||
return df.with_columns(
|
||||
pl.col(column).rank(method='dense', descending=not ascending).alias(output_col)
|
||||
)
|
||||
|
||||
|
||||
def pct_change_operator(df: pl.DataFrame, column: str, periods: int = 1, output_col: str = None) -> pl.DataFrame:
|
||||
"""
|
||||
百分比变化算子 - 股票截面
|
||||
|
||||
Args:
|
||||
df: 单个股票的数据
|
||||
column: 要计算变化的列
|
||||
periods: 期数
|
||||
output_col: 输出列名,默认为f'{column}_pct_change_{periods}'
|
||||
|
||||
Returns:
|
||||
添加变化率列的DataFrame
|
||||
"""
|
||||
if output_col is None:
|
||||
output_col = f'{column}_pct_change_{periods}'
|
||||
|
||||
return df.with_columns(
|
||||
((pl.col(column) / pl.col(column).shift(periods)) - 1).alias(output_col)
|
||||
)
|
||||
Reference in New Issue
Block a user