190 lines
6.4 KiB
Python
190 lines
6.4 KiB
Python
"""
|
||
因子算子框架 - Polars 实现(最终精简版)
|
||
- 因子自行生成 ID
|
||
- parameters 仅含计算参数(不含因子引用)
|
||
- required_factor_ids 是因子ID字符串列表
|
||
- calc_factor 通过 self.parameters 和 self.required_factor_ids 获取所需信息
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Literal, Dict, Any
|
||
from collections import defaultdict, deque
|
||
import json
|
||
import polars as pl
|
||
|
||
|
||
def _normalize_params(params: Dict[str, Any]) -> str:
|
||
if not params:
|
||
return ""
|
||
return json.dumps(sorted(params.items()), separators=(",", ":"))
|
||
|
||
def _simple_factor_id(name: str, params: Dict[str, Any]) -> str:
|
||
"""
|
||
生成简洁因子ID,如:
|
||
("sma", {"window": 5}) → "sma_5"
|
||
("return", {"days": 20}) → "return_20"
|
||
("rank", {"input": "sma_5"}) → "rank_sma_5"
|
||
|
||
要求: params 的值必须是简单类型(str/int/float/bool)
|
||
"""
|
||
if not params:
|
||
return name
|
||
|
||
# 提取所有参数值,按 key 排序保证一致性
|
||
parts = []
|
||
for k in sorted(params.keys()):
|
||
v = params[k]
|
||
if isinstance(v, (str, int, float, bool)):
|
||
# 布尔转小写字符串
|
||
if isinstance(v, bool):
|
||
v = str(v).lower()
|
||
parts.append(str(v))
|
||
else:
|
||
raise ValueError(f"Unsupported parameter type for '{k}': {type(v)}. "
|
||
f"Only str/int/float/bool allowed for simple ID.")
|
||
|
||
return f"{name}_{'_'.join(parts)}"
|
||
|
||
|
||
class BaseFactor(ABC):
|
||
def __init__(self, name: str, parameters: Dict[str, Any], required_factor_ids: List[str]):
|
||
self.name = name
|
||
self.parameters = parameters
|
||
self.required_factor_ids = required_factor_ids
|
||
self.factor_id = self._generate_factor_id()
|
||
|
||
def _generate_factor_id(self) -> str:
|
||
return _simple_factor_id(self.name, self.parameters)
|
||
|
||
def get_factor_id(self) -> str:
|
||
return self.factor_id
|
||
|
||
@abstractmethod
|
||
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
||
pass
|
||
|
||
@property
|
||
@abstractmethod
|
||
def operator_type(self) -> Literal["stock", "date"]:
|
||
pass
|
||
|
||
|
||
class StockWiseFactor(BaseFactor):
|
||
@property
|
||
def operator_type(self) -> Literal["stock"]:
|
||
return "stock"
|
||
|
||
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
|
||
df_sorted = df.sort(["ts_code", "trade_date"])
|
||
result = (
|
||
df_sorted
|
||
.group_by("ts_code", maintain_order=True)
|
||
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
|
||
.select(["ts_code", "trade_date", self.factor_id])
|
||
)
|
||
return result
|
||
|
||
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
|
||
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
|
||
if missing:
|
||
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
|
||
long_table = self._sectional_roll(df)
|
||
return df.join(
|
||
long_table.select(["ts_code", "trade_date", self.factor_id]),
|
||
on=["ts_code", "trade_date"],
|
||
how="left"
|
||
)
|
||
|
||
|
||
class DateWiseFactor(BaseFactor):
|
||
@property
|
||
def operator_type(self) -> Literal["date"]:
|
||
return "date"
|
||
|
||
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
|
||
df_sorted = df.sort(["trade_date", "ts_code"])
|
||
result = (
|
||
df_sorted
|
||
.group_by("trade_date", maintain_order=True)
|
||
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
|
||
.select(["ts_code", "trade_date", self.factor_id])
|
||
)
|
||
return result
|
||
|
||
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
|
||
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
|
||
if missing:
|
||
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
|
||
long_table = self._sectional_roll(df)
|
||
return df.join(
|
||
long_table.select(["ts_code", "trade_date", self.factor_id]),
|
||
on=["ts_code", "trade_date"],
|
||
how="left"
|
||
)
|
||
|
||
|
||
class FactorGraph:
|
||
def __init__(self):
|
||
self._factors = {} # factor_id -> factor
|
||
|
||
def add_factor(self, factor: BaseFactor):
|
||
fid = factor.get_factor_id()
|
||
if fid in self._factors:
|
||
raise ValueError(f"Factor '{fid}' already registered.")
|
||
self._factors[fid] = factor
|
||
|
||
def _topological_sort(self, target_ids: List[str]) -> List[str]:
|
||
all_factors = set()
|
||
queue = deque(target_ids)
|
||
while queue:
|
||
f = queue.popleft()
|
||
if f not in all_factors:
|
||
all_factors.add(f)
|
||
if f in self._factors:
|
||
for dep in self._factors[f].required_factor_ids:
|
||
if dep not in all_factors:
|
||
queue.append(dep)
|
||
|
||
to_compute = {f for f in all_factors if f in self._factors}
|
||
indegree = {f: 0 for f in to_compute}
|
||
adj = defaultdict(list)
|
||
|
||
for f in to_compute:
|
||
for dep in self._factors[f].required_factor_ids:
|
||
if dep in to_compute:
|
||
adj[dep].append(f)
|
||
indegree[f] += 1
|
||
|
||
queue = deque([f for f in to_compute if indegree[f] == 0])
|
||
order = []
|
||
while queue:
|
||
node = queue.popleft()
|
||
order.append(node)
|
||
for nb in adj[node]:
|
||
indegree[nb] -= 1
|
||
if indegree[nb] == 0:
|
||
queue.append(nb)
|
||
|
||
print("\n=== Factor Dependency Graph ===")
|
||
to_compute = {f for f in all_factors if f in self._factors}
|
||
for fid in sorted(to_compute):
|
||
deps = self._factors[fid].required_factor_ids
|
||
compute_deps = [d for d in deps if d in to_compute] # 只显示可计算的依赖
|
||
print(f"{fid} -> {compute_deps}")
|
||
print("================================\n")
|
||
|
||
if len(order) != len(to_compute):
|
||
print(len(order), len(to_compute))
|
||
raise RuntimeError("Circular dependency!")
|
||
return order
|
||
|
||
def compute(self, df: pl.DataFrame, target_factor_ids: List[str]) -> pl.DataFrame:
|
||
exec_order = self._topological_sort(target_factor_ids)
|
||
current_df = df.clone()
|
||
for fid in exec_order:
|
||
print(fid)
|
||
if fid in current_df.columns:
|
||
continue
|
||
factor = self._factors[fid]
|
||
current_df = factor.apply(current_df)
|
||
return current_df |