factor优化(暂存版)

This commit is contained in:
2025-10-14 09:44:46 +08:00
parent 44315b2c76
commit 7862b9739a
9 changed files with 804 additions and 4427 deletions

View File

@@ -1,18 +1,14 @@
"""
因子算子框架 - 使用Polars实现统一的因子计算
避免数据泄露,支持切面计算
因子算子框架 - Polars 实现
支持:截面滚动 → 拼回长表 → 按列名合并
返回形式可选:完整 DataFrame默认或单列 Series
"""
import polars as pl
import numpy as np
from typing import Dict, List, Callable, Optional, Union, Any
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
from typing import List, Literal
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import polars as pl
@dataclass
@@ -22,229 +18,107 @@ class OperatorConfig:
description: str
required_columns: List[str]
output_columns: List[str]
parameters: Dict[str, Any]
class DataSlice:
"""数据切面基类"""
def __init__(self, df: pl.DataFrame):
self.df = df
self.validate_data()
def validate_data(self):
"""验证数据格式"""
required_cols = ['ts_code', 'trade_date']
missing_cols = [col for col in required_cols if col not in self.df.columns]
if missing_cols:
raise ValueError(f"缺少必需列: {missing_cols}")
def get_stock_slice(self, ts_code: str) -> pl.DataFrame:
"""获取单个股票的数据切面"""
return self.df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
def get_date_slice(self, trade_date: str) -> pl.DataFrame:
"""获取单个日期的数据切面"""
return self.df.filter(pl.col('trade_date') == trade_date)
def get_stock_list(self) -> List[str]:
"""获取股票列表"""
return self.df['ts_code'].unique().to_list()
def get_date_list(self) -> List[str]:
"""获取日期列表"""
return self.df['trade_date'].unique().to_list()
parameters: dict
class BaseOperator(ABC):
"""算子基类"""
def __init__(self, config: OperatorConfig):
self.config = config
self.name = config.name
self.required_columns = config.required_columns
self.output_columns = config.output_columns
def validate_input(self, df: pl.DataFrame) -> bool:
"""验证输入数据"""
missing_cols = [col for col in self.required_columns if col not in df.columns]
if missing_cols:
logger.warning(f"算子 {self.name} 缺少必需列: {missing_cols}")
return False
return True
# ---------- 子类必须实现 ----------
@abstractmethod
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用算子"""
def get_factor_name(self) -> str:
"""返回因子列名(用于合并)"""
pass
def __call__(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""调用算子"""
@abstractmethod
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
"""
真正的截面计算逻辑。
参数:按 ts_code 或 trade_date 分组后的子表
返回:与 group_df 行数一一对应的因子 Series含正确索引
"""
pass
# ---------- 公共接口 ----------
def apply(self,
df: pl.DataFrame,
return_type: Literal['df', 'series'] = 'df',
**kwargs) -> pl.DataFrame | pl.Series:
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
if not self.validate_input(df):
# 返回原始数据添加NaN列
for col in self.output_columns:
df = df.with_columns(pl.lit(None).alias(col))
return df
try:
return self.apply(df, **kwargs)
except Exception as e:
logger.error(f"算子 {self.name} 应用失败: {e}")
# 返回原始数据添加NaN列
for col in self.output_columns:
df = df.with_columns(pl.lit(None).alias(col))
return df
raise ValueError(f"缺少必需列:{self.required_columns}")
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
merged = self._merge_factor(df, long_table) # ② 合并
return merged if return_type == 'df' else merged[self.get_factor_name()]
# ---------- 内部流程 ----------
def validate_input(self, df: pl.DataFrame) -> bool:
return all(col in df.columns for col in self.required_columns)
@abstractmethod
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""
截面滚动模板group → calc_factor → 拼回长表
返回含【trade_date, ts_code, factor】的长表
"""
pass
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
factor_name = self.get_factor_name()
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
on=['ts_code', 'trade_date'],
how='left')
# -------------------- 股票截面:按 ts_code 分组 --------------------
class StockWiseOperator(BaseOperator):
"""股票切面算子 - 按股票分组计算"""
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""按股票分组应用算子"""
stock_list = df['ts_code'].unique().to_list()
results = []
for ts_code in stock_list:
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
try:
result_df = self.apply_stock(stock_df, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"股票 {ts_code} 算子应用失败: {e}")
# 为失败的股票添加NaN列
for col in self.output_columns:
stock_df = stock_df.with_columns(pl.lit(None).alias(col))
results.append(stock_df)
return pl.concat(results).sort(['ts_code', 'trade_date'])
@abstractmethod
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用到单个股票数据"""
pass
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
# 确保排序(时间顺序对 shift 等操作至关重要)
df_sorted = df.sort(['ts_code', 'trade_date'])
# 使用 map_groups对每个 ts_code 分组,传入完整子 DataFrame
result = (
df_sorted
.group_by('ts_code', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
)
return result
# -------------------- 日期截面:按 trade_date 分组 --------------------
class DateWiseOperator(BaseOperator):
"""日期切面算子 - 按日期分组计算"""
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""按日期分组应用算子"""
date_list = df['trade_date'].unique().to_list()
results = []
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
for trade_date in date_list:
date_df = df.filter(pl.col('trade_date') == trade_date)
try:
result_df = self.apply_date(date_df, **kwargs)
results.append(result_df)
except Exception as e:
logger.error(f"日期 {trade_date} 算子应用失败: {e}")
# 为失败的日期添加NaN列
for col in self.output_columns:
date_df = date_df.with_columns(pl.lit(None).alias(col))
results.append(date_df)
df_sorted = df.sort(['trade_date', 'ts_code'])
return pl.concat(results).sort(['ts_code', 'trade_date'])
@abstractmethod
def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用到单个日期数据"""
pass
class RollingOperator(StockWiseOperator):
"""滚动窗口算子基类"""
def __init__(self, config: OperatorConfig, window: int, min_periods: Optional[int] = None):
super().__init__(config)
self.window = window
self.min_periods = min_periods or max(1, window // 2)
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""应用滚动窗口计算"""
return self.apply_rolling(stock_df, **kwargs)
@abstractmethod
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""滚动窗口计算逻辑"""
pass
# 基础算子实现
class ReturnOperator(RollingOperator):
"""收益率算子"""
def __init__(self, periods: int = 1):
config = OperatorConfig(
name=f"return_{periods}",
description=f"{periods}期收益率",
required_columns=['close'],
output_columns=[f'return_{periods}'],
parameters={'periods': periods}
result = (
df_sorted
.group_by('trade_date', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
)
super().__init__(config, window=periods + 1)
self.periods = periods
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算收益率"""
return stock_df.with_columns(
(pl.col('close') / pl.col('close').shift(self.periods) - 1).alias(f'return_{self.periods}')
)
class VolatilityOperator(RollingOperator):
"""波动率算子"""
def __init__(self, window: int = 20):
config = OperatorConfig(
name=f"volatility_{window}",
description=f"{window}日波动率",
required_columns=['pct_chg'],
output_columns=[f'volatility_{window}'],
parameters={'window': window}
)
super().__init__(config, window=window)
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算波动率"""
return stock_df.with_columns(
pl.col('pct_chg').rolling_std(window=self.window).alias(f'volatility_{self.window}')
)
class MeanOperator(RollingOperator):
"""均值算子"""
def __init__(self, column: str, window: int):
config = OperatorConfig(
name=f"mean_{column}_{window}",
description=f"{column}{window}日均值",
required_columns=[column],
output_columns=[f'mean_{column}_{window}'],
parameters={'column': column, 'window': window}
)
super().__init__(config, window=window)
self.column = column
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""计算均值"""
return stock_df.with_columns(
pl.col(self.column).rolling_mean(window=self.window).alias(f'mean_{self.column}_{self.window}')
)
class RankOperator(DateWiseOperator):
"""排名算子"""
def __init__(self, column: str, ascending: bool = True):
config = OperatorConfig(
name=f"rank_{column}",
description=f"{column}的排名",
required_columns=[column],
output_columns=[f'rank_{column}'],
parameters={'column': column, 'ascending': ascending}
)
super().__init__(config)
self.column = column
self.ascending = ascending
return result