Files
NewStock/main/factor/operator_framework.py

190 lines
6.4 KiB
Python
Raw Normal View History

2025-10-13 21:42:35 +08:00
"""
2025-11-29 00:23:12 +08:00
因子算子框架 - Polars 实现最终精简版
- 因子自行生成 ID
- parameters 仅含计算参数不含因子引用
- required_factor_ids 是因子ID字符串列表
- calc_factor 通过 self.parameters self.required_factor_ids 获取所需信息
2025-10-13 21:42:35 +08:00
"""
from abc import ABC, abstractmethod
2025-11-29 00:23:12 +08:00
from typing import List, Literal, Dict, Any
from collections import defaultdict, deque
import json
2025-10-14 09:44:46 +08:00
import polars as pl
2025-10-13 21:42:35 +08:00
2025-11-29 00:23:12 +08:00
def _normalize_params(params: Dict[str, Any]) -> str:
if not params:
return ""
return json.dumps(sorted(params.items()), separators=(",", ":"))
def _simple_factor_id(name: str, params: Dict[str, Any]) -> str:
"""
生成简洁因子ID:
("sma", {"window": 5}) "sma_5"
("return", {"days": 20}) "return_20"
("rank", {"input": "sma_5"}) "rank_sma_5"
要求: params 的值必须是简单类型str/int/float/bool
"""
if not params:
return name
# 提取所有参数值,按 key 排序保证一致性
parts = []
for k in sorted(params.keys()):
v = params[k]
if isinstance(v, (str, int, float, bool)):
# 布尔转小写字符串
if isinstance(v, bool):
v = str(v).lower()
parts.append(str(v))
else:
raise ValueError(f"Unsupported parameter type for '{k}': {type(v)}. "
f"Only str/int/float/bool allowed for simple ID.")
return f"{name}_{'_'.join(parts)}"
class BaseFactor(ABC):
def __init__(self, name: str, parameters: Dict[str, Any], required_factor_ids: List[str]):
self.name = name
self.parameters = parameters
self.required_factor_ids = required_factor_ids
self.factor_id = self._generate_factor_id()
def _generate_factor_id(self) -> str:
return _simple_factor_id(self.name, self.parameters)
def get_factor_id(self) -> str:
return self.factor_id
2025-10-14 09:44:46 +08:00
2025-10-13 21:42:35 +08:00
@abstractmethod
2025-11-29 00:23:12 +08:00
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
2025-10-13 21:42:35 +08:00
pass
2025-11-29 00:23:12 +08:00
@property
2025-10-13 21:42:35 +08:00
@abstractmethod
2025-11-29 00:23:12 +08:00
def operator_type(self) -> Literal["stock", "date"]:
2025-10-13 21:42:35 +08:00
pass
2025-11-29 00:23:12 +08:00
class StockWiseFactor(BaseFactor):
@property
def operator_type(self) -> Literal["stock"]:
return "stock"
2025-10-13 21:42:35 +08:00
2025-11-29 00:23:12 +08:00
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
df_sorted = df.sort(["ts_code", "trade_date"])
2025-10-14 09:44:46 +08:00
result = (
df_sorted
2025-11-29 00:23:12 +08:00
.group_by("ts_code", maintain_order=True)
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
.select(["ts_code", "trade_date", self.factor_id])
2025-10-13 21:42:35 +08:00
)
2025-10-14 09:44:46 +08:00
return result
2025-10-13 21:42:35 +08:00
2025-11-29 00:23:12 +08:00
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
if missing:
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
long_table = self._sectional_roll(df)
return df.join(
long_table.select(["ts_code", "trade_date", self.factor_id]),
on=["ts_code", "trade_date"],
how="left"
)
2025-10-13 21:42:35 +08:00
2025-11-29 00:23:12 +08:00
class DateWiseFactor(BaseFactor):
@property
def operator_type(self) -> Literal["date"]:
return "date"
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
df_sorted = df.sort(["trade_date", "ts_code"])
2025-10-14 09:44:46 +08:00
result = (
df_sorted
2025-11-29 00:23:12 +08:00
.group_by("trade_date", maintain_order=True)
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
.select(["ts_code", "trade_date", self.factor_id])
2025-10-13 21:42:35 +08:00
)
2025-10-14 09:44:46 +08:00
return result
2025-11-29 00:23:12 +08:00
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
if missing:
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
long_table = self._sectional_roll(df)
return df.join(
long_table.select(["ts_code", "trade_date", self.factor_id]),
on=["ts_code", "trade_date"],
how="left"
)
class FactorGraph:
def __init__(self):
self._factors = {} # factor_id -> factor
def add_factor(self, factor: BaseFactor):
fid = factor.get_factor_id()
if fid in self._factors:
raise ValueError(f"Factor '{fid}' already registered.")
self._factors[fid] = factor
def _topological_sort(self, target_ids: List[str]) -> List[str]:
all_factors = set()
queue = deque(target_ids)
while queue:
f = queue.popleft()
if f not in all_factors:
all_factors.add(f)
if f in self._factors:
for dep in self._factors[f].required_factor_ids:
if dep not in all_factors:
queue.append(dep)
to_compute = {f for f in all_factors if f in self._factors}
indegree = {f: 0 for f in to_compute}
adj = defaultdict(list)
for f in to_compute:
for dep in self._factors[f].required_factor_ids:
if dep in to_compute:
adj[dep].append(f)
indegree[f] += 1
queue = deque([f for f in to_compute if indegree[f] == 0])
order = []
while queue:
node = queue.popleft()
order.append(node)
for nb in adj[node]:
indegree[nb] -= 1
if indegree[nb] == 0:
queue.append(nb)
print("\n=== Factor Dependency Graph ===")
to_compute = {f for f in all_factors if f in self._factors}
for fid in sorted(to_compute):
deps = self._factors[fid].required_factor_ids
compute_deps = [d for d in deps if d in to_compute] # 只显示可计算的依赖
print(f"{fid} -> {compute_deps}")
print("================================\n")
if len(order) != len(to_compute):
print(len(order), len(to_compute))
raise RuntimeError("Circular dependency!")
return order
def compute(self, df: pl.DataFrame, target_factor_ids: List[str]) -> pl.DataFrame:
exec_order = self._topological_sort(target_factor_ids)
current_df = df.clone()
for fid in exec_order:
print(fid)
if fid in current_df.columns:
continue
factor = self._factors[fid]
current_df = factor.apply(current_df)
return current_df