Files
NewStock/main/factor/operator_framework.py

124 lines
4.3 KiB
Python
Raw Normal View History

2025-10-13 21:42:35 +08:00
"""
2025-10-14 09:44:46 +08:00
因子算子框架 - Polars 实现
支持截面滚动 拼回长表 按列名合并
返回形式可选完整 DataFrame默认或单列 Series
2025-10-13 21:42:35 +08:00
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
2025-10-14 09:44:46 +08:00
from typing import List, Literal
2025-10-13 21:42:35 +08:00
2025-10-14 09:44:46 +08:00
import polars as pl
2025-10-13 21:42:35 +08:00
@dataclass
class OperatorConfig:
"""算子配置"""
name: str
description: str
required_columns: List[str]
output_columns: List[str]
2025-10-14 09:44:46 +08:00
parameters: dict
2025-10-13 21:42:35 +08:00
class BaseOperator(ABC):
"""算子基类"""
2025-10-14 09:44:46 +08:00
2025-10-13 21:42:35 +08:00
def __init__(self, config: OperatorConfig):
self.config = config
self.name = config.name
self.required_columns = config.required_columns
self.output_columns = config.output_columns
2025-10-14 09:44:46 +08:00
# ---------- 子类必须实现 ----------
2025-10-13 21:42:35 +08:00
@abstractmethod
2025-10-14 09:44:46 +08:00
def get_factor_name(self) -> str:
"""返回因子列名(用于合并)"""
2025-10-13 21:42:35 +08:00
pass
@abstractmethod
2025-10-14 09:44:46 +08:00
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
"""
真正的截面计算逻辑
参数 ts_code trade_date 分组后的子表
返回 group_df 行数一一对应的因子 Series含正确索引
"""
2025-10-13 21:42:35 +08:00
pass
2025-10-14 09:44:46 +08:00
# ---------- 公共接口 ----------
def apply(self,
df: pl.DataFrame,
return_type: Literal['df', 'series'] = 'df',
**kwargs) -> pl.DataFrame | pl.Series:
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
if not self.validate_input(df):
raise ValueError(f"缺少必需列:{self.required_columns}")
2025-10-13 21:42:35 +08:00
2025-10-14 09:44:46 +08:00
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
merged = self._merge_factor(df, long_table) # ② 合并
return merged if return_type == 'df' else merged[self.get_factor_name()]
2025-10-13 21:42:35 +08:00
2025-10-14 09:44:46 +08:00
# ---------- 内部流程 ----------
def validate_input(self, df: pl.DataFrame) -> bool:
return all(col in df.columns for col in self.required_columns)
2025-10-13 21:42:35 +08:00
@abstractmethod
2025-10-14 09:44:46 +08:00
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""
截面滚动模板group calc_factor 拼回长表
返回trade_date, ts_code, factor的长表
"""
2025-10-13 21:42:35 +08:00
pass
2025-10-14 09:44:46 +08:00
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
factor_name = self.get_factor_name()
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
on=['ts_code', 'trade_date'],
how='left')
2025-10-13 21:42:35 +08:00
2025-10-14 09:44:46 +08:00
# -------------------- 股票截面:按 ts_code 分组 --------------------
class StockWiseOperator(BaseOperator):
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
2025-10-13 21:42:35 +08:00
2025-10-14 09:44:46 +08:00
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
# 确保排序(时间顺序对 shift 等操作至关重要)
df_sorted = df.sort(['ts_code', 'trade_date'])
# 使用 map_groups对每个 ts_code 分组,传入完整子 DataFrame
result = (
df_sorted
.group_by('ts_code', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
2025-10-13 21:42:35 +08:00
)
2025-10-14 09:44:46 +08:00
return result
2025-10-13 21:42:35 +08:00
2025-10-14 09:44:46 +08:00
# -------------------- 日期截面:按 trade_date 分组 --------------------
class DateWiseOperator(BaseOperator):
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
2025-10-13 21:42:35 +08:00
2025-10-14 09:44:46 +08:00
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
df_sorted = df.sort(['trade_date', 'ts_code'])
result = (
df_sorted
.group_by('trade_date', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
2025-10-13 21:42:35 +08:00
)
2025-10-14 09:44:46 +08:00
return result