1、策略更新
2、新增qmt
This commit is contained in:
@@ -1,124 +1,190 @@
|
||||
"""
|
||||
因子算子框架 - Polars 实现
|
||||
支持:截面滚动 → 拼回长表 → 按列名合并
|
||||
返回形式可选:完整 DataFrame(默认)或单列 Series
|
||||
因子算子框架 - Polars 实现(最终精简版)
|
||||
- 因子自行生成 ID
|
||||
- parameters 仅含计算参数(不含因子引用)
|
||||
- required_factor_ids 是因子ID字符串列表
|
||||
- calc_factor 通过 self.parameters 和 self.required_factor_ids 获取所需信息
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal
|
||||
|
||||
from typing import List, Literal, Dict, Any
|
||||
from collections import defaultdict, deque
|
||||
import json
|
||||
import polars as pl
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperatorConfig:
|
||||
"""算子配置"""
|
||||
name: str
|
||||
description: str
|
||||
required_columns: List[str]
|
||||
output_columns: List[str]
|
||||
parameters: dict
|
||||
def _normalize_params(params: Dict[str, Any]) -> str:
|
||||
if not params:
|
||||
return ""
|
||||
return json.dumps(sorted(params.items()), separators=(",", ":"))
|
||||
|
||||
def _simple_factor_id(name: str, params: Dict[str, Any]) -> str:
|
||||
"""
|
||||
生成简洁因子ID,如:
|
||||
("sma", {"window": 5}) → "sma_5"
|
||||
("return", {"days": 20}) → "return_20"
|
||||
("rank", {"input": "sma_5"}) → "rank_sma_5"
|
||||
|
||||
要求: params 的值必须是简单类型(str/int/float/bool)
|
||||
"""
|
||||
if not params:
|
||||
return name
|
||||
|
||||
# 提取所有参数值,按 key 排序保证一致性
|
||||
parts = []
|
||||
for k in sorted(params.keys()):
|
||||
v = params[k]
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
# 布尔转小写字符串
|
||||
if isinstance(v, bool):
|
||||
v = str(v).lower()
|
||||
parts.append(str(v))
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type for '{k}': {type(v)}. "
|
||||
f"Only str/int/float/bool allowed for simple ID.")
|
||||
|
||||
return f"{name}_{'_'.join(parts)}"
|
||||
|
||||
|
||||
class BaseOperator(ABC):
|
||||
"""算子基类"""
|
||||
class BaseFactor(ABC):
|
||||
def __init__(self, name: str, parameters: Dict[str, Any], required_factor_ids: List[str]):
|
||||
self.name = name
|
||||
self.parameters = parameters
|
||||
self.required_factor_ids = required_factor_ids
|
||||
self.factor_id = self._generate_factor_id()
|
||||
|
||||
def __init__(self, config: OperatorConfig):
|
||||
self.config = config
|
||||
self.name = config.name
|
||||
self.required_columns = config.required_columns
|
||||
self.output_columns = config.output_columns
|
||||
def _generate_factor_id(self) -> str:
|
||||
return _simple_factor_id(self.name, self.parameters)
|
||||
|
||||
# ---------- 子类必须实现 ----------
|
||||
@abstractmethod
|
||||
def get_factor_name(self) -> str:
|
||||
"""返回因子列名(用于合并)"""
|
||||
pass
|
||||
def get_factor_id(self) -> str:
|
||||
return self.factor_id
|
||||
|
||||
@abstractmethod
|
||||
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
|
||||
"""
|
||||
真正的截面计算逻辑。
|
||||
参数:按 ts_code 或 trade_date 分组后的子表
|
||||
返回:与 group_df 行数一一对应的因子 Series(含正确索引)
|
||||
"""
|
||||
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
||||
pass
|
||||
|
||||
# ---------- 公共接口 ----------
|
||||
def apply(self,
|
||||
df: pl.DataFrame,
|
||||
return_type: Literal['df', 'series'] = 'df',
|
||||
**kwargs) -> pl.DataFrame | pl.Series:
|
||||
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
|
||||
if not self.validate_input(df):
|
||||
raise ValueError(f"缺少必需列:{self.required_columns}")
|
||||
|
||||
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
|
||||
merged = self._merge_factor(df, long_table) # ② 合并
|
||||
return merged if return_type == 'df' else merged[self.get_factor_name()]
|
||||
|
||||
# ---------- 内部流程 ----------
|
||||
def validate_input(self, df: pl.DataFrame) -> bool:
|
||||
return all(col in df.columns for col in self.required_columns)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
"""
|
||||
截面滚动模板:group → calc_factor → 拼回长表
|
||||
返回:含【trade_date, ts_code, factor】的长表
|
||||
"""
|
||||
def operator_type(self) -> Literal["stock", "date"]:
|
||||
pass
|
||||
|
||||
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
|
||||
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
|
||||
factor_name = self.get_factor_name()
|
||||
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
|
||||
on=['ts_code', 'trade_date'],
|
||||
how='left')
|
||||
|
||||
class StockWiseFactor(BaseFactor):
|
||||
@property
|
||||
def operator_type(self) -> Literal["stock"]:
|
||||
return "stock"
|
||||
|
||||
# -------------------- 股票截面:按 ts_code 分组 --------------------
|
||||
class StockWiseOperator(BaseOperator):
|
||||
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
|
||||
|
||||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
factor_name = self.get_factor_name()
|
||||
|
||||
# 确保排序(时间顺序对 shift 等操作至关重要)
|
||||
df_sorted = df.sort(['ts_code', 'trade_date'])
|
||||
|
||||
# 使用 map_groups:对每个 ts_code 分组,传入完整子 DataFrame
|
||||
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
df_sorted = df.sort(["ts_code", "trade_date"])
|
||||
result = (
|
||||
df_sorted
|
||||
.group_by('ts_code', maintain_order=True)
|
||||
.map_groups(
|
||||
lambda group_df: group_df.with_columns(
|
||||
self.calc_factor(group_df, **kwargs)
|
||||
)
|
||||
)
|
||||
.select(['ts_code', 'trade_date', factor_name])
|
||||
.group_by("ts_code", maintain_order=True)
|
||||
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
|
||||
.select(["ts_code", "trade_date", self.factor_id])
|
||||
)
|
||||
return result
|
||||
|
||||
# -------------------- 日期截面:按 trade_date 分组 --------------------
|
||||
class DateWiseOperator(BaseOperator):
|
||||
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
|
||||
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
|
||||
long_table = self._sectional_roll(df)
|
||||
return df.join(
|
||||
long_table.select(["ts_code", "trade_date", self.factor_id]),
|
||||
on=["ts_code", "trade_date"],
|
||||
how="left"
|
||||
)
|
||||
|
||||
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
|
||||
factor_name = self.get_factor_name()
|
||||
|
||||
df_sorted = df.sort(['trade_date', 'ts_code'])
|
||||
|
||||
|
||||
class DateWiseFactor(BaseFactor):
|
||||
@property
|
||||
def operator_type(self) -> Literal["date"]:
|
||||
return "date"
|
||||
|
||||
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
df_sorted = df.sort(["trade_date", "ts_code"])
|
||||
result = (
|
||||
df_sorted
|
||||
.group_by('trade_date', maintain_order=True)
|
||||
.map_groups(
|
||||
lambda group_df: group_df.with_columns(
|
||||
self.calc_factor(group_df, **kwargs)
|
||||
)
|
||||
)
|
||||
.select(['ts_code', 'trade_date', factor_name])
|
||||
.group_by("trade_date", maintain_order=True)
|
||||
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
|
||||
.select(["ts_code", "trade_date", self.factor_id])
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
|
||||
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
|
||||
long_table = self._sectional_roll(df)
|
||||
return df.join(
|
||||
long_table.select(["ts_code", "trade_date", self.factor_id]),
|
||||
on=["ts_code", "trade_date"],
|
||||
how="left"
|
||||
)
|
||||
|
||||
|
||||
class FactorGraph:
|
||||
def __init__(self):
|
||||
self._factors = {} # factor_id -> factor
|
||||
|
||||
def add_factor(self, factor: BaseFactor):
|
||||
fid = factor.get_factor_id()
|
||||
if fid in self._factors:
|
||||
raise ValueError(f"Factor '{fid}' already registered.")
|
||||
self._factors[fid] = factor
|
||||
|
||||
def _topological_sort(self, target_ids: List[str]) -> List[str]:
|
||||
all_factors = set()
|
||||
queue = deque(target_ids)
|
||||
while queue:
|
||||
f = queue.popleft()
|
||||
if f not in all_factors:
|
||||
all_factors.add(f)
|
||||
if f in self._factors:
|
||||
for dep in self._factors[f].required_factor_ids:
|
||||
if dep not in all_factors:
|
||||
queue.append(dep)
|
||||
|
||||
to_compute = {f for f in all_factors if f in self._factors}
|
||||
indegree = {f: 0 for f in to_compute}
|
||||
adj = defaultdict(list)
|
||||
|
||||
for f in to_compute:
|
||||
for dep in self._factors[f].required_factor_ids:
|
||||
if dep in to_compute:
|
||||
adj[dep].append(f)
|
||||
indegree[f] += 1
|
||||
|
||||
queue = deque([f for f in to_compute if indegree[f] == 0])
|
||||
order = []
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
order.append(node)
|
||||
for nb in adj[node]:
|
||||
indegree[nb] -= 1
|
||||
if indegree[nb] == 0:
|
||||
queue.append(nb)
|
||||
|
||||
print("\n=== Factor Dependency Graph ===")
|
||||
to_compute = {f for f in all_factors if f in self._factors}
|
||||
for fid in sorted(to_compute):
|
||||
deps = self._factors[fid].required_factor_ids
|
||||
compute_deps = [d for d in deps if d in to_compute] # 只显示可计算的依赖
|
||||
print(f"{fid} -> {compute_deps}")
|
||||
print("================================\n")
|
||||
|
||||
if len(order) != len(to_compute):
|
||||
print(len(order), len(to_compute))
|
||||
raise RuntimeError("Circular dependency!")
|
||||
return order
|
||||
|
||||
def compute(self, df: pl.DataFrame, target_factor_ids: List[str]) -> pl.DataFrame:
|
||||
exec_order = self._topological_sort(target_factor_ids)
|
||||
current_df = df.clone()
|
||||
for fid in exec_order:
|
||||
print(fid)
|
||||
if fid in current_df.columns:
|
||||
continue
|
||||
factor = self._factors[fid]
|
||||
current_df = factor.apply(current_df)
|
||||
return current_df
|
||||
Reference in New Issue
Block a user