1、策略更新

2、新增qmt
This commit is contained in:
2025-11-29 00:23:12 +08:00
parent 0a942f92d1
commit c9b61db5b7
47 changed files with 97116 additions and 8867 deletions

View File

@@ -1,124 +1,190 @@
"""
因子算子框架 - Polars 实现
支持:截面滚动 → 拼回长表 → 按列名合并
返回形式可选:完整 DataFrame默认或单列 Series
因子算子框架 - Polars 实现(最终精简版)
- 因子自行生成 ID
- parameters 仅含计算参数(不含因子引用)
- required_factor_ids 是因子ID字符串列表
- calc_factor 通过 self.parameters 和 self.required_factor_ids 获取所需信息
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Literal
from typing import List, Literal, Dict, Any
from collections import defaultdict, deque
import json
import polars as pl
@dataclass
class OperatorConfig:
"""算子配置"""
name: str
description: str
required_columns: List[str]
output_columns: List[str]
parameters: dict
def _normalize_params(params: Dict[str, Any]) -> str:
if not params:
return ""
return json.dumps(sorted(params.items()), separators=(",", ":"))
def _simple_factor_id(name: str, params: Dict[str, Any]) -> str:
"""
生成简洁因子ID如:
("sma", {"window": 5}) → "sma_5"
("return", {"days": 20}) → "return_20"
("rank", {"input": "sma_5"}) → "rank_sma_5"
要求: params 的值必须是简单类型str/int/float/bool
"""
if not params:
return name
# 提取所有参数值,按 key 排序保证一致性
parts = []
for k in sorted(params.keys()):
v = params[k]
if isinstance(v, (str, int, float, bool)):
# 布尔转小写字符串
if isinstance(v, bool):
v = str(v).lower()
parts.append(str(v))
else:
raise ValueError(f"Unsupported parameter type for '{k}': {type(v)}. "
f"Only str/int/float/bool allowed for simple ID.")
return f"{name}_{'_'.join(parts)}"
class BaseOperator(ABC):
"""算子基类"""
class BaseFactor(ABC):
def __init__(self, name: str, parameters: Dict[str, Any], required_factor_ids: List[str]):
self.name = name
self.parameters = parameters
self.required_factor_ids = required_factor_ids
self.factor_id = self._generate_factor_id()
def __init__(self, config: OperatorConfig):
self.config = config
self.name = config.name
self.required_columns = config.required_columns
self.output_columns = config.output_columns
def _generate_factor_id(self) -> str:
return _simple_factor_id(self.name, self.parameters)
# ---------- 子类必须实现 ----------
@abstractmethod
def get_factor_name(self) -> str:
"""返回因子列名(用于合并)"""
pass
def get_factor_id(self) -> str:
return self.factor_id
@abstractmethod
def calc_factor(self, group_df: pl.DataFrame, **kwargs) -> pl.Series:
"""
真正的截面计算逻辑。
参数:按 ts_code 或 trade_date 分组后的子表
返回:与 group_df 行数一一对应的因子 Series含正确索引
"""
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
pass
# ---------- 公共接口 ----------
def apply(self,
df: pl.DataFrame,
return_type: Literal['df', 'series'] = 'df',
**kwargs) -> pl.DataFrame | pl.Series:
"""入口:截面滚动 → 拼回长表 → 合并/返回"""
if not self.validate_input(df):
raise ValueError(f"缺少必需列:{self.required_columns}")
long_table = self._sectional_roll(df, **kwargs) # ① 滚动
merged = self._merge_factor(df, long_table) # ② 合并
return merged if return_type == 'df' else merged[self.get_factor_name()]
# ---------- 内部流程 ----------
def validate_input(self, df: pl.DataFrame) -> bool:
return all(col in df.columns for col in self.required_columns)
@property
@abstractmethod
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
"""
截面滚动模板group → calc_factor → 拼回长表
返回含【trade_date, ts_code, factor】的长表
"""
def operator_type(self) -> Literal["stock", "date"]:
pass
def _merge_factor(self, original: pl.DataFrame, factor_table: pl.DataFrame) -> pl.DataFrame:
"""按 [ts_code, trade_date] 左联,原地追加因子列"""
factor_name = self.get_factor_name()
return original.join(factor_table.select(['ts_code', 'trade_date', factor_name]),
on=['ts_code', 'trade_date'],
how='left')
class StockWiseFactor(BaseFactor):
@property
def operator_type(self) -> Literal["stock"]:
return "stock"
# -------------------- 股票截面:按 ts_code 分组 --------------------
class StockWiseOperator(BaseOperator):
"""股票切面算子抽象类:按 ts_code 分组,对每个股票的时间序列计算因子"""
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
# 确保排序(时间顺序对 shift 等操作至关重要)
df_sorted = df.sort(['ts_code', 'trade_date'])
# 使用 map_groups对每个 ts_code 分组,传入完整子 DataFrame
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
df_sorted = df.sort(["ts_code", "trade_date"])
result = (
df_sorted
.group_by('ts_code', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
.group_by("ts_code", maintain_order=True)
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
.select(["ts_code", "trade_date", self.factor_id])
)
return result
# -------------------- 日期截面:按 trade_date 分组 --------------------
class DateWiseOperator(BaseOperator):
"""日期切面算子抽象类:按 trade_date 分组,对每个截面计算因子"""
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
if missing:
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
long_table = self._sectional_roll(df)
return df.join(
long_table.select(["ts_code", "trade_date", self.factor_id]),
on=["ts_code", "trade_date"],
how="left"
)
def _sectional_roll(self, df: pl.DataFrame, **kwargs) -> pl.DataFrame:
factor_name = self.get_factor_name()
df_sorted = df.sort(['trade_date', 'ts_code'])
class DateWiseFactor(BaseFactor):
@property
def operator_type(self) -> Literal["date"]:
return "date"
def _sectional_roll(self, df: pl.DataFrame) -> pl.DataFrame:
df_sorted = df.sort(["trade_date", "ts_code"])
result = (
df_sorted
.group_by('trade_date', maintain_order=True)
.map_groups(
lambda group_df: group_df.with_columns(
self.calc_factor(group_df, **kwargs)
)
)
.select(['ts_code', 'trade_date', factor_name])
.group_by("trade_date", maintain_order=True)
.map_groups(lambda g: g.with_columns(self.calc_factor(g)))
.select(["ts_code", "trade_date", self.factor_id])
)
return result
def apply(self, df: pl.DataFrame) -> pl.DataFrame:
missing = [fid for fid in self.required_factor_ids if fid not in df.columns]
if missing:
raise ValueError(f"Missing dependencies for {self.factor_id}: {missing}")
long_table = self._sectional_roll(df)
return df.join(
long_table.select(["ts_code", "trade_date", self.factor_id]),
on=["ts_code", "trade_date"],
how="left"
)
class FactorGraph:
def __init__(self):
self._factors = {} # factor_id -> factor
def add_factor(self, factor: BaseFactor):
fid = factor.get_factor_id()
if fid in self._factors:
raise ValueError(f"Factor '{fid}' already registered.")
self._factors[fid] = factor
def _topological_sort(self, target_ids: List[str]) -> List[str]:
all_factors = set()
queue = deque(target_ids)
while queue:
f = queue.popleft()
if f not in all_factors:
all_factors.add(f)
if f in self._factors:
for dep in self._factors[f].required_factor_ids:
if dep not in all_factors:
queue.append(dep)
to_compute = {f for f in all_factors if f in self._factors}
indegree = {f: 0 for f in to_compute}
adj = defaultdict(list)
for f in to_compute:
for dep in self._factors[f].required_factor_ids:
if dep in to_compute:
adj[dep].append(f)
indegree[f] += 1
queue = deque([f for f in to_compute if indegree[f] == 0])
order = []
while queue:
node = queue.popleft()
order.append(node)
for nb in adj[node]:
indegree[nb] -= 1
if indegree[nb] == 0:
queue.append(nb)
print("\n=== Factor Dependency Graph ===")
to_compute = {f for f in all_factors if f in self._factors}
for fid in sorted(to_compute):
deps = self._factors[fid].required_factor_ids
compute_deps = [d for d in deps if d in to_compute] # 只显示可计算的依赖
print(f"{fid} -> {compute_deps}")
print("================================\n")
if len(order) != len(to_compute):
print(len(order), len(to_compute))
raise RuntimeError("Circular dependency!")
return order
def compute(self, df: pl.DataFrame, target_factor_ids: List[str]) -> pl.DataFrame:
exec_order = self._topological_sort(target_factor_ids)
current_df = df.clone()
for fid in exec_order:
print(fid)
if fid in current_df.columns:
continue
factor = self._factors[fid]
current_df = factor.apply(current_df)
return current_df