Files
NewStock/main/factor/operator_framework.py

251 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
因子算子框架 - 使用Polars实现统一的因子计算
避免数据泄露,支持切面计算
"""
import polars as pl
import numpy as np
from typing import Dict, List, Callable, Optional, Union, Any
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class OperatorConfig:
"""算子配置"""
name: str
description: str
required_columns: List[str]
output_columns: List[str]
parameters: Dict[str, Any]
class DataSlice:
"""数据切面基类"""
def __init__(self, df: pl.DataFrame):
self.df = df
self.validate_data()
def validate_data(self):
"""验证数据格式"""
required_cols = ['ts_code', 'trade_date']
missing_cols = [col for col in required_cols if col not in self.df.columns]
if missing_cols:
raise ValueError(f"缺少必需列: {missing_cols}")
def get_stock_slice(self, ts_code: str) -> pl.DataFrame:
"""获取单个股票的数据切面"""
return self.df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
def get_date_slice(self, trade_date: str) -> pl.DataFrame:
"""获取单个日期的数据切面"""
return self.df.filter(pl.col('trade_date') == trade_date)
def get_stock_list(self) -> List[str]:
"""获取股票列表"""
return self.df['ts_code'].unique().to_list()
def get_date_list(self) -> List[str]:
"""获取日期列表"""
return self.df['trade_date'].unique().to_list()
class BaseOperator(ABC):
"""算子基类"""
def __init__(self, config: OperatorConfig):
self.config = config
self.name = config.name
self.required_columns = config.required_columns
self.output_columns = config.output_columns
def validate_input(self, df: pl.DataFrame) -> bool:
"""验证输入数据"""
missing_cols = [col for col in self.required_columns if col not in df.columns]
if missing_cols:
logger.warning(f"算子 {self.name} 缺少必需列: {missing_cols}")
return False
return True
@abstractmethod
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用算子"""
pass
def __call__(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""调用算子"""
if not self.validate_input(df):
# 返回原始数据添加NaN列
for col in self.output_columns:
df = df.with_columns(pl.lit(None).alias(col))
return df
try:
return self.apply(df, **kwargs)
except Exception as e:
logger.error(f"算子 {self.name} 应用失败: {e}")
# 返回原始数据添加NaN列
for col in self.output_columns:
df = df.with_columns(pl.lit(None).alias(col))
return df
class StockWiseOperator(BaseOperator):
"""股票切面算子 - 按股票分组计算"""
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""按股票分组应用算子"""
stock_list = df['ts_code'].unique().to_list()
results = []
for ts_code in stock_list:
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
try:
result_df = self.apply_stock(stock_df, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"股票 {ts_code} 算子应用失败: {e}")
# 为失败的股票添加NaN列
for col in self.output_columns:
stock_df = stock_df.with_columns(pl.lit(None).alias(col))
results.append(stock_df)
return pl.concat(results).sort(['ts_code', 'trade_date'])
@abstractmethod
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用到单个股票数据"""
pass
class DateWiseOperator(BaseOperator):
"""日期切面算子 - 按日期分组计算"""
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""按日期分组应用算子"""
date_list = df['trade_date'].unique().to_list()
results = []
for trade_date in date_list:
date_df = df.filter(pl.col('trade_date') == trade_date)
try:
result_df = self.apply_date(date_df, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"日期 {trade_date} 算子应用失败: {e}")
# 为失败的日期添加NaN列
for col in self.output_columns:
date_df = date_df.with_columns(pl.lit(None).alias(col))
results.append(date_df)
return pl.concat(results).sort(['ts_code', 'trade_date'])
@abstractmethod
def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用到单个日期数据"""
pass
class RollingOperator(StockWiseOperator):
"""滚动窗口算子基类"""
def __init__(self, config: OperatorConfig, window: int, min_periods: Optional[int] = None):
super().__init__(config)
self.window = window
self.min_periods = min_periods or max(1, window // 2)
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用滚动窗口计算"""
return self.apply_rolling(stock_df, **kwargs)
@abstractmethod
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""滚动窗口计算逻辑"""
pass
# 基础算子实现
class ReturnOperator(RollingOperator):
"""收益率算子"""
def __init__(self, periods: int = 1):
config = OperatorConfig(
name=f"return_{periods}",
description=f"{periods}期收益率",
required_columns=['close'],
output_columns=[f'return_{periods}'],
parameters={'periods': periods}
)
super().__init__(config, window=periods + 1)
self.periods = periods
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算收益率"""
return stock_df.with_columns(
(pl.col('close') / pl.col('close').shift(self.periods) - 1).alias(f'return_{self.periods}')
)
class VolatilityOperator(RollingOperator):
"""波动率算子"""
def __init__(self, window: int = 20):
config = OperatorConfig(
name=f"volatility_{window}",
description=f"{window}日波动率",
required_columns=['pct_chg'],
output_columns=[f'volatility_{window}'],
parameters={'window': window}
)
super().__init__(config, window=window)
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算波动率"""
return stock_df.with_columns(
pl.col('pct_chg').rolling_std(window=self.window).alias(f'volatility_{self.window}')
)
class MeanOperator(RollingOperator):
"""均值算子"""
def __init__(self, column: str, window: int):
config = OperatorConfig(
name=f"mean_{column}_{window}",
description=f"{column}{window}日均值",
required_columns=[column],
output_columns=[f'mean_{column}_{window}'],
parameters={'column': column, 'window': window}
)
super().__init__(config, window=window)
self.column = column
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算均值"""
return stock_df.with_columns(
pl.col(self.column).rolling_mean(window=self.window).alias(f'mean_{self.column}_{self.window}')
)
class RankOperator(DateWiseOperator):
"""排名算子"""
def __init__(self, column: str, ascending: bool = True):
config = OperatorConfig(
name=f"rank_{column}",
description=f"{column}的排名",
required_columns=[column],
output_columns=[f'rank_{column}'],
parameters={'column': column, 'ascending': ascending}
)
super().__init__(config)
self.column = column
self.ascending = ascending