Files
NewStock/main/factor/operator_framework.py
2025-11-29 00:23:12 +08:00

190 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
因子算子框架 - Polars 实现(最终精简版)
- 因子自行生成 ID
- parameters 仅含计算参数(不含因子引用)
- required_factor_ids 是因子ID字符串列表
- calc_factor 通过 self.parameters 和 self.required_factor_ids 获取所需信息
"""
from abc import ABC, abstractmethod
from typing import List, Literal, Dict, Any
from collections import defaultdict, deque
import json
import polars as pl
def _normalize_params(params: Dict[str, Any]) -> str:
if not params:
return ""
return json.dumps(sorted(params.items()), separators=(",", ":"))
def _simple_factor_id(name: str, params: Dict[str, Any]) -> str:
"""
生成简洁因子ID如:
("sma", {"window": 5}) → "sma_5"
("return", {"days": 20}) → "return_20"
("rank", {"input": "sma_5"}) → "rank_sma_5"
要求: params 的值必须是简单类型str/int/float/bool
"""
if not params:
return name
# 提取所有参数值,按 key 排序保证一致性
parts = []
for k in sorted(params.keys()):
v = params[k]
if isinstance(v, (str, int, float, bool)):
# 布尔转小写字符串
if isinstance(v, bool):
v = str(v).lower()
parts.append(str(v))
else:
raise ValueError(f"Unsupported parameter type for '{k}': {type(v)}. "
f"Only str/int/float/bool allowed for simple ID.")
return f"{name}_{'_'.join(parts)}"
class BaseFactor(ABC):
def __init__(self, name: str, parameters: Dict[str, Any], required_factor_ids: List[str]):
self.name = name
self.parameters = parameters
self.required_factor_ids = required_factor_ids
self.factor_id = self._generate_factor_id()
def _generate_factor_id(self) -> str:
return _simple_factor_id(self.name, self.parameters)
def get_factor_id(self) -> str:
return self.factor_id
@abstractmethod
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
pass
@property
@abstractmethod
def operator_type(self) -> Literal["stock", "date"]:
pass
class StockWiseFactor(BaseFactor):
@property
def operator_type(self) -> Literal["stock"]:
return "stock"
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
df_sorted = df.sort(["ts_code", "trade_date"])
result = (
df_sorted
.group_by("ts_code", maintain_order=True)
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
.select(["ts_code", "trade_date", self.factor_id])
)
return result
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
if missing:
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
long_table = self._sectional_roll(df)
return df.join(
long_table.select(["ts_code", "trade_date", self.factor_id]),
on=["ts_code", "trade_date"],
how="left"
)
class DateWiseFactor(BaseFactor):
@property
def operator_type(self) -> Literal["date"]:
return "date"
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
df_sorted = df.sort(["trade_date", "ts_code"])
result = (
df_sorted
.group_by("trade_date", maintain_order=True)
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
.select(["ts_code", "trade_date", self.factor_id])
)
return result
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
if missing:
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
long_table = self._sectional_roll(df)
return df.join(
long_table.select(["ts_code", "trade_date", self.factor_id]),
on=["ts_code", "trade_date"],
how="left"
)
class FactorGraph:
def __init__(self):
self._factors = {} # factor_id -> factor
def add_factor(self, factor: BaseFactor):
fid = factor.get_factor_id()
if fid in self._factors:
raise ValueError(f"Factor '{fid}' already registered.")
self._factors[fid] = factor
def _topological_sort(self, target_ids: List[str]) -> List[str]:
all_factors = set()
queue = deque(target_ids)
while queue:
f = queue.popleft()
if f not in all_factors:
all_factors.add(f)
if f in self._factors:
for dep in self._factors[f].required_factor_ids:
if dep not in all_factors:
queue.append(dep)
to_compute = {f for f in all_factors if f in self._factors}
indegree = {f: 0 for f in to_compute}
adj = defaultdict(list)
for f in to_compute:
for dep in self._factors[f].required_factor_ids:
if dep in to_compute:
adj[dep].append(f)
indegree[f] += 1
queue = deque([f for f in to_compute if indegree[f] == 0])
order = []
while queue:
node = queue.popleft()
order.append(node)
for nb in adj[node]:
indegree[nb] -= 1
if indegree[nb] == 0:
queue.append(nb)
print("\n=== Factor Dependency Graph ===")
to_compute = {f for f in all_factors if f in self._factors}
for fid in sorted(to_compute):
deps = self._factors[fid].required_factor_ids
compute_deps = [d for d in deps if d in to_compute] # 只显示可计算的依赖
print(f"{fid} -> {compute_deps}")
print("================================\n")
if len(order) != len(to_compute):
print(len(order), len(to_compute))
raise RuntimeError("Circular dependency!")
return order
def compute(self, df: pl.DataFrame, target_factor_ids: List[str]) -> pl.DataFrame:
exec_order = self._topological_sort(target_factor_ids)
current_df = df.clone()
for fid in exec_order:
print(fid)
if fid in current_df.columns:
continue
factor = self._factors[fid]
current_df = factor.apply(current_df)
return current_df