Files
NewStock/main/factor/operator_framework.py

124 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
因子算子框架 - Polars 实现
支持:截面滚动 → 拼回长表 → 按列名合并
返回形式可选:完整 DataFrame默认或单列 Series
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Literal
import polars as pl
@dataclass
class OperatorConfig:
"""算子配置"""
name: str
description: str
required_columns: List[str]
output_columns: List[str]
parameters: dict
class BaseOperator(ABC):
"""算子基类"""
def __init__(self, config: OperatorConfig):
self.config = config
self.name = config.name
self.required_columns = config.required_columns
self.output_columns = config.output_columns
# ---------- 子类必须实现 ----------
@abstractmethod
def get_factor_name(self) -> str:
"""返回因子列名(用于合并)"""
pass
@abstractmethod
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
"""
真正的截面计算逻辑。
参数:按 ts_code 或 trade_date 分组后的子表
返回:与 group_df 行数一一对应的因子 Series含正确索引
"""
pass
# ---------- 公共接口 ----------
def apply(self,
df: pl.DataFrame,
return_type: Literal['df', 'series'] = 'df',
**kwargs) -> pl.DataFrame | pl.Series:
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
if not self.validate_input(df):
raise ValueError(f"缺少必需列:{self.required_columns}")
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
merged = self._merge_factor(df, long_table) # ② 合并
return merged if return_type == 'df' else merged[self.get_factor_name()]
# ---------- 内部流程 ----------
def validate_input(self, df: pl.DataFrame) -> bool:
return all(col in df.columns for col in self.required_columns)
@abstractmethod
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""
截面滚动模板group → calc_factor → 拼回长表
返回含【trade_date, ts_code, factor】的长表
"""
pass
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
factor_name = self.get_factor_name()
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
on=['ts_code', 'trade_date'],
how='left')
# -------------------- 股票截面:按 ts_code 分组 --------------------
class StockWiseOperator(BaseOperator):
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
# 确保排序(时间顺序对 shift 等操作至关重要)
df_sorted = df.sort(['ts_code', 'trade_date'])
# 使用 map_groups对每个 ts_code 分组,传入完整子 DataFrame
result = (
df_sorted
.group_by('ts_code', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
)
return result
# -------------------- 日期截面:按 trade_date 分组 --------------------
class DateWiseOperator(BaseOperator):
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
df_sorted = df.sort(['trade_date', 'ts_code'])
result = (
df_sorted
.group_by('trade_date', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
)
return result