2025-10-13 21:42:35 +08:00
|
|
|
|
"""
|
2025-10-14 09:44:46 +08:00
|
|
|
|
因子算子框架 - Polars 实现
|
|
|
|
|
|
支持:截面滚动 → 拼回长表 → 按列名合并
|
|
|
|
|
|
返回形式可选:完整 DataFrame(默认)或单列 Series
|
2025-10-13 21:42:35 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
from dataclasses import dataclass
|
2025-10-14 09:44:46 +08:00
|
|
|
|
from typing import List, Literal
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
import polars as pl
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class OperatorConfig:
|
|
|
|
|
|
"""算子配置"""
|
|
|
|
|
|
name: str
|
|
|
|
|
|
description: str
|
|
|
|
|
|
required_columns: List[str]
|
|
|
|
|
|
output_columns: List[str]
|
2025-10-14 09:44:46 +08:00
|
|
|
|
parameters: dict
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseOperator(ABC):
|
|
|
|
|
|
"""算子基类"""
|
2025-10-14 09:44:46 +08:00
|
|
|
|
|
2025-10-13 21:42:35 +08:00
|
|
|
|
def __init__(self, config: OperatorConfig):
|
|
|
|
|
|
self.config = config
|
|
|
|
|
|
self.name = config.name
|
|
|
|
|
|
self.required_columns = config.required_columns
|
|
|
|
|
|
self.output_columns = config.output_columns
|
2025-10-14 09:44:46 +08:00
|
|
|
|
|
|
|
|
|
|
# ---------- 子类必须实现 ----------
|
2025-10-13 21:42:35 +08:00
|
|
|
|
@abstractmethod
|
2025-10-14 09:44:46 +08:00
|
|
|
|
def get_factor_name(self) -> str:
|
|
|
|
|
|
"""返回因子列名(用于合并)"""
|
2025-10-13 21:42:35 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
2025-10-14 09:44:46 +08:00
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
|
|
|
|
|
|
"""
|
|
|
|
|
|
真正的截面计算逻辑。
|
|
|
|
|
|
参数:按 ts_code 或 trade_date 分组后的子表
|
|
|
|
|
|
返回:与 group_df 行数一一对应的因子 Series(含正确索引)
|
|
|
|
|
|
"""
|
2025-10-13 21:42:35 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
# ---------- 公共接口 ----------
|
|
|
|
|
|
def apply(self,
|
|
|
|
|
|
df: pl.DataFrame,
|
|
|
|
|
|
return_type: Literal['df', 'series'] = 'df',
|
|
|
|
|
|
**kwargs) -> pl.DataFrame | pl.Series:
|
|
|
|
|
|
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
|
|
|
|
|
|
if not self.validate_input(df):
|
|
|
|
|
|
raise ValueError(f"缺少必需列:{self.required_columns}")
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
|
|
|
|
|
|
merged = self._merge_factor(df, long_table) # ② 合并
|
|
|
|
|
|
return merged if return_type == 'df' else merged[self.get_factor_name()]
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
# ---------- 内部流程 ----------
|
|
|
|
|
|
def validate_input(self, df: pl.DataFrame) -> bool:
|
|
|
|
|
|
return all(col in df.columns for col in self.required_columns)
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
2025-10-14 09:44:46 +08:00
|
|
|
|
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
|
|
|
|
|
"""
|
|
|
|
|
|
截面滚动模板:group → calc_factor → 拼回长表
|
|
|
|
|
|
返回:含【trade_date, ts_code, factor】的长表
|
|
|
|
|
|
"""
|
2025-10-13 21:42:35 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
|
|
|
|
|
|
factor_name = self.get_factor_name()
|
|
|
|
|
|
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
|
|
|
|
|
|
on=['ts_code', 'trade_date'],
|
|
|
|
|
|
how='left')
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
# -------------------- 股票截面:按 ts_code 分组 --------------------
|
|
|
|
|
|
class StockWiseOperator(BaseOperator):
|
|
|
|
|
|
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
|
|
|
|
|
factor_name = self.get_factor_name()
|
|
|
|
|
|
|
|
|
|
|
|
# 确保排序(时间顺序对 shift 等操作至关重要)
|
|
|
|
|
|
df_sorted = df.sort(['ts_code', 'trade_date'])
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 map_groups:对每个 ts_code 分组,传入完整子 DataFrame
|
|
|
|
|
|
result = (
|
|
|
|
|
|
df_sorted
|
|
|
|
|
|
.group_by('ts_code', maintain_order=True)
|
|
|
|
|
|
.map_groups(
|
|
|
|
|
|
lambda group_df: group_df.with_columns(
|
|
|
|
|
|
self.calc_factor(group_df, **kwargs)
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
.select(['ts_code', 'trade_date', factor_name])
|
2025-10-13 21:42:35 +08:00
|
|
|
|
)
|
2025-10-14 09:44:46 +08:00
|
|
|
|
return result
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
# -------------------- 日期截面:按 trade_date 分组 --------------------
|
|
|
|
|
|
class DateWiseOperator(BaseOperator):
|
|
|
|
|
|
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-10-14 09:44:46 +08:00
|
|
|
|
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
|
|
|
|
|
factor_name = self.get_factor_name()
|
|
|
|
|
|
|
|
|
|
|
|
df_sorted = df.sort(['trade_date', 'ts_code'])
|
|
|
|
|
|
|
|
|
|
|
|
result = (
|
|
|
|
|
|
df_sorted
|
|
|
|
|
|
.group_by('trade_date', maintain_order=True)
|
|
|
|
|
|
.map_groups(
|
|
|
|
|
|
lambda group_df: group_df.with_columns(
|
|
|
|
|
|
self.calc_factor(group_df, **kwargs)
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
.select(['ts_code', 'trade_date', factor_name])
|
2025-10-13 21:42:35 +08:00
|
|
|
|
)
|
2025-10-14 09:44:46 +08:00
|
|
|
|
return result
|
|
|
|
|
|
|