factor优化(暂存版)
This commit is contained in:
@@ -1,18 +1,14 @@
|
||||
"""
|
||||
因子算子框架 - 使用Polars实现统一的因子计算
|
||||
避免数据泄露,支持切面计算
|
||||
因子算子框架 - Polars 实现
|
||||
支持:截面滚动 → 拼回长表 → 按列名合并
|
||||
返回形式可选:完整 DataFrame(默认)或单列 Series
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from typing import Dict, List, Callable, Optional, Union, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import List, Literal
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
import polars as pl
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,229 +18,107 @@ class OperatorConfig:
|
||||
description: str
|
||||
required_columns: List[str]
|
||||
output_columns: List[str]
|
||||
parameters: Dict[str, Any]
|
||||
|
||||
|
||||
class DataSlice:
|
||||
"""数据切面基类"""
|
||||
|
||||
def __init__(self, df: pl.DataFrame):
|
||||
self.df = df
|
||||
self.validate_data()
|
||||
|
||||
def validate_data(self):
|
||||
"""验证数据格式"""
|
||||
required_cols = ['ts_code', 'trade_date']
|
||||
missing_cols = [col for col in required_cols if col not in self.df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"缺少必需列: {missing_cols}")
|
||||
|
||||
def get_stock_slice(self, ts_code: str) -> pl.DataFrame:
|
||||
"""获取单个股票的数据切面"""
|
||||
return self.df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||||
|
||||
def get_date_slice(self, trade_date: str) -> pl.DataFrame:
|
||||
"""获取单个日期的数据切面"""
|
||||
return self.df.filter(pl.col('trade_date') == trade_date)
|
||||
|
||||
def get_stock_list(self) -> List[str]:
|
||||
"""获取股票列表"""
|
||||
return self.df['ts_code'].unique().to_list()
|
||||
|
||||
def get_date_list(self) -> List[str]:
|
||||
"""获取日期列表"""
|
||||
return self.df['trade_date'].unique().to_list()
|
||||
parameters: dict
|
||||
|
||||
|
||||
class BaseOperator(ABC):
|
||||
"""算子基类"""
|
||||
|
||||
|
||||
def __init__(self, config: OperatorConfig):
|
||||
self.config = config
|
||||
self.name = config.name
|
||||
self.required_columns = config.required_columns
|
||||
self.output_columns = config.output_columns
|
||||
|
||||
def validate_input(self, df: pl.DataFrame) -> bool:
|
||||
"""验证输入数据"""
|
||||
missing_cols = [col for col in self.required_columns if col not in df.columns]
|
||||
if missing_cols:
|
||||
logger.warning(f"算子 {self.name} 缺少必需列: {missing_cols}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ---------- 子类必须实现 ----------
|
||||
@abstractmethod
|
||||
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""应用算子"""
|
||||
def get_factor_name(self) -> str:
|
||||
"""返回因子列名(用于合并)"""
|
||||
pass
|
||||
|
||||
def __call__(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""调用算子"""
|
||||
|
||||
@abstractmethod
|
||||
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
|
||||
"""
|
||||
真正的截面计算逻辑。
|
||||
参数:按 ts_code 或 trade_date 分组后的子表
|
||||
返回:与 group_df 行数一一对应的因子 Series(含正确索引)
|
||||
"""
|
||||
pass
|
||||
|
||||
# ---------- 公共接口 ----------
|
||||
def apply(self,
|
||||
df: pl.DataFrame,
|
||||
return_type: Literal['df', 'series'] = 'df',
|
||||
**kwargs) -> pl.DataFrame | pl.Series:
|
||||
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
|
||||
if not self.validate_input(df):
|
||||
# 返回原始数据,添加NaN列
|
||||
for col in self.output_columns:
|
||||
df = df.with_columns(pl.lit(None).alias(col))
|
||||
return df
|
||||
|
||||
try:
|
||||
return self.apply(df, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"算子 {self.name} 应用失败: {e}")
|
||||
# 返回原始数据,添加NaN列
|
||||
for col in self.output_columns:
|
||||
df = df.with_columns(pl.lit(None).alias(col))
|
||||
return df
|
||||
raise ValueError(f"缺少必需列:{self.required_columns}")
|
||||
|
||||
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
|
||||
merged = self._merge_factor(df, long_table) # ② 合并
|
||||
return merged if return_type == 'df' else merged[self.get_factor_name()]
|
||||
|
||||
# ---------- 内部流程 ----------
|
||||
def validate_input(self, df: pl.DataFrame) -> bool:
|
||||
return all(col in df.columns for col in self.required_columns)
|
||||
|
||||
@abstractmethod
|
||||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""
|
||||
截面滚动模板:group → calc_factor → 拼回长表
|
||||
返回:含【trade_date, ts_code, factor】的长表
|
||||
"""
|
||||
pass
|
||||
|
||||
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
|
||||
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
|
||||
factor_name = self.get_factor_name()
|
||||
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
|
||||
on=['ts_code', 'trade_date'],
|
||||
how='left')
|
||||
|
||||
|
||||
# -------------------- 股票截面:按 ts_code 分组 --------------------
|
||||
class StockWiseOperator(BaseOperator):
|
||||
"""股票切面算子 - 按股票分组计算"""
|
||||
|
||||
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""按股票分组应用算子"""
|
||||
stock_list = df['ts_code'].unique().to_list()
|
||||
results = []
|
||||
|
||||
for ts_code in stock_list:
|
||||
stock_df = df.filter(pl.col('ts_code') == ts_code).sort('trade_date')
|
||||
try:
|
||||
result_df = self.apply_stock(stock_df, **kwargs)
|
||||
results.append(result_df)
|
||||
except Exception as e:
|
||||
logger.error(f"股票 {ts_code} 算子应用失败: {e}")
|
||||
# 为失败的股票添加NaN列
|
||||
for col in self.output_columns:
|
||||
stock_df = stock_df.with_columns(pl.lit(None).alias(col))
|
||||
results.append(stock_df)
|
||||
|
||||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||||
|
||||
@abstractmethod
|
||||
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""应用到单个股票数据"""
|
||||
pass
|
||||
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
|
||||
|
||||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
factor_name = self.get_factor_name()
|
||||
|
||||
# 确保排序(时间顺序对 shift 等操作至关重要)
|
||||
df_sorted = df.sort(['ts_code', 'trade_date'])
|
||||
|
||||
# 使用 map_groups:对每个 ts_code 分组,传入完整子 DataFrame
|
||||
result = (
|
||||
df_sorted
|
||||
.group_by('ts_code', maintain_order=True)
|
||||
.map_groups(
|
||||
lambda group_df: group_df.with_columns(
|
||||
self.calc_factor(group_df, **kwargs)
|
||||
)
|
||||
)
|
||||
.select(['ts_code', 'trade_date', factor_name])
|
||||
)
|
||||
return result
|
||||
|
||||
# -------------------- 日期截面:按 trade_date 分组 --------------------
|
||||
class DateWiseOperator(BaseOperator):
|
||||
"""日期切面算子 - 按日期分组计算"""
|
||||
|
||||
def apply(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""按日期分组应用算子"""
|
||||
date_list = df['trade_date'].unique().to_list()
|
||||
results = []
|
||||
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
|
||||
|
||||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
factor_name = self.get_factor_name()
|
||||
|
||||
for trade_date in date_list:
|
||||
date_df = df.filter(pl.col('trade_date') == trade_date)
|
||||
try:
|
||||
result_df = self.apply_date(date_df, **kwargs)
|
||||
results.append(result_df)
|
||||
except Exception as e:
|
||||
logger.error(f"日期 {trade_date} 算子应用失败: {e}")
|
||||
# 为失败的日期添加NaN列
|
||||
for col in self.output_columns:
|
||||
date_df = date_df.with_columns(pl.lit(None).alias(col))
|
||||
results.append(date_df)
|
||||
df_sorted = df.sort(['trade_date', 'ts_code'])
|
||||
|
||||
return pl.concat(results).sort(['ts_code', 'trade_date'])
|
||||
|
||||
@abstractmethod
|
||||
def apply_date(self, date_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""应用到单个日期数据"""
|
||||
pass
|
||||
|
||||
|
||||
class RollingOperator(StockWiseOperator):
|
||||
"""滚动窗口算子基类"""
|
||||
|
||||
def __init__(self, config: OperatorConfig, window: int, min_periods: Optional[int] = None):
|
||||
super().__init__(config)
|
||||
self.window = window
|
||||
self.min_periods = min_periods or max(1, window // 2)
|
||||
|
||||
def apply_stock(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""应用滚动窗口计算"""
|
||||
return self.apply_rolling(stock_df, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""滚动窗口计算逻辑"""
|
||||
pass
|
||||
|
||||
|
||||
# 基础算子实现
|
||||
class ReturnOperator(RollingOperator):
|
||||
"""收益率算子"""
|
||||
|
||||
def __init__(self, periods: int = 1):
|
||||
config = OperatorConfig(
|
||||
name=f"return_{periods}",
|
||||
description=f"{periods}期收益率",
|
||||
required_columns=['close'],
|
||||
output_columns=[f'return_{periods}'],
|
||||
parameters={'periods': periods}
|
||||
result = (
|
||||
df_sorted
|
||||
.group_by('trade_date', maintain_order=True)
|
||||
.map_groups(
|
||||
lambda group_df: group_df.with_columns(
|
||||
self.calc_factor(group_df, **kwargs)
|
||||
)
|
||||
)
|
||||
.select(['ts_code', 'trade_date', factor_name])
|
||||
)
|
||||
super().__init__(config, window=periods + 1)
|
||||
self.periods = periods
|
||||
|
||||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""计算收益率"""
|
||||
return stock_df.with_columns(
|
||||
(pl.col('close') / pl.col('close').shift(self.periods) - 1).alias(f'return_{self.periods}')
|
||||
)
|
||||
|
||||
|
||||
class VolatilityOperator(RollingOperator):
|
||||
"""波动率算子"""
|
||||
|
||||
def __init__(self, window: int = 20):
|
||||
config = OperatorConfig(
|
||||
name=f"volatility_{window}",
|
||||
description=f"{window}日波动率",
|
||||
required_columns=['pct_chg'],
|
||||
output_columns=[f'volatility_{window}'],
|
||||
parameters={'window': window}
|
||||
)
|
||||
super().__init__(config, window=window)
|
||||
|
||||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""计算波动率"""
|
||||
return stock_df.with_columns(
|
||||
pl.col('pct_chg').rolling_std(window=self.window).alias(f'volatility_{self.window}')
|
||||
)
|
||||
|
||||
|
||||
class MeanOperator(RollingOperator):
|
||||
"""均值算子"""
|
||||
|
||||
def __init__(self, column: str, window: int):
|
||||
config = OperatorConfig(
|
||||
name=f"mean_{column}_{window}",
|
||||
description=f"{column}的{window}日均值",
|
||||
required_columns=[column],
|
||||
output_columns=[f'mean_{column}_{window}'],
|
||||
parameters={'column': column, 'window': window}
|
||||
)
|
||||
super().__init__(config, window=window)
|
||||
self.column = column
|
||||
|
||||
def apply_rolling(self, stock_df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""计算均值"""
|
||||
return stock_df.with_columns(
|
||||
pl.col(self.column).rolling_mean(window=self.window).alias(f'mean_{self.column}_{self.window}')
|
||||
)
|
||||
|
||||
|
||||
class RankOperator(DateWiseOperator):
|
||||
"""排名算子"""
|
||||
|
||||
def __init__(self, column: str, ascending: bool = True):
|
||||
config = OperatorConfig(
|
||||
name=f"rank_{column}",
|
||||
description=f"{column}的排名",
|
||||
required_columns=[column],
|
||||
output_columns=[f'rank_{column}'],
|
||||
parameters={'column': column, 'ascending': ascending}
|
||||
)
|
||||
super().__init__(config)
|
||||
self.column = column
|
||||
self.ascending = ascending
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user