""" 因子算子框架 - Polars 实现(最终精简版) - 因子自行生成 ID - parameters 仅含计算参数(不含因子引用) - required_factor_ids 是因子ID字符串列表 - calc_factor 通过 self.parameters 和 self.required_factor_ids 获取所需信息 """ from abc import ABC, abstractmethod from typing import List, Literal, Dict, Any from collections import defaultdict, deque import json import polars as pl def _normalize_params(params: Dict[str, Any]) -> str: if not params: return "" return json.dumps(sorted(params.items()), separators=(",", ":")) def _simple_factor_id(name: str, params: Dict[str, Any]) -> str: """ 生成简洁因子ID,如: ("sma", {"window": 5}) → "sma_5" ("return", {"days": 20}) → "return_20" ("rank", {"input": "sma_5"}) → "rank_sma_5" 要求: params 的值必须是简单类型(str/int/float/bool) """ if not params: return name # 提取所有参数值,按 key 排序保证一致性 parts = [] for k in sorted(params.keys()): v = params[k] if isinstance(v, (str, int, float, bool)): # 布尔转小写字符串 if isinstance(v, bool): v = str(v).lower() parts.append(str(v)) else: raise ValueError(f"Unsupported parameter type for '{k}': {type(v)}. " f"Only str/int/float/bool allowed for simple ID.") return f"{name}_{'_'.join(parts)}" class BaseFactor(ABC): def __init__(self, name: str, parameters: Dict[str, Any], required_factor_ids: List[str]): self.name = name self.parameters = parameters self.required_factor_ids = required_factor_ids self.factor_id = self._generate_factor_id() def _generate_factor_id(self) -> str: return _simple_factor_id(self.name, self.parameters) def get_factor_id(self) -> str: return self.factor_id @abstractmethod def calc_factor(self, group_df: pl.DataFrame) -> pl.Series: pass @property @abstractmethod def operator_type(self) -> Literal["stock", "date"]: pass class StockWiseFactor(BaseFactor): @property def operator_type(self) -> Literal["stock"]: return "stock" def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame: df_sorted = df.sort(["ts_code", "trade_date"]) result = ( df_sorted .group_by("ts_code", maintain_order=True) .map_groups(lambda g: g.with_columns(self.calc_factor(g))) .select(["ts_code", "trade_date", self.factor_id]) ) return result def apply(self, df: pl.DataFrame) -> pl.DataFrame: missing = [fid for fid in self.required_factor_ids if fid not in df.columns] if missing: raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}") long_table = self._sectional_roll(df) return df.join( long_table.select(["ts_code", "trade_date", self.factor_id]), on=["ts_code", "trade_date"], how="left" ) class DateWiseFactor(BaseFactor): @property def operator_type(self) -> Literal["date"]: return "date" def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame: df_sorted = df.sort(["trade_date", "ts_code"]) result = ( df_sorted .group_by("trade_date", maintain_order=True) .map_groups(lambda g: g.with_columns(self.calc_factor(g))) .select(["ts_code", "trade_date", self.factor_id]) ) return result def apply(self, df: pl.DataFrame) -> pl.DataFrame: missing = [fid for fid in self.required_factor_ids if fid not in df.columns] if missing: raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}") long_table = self._sectional_roll(df) return df.join( long_table.select(["ts_code", "trade_date", self.factor_id]), on=["ts_code", "trade_date"], how="left" ) class FactorGraph: def __init__(self): self._factors = {} # factor_id -> factor def add_factor(self, factor: BaseFactor): fid = factor.get_factor_id() if fid in self._factors: raise ValueError(f"Factor '{fid}' already registered.") self._factors[fid] = factor def _topological_sort(self, target_ids: List[str]) -> List[str]: all_factors = set() queue = deque(target_ids) while queue: f = queue.popleft() if f not in all_factors: all_factors.add(f) if f in self._factors: for dep in self._factors[f].required_factor_ids: if dep not in all_factors: queue.append(dep) to_compute = {f for f in all_factors if f in self._factors} indegree = {f: 0 for f in to_compute} adj = defaultdict(list) for f in to_compute: for dep in self._factors[f].required_factor_ids: if dep in to_compute: adj[dep].append(f) indegree[f] += 1 queue = deque([f for f in to_compute if indegree[f] == 0]) order = [] while queue: node = queue.popleft() order.append(node) for nb in adj[node]: indegree[nb] -= 1 if indegree[nb] == 0: queue.append(nb) print("\n=== Factor Dependency Graph ===") to_compute = {f for f in all_factors if f in self._factors} for fid in sorted(to_compute): deps = self._factors[fid].required_factor_ids compute_deps = [d for d in deps if d in to_compute] # 只显示可计算的依赖 print(f"{fid} -> {compute_deps}") print("================================\n") if len(order) != len(to_compute): print(len(order), len(to_compute)) raise RuntimeError("Circular dependency!") return order def compute(self, df: pl.DataFrame, target_factor_ids: List[str]) -> pl.DataFrame: exec_order = self._topological_sort(target_factor_ids) current_df = df.clone() for fid in exec_order: print(fid) if fid in current_df.columns: continue factor = self._factors[fid] current_df = factor.apply(current_df) return current_df