124 lines
4.3 KiB
Python
124 lines
4.3 KiB
Python
"""
|
||
因子算子框架 - Polars 实现
|
||
支持:截面滚动 → 拼回长表 → 按列名合并
|
||
返回形式可选:完整 DataFrame(默认)或单列 Series
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass
|
||
from typing import List, Literal
|
||
|
||
import polars as pl
|
||
|
||
|
||
@dataclass
|
||
class OperatorConfig:
|
||
"""算子配置"""
|
||
name: str
|
||
description: str
|
||
required_columns: List[str]
|
||
output_columns: List[str]
|
||
parameters: dict
|
||
|
||
|
||
class BaseOperator(ABC):
|
||
"""算子基类"""
|
||
|
||
def __init__(self, config: OperatorConfig):
|
||
self.config = config
|
||
self.name = config.name
|
||
self.required_columns = config.required_columns
|
||
self.output_columns = config.output_columns
|
||
|
||
# ---------- 子类必须实现 ----------
|
||
@abstractmethod
|
||
def get_factor_name(self) -> str:
|
||
"""返回因子列名(用于合并)"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
|
||
"""
|
||
真正的截面计算逻辑。
|
||
参数:按 ts_code 或 trade_date 分组后的子表
|
||
返回:与 group_df 行数一一对应的因子 Series(含正确索引)
|
||
"""
|
||
pass
|
||
|
||
# ---------- 公共接口 ----------
|
||
def apply(self,
|
||
df: pl.DataFrame,
|
||
return_type: Literal['df', 'series'] = 'df',
|
||
**kwargs) -> pl.DataFrame | pl.Series:
|
||
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
|
||
if not self.validate_input(df):
|
||
raise ValueError(f"缺少必需列:{self.required_columns}")
|
||
|
||
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
|
||
merged = self._merge_factor(df, long_table) # ② 合并
|
||
return merged if return_type == 'df' else merged[self.get_factor_name()]
|
||
|
||
# ---------- 内部流程 ----------
|
||
def validate_input(self, df: pl.DataFrame) -> bool:
|
||
return all(col in df.columns for col in self.required_columns)
|
||
|
||
@abstractmethod
|
||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
"""
|
||
截面滚动模板:group → calc_factor → 拼回长表
|
||
返回:含【trade_date, ts_code, factor】的长表
|
||
"""
|
||
pass
|
||
|
||
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
|
||
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
|
||
factor_name = self.get_factor_name()
|
||
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
|
||
on=['ts_code', 'trade_date'],
|
||
how='left')
|
||
|
||
|
||
# -------------------- 股票截面:按 ts_code 分组 --------------------
|
||
class StockWiseOperator(BaseOperator):
|
||
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
|
||
|
||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
factor_name = self.get_factor_name()
|
||
|
||
# 确保排序(时间顺序对 shift 等操作至关重要)
|
||
df_sorted = df.sort(['ts_code', 'trade_date'])
|
||
|
||
# 使用 map_groups:对每个 ts_code 分组,传入完整子 DataFrame
|
||
result = (
|
||
df_sorted
|
||
.group_by('ts_code', maintain_order=True)
|
||
.map_groups(
|
||
lambda group_df: group_df.with_columns(
|
||
self.calc_factor(group_df, **kwargs)
|
||
)
|
||
)
|
||
.select(['ts_code', 'trade_date', factor_name])
|
||
)
|
||
return result
|
||
|
||
# -------------------- 日期截面:按 trade_date 分组 --------------------
|
||
class DateWiseOperator(BaseOperator):
|
||
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
|
||
|
||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||
factor_name = self.get_factor_name()
|
||
|
||
df_sorted = df.sort(['trade_date', 'ts_code'])
|
||
|
||
result = (
|
||
df_sorted
|
||
.group_by('trade_date', maintain_order=True)
|
||
.map_groups(
|
||
lambda group_df: group_df.with_columns(
|
||
self.calc_factor(group_df, **kwargs)
|
||
)
|
||
)
|
||
.select(['ts_code', 'trade_date', factor_name])
|
||
)
|
||
return result
|
||
|