251 lines
8.4 KiB
Python
251 lines
8.4 KiB
Python
"""
|
||
因子算子框架 - 使用Polars实现统一的因子计算
|
||
避免数据泄露,支持切面计算
|
||
"""
|
||
|
||
import polars as pl
|
||
import numpy as np
|
||
from typing import Dict, List, Callable, Optional, Union, Any
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass
|
||
import logging
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class OperatorConfig:
|
||
"""算子配置"""
|
||
name: str
|
||
description: str
|
||
required_columns: List[str]
|
||
output_columns: List[str]
|
||
parameters: Dict[str, Any]
|
||
|
||
|
||
class DataSlice:
|
||
"""数据切面基类"""
|
||
|
||
def __init__(self, df: pl.DataFrame):
|
||
self.df = df
|
||
self.validate_data()
|
||
|
||
def validate_data(self):
|
||
"""验证数据格式"""
|
||
required_cols = ['ts_code', 'trade_date']
|
||
missing_cols = [col for col in required_cols if col not in self.df.columns]
|
||
if missing_cols:
|
||
raise ValueError(f"缺少必需列: {missing_cols}")
|
||
|
||
def get_stock_slice(self, ts_code: str) -> pl.DataFrame:
|
||
"""获取单个股票的数据切面"""
|
||
return self.df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||
|
||
def get_date_slice(self, trade_date: str) -> pl.DataFrame:
|
||
"""获取单个日期的数据切面"""
|
||
return self.df.filter(pl.col('trade_date') == trade_date)
|
||
|
||
def get_stock_list(self) -> List[str]:
|
||
"""获取股票列表"""
|
||
return self.df['ts_code'].unique().to_list()
|
||
|
||
def get_date_list(self) -> List[str]:
|
||
"""获取日期列表"""
|
||
return self.df['trade_date'].unique().to_list()
|
||
|
||
|
||
class BaseOperator(ABC):
|
||
"""算子基类"""
|
||
|
||
def __init__(self, config: OperatorConfig):
|
||
self.config = config
|
||
self.name = config.name
|
||
self.required_columns = config.required_columns
|
||
self.output_columns = config.output_columns
|
||
|
||
def validate_input(self, df: pl.DataFrame) -> bool:
|
||
"""验证输入数据"""
|
||
missing_cols = [col for col in self.required_columns if col not in df.columns]
|
||
if missing_cols:
|
||
logger.warning(f"算子 {self.name} 缺少必需列: {missing_cols}")
|
||
return False
|
||
return True
|
||
|
||
@abstractmethod
|
||
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""应用算子"""
|
||
pass
|
||
|
||
def __call__(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""调用算子"""
|
||
if not self.validate_input(df):
|
||
# 返回原始数据,添加NaN列
|
||
for col in self.output_columns:
|
||
df = df.with_columns(pl.lit(None).alias(col))
|
||
return df
|
||
|
||
try:
|
||
return self.apply(df, **kwargs)
|
||
except Exception as e:
|
||
logger.error(f"算子 {self.name} 应用失败: {e}")
|
||
# 返回原始数据,添加NaN列
|
||
for col in self.output_columns:
|
||
df = df.with_columns(pl.lit(None).alias(col))
|
||
return df
|
||
|
||
|
||
class StockWiseOperator(BaseOperator):
|
||
"""股票切面算子 - 按股票分组计算"""
|
||
|
||
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""按股票分组应用算子"""
|
||
stock_list = df['ts_code'].unique().to_list()
|
||
results = []
|
||
|
||
for ts_code in stock_list:
|
||
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||
try:
|
||
result_df = self.apply_stock(stock_df, **kwargs)
|
||
results.append(result_df)
|
||
except Exception as e:
|
||
logger.error(f"股票 {ts_code} 算子应用失败: {e}")
|
||
# 为失败的股票添加NaN列
|
||
for col in self.output_columns:
|
||
stock_df = stock_df.with_columns(pl.lit(None).alias(col))
|
||
results.append(stock_df)
|
||
|
||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||
|
||
@abstractmethod
|
||
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""应用到单个股票数据"""
|
||
pass
|
||
|
||
|
||
class DateWiseOperator(BaseOperator):
|
||
"""日期切面算子 - 按日期分组计算"""
|
||
|
||
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""按日期分组应用算子"""
|
||
date_list = df['trade_date'].unique().to_list()
|
||
results = []
|
||
|
||
for trade_date in date_list:
|
||
date_df = df.filter(pl.col('trade_date') == trade_date)
|
||
try:
|
||
result_df = self.apply_date(date_df, **kwargs)
|
||
results.append(result_df)
|
||
except Exception as e:
|
||
logger.error(f"日期 {trade_date} 算子应用失败: {e}")
|
||
# 为失败的日期添加NaN列
|
||
for col in self.output_columns:
|
||
date_df = date_df.with_columns(pl.lit(None).alias(col))
|
||
results.append(date_df)
|
||
|
||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||
|
||
@abstractmethod
|
||
def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""应用到单个日期数据"""
|
||
pass
|
||
|
||
|
||
class RollingOperator(StockWiseOperator):
|
||
"""滚动窗口算子基类"""
|
||
|
||
def __init__(self, config: OperatorConfig, window: int, min_periods: Optional[int] = None):
|
||
super().__init__(config)
|
||
self.window = window
|
||
self.min_periods = min_periods or max(1, window // 2)
|
||
|
||
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""应用滚动窗口计算"""
|
||
return self.apply_rolling(stock_df, **kwargs)
|
||
|
||
@abstractmethod
|
||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""滚动窗口计算逻辑"""
|
||
pass
|
||
|
||
|
||
# 基础算子实现
|
||
class ReturnOperator(RollingOperator):
|
||
"""收益率算子"""
|
||
|
||
def __init__(self, periods: int = 1):
|
||
config = OperatorConfig(
|
||
name=f"return_{periods}",
|
||
description=f"{periods}期收益率",
|
||
required_columns=['close'],
|
||
output_columns=[f'return_{periods}'],
|
||
parameters={'periods': periods}
|
||
)
|
||
super().__init__(config, window=periods + 1)
|
||
self.periods = periods
|
||
|
||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""计算收益率"""
|
||
return stock_df.with_columns(
|
||
(pl.col('close') / pl.col('close').shift(self.periods) - 1).alias(f'return_{self.periods}')
|
||
)
|
||
|
||
|
||
class VolatilityOperator(RollingOperator):
|
||
"""波动率算子"""
|
||
|
||
def __init__(self, window: int = 20):
|
||
config = OperatorConfig(
|
||
name=f"volatility_{window}",
|
||
description=f"{window}日波动率",
|
||
required_columns=['pct_chg'],
|
||
output_columns=[f'volatility_{window}'],
|
||
parameters={'window': window}
|
||
)
|
||
super().__init__(config, window=window)
|
||
|
||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""计算波动率"""
|
||
return stock_df.with_columns(
|
||
pl.col('pct_chg').rolling_std(window=self.window).alias(f'volatility_{self.window}')
|
||
)
|
||
|
||
|
||
class MeanOperator(RollingOperator):
|
||
"""均值算子"""
|
||
|
||
def __init__(self, column: str, window: int):
|
||
config = OperatorConfig(
|
||
name=f"mean_{column}_{window}",
|
||
description=f"{column}的{window}日均值",
|
||
required_columns=[column],
|
||
output_columns=[f'mean_{column}_{window}'],
|
||
parameters={'column': column, 'window': window}
|
||
)
|
||
super().__init__(config, window=window)
|
||
self.column = column
|
||
|
||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""计算均值"""
|
||
return stock_df.with_columns(
|
||
pl.col(self.column).rolling_mean(window=self.window).alias(f'mean_{self.column}_{self.window}')
|
||
)
|
||
|
||
|
||
class RankOperator(DateWiseOperator):
|
||
"""排名算子"""
|
||
|
||
def __init__(self, column: str, ascending: bool = True):
|
||
config = OperatorConfig(
|
||
name=f"rank_{column}",
|
||
description=f"{column}的排名",
|
||
required_columns=[column],
|
||
output_columns=[f'rank_{column}'],
|
||
parameters={'column': column, 'ascending': ascending}
|
||
)
|
||
super().__init__(config)
|
||
self.column = column
|
||
self.ascending = ascending
|
||
|