2025-10-13 21:42:35 +08:00
|
|
|
|
"""
|
2025-11-29 00:23:12 +08:00
|
|
|
|
因子算子框架 - Polars 实现(最终精简版)
|
|
|
|
|
|
- 因子自行生成 ID
|
|
|
|
|
|
- parameters 仅含计算参数(不含因子引用)
|
|
|
|
|
|
- required_factor_ids 是因子ID字符串列表
|
|
|
|
|
|
- calc_factor 通过 self.parameters 和 self.required_factor_ids 获取所需信息
|
2025-10-13 21:42:35 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
2025-11-29 00:23:12 +08:00
|
|
|
|
from typing import List, Literal, Dict, Any
|
|
|
|
|
|
from collections import defaultdict, deque
|
|
|
|
|
|
import json
|
2025-10-14 09:44:46 +08:00
|
|
|
|
import polars as pl
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-11-29 00:23:12 +08:00
|
|
|
|
def _normalize_params(params: Dict[str, Any]) -> str:
|
|
|
|
|
|
if not params:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
return json.dumps(sorted(params.items()), separators=(",", ":"))
|
|
|
|
|
|
|
|
|
|
|
|
def _simple_factor_id(name: str, params: Dict[str, Any]) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
生成简洁因子ID,如:
|
|
|
|
|
|
("sma", {"window": 5}) → "sma_5"
|
|
|
|
|
|
("return", {"days": 20}) → "return_20"
|
|
|
|
|
|
("rank", {"input": "sma_5"}) → "rank_sma_5"
|
|
|
|
|
|
|
|
|
|
|
|
要求: params 的值必须是简单类型(str/int/float/bool)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not params:
|
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
|
# 提取所有参数值,按 key 排序保证一致性
|
|
|
|
|
|
parts = []
|
|
|
|
|
|
for k in sorted(params.keys()):
|
|
|
|
|
|
v = params[k]
|
|
|
|
|
|
if isinstance(v, (str, int, float, bool)):
|
|
|
|
|
|
# 布尔转小写字符串
|
|
|
|
|
|
if isinstance(v, bool):
|
|
|
|
|
|
v = str(v).lower()
|
|
|
|
|
|
parts.append(str(v))
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"Unsupported parameter type for '{k}': {type(v)}. "
|
|
|
|
|
|
f"Only str/int/float/bool allowed for simple ID.")
|
|
|
|
|
|
|
|
|
|
|
|
return f"{name}_{'_'.join(parts)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseFactor(ABC):
|
|
|
|
|
|
def __init__(self, name: str, parameters: Dict[str, Any], required_factor_ids: List[str]):
|
|
|
|
|
|
self.name = name
|
|
|
|
|
|
self.parameters = parameters
|
|
|
|
|
|
self.required_factor_ids = required_factor_ids
|
|
|
|
|
|
self.factor_id = self._generate_factor_id()
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_factor_id(self) -> str:
|
|
|
|
|
|
return _simple_factor_id(self.name, self.parameters)
|
|
|
|
|
|
|
|
|
|
|
|
def get_factor_id(self) -> str:
|
|
|
|
|
|
return self.factor_id
|
2025-10-14 09:44:46 +08:00
|
|
|
|
|
2025-10-13 21:42:35 +08:00
|
|
|
|
@abstractmethod
|
2025-11-29 00:23:12 +08:00
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
2025-10-13 21:42:35 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2025-11-29 00:23:12 +08:00
|
|
|
|
@property
|
2025-10-13 21:42:35 +08:00
|
|
|
|
@abstractmethod
|
2025-11-29 00:23:12 +08:00
|
|
|
|
def operator_type(self) -> Literal["stock", "date"]:
|
2025-10-13 21:42:35 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-11-29 00:23:12 +08:00
|
|
|
|
class StockWiseFactor(BaseFactor):
|
|
|
|
|
|
@property
|
|
|
|
|
|
def operator_type(self) -> Literal["stock"]:
|
|
|
|
|
|
return "stock"
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-11-29 00:23:12 +08:00
|
|
|
|
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
df_sorted = df.sort(["ts_code", "trade_date"])
|
2025-10-14 09:44:46 +08:00
|
|
|
|
result = (
|
|
|
|
|
|
df_sorted
|
2025-11-29 00:23:12 +08:00
|
|
|
|
.group_by("ts_code", maintain_order=True)
|
|
|
|
|
|
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
|
|
|
|
|
|
.select(["ts_code", "trade_date", self.factor_id])
|
2025-10-13 21:42:35 +08:00
|
|
|
|
)
|
2025-10-14 09:44:46 +08:00
|
|
|
|
return result
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-11-29 00:23:12 +08:00
|
|
|
|
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
|
|
|
|
|
|
if missing:
|
|
|
|
|
|
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
|
|
|
|
|
|
long_table = self._sectional_roll(df)
|
|
|
|
|
|
return df.join(
|
|
|
|
|
|
long_table.select(["ts_code", "trade_date", self.factor_id]),
|
|
|
|
|
|
on=["ts_code", "trade_date"],
|
|
|
|
|
|
how="left"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-10-13 21:42:35 +08:00
|
|
|
|
|
2025-11-29 00:23:12 +08:00
|
|
|
|
class DateWiseFactor(BaseFactor):
|
|
|
|
|
|
@property
|
|
|
|
|
|
def operator_type(self) -> Literal["date"]:
|
|
|
|
|
|
return "date"
|
|
|
|
|
|
|
|
|
|
|
|
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
df_sorted = df.sort(["trade_date", "ts_code"])
|
2025-10-14 09:44:46 +08:00
|
|
|
|
result = (
|
|
|
|
|
|
df_sorted
|
2025-11-29 00:23:12 +08:00
|
|
|
|
.group_by("trade_date", maintain_order=True)
|
|
|
|
|
|
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
|
|
|
|
|
|
.select(["ts_code", "trade_date", self.factor_id])
|
2025-10-13 21:42:35 +08:00
|
|
|
|
)
|
2025-10-14 09:44:46 +08:00
|
|
|
|
return result
|
2025-11-29 00:23:12 +08:00
|
|
|
|
|
|
|
|
|
|
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
|
|
|
|
|
|
if missing:
|
|
|
|
|
|
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
|
|
|
|
|
|
long_table = self._sectional_roll(df)
|
|
|
|
|
|
return df.join(
|
|
|
|
|
|
long_table.select(["ts_code", "trade_date", self.factor_id]),
|
|
|
|
|
|
on=["ts_code", "trade_date"],
|
|
|
|
|
|
how="left"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FactorGraph:
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self._factors = {} # factor_id -> factor
|
|
|
|
|
|
|
|
|
|
|
|
def add_factor(self, factor: BaseFactor):
|
|
|
|
|
|
fid = factor.get_factor_id()
|
|
|
|
|
|
if fid in self._factors:
|
|
|
|
|
|
raise ValueError(f"Factor '{fid}' already registered.")
|
|
|
|
|
|
self._factors[fid] = factor
|
|
|
|
|
|
|
|
|
|
|
|
def _topological_sort(self, target_ids: List[str]) -> List[str]:
|
|
|
|
|
|
all_factors = set()
|
|
|
|
|
|
queue = deque(target_ids)
|
|
|
|
|
|
while queue:
|
|
|
|
|
|
f = queue.popleft()
|
|
|
|
|
|
if f not in all_factors:
|
|
|
|
|
|
all_factors.add(f)
|
|
|
|
|
|
if f in self._factors:
|
|
|
|
|
|
for dep in self._factors[f].required_factor_ids:
|
|
|
|
|
|
if dep not in all_factors:
|
|
|
|
|
|
queue.append(dep)
|
|
|
|
|
|
|
|
|
|
|
|
to_compute = {f for f in all_factors if f in self._factors}
|
|
|
|
|
|
indegree = {f: 0 for f in to_compute}
|
|
|
|
|
|
adj = defaultdict(list)
|
|
|
|
|
|
|
|
|
|
|
|
for f in to_compute:
|
|
|
|
|
|
for dep in self._factors[f].required_factor_ids:
|
|
|
|
|
|
if dep in to_compute:
|
|
|
|
|
|
adj[dep].append(f)
|
|
|
|
|
|
indegree[f] += 1
|
|
|
|
|
|
|
|
|
|
|
|
queue = deque([f for f in to_compute if indegree[f] == 0])
|
|
|
|
|
|
order = []
|
|
|
|
|
|
while queue:
|
|
|
|
|
|
node = queue.popleft()
|
|
|
|
|
|
order.append(node)
|
|
|
|
|
|
for nb in adj[node]:
|
|
|
|
|
|
indegree[nb] -= 1
|
|
|
|
|
|
if indegree[nb] == 0:
|
|
|
|
|
|
queue.append(nb)
|
|
|
|
|
|
|
|
|
|
|
|
print("\n=== Factor Dependency Graph ===")
|
|
|
|
|
|
to_compute = {f for f in all_factors if f in self._factors}
|
|
|
|
|
|
for fid in sorted(to_compute):
|
|
|
|
|
|
deps = self._factors[fid].required_factor_ids
|
|
|
|
|
|
compute_deps = [d for d in deps if d in to_compute] # 只显示可计算的依赖
|
|
|
|
|
|
print(f"{fid} -> {compute_deps}")
|
|
|
|
|
|
print("================================\n")
|
|
|
|
|
|
|
|
|
|
|
|
if len(order) != len(to_compute):
|
|
|
|
|
|
print(len(order), len(to_compute))
|
|
|
|
|
|
raise RuntimeError("Circular dependency!")
|
|
|
|
|
|
return order
|
|
|
|
|
|
|
|
|
|
|
|
def compute(self, df: pl.DataFrame, target_factor_ids: List[str]) -> pl.DataFrame:
|
|
|
|
|
|
exec_order = self._topological_sort(target_factor_ids)
|
|
|
|
|
|
current_df = df.clone()
|
|
|
|
|
|
for fid in exec_order:
|
|
|
|
|
|
print(fid)
|
|
|
|
|
|
if fid in current_df.columns:
|
|
|
|
|
|
continue
|
|
|
|
|
|
factor = self._factors[fid]
|
|
|
|
|
|
current_df = factor.apply(current_df)
|
|
|
|
|
|
return current_df
|