Compare commits
3 Commits
9b826c1845
...
05d0c90312
| Author | SHA1 | Date | |
|---|---|---|---|
| 05d0c90312 | |||
| 77e4e94e05 | |||
| 1c0c4a0de1 |
@@ -52,6 +52,22 @@ from src.factors.engine import (
|
|||||||
ComputeEngine,
|
ComputeEngine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from src.factors.parser import FormulaParser
|
||||||
|
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
|
||||||
|
from src.factors.exceptions import (
|
||||||
|
FormulaParseError,
|
||||||
|
UnknownFunctionError,
|
||||||
|
InvalidSyntaxError,
|
||||||
|
EmptyExpressionError,
|
||||||
|
RegistryError,
|
||||||
|
DuplicateFunctionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保持向后兼容:factor_engine.py 中的类也可以通过 src.factors.engine 访问
|
||||||
|
# 例如:from src.factors.engine import FactorEngine
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# DSL 层
|
# DSL 层
|
||||||
"Node",
|
"Node",
|
||||||
@@ -73,4 +89,15 @@ __all__ = [
|
|||||||
"DataRouter",
|
"DataRouter",
|
||||||
"ExecutionPlanner",
|
"ExecutionPlanner",
|
||||||
"ComputeEngine",
|
"ComputeEngine",
|
||||||
|
# 解析器 (Phase 1 新增)
|
||||||
|
"FormulaParser",
|
||||||
|
# 注册表 (Phase 1 新增)
|
||||||
|
"FunctionRegistry",
|
||||||
|
# 异常类 (Phase 1 新增)
|
||||||
|
"FormulaParseError",
|
||||||
|
"UnknownFunctionError",
|
||||||
|
"InvalidSyntaxError",
|
||||||
|
"EmptyExpressionError",
|
||||||
|
"RegistryError",
|
||||||
|
"DuplicateFunctionError",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,817 +0,0 @@
|
|||||||
"""FactorEngine - 因子计算引擎统一入口。
|
|
||||||
|
|
||||||
提供从表达式注册到结果输出的完整执行链路:
|
|
||||||
接收研究员的表达式 -> 调用编译器解析依赖 -> 调用路由器连接数据库拉取并组装核心宽表
|
|
||||||
-> 调用翻译器生成物理执行计划 -> 将计划提交给计算引擎执行并行运算。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import polars as pl
|
|
||||||
|
|
||||||
from src.factors.dsl import (
|
|
||||||
Node,
|
|
||||||
Symbol,
|
|
||||||
FunctionNode,
|
|
||||||
BinaryOpNode,
|
|
||||||
UnaryOpNode,
|
|
||||||
Constant,
|
|
||||||
)
|
|
||||||
from src.factors.compiler import DependencyExtractor
|
|
||||||
from src.factors.translator import PolarsTranslator
|
|
||||||
from src.data.storage import Storage
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DataSpec:
|
|
||||||
"""数据规格定义。
|
|
||||||
|
|
||||||
描述因子计算所需的数据表和字段。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
table: 数据表名称
|
|
||||||
columns: 需要的字段列表
|
|
||||||
lookback_days: 回看天数(用于时序计算)
|
|
||||||
"""
|
|
||||||
|
|
||||||
table: str
|
|
||||||
columns: List[str]
|
|
||||||
lookback_days: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExecutionPlan:
|
|
||||||
"""执行计划。
|
|
||||||
|
|
||||||
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
data_specs: 数据规格列表
|
|
||||||
polars_expr: Polars 表达式
|
|
||||||
dependencies: 依赖的原始字段
|
|
||||||
output_name: 输出因子名称
|
|
||||||
"""
|
|
||||||
|
|
||||||
data_specs: List[DataSpec]
|
|
||||||
polars_expr: pl.Expr
|
|
||||||
dependencies: Set[str]
|
|
||||||
output_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class DataRouter:
|
|
||||||
"""数据路由器 - 按需取数、组装核心宽表。
|
|
||||||
|
|
||||||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
|
||||||
支持内存数据源(用于测试)和真实数据库连接。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
|
|
||||||
is_memory_mode: 是否为内存模式
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
|
|
||||||
"""初始化数据路由器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_source: 内存数据源,字典格式 {表名: DataFrame}
|
|
||||||
为 None 时自动连接 DuckDB 数据库
|
|
||||||
"""
|
|
||||||
self.data_source = data_source or {}
|
|
||||||
self.is_memory_mode = data_source is not None
|
|
||||||
self._cache: Dict[str, pl.DataFrame] = {}
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
# 数据库模式下初始化 Storage
|
|
||||||
if not self.is_memory_mode:
|
|
||||||
self._storage = Storage()
|
|
||||||
else:
|
|
||||||
self._storage = None
|
|
||||||
|
|
||||||
def fetch_data(
|
|
||||||
self,
|
|
||||||
data_specs: List[DataSpec],
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
stock_codes: Optional[List[str]] = None,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""根据数据规格获取并组装核心宽表。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_specs: 数据规格列表
|
|
||||||
start_date: 开始日期 (YYYYMMDD)
|
|
||||||
end_date: 结束日期 (YYYYMMDD)
|
|
||||||
stock_codes: 股票代码列表,None 表示全市场
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
组装好的核心宽表 DataFrame
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 当数据源中缺少必要的表或字段时
|
|
||||||
"""
|
|
||||||
if not data_specs:
|
|
||||||
raise ValueError("数据规格不能为空")
|
|
||||||
|
|
||||||
# 收集所有需要的表和字段
|
|
||||||
required_tables: Dict[str, Set[str]] = {}
|
|
||||||
max_lookback = 0
|
|
||||||
|
|
||||||
for spec in data_specs:
|
|
||||||
if spec.table not in required_tables:
|
|
||||||
required_tables[spec.table] = set()
|
|
||||||
required_tables[spec.table].update(spec.columns)
|
|
||||||
max_lookback = max(max_lookback, spec.lookback_days)
|
|
||||||
|
|
||||||
# 调整日期范围以包含回看期
|
|
||||||
adjusted_start = self._adjust_start_date(start_date, max_lookback)
|
|
||||||
|
|
||||||
# 从数据源获取各表数据
|
|
||||||
table_data = {}
|
|
||||||
for table_name, columns in required_tables.items():
|
|
||||||
df = self._load_table(
|
|
||||||
table_name=table_name,
|
|
||||||
columns=list(columns),
|
|
||||||
start_date=adjusted_start,
|
|
||||||
end_date=end_date,
|
|
||||||
stock_codes=stock_codes,
|
|
||||||
)
|
|
||||||
table_data[table_name] = df
|
|
||||||
|
|
||||||
# 组装核心宽表
|
|
||||||
core_table = self._assemble_wide_table(table_data, required_tables)
|
|
||||||
|
|
||||||
# 过滤到实际请求日期范围
|
|
||||||
core_table = core_table.filter(
|
|
||||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
|
||||||
)
|
|
||||||
|
|
||||||
return core_table
|
|
||||||
|
|
||||||
def _load_table(
|
|
||||||
self,
|
|
||||||
table_name: str,
|
|
||||||
columns: List[str],
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
stock_codes: Optional[List[str]] = None,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""加载单个表的数据。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table_name: 表名
|
|
||||||
columns: 需要的字段
|
|
||||||
start_date: 开始日期
|
|
||||||
end_date: 结束日期
|
|
||||||
stock_codes: 股票代码过滤
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
过滤后的 DataFrame
|
|
||||||
"""
|
|
||||||
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if cache_key in self._cache:
|
|
||||||
return self._cache[cache_key]
|
|
||||||
|
|
||||||
if self.is_memory_mode:
|
|
||||||
df = self._load_from_memory(
|
|
||||||
table_name, columns, start_date, end_date, stock_codes
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
df = self._load_from_database(
|
|
||||||
table_name, columns, start_date, end_date, stock_codes
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
self._cache[cache_key] = df
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
def _load_from_memory(
|
|
||||||
self,
|
|
||||||
table_name: str,
|
|
||||||
columns: List[str],
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
stock_codes: Optional[List[str]] = None,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""从内存数据源加载数据。"""
|
|
||||||
if table_name not in self.data_source:
|
|
||||||
raise ValueError(f"内存数据源中缺少表: {table_name}")
|
|
||||||
|
|
||||||
df = self.data_source[table_name]
|
|
||||||
|
|
||||||
# 确保必需字段存在
|
|
||||||
for col in columns:
|
|
||||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
|
||||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
|
||||||
|
|
||||||
# 过滤日期和股票
|
|
||||||
df = df.filter(
|
|
||||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
|
||||||
)
|
|
||||||
|
|
||||||
if stock_codes is not None:
|
|
||||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
|
||||||
|
|
||||||
# 选择需要的列
|
|
||||||
select_cols = ["ts_code", "trade_date"] + [
|
|
||||||
c for c in columns if c in df.columns
|
|
||||||
]
|
|
||||||
return df.select(select_cols)
|
|
||||||
|
|
||||||
def _load_from_database(
|
|
||||||
self,
|
|
||||||
table_name: str,
|
|
||||||
columns: List[str],
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
stock_codes: Optional[List[str]] = None,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""从 DuckDB 数据库加载数据。
|
|
||||||
|
|
||||||
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
|
|
||||||
"""
|
|
||||||
if self._storage is None:
|
|
||||||
raise RuntimeError("Storage 未初始化")
|
|
||||||
|
|
||||||
# 检查表是否存在
|
|
||||||
if not self._storage.exists(table_name):
|
|
||||||
raise ValueError(f"数据库中不存在表: {table_name}")
|
|
||||||
|
|
||||||
# 构建查询参数
|
|
||||||
# Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况
|
|
||||||
if stock_codes is not None and len(stock_codes) == 1:
|
|
||||||
ts_code_filter = stock_codes[0]
|
|
||||||
else:
|
|
||||||
ts_code_filter = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 从数据库加载原始数据
|
|
||||||
df = self._storage.load_polars(
|
|
||||||
name=table_name,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
ts_code=ts_code_filter,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
|
|
||||||
|
|
||||||
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
|
||||||
if stock_codes is not None and len(stock_codes) > 1:
|
|
||||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
|
||||||
|
|
||||||
# 检查必需字段
|
|
||||||
for col in columns:
|
|
||||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
|
||||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
|
||||||
|
|
||||||
# 选择需要的列
|
|
||||||
select_cols = ["ts_code", "trade_date"] + [
|
|
||||||
c for c in columns if c in df.columns
|
|
||||||
]
|
|
||||||
|
|
||||||
return df.select(select_cols)
|
|
||||||
|
|
||||||
def _assemble_wide_table(
|
|
||||||
self,
|
|
||||||
table_data: Dict[str, pl.DataFrame],
|
|
||||||
required_tables: Dict[str, Set[str]],
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""组装多表数据为核心宽表。
|
|
||||||
|
|
||||||
使用 left join 合并各表数据,以第一个表为基准。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table_data: 表名到 DataFrame 的映射
|
|
||||||
required_tables: 表名到字段集合的映射
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
组装后的宽表
|
|
||||||
"""
|
|
||||||
if not table_data:
|
|
||||||
raise ValueError("没有数据可组装")
|
|
||||||
|
|
||||||
# 以第一个表为基准
|
|
||||||
base_table_name = list(table_data.keys())[0]
|
|
||||||
result = table_data[base_table_name]
|
|
||||||
|
|
||||||
# 与其他表 join
|
|
||||||
for table_name, df in table_data.items():
|
|
||||||
if table_name == base_table_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用 ts_code 和 trade_date 作为 join 键
|
|
||||||
result = result.join(
|
|
||||||
df,
|
|
||||||
on=["ts_code", "trade_date"],
|
|
||||||
how="left",
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
|
|
||||||
"""根据回看天数调整开始日期。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_date: 原始开始日期 (YYYYMMDD)
|
|
||||||
lookback_days: 需要回看的交易日数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
调整后的开始日期
|
|
||||||
"""
|
|
||||||
# 简化的日期调整:假设每月30天,向前推移
|
|
||||||
# 实际应用中应该使用交易日历
|
|
||||||
year = int(start_date[:4])
|
|
||||||
month = int(start_date[4:6])
|
|
||||||
day = int(start_date[6:8])
|
|
||||||
|
|
||||||
total_days = lookback_days + 30 # 额外缓冲
|
|
||||||
|
|
||||||
day -= total_days
|
|
||||||
while day <= 0:
|
|
||||||
month -= 1
|
|
||||||
if month <= 0:
|
|
||||||
month = 12
|
|
||||||
year -= 1
|
|
||||||
day += 30
|
|
||||||
|
|
||||||
return f"{year:04d}{month:02d}{day:02d}"
|
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
|
||||||
"""清除数据缓存。"""
|
|
||||||
with self._lock:
|
|
||||||
self._cache.clear()
|
|
||||||
|
|
||||||
# 数据库模式下清理 Storage 连接(可选)
|
|
||||||
if not self.is_memory_mode and self._storage is not None:
|
|
||||||
# Storage 使用单例模式,不需要关闭连接
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanner:
|
|
||||||
"""执行计划生成器。
|
|
||||||
|
|
||||||
整合编译器和翻译器,生成完整的执行计划。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
compiler: 依赖提取器
|
|
||||||
translator: Polars 翻译器
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""初始化执行计划生成器。"""
|
|
||||||
self.compiler = DependencyExtractor()
|
|
||||||
self.translator = PolarsTranslator()
|
|
||||||
|
|
||||||
def create_plan(
|
|
||||||
self,
|
|
||||||
expression: Node,
|
|
||||||
output_name: str = "factor",
|
|
||||||
data_specs: Optional[List[DataSpec]] = None,
|
|
||||||
) -> ExecutionPlan:
|
|
||||||
"""从表达式创建执行计划。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
expression: DSL 表达式节点
|
|
||||||
output_name: 输出因子名称
|
|
||||||
data_specs: 预定义的数据规格,None 时自动推导
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
执行计划对象
|
|
||||||
"""
|
|
||||||
# 1. 提取依赖
|
|
||||||
dependencies = self.compiler.extract_dependencies(expression)
|
|
||||||
|
|
||||||
# 2. 翻译为 Polars 表达式
|
|
||||||
polars_expr = self.translator.translate(expression)
|
|
||||||
|
|
||||||
# 3. 推导或验证数据规格
|
|
||||||
if data_specs is None:
|
|
||||||
data_specs = self._infer_data_specs(dependencies, expression)
|
|
||||||
|
|
||||||
return ExecutionPlan(
|
|
||||||
data_specs=data_specs,
|
|
||||||
polars_expr=polars_expr,
|
|
||||||
dependencies=dependencies,
|
|
||||||
output_name=output_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _infer_data_specs(
|
|
||||||
self,
|
|
||||||
dependencies: Set[str],
|
|
||||||
expression: Node,
|
|
||||||
) -> List[DataSpec]:
|
|
||||||
"""从依赖推导数据规格。
|
|
||||||
|
|
||||||
根据表达式中的函数类型推断回看天数需求。
|
|
||||||
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
|
||||||
默认从 pro_bar 表获取。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dependencies: 依赖的字段集合
|
|
||||||
expression: 表达式节点
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
数据规格列表
|
|
||||||
"""
|
|
||||||
# 计算最大回看窗口
|
|
||||||
max_window = self._extract_max_window(expression)
|
|
||||||
lookback_days = max(1, max_window)
|
|
||||||
|
|
||||||
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
|
||||||
pro_bar_fields = {
|
|
||||||
"open",
|
|
||||||
"high",
|
|
||||||
"low",
|
|
||||||
"close",
|
|
||||||
"vol",
|
|
||||||
"amount",
|
|
||||||
"pre_close",
|
|
||||||
"change",
|
|
||||||
"pct_chg",
|
|
||||||
"turnover_rate",
|
|
||||||
"volume_ratio",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 将依赖分为 pro_bar 字段和其他字段
|
|
||||||
pro_bar_deps = dependencies & pro_bar_fields
|
|
||||||
other_deps = dependencies - pro_bar_fields
|
|
||||||
|
|
||||||
data_specs = []
|
|
||||||
|
|
||||||
# pro_bar 表的数据规格
|
|
||||||
if pro_bar_deps:
|
|
||||||
data_specs.append(
|
|
||||||
DataSpec(
|
|
||||||
table="pro_bar",
|
|
||||||
columns=sorted(pro_bar_deps),
|
|
||||||
lookback_days=lookback_days,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 其他字段从 daily 表获取
|
|
||||||
if other_deps:
|
|
||||||
data_specs.append(
|
|
||||||
DataSpec(
|
|
||||||
table="daily",
|
|
||||||
columns=sorted(other_deps),
|
|
||||||
lookback_days=lookback_days,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return data_specs
|
|
||||||
|
|
||||||
def _extract_max_window(self, node: Node) -> int:
|
|
||||||
"""从表达式中提取最大窗口大小。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node: AST 节点
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
最大窗口大小,无时序函数返回 1
|
|
||||||
"""
|
|
||||||
if isinstance(node, FunctionNode):
|
|
||||||
window = 1
|
|
||||||
# 检查函数参数中的窗口大小
|
|
||||||
for arg in node.args:
|
|
||||||
if (
|
|
||||||
isinstance(arg, Constant)
|
|
||||||
and isinstance(arg.value, int)
|
|
||||||
and arg.value > window
|
|
||||||
):
|
|
||||||
window = arg.value
|
|
||||||
|
|
||||||
# 递归检查子表达式
|
|
||||||
for arg in node.args:
|
|
||||||
if isinstance(arg, Node) and not isinstance(arg, Constant):
|
|
||||||
window = max(window, self._extract_max_window(arg))
|
|
||||||
|
|
||||||
return window
|
|
||||||
|
|
||||||
elif isinstance(node, BinaryOpNode):
|
|
||||||
return max(
|
|
||||||
self._extract_max_window(node.left),
|
|
||||||
self._extract_max_window(node.right),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(node, UnaryOpNode):
|
|
||||||
return self._extract_max_window(node.operand)
|
|
||||||
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
class ComputeEngine:
|
|
||||||
"""计算引擎 - 执行并行运算。
|
|
||||||
|
|
||||||
负责将执行计划应用到数据上,支持并行计算。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
max_workers: 最大并行工作线程数
|
|
||||||
use_processes: 是否使用进程池(CPU 密集型任务)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_workers: int = 4,
|
|
||||||
use_processes: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""初始化计算引擎。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_workers: 最大并行工作线程数
|
|
||||||
use_processes: 是否使用进程池代替线程池
|
|
||||||
"""
|
|
||||||
self.max_workers = max_workers
|
|
||||||
self.use_processes = use_processes
|
|
||||||
|
|
||||||
def execute(
|
|
||||||
self,
|
|
||||||
plan: ExecutionPlan,
|
|
||||||
data: pl.DataFrame,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""执行计算计划。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plan: 执行计划
|
|
||||||
data: 输入数据(核心宽表)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含因子结果的 DataFrame
|
|
||||||
"""
|
|
||||||
# 检查依赖字段是否存在
|
|
||||||
missing_cols = plan.dependencies - set(data.columns)
|
|
||||||
if missing_cols:
|
|
||||||
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
|
|
||||||
|
|
||||||
# 执行计算
|
|
||||||
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def execute_batch(
|
|
||||||
self,
|
|
||||||
plans: List[ExecutionPlan],
|
|
||||||
data: pl.DataFrame,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""批量执行多个计算计划。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plans: 执行计划列表
|
|
||||||
data: 输入数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含所有因子结果的 DataFrame
|
|
||||||
"""
|
|
||||||
result = data
|
|
||||||
|
|
||||||
for plan in plans:
|
|
||||||
result = self.execute(plan, result)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def execute_parallel(
|
|
||||||
self,
|
|
||||||
plans: List[ExecutionPlan],
|
|
||||||
data: pl.DataFrame,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""并行执行多个计算计划。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plans: 执行计划列表
|
|
||||||
data: 输入数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含所有因子结果的 DataFrame
|
|
||||||
"""
|
|
||||||
# 检查计划间依赖
|
|
||||||
independent_plans = []
|
|
||||||
dependent_plans = []
|
|
||||||
available_cols = set(data.columns)
|
|
||||||
|
|
||||||
for plan in plans:
|
|
||||||
if plan.dependencies <= available_cols:
|
|
||||||
independent_plans.append(plan)
|
|
||||||
available_cols.add(plan.output_name)
|
|
||||||
else:
|
|
||||||
dependent_plans.append(plan)
|
|
||||||
|
|
||||||
# 并行执行独立计划
|
|
||||||
if independent_plans:
|
|
||||||
ExecutorClass = (
|
|
||||||
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
|
||||||
)
|
|
||||||
|
|
||||||
with ExecutorClass(max_workers=self.max_workers) as executor:
|
|
||||||
futures = {
|
|
||||||
executor.submit(self._execute_single, plan, data): plan
|
|
||||||
for plan in independent_plans
|
|
||||||
}
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for future in futures:
|
|
||||||
plan = futures[future]
|
|
||||||
try:
|
|
||||||
result_col = future.result()
|
|
||||||
results.append((plan.output_name, result_col))
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
|
|
||||||
|
|
||||||
# 合并结果
|
|
||||||
for name, series in results:
|
|
||||||
data = data.with_columns([series.alias(name)])
|
|
||||||
|
|
||||||
# 顺序执行依赖计划
|
|
||||||
for plan in dependent_plans:
|
|
||||||
data = self.execute(plan, data)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _execute_single(
|
|
||||||
self,
|
|
||||||
plan: ExecutionPlan,
|
|
||||||
data: pl.DataFrame,
|
|
||||||
) -> pl.Series:
|
|
||||||
"""执行单个计划并返回结果列。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plan: 执行计划
|
|
||||||
data: 输入数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
计算结果序列
|
|
||||||
"""
|
|
||||||
result = self.execute(plan, data)
|
|
||||||
return result[plan.output_name]
|
|
||||||
|
|
||||||
|
|
||||||
class FactorEngine:
|
|
||||||
"""因子计算引擎 - 系统统一入口。
|
|
||||||
|
|
||||||
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
|
||||||
|
|
||||||
执行流程:
|
|
||||||
1. 注册表达式 -> 调用编译器解析依赖
|
|
||||||
2. 调用路由器连接数据库拉取并组装核心宽表
|
|
||||||
3. 调用翻译器生成物理执行计划
|
|
||||||
4. 将计划提交给计算引擎执行并行运算
|
|
||||||
5. 返回包含因子结果的数据表
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
router: 数据路由器
|
|
||||||
planner: 执行计划生成器
|
|
||||||
compute_engine: 计算引擎
|
|
||||||
registered_expressions: 注册的表达式字典
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
|
||||||
max_workers: int = 4,
|
|
||||||
) -> None:
|
|
||||||
"""初始化因子引擎。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_source: 内存数据源,为 None 时使用数据库连接
|
|
||||||
max_workers: 并行计算的最大工作线程数
|
|
||||||
"""
|
|
||||||
self.router = DataRouter(data_source)
|
|
||||||
self.planner = ExecutionPlanner()
|
|
||||||
self.compute_engine = ComputeEngine(max_workers=max_workers)
|
|
||||||
self.registered_expressions: Dict[str, Node] = {}
|
|
||||||
self._plans: Dict[str, ExecutionPlan] = {}
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
expression: Node,
|
|
||||||
data_specs: Optional[List[DataSpec]] = None,
|
|
||||||
) -> FactorEngine:
|
|
||||||
"""注册因子表达式。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: 因子名称
|
|
||||||
expression: DSL 表达式
|
|
||||||
data_specs: 数据规格,None 时自动推导
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
self,支持链式调用
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from src.factors.api import close, ts_mean
|
|
||||||
>>> engine = FactorEngine()
|
|
||||||
>>> engine.register("ma20", ts_mean(close, 20))
|
|
||||||
"""
|
|
||||||
self.registered_expressions[name] = expression
|
|
||||||
|
|
||||||
# 预创建执行计划
|
|
||||||
plan = self.planner.create_plan(
|
|
||||||
expression=expression,
|
|
||||||
output_name=name,
|
|
||||||
data_specs=data_specs,
|
|
||||||
)
|
|
||||||
self._plans[name] = plan
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def compute(
|
|
||||||
self,
|
|
||||||
factor_names: Union[str, List[str]],
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
stock_codes: Optional[List[str]] = None,
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""计算指定因子的值。
|
|
||||||
|
|
||||||
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factor_names: 因子名称或名称列表
|
|
||||||
start_date: 开始日期 (YYYYMMDD)
|
|
||||||
end_date: 结束日期 (YYYYMMDD)
|
|
||||||
stock_codes: 股票代码列表,None 表示全市场
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含因子结果的数据表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 当因子未注册或数据不足时
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> result = engine.compute("ma20", "20240101", "20240131")
|
|
||||||
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
|
|
||||||
"""
|
|
||||||
# 标准化因子名称
|
|
||||||
if isinstance(factor_names, str):
|
|
||||||
factor_names = [factor_names]
|
|
||||||
|
|
||||||
# 1. 获取执行计划
|
|
||||||
plans = []
|
|
||||||
for name in factor_names:
|
|
||||||
if name not in self._plans:
|
|
||||||
raise ValueError(f"因子未注册: {name}")
|
|
||||||
plans.append(self._plans[name])
|
|
||||||
|
|
||||||
# 2. 合并数据规格并获取数据
|
|
||||||
all_specs = []
|
|
||||||
for plan in plans:
|
|
||||||
all_specs.extend(plan.data_specs)
|
|
||||||
|
|
||||||
# 3. 从路由器获取核心宽表
|
|
||||||
core_data = self.router.fetch_data(
|
|
||||||
data_specs=all_specs,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
stock_codes=stock_codes,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(core_data) == 0:
|
|
||||||
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
|
||||||
|
|
||||||
# 4. 执行计算
|
|
||||||
if len(plans) == 1:
|
|
||||||
result = self.compute_engine.execute(plans[0], core_data)
|
|
||||||
else:
|
|
||||||
result = self.compute_engine.execute_batch(plans, core_data)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def list_registered(self) -> List[str]:
|
|
||||||
"""获取已注册的因子列表。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
因子名称列表
|
|
||||||
"""
|
|
||||||
return list(self.registered_expressions.keys())
|
|
||||||
|
|
||||||
def get_expression(self, name: str) -> Optional[Node]:
|
|
||||||
"""获取已注册的表达式。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: 因子名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
表达式节点,未注册时返回 None
|
|
||||||
"""
|
|
||||||
return self.registered_expressions.get(name)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""清除所有注册的表达式和缓存。"""
|
|
||||||
self.registered_expressions.clear()
|
|
||||||
self._plans.clear()
|
|
||||||
self.router.clear_cache()
|
|
||||||
|
|
||||||
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
|
|
||||||
"""预览因子的执行计划。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factor_name: 因子名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
执行计划,未注册时返回 None
|
|
||||||
"""
|
|
||||||
return self._plans.get(factor_name)
|
|
||||||
28
src/factors/engine/__init__.py
Normal file
28
src/factors/engine/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""因子计算引擎模块。
|
||||||
|
|
||||||
|
提供完整的因子计算引擎组件:
|
||||||
|
- DataSpec: 数据规格定义
|
||||||
|
- ExecutionPlan: 执行计划
|
||||||
|
- DataRouter: 数据路由器
|
||||||
|
- ExecutionPlanner: 执行计划生成器
|
||||||
|
- ComputeEngine: 计算引擎
|
||||||
|
- FactorEngine: 因子计算引擎(统一入口)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||||
|
from src.factors.engine.data_router import DataRouter
|
||||||
|
from src.factors.engine.planner import ExecutionPlanner
|
||||||
|
from src.factors.engine.compute_engine import ComputeEngine
|
||||||
|
from src.factors.engine.factor_engine import FactorEngine
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataSpec",
|
||||||
|
"ExecutionPlan",
|
||||||
|
"DataRouter",
|
||||||
|
"ExecutionPlanner",
|
||||||
|
"ComputeEngine",
|
||||||
|
"FactorEngine",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 类型导出(用于类型注解)
|
||||||
|
# FunctionRegistry 从 src.factors.registry 导入
|
||||||
155
src/factors/engine/compute_engine.py
Normal file
155
src/factors/engine/compute_engine.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""计算引擎。
|
||||||
|
|
||||||
|
执行并行运算,负责将执行计划应用到数据上。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.factors.engine.data_spec import ExecutionPlan
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeEngine:
|
||||||
|
"""计算引擎 - 执行并行运算。
|
||||||
|
|
||||||
|
负责将执行计划应用到数据上,支持并行计算。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
max_workers: 最大并行工作线程数
|
||||||
|
use_processes: 是否使用进程池(CPU 密集型任务)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_workers: int = 4,
|
||||||
|
use_processes: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""初始化计算引擎。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_workers: 最大并行工作线程数
|
||||||
|
use_processes: 是否使用进程池代替线程池
|
||||||
|
"""
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.use_processes = use_processes
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
plan: ExecutionPlan,
|
||||||
|
data: pl.DataFrame,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""执行计算计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: 执行计划
|
||||||
|
data: 输入数据(核心宽表)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含因子结果的 DataFrame
|
||||||
|
"""
|
||||||
|
# 检查依赖字段是否存在
|
||||||
|
missing_cols = plan.dependencies - set(data.columns)
|
||||||
|
if missing_cols:
|
||||||
|
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
|
||||||
|
|
||||||
|
# 执行计算
|
||||||
|
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def execute_batch(
|
||||||
|
self,
|
||||||
|
plans: List[ExecutionPlan],
|
||||||
|
data: pl.DataFrame,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""批量执行多个计算计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plans: 执行计划列表
|
||||||
|
data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有因子结果的 DataFrame
|
||||||
|
"""
|
||||||
|
result = data
|
||||||
|
|
||||||
|
for plan in plans:
|
||||||
|
result = self.execute(plan, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def execute_parallel(
|
||||||
|
self,
|
||||||
|
plans: List[ExecutionPlan],
|
||||||
|
data: pl.DataFrame,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""并行执行多个计算计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plans: 执行计划列表
|
||||||
|
data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有因子结果的 DataFrame
|
||||||
|
"""
|
||||||
|
# 检查计划间依赖
|
||||||
|
independent_plans = []
|
||||||
|
dependent_plans = []
|
||||||
|
available_cols = set(data.columns)
|
||||||
|
|
||||||
|
for plan in plans:
|
||||||
|
if plan.dependencies <= available_cols:
|
||||||
|
independent_plans.append(plan)
|
||||||
|
available_cols.add(plan.output_name)
|
||||||
|
else:
|
||||||
|
dependent_plans.append(plan)
|
||||||
|
|
||||||
|
# 并行执行独立计划
|
||||||
|
if independent_plans:
|
||||||
|
ExecutorClass = (
|
||||||
|
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
||||||
|
)
|
||||||
|
|
||||||
|
with ExecutorClass(max_workers=self.max_workers) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(self._execute_single, plan, data): plan
|
||||||
|
for plan in independent_plans
|
||||||
|
}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for future in futures:
|
||||||
|
plan = futures[future]
|
||||||
|
try:
|
||||||
|
result_col = future.result()
|
||||||
|
results.append((plan.output_name, result_col))
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
|
||||||
|
|
||||||
|
# 合并结果
|
||||||
|
for name, series in results:
|
||||||
|
data = data.with_columns([series.alias(name)])
|
||||||
|
|
||||||
|
# 顺序执行依赖计划
|
||||||
|
for plan in dependent_plans:
|
||||||
|
data = self.execute(plan, data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _execute_single(
|
||||||
|
self,
|
||||||
|
plan: ExecutionPlan,
|
||||||
|
data: pl.DataFrame,
|
||||||
|
) -> pl.Series:
|
||||||
|
"""执行单个计划并返回结果列。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: 执行计划
|
||||||
|
data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
计算结果序列
|
||||||
|
"""
|
||||||
|
result = self.execute(plan, data)
|
||||||
|
return result[plan.output_name]
|
||||||
304
src/factors/engine/data_router.py
Normal file
304
src/factors/engine/data_router.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""数据路由器。
|
||||||
|
|
||||||
|
按需取数、组装核心宽表。
|
||||||
|
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||||||
|
支持内存数据源(用于测试)和真实数据库连接。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.factors.engine.data_spec import DataSpec
|
||||||
|
from src.data.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
class DataRouter:
|
||||||
|
"""数据路由器 - 按需取数、组装核心宽表。
|
||||||
|
|
||||||
|
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||||||
|
支持内存数据源(用于测试)和真实数据库连接。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
|
||||||
|
is_memory_mode: 是否为内存模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
|
||||||
|
"""初始化数据路由器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source: 内存数据源,字典格式 {表名: DataFrame}
|
||||||
|
为 None 时自动连接 DuckDB 数据库
|
||||||
|
"""
|
||||||
|
self.data_source = data_source or {}
|
||||||
|
self.is_memory_mode = data_source is not None
|
||||||
|
self._cache: Dict[str, pl.DataFrame] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# 数据库模式下初始化 Storage
|
||||||
|
if not self.is_memory_mode:
|
||||||
|
self._storage = Storage()
|
||||||
|
else:
|
||||||
|
self._storage = None
|
||||||
|
|
||||||
|
def fetch_data(
|
||||||
|
self,
|
||||||
|
data_specs: List[DataSpec],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""根据数据规格获取并组装核心宽表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_specs: 数据规格列表
|
||||||
|
start_date: 开始日期 (YYYYMMDD)
|
||||||
|
end_date: 结束日期 (YYYYMMDD)
|
||||||
|
stock_codes: 股票代码列表,None 表示全市场
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
组装好的核心宽表 DataFrame
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当数据源中缺少必要的表或字段时
|
||||||
|
"""
|
||||||
|
if not data_specs:
|
||||||
|
raise ValueError("数据规格不能为空")
|
||||||
|
|
||||||
|
# 收集所有需要的表和字段
|
||||||
|
required_tables: Dict[str, Set[str]] = {}
|
||||||
|
max_lookback = 0
|
||||||
|
|
||||||
|
for spec in data_specs:
|
||||||
|
if spec.table not in required_tables:
|
||||||
|
required_tables[spec.table] = set()
|
||||||
|
required_tables[spec.table].update(spec.columns)
|
||||||
|
max_lookback = max(max_lookback, spec.lookback_days)
|
||||||
|
|
||||||
|
# 调整日期范围以包含回看期
|
||||||
|
adjusted_start = self._adjust_start_date(start_date, max_lookback)
|
||||||
|
|
||||||
|
# 从数据源获取各表数据
|
||||||
|
table_data = {}
|
||||||
|
for table_name, columns in required_tables.items():
|
||||||
|
df = self._load_table(
|
||||||
|
table_name=table_name,
|
||||||
|
columns=list(columns),
|
||||||
|
start_date=adjusted_start,
|
||||||
|
end_date=end_date,
|
||||||
|
stock_codes=stock_codes,
|
||||||
|
)
|
||||||
|
table_data[table_name] = df
|
||||||
|
|
||||||
|
# 组装核心宽表
|
||||||
|
core_table = self._assemble_wide_table(table_data, required_tables)
|
||||||
|
|
||||||
|
# 过滤到实际请求日期范围
|
||||||
|
core_table = core_table.filter(
|
||||||
|
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||||
|
)
|
||||||
|
|
||||||
|
return core_table
|
||||||
|
|
||||||
|
def _load_table(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
columns: List[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""加载单个表的数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
columns: 需要的字段
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
stock_codes: 股票代码过滤
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的 DataFrame
|
||||||
|
"""
|
||||||
|
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if cache_key in self._cache:
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
if self.is_memory_mode:
|
||||||
|
df = self._load_from_memory(
|
||||||
|
table_name, columns, start_date, end_date, stock_codes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
df = self._load_from_database(
|
||||||
|
table_name, columns, start_date, end_date, stock_codes
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._cache[cache_key] = df
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def _load_from_memory(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
columns: List[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""从内存数据源加载数据。"""
|
||||||
|
if table_name not in self.data_source:
|
||||||
|
raise ValueError(f"内存数据源中缺少表: {table_name}")
|
||||||
|
|
||||||
|
df = self.data_source[table_name]
|
||||||
|
|
||||||
|
# 确保必需字段存在
|
||||||
|
for col in columns:
|
||||||
|
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||||
|
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||||
|
|
||||||
|
# 过滤日期和股票
|
||||||
|
df = df.filter(
|
||||||
|
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||||
|
)
|
||||||
|
|
||||||
|
if stock_codes is not None:
|
||||||
|
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||||
|
|
||||||
|
# 选择需要的列
|
||||||
|
select_cols = ["ts_code", "trade_date"] + [
|
||||||
|
c for c in columns if c in df.columns
|
||||||
|
]
|
||||||
|
return df.select(select_cols)
|
||||||
|
|
||||||
|
def _load_from_database(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
columns: List[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""从 DuckDB 数据库加载数据。
|
||||||
|
|
||||||
|
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
|
||||||
|
"""
|
||||||
|
if self._storage is None:
|
||||||
|
raise RuntimeError("Storage 未初始化")
|
||||||
|
|
||||||
|
# 检查表是否存在
|
||||||
|
if not self._storage.exists(table_name):
|
||||||
|
raise ValueError(f"数据库中不存在表: {table_name}")
|
||||||
|
|
||||||
|
# 构建查询参数
|
||||||
|
# Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况
|
||||||
|
if stock_codes is not None and len(stock_codes) == 1:
|
||||||
|
ts_code_filter = stock_codes[0]
|
||||||
|
else:
|
||||||
|
ts_code_filter = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 从数据库加载原始数据
|
||||||
|
df = self._storage.load_polars(
|
||||||
|
name=table_name,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
ts_code=ts_code_filter,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
|
||||||
|
|
||||||
|
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||||||
|
if stock_codes is not None and len(stock_codes) > 1:
|
||||||
|
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||||
|
|
||||||
|
# 检查必需字段
|
||||||
|
for col in columns:
|
||||||
|
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||||
|
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||||
|
|
||||||
|
# 选择需要的列
|
||||||
|
select_cols = ["ts_code", "trade_date"] + [
|
||||||
|
c for c in columns if c in df.columns
|
||||||
|
]
|
||||||
|
|
||||||
|
return df.select(select_cols)
|
||||||
|
|
||||||
|
def _assemble_wide_table(
|
||||||
|
self,
|
||||||
|
table_data: Dict[str, pl.DataFrame],
|
||||||
|
required_tables: Dict[str, Set[str]],
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""组装多表数据为核心宽表。
|
||||||
|
|
||||||
|
使用 left join 合并各表数据,以第一个表为基准。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_data: 表名到 DataFrame 的映射
|
||||||
|
required_tables: 表名到字段集合的映射
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
组装后的宽表
|
||||||
|
"""
|
||||||
|
if not table_data:
|
||||||
|
raise ValueError("没有数据可组装")
|
||||||
|
|
||||||
|
# 以第一个表为基准
|
||||||
|
base_table_name = list(table_data.keys())[0]
|
||||||
|
result = table_data[base_table_name]
|
||||||
|
|
||||||
|
# 与其他表 join
|
||||||
|
for table_name, df in table_data.items():
|
||||||
|
if table_name == base_table_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 使用 ts_code 和 trade_date 作为 join 键
|
||||||
|
result = result.join(
|
||||||
|
df,
|
||||||
|
on=["ts_code", "trade_date"],
|
||||||
|
how="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
|
||||||
|
"""根据回看天数调整开始日期。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 原始开始日期 (YYYYMMDD)
|
||||||
|
lookback_days: 需要回看的交易日数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
调整后的开始日期
|
||||||
|
"""
|
||||||
|
# 简化的日期调整:假设每月30天,向前推移
|
||||||
|
# 实际应用中应该使用交易日历
|
||||||
|
year = int(start_date[:4])
|
||||||
|
month = int(start_date[4:6])
|
||||||
|
day = int(start_date[6:8])
|
||||||
|
|
||||||
|
total_days = lookback_days + 30 # 额外缓冲
|
||||||
|
|
||||||
|
day -= total_days
|
||||||
|
while day <= 0:
|
||||||
|
month -= 1
|
||||||
|
if month <= 0:
|
||||||
|
month = 12
|
||||||
|
year -= 1
|
||||||
|
day += 30
|
||||||
|
|
||||||
|
return f"{year:04d}{month:02d}{day:02d}"
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""清除数据缓存。"""
|
||||||
|
with self._lock:
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
# 数据库模式下清理 Storage 连接(可选)
|
||||||
|
if not self.is_memory_mode and self._storage is not None:
|
||||||
|
# Storage 使用单例模式,不需要关闭连接
|
||||||
|
pass
|
||||||
47
src/factors/engine/data_spec.py
Normal file
47
src/factors/engine/data_spec.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""数据规格和执行计划定义。
|
||||||
|
|
||||||
|
定义因子计算所需的数据规格和执行计划结构。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataSpec:
|
||||||
|
"""数据规格定义。
|
||||||
|
|
||||||
|
描述因子计算所需的数据表和字段。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
table: 数据表名称
|
||||||
|
columns: 需要的字段列表
|
||||||
|
lookback_days: 回看天数(用于时序计算)
|
||||||
|
"""
|
||||||
|
|
||||||
|
table: str
|
||||||
|
columns: List[str]
|
||||||
|
lookback_days: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionPlan:
|
||||||
|
"""执行计划。
|
||||||
|
|
||||||
|
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
data_specs: 数据规格列表
|
||||||
|
polars_expr: Polars 表达式
|
||||||
|
dependencies: 依赖的原始字段
|
||||||
|
output_name: 输出因子名称
|
||||||
|
factor_dependencies: 依赖的其他因子名称(用于分步执行)
|
||||||
|
"""
|
||||||
|
|
||||||
|
data_specs: List[DataSpec]
|
||||||
|
polars_expr: pl.Expr
|
||||||
|
dependencies: Set[str]
|
||||||
|
output_name: str
|
||||||
|
factor_dependencies: Set[str] = field(default_factory=set)
|
||||||
513
src/factors/engine/factor_engine.py
Normal file
513
src/factors/engine/factor_engine.py
Normal file
@@ -0,0 +1,513 @@
|
|||||||
|
"""因子计算引擎 - 系统统一入口。
|
||||||
|
|
||||||
|
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
||||||
|
|
||||||
|
执行流程:
|
||||||
|
1. 注册表达式 -> 调用编译器解析依赖
|
||||||
|
2. 调用路由器连接数据库拉取并组装核心宽表
|
||||||
|
3. 调用翻译器生成物理执行计划
|
||||||
|
4. 将计划提交给计算引擎执行并行运算
|
||||||
|
5. 返回包含因子结果的数据表
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
|
||||||
|
from src.factors.dsl import (
|
||||||
|
Node,
|
||||||
|
Symbol,
|
||||||
|
BinaryOpNode,
|
||||||
|
UnaryOpNode,
|
||||||
|
FunctionNode,
|
||||||
|
)
|
||||||
|
from src.factors.translator import PolarsTranslator
|
||||||
|
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||||
|
from src.factors.engine.data_router import DataRouter
|
||||||
|
from src.factors.engine.planner import ExecutionPlanner
|
||||||
|
from src.factors.engine.compute_engine import ComputeEngine
|
||||||
|
|
||||||
|
|
||||||
|
class FactorEngine:
|
||||||
|
"""因子计算引擎 - 系统统一入口。
|
||||||
|
|
||||||
|
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
||||||
|
|
||||||
|
执行流程:
|
||||||
|
1. 注册表达式 -> 调用编译器解析依赖
|
||||||
|
2. 调用路由器连接数据库拉取并组装核心宽表
|
||||||
|
3. 调用翻译器生成物理执行计划
|
||||||
|
4. 将计划提交给计算引擎执行并行运算
|
||||||
|
5. 返回包含因子结果的数据表
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
router: 数据路由器
|
||||||
|
planner: 执行计划生成器
|
||||||
|
compute_engine: 计算引擎
|
||||||
|
registered_expressions: 注册的表达式字典
|
||||||
|
_registry: 函数注册表
|
||||||
|
_parser: 公式解析器
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
||||||
|
max_workers: int = 4,
|
||||||
|
registry: Optional["FunctionRegistry"] = None,
|
||||||
|
) -> None:
|
||||||
|
"""初始化因子引擎。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source: 内存数据源,为 None 时使用数据库连接
|
||||||
|
max_workers: 并行计算的最大工作线程数
|
||||||
|
registry: 函数注册表,None 时创建独立实例
|
||||||
|
"""
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
from src.factors.parser import FormulaParser
|
||||||
|
|
||||||
|
self.router = DataRouter(data_source)
|
||||||
|
self.planner = ExecutionPlanner()
|
||||||
|
self.compute_engine = ComputeEngine(max_workers=max_workers)
|
||||||
|
self.registered_expressions: Dict[str, Node] = {}
|
||||||
|
self._plans: Dict[str, ExecutionPlan] = {}
|
||||||
|
|
||||||
|
# 初始化注册表和解析器(支持注入外部注册表实现共享)
|
||||||
|
self._registry = registry if registry is not None else FunctionRegistry()
|
||||||
|
self._parser = FormulaParser(self._registry)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
expression: Node,
|
||||||
|
data_specs: Optional[List[DataSpec]] = None,
|
||||||
|
) -> "FactorEngine":
|
||||||
|
"""注册因子表达式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
expression: DSL 表达式
|
||||||
|
data_specs: 数据规格,None 时自动推导
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from src.factors.api import close, ts_mean
|
||||||
|
>>> engine = FactorEngine()
|
||||||
|
>>> engine.register("ma20", ts_mean(close, 20))
|
||||||
|
"""
|
||||||
|
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
|
||||||
|
factor_deps = self._find_factor_dependencies(expression)
|
||||||
|
|
||||||
|
self.registered_expressions[name] = expression
|
||||||
|
|
||||||
|
# 预创建执行计划
|
||||||
|
plan = self.planner.create_plan(
|
||||||
|
expression=expression,
|
||||||
|
output_name=name,
|
||||||
|
data_specs=data_specs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加因子依赖信息
|
||||||
|
plan.factor_dependencies = factor_deps
|
||||||
|
|
||||||
|
self._plans[name] = plan
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_factor(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
expression: Union[str, Node],
|
||||||
|
data_specs: Optional[List[DataSpec]] = None,
|
||||||
|
) -> "FactorEngine":
|
||||||
|
"""注册因子(支持字符串或 Node 表达式)。
|
||||||
|
|
||||||
|
这是 register 方法的增强版,支持字符串表达式解析。
|
||||||
|
向后兼容:register 方法保持不变,继续只接受 Node 类型。
|
||||||
|
|
||||||
|
遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
expression: 字符串表达式或 Node 对象
|
||||||
|
data_specs: 可选的数据规格
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 当 expression 类型不支持时
|
||||||
|
FormulaParseError: 当字符串解析失败时(立即报错)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> engine = FactorEngine()
|
||||||
|
>>>
|
||||||
|
>>> # 字符串方式(新功能)
|
||||||
|
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||||
|
>>>
|
||||||
|
>>> # Node 方式(与 register 相同)
|
||||||
|
>>> from src.factors.api import close, ts_mean
|
||||||
|
>>> engine.add_factor("ma20", ts_mean(close, 20))
|
||||||
|
>>>
|
||||||
|
>>> # 复杂表达式
|
||||||
|
>>> engine.add_factor("alpha1", "cs_rank(close / open)")
|
||||||
|
>>>
|
||||||
|
>>> # 链式调用
|
||||||
|
>>> (engine
|
||||||
|
... .add_factor("ma5", "ts_mean(close, 5)")
|
||||||
|
... .add_factor("ma10", "ts_mean(close, 10)")
|
||||||
|
... .add_factor("golden_cross", "ma5 > ma10"))
|
||||||
|
"""
|
||||||
|
if isinstance(expression, str):
|
||||||
|
# Fail-Fast:立即解析,失败立即报错
|
||||||
|
node = self._parser.parse(expression)
|
||||||
|
elif isinstance(expression, Node):
|
||||||
|
node = expression
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 委托给现有的 register 方法
|
||||||
|
return self.register(name, node, data_specs)
|
||||||
|
|
||||||
|
def compute(
|
||||||
|
self,
|
||||||
|
factor_names: Union[str, List[str]],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""计算指定因子的值。
|
||||||
|
|
||||||
|
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factor_names: 因子名称或名称列表
|
||||||
|
start_date: 开始日期 (YYYYMMDD)
|
||||||
|
end_date: 结束日期 (YYYYMMDD)
|
||||||
|
stock_codes: 股票代码列表,None 表示全市场
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含因子结果的数据表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当因子未注册或数据不足时
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> result = engine.compute("ma20", "20240101", "20240131")
|
||||||
|
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
|
||||||
|
"""
|
||||||
|
# 标准化因子名称
|
||||||
|
if isinstance(factor_names, str):
|
||||||
|
factor_names = [factor_names]
|
||||||
|
|
||||||
|
# 1. 获取执行计划
|
||||||
|
plans = []
|
||||||
|
for name in factor_names:
|
||||||
|
if name not in self._plans:
|
||||||
|
raise ValueError(f"因子未注册: {name}")
|
||||||
|
plans.append(self._plans[name])
|
||||||
|
|
||||||
|
# 2. 合并数据规格并获取数据
|
||||||
|
all_specs = []
|
||||||
|
for plan in plans:
|
||||||
|
all_specs.extend(plan.data_specs)
|
||||||
|
|
||||||
|
# 3. 从路由器获取核心宽表
|
||||||
|
core_data = self.router.fetch_data(
|
||||||
|
data_specs=all_specs,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
stock_codes=stock_codes,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(core_data) == 0:
|
||||||
|
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
||||||
|
|
||||||
|
# 4. 按依赖顺序执行计算
|
||||||
|
if len(plans) == 1:
|
||||||
|
result = self.compute_engine.execute(plans[0], core_data)
|
||||||
|
else:
|
||||||
|
# 使用依赖感知的方式执行
|
||||||
|
result = self._execute_with_dependencies(factor_names, core_data)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def list_registered(self) -> List[str]:
|
||||||
|
"""获取已注册的因子列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
因子名称列表
|
||||||
|
"""
|
||||||
|
return list(self.registered_expressions.keys())
|
||||||
|
|
||||||
|
def get_expression(self, name: str) -> Optional[Node]:
|
||||||
|
"""获取已注册的表达式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表达式节点,未注册时返回 None
|
||||||
|
"""
|
||||||
|
return self.registered_expressions.get(name)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清除所有注册的表达式和缓存。"""
|
||||||
|
self.registered_expressions.clear()
|
||||||
|
self._plans.clear()
|
||||||
|
self.router.clear_cache()
|
||||||
|
|
||||||
|
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
|
||||||
|
"""预览因子的执行计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factor_name: 因子名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行计划,未注册时返回 None
|
||||||
|
"""
|
||||||
|
return self._plans.get(factor_name)
|
||||||
|
|
||||||
|
def _execute_with_dependencies(
|
||||||
|
self,
|
||||||
|
factor_names: List[str],
|
||||||
|
core_data: pl.DataFrame,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""按依赖顺序执行因子计算。
|
||||||
|
|
||||||
|
支持 cs_rank 等需要依赖列已存在的场景。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factor_names: 因子名称列表
|
||||||
|
core_data: 核心宽表数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有因子结果的数据表
|
||||||
|
"""
|
||||||
|
# 1. 拓扑排序
|
||||||
|
sorted_names = self._topological_sort(factor_names)
|
||||||
|
|
||||||
|
# 2. 按顺序执行
|
||||||
|
result = core_data
|
||||||
|
for name in sorted_names:
|
||||||
|
plan = self._plans[name]
|
||||||
|
|
||||||
|
# 创建新的执行计划,引用已计算的依赖列
|
||||||
|
new_plan = self._create_optimized_plan(plan, result)
|
||||||
|
|
||||||
|
# 执行计算
|
||||||
|
result = self.compute_engine.execute(new_plan, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _create_optimized_plan(
|
||||||
|
self,
|
||||||
|
plan: ExecutionPlan,
|
||||||
|
current_data: pl.DataFrame,
|
||||||
|
) -> ExecutionPlan:
|
||||||
|
"""创建优化的执行计划。
|
||||||
|
|
||||||
|
将表达式中已计算的依赖因子替换为列引用。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: 原始执行计划
|
||||||
|
current_data: 当前数据(包含已计算的依赖列)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新的执行计划
|
||||||
|
"""
|
||||||
|
from src.factors.dsl import Symbol
|
||||||
|
|
||||||
|
# 获取当前数据中已存在的列
|
||||||
|
existing_cols = set(current_data.columns)
|
||||||
|
|
||||||
|
# 检查依赖列是否已存在
|
||||||
|
deps_available = plan.factor_dependencies & existing_cols
|
||||||
|
|
||||||
|
if not deps_available:
|
||||||
|
# 没有可用的依赖列,直接返回原计划
|
||||||
|
return plan
|
||||||
|
|
||||||
|
# 获取原始表达式
|
||||||
|
original_expr = self.registered_expressions[plan.output_name]
|
||||||
|
|
||||||
|
# 创建新的表达式,用 Symbol 引用替换依赖因子
|
||||||
|
def replace_with_symbol(node: Node) -> Node:
|
||||||
|
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
n: Any = node
|
||||||
|
|
||||||
|
# 检查当前节点是否等于某个已计算依赖因子
|
||||||
|
for dep_name in deps_available:
|
||||||
|
dep_expr = self.registered_expressions[dep_name]
|
||||||
|
if self._expressions_equal(node, dep_expr):
|
||||||
|
return Symbol(dep_name)
|
||||||
|
|
||||||
|
# 递归处理子节点
|
||||||
|
if isinstance(n, BinaryOpNode):
|
||||||
|
new_left = replace_with_symbol(n.left)
|
||||||
|
new_right = replace_with_symbol(n.right)
|
||||||
|
if new_left is not n.left or new_right is not n.right:
|
||||||
|
return BinaryOpNode(n.op, new_left, new_right)
|
||||||
|
elif isinstance(n, UnaryOpNode):
|
||||||
|
new_operand = replace_with_symbol(n.operand)
|
||||||
|
if new_operand is not n.operand:
|
||||||
|
return UnaryOpNode(n.op, new_operand)
|
||||||
|
elif isinstance(n, FunctionNode):
|
||||||
|
new_args = [replace_with_symbol(arg) for arg in n.args]
|
||||||
|
if any(
|
||||||
|
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
|
||||||
|
):
|
||||||
|
return FunctionNode(n.func_name, *new_args)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
# 替换表达式
|
||||||
|
new_expr = replace_with_symbol(original_expr)
|
||||||
|
|
||||||
|
# 重新翻译表达式
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
new_polars_expr = translator.translate(new_expr)
|
||||||
|
|
||||||
|
# 更新依赖集合
|
||||||
|
new_factor_deps = plan.factor_dependencies - deps_available
|
||||||
|
new_deps = plan.dependencies | deps_available
|
||||||
|
|
||||||
|
return ExecutionPlan(
|
||||||
|
data_specs=plan.data_specs,
|
||||||
|
polars_expr=new_polars_expr,
|
||||||
|
dependencies=new_deps,
|
||||||
|
output_name=plan.output_name,
|
||||||
|
factor_dependencies=new_factor_deps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
|
||||||
|
"""比较两个表达式是否相等。
|
||||||
|
|
||||||
|
用于检测因子间的依赖关系。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr1: 第一个表达式
|
||||||
|
expr2: 第二个表达式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否相等
|
||||||
|
"""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
e1: Any = expr1
|
||||||
|
e2: Any = expr2
|
||||||
|
|
||||||
|
if type(e1) != type(e2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(e1, Symbol):
|
||||||
|
return e1.name == e2.name
|
||||||
|
|
||||||
|
from src.factors.dsl import Constant
|
||||||
|
|
||||||
|
if isinstance(e1, Constant):
|
||||||
|
return e1.value == e2.value
|
||||||
|
|
||||||
|
if isinstance(e1, BinaryOpNode):
|
||||||
|
return (
|
||||||
|
e1.op == e2.op
|
||||||
|
and self._expressions_equal(e1.left, e2.left)
|
||||||
|
and self._expressions_equal(e1.right, e2.right)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(e1, UnaryOpNode):
|
||||||
|
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
|
||||||
|
|
||||||
|
if isinstance(e1, FunctionNode):
|
||||||
|
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
|
||||||
|
return False
|
||||||
|
return all(
|
||||||
|
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
|
||||||
|
"""查找表达式依赖的其他因子。
|
||||||
|
|
||||||
|
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: 待检查的表达式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
依赖的因子名称集合
|
||||||
|
"""
|
||||||
|
deps: Set[str] = set()
|
||||||
|
|
||||||
|
# 检查表达式本身是否等于某个已注册因子
|
||||||
|
for name, registered_expr in self.registered_expressions.items():
|
||||||
|
if self._expressions_equal(expression, registered_expr):
|
||||||
|
deps.add(name)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 递归检查子节点
|
||||||
|
if isinstance(expression, BinaryOpNode):
|
||||||
|
deps.update(self._find_factor_dependencies(expression.left))
|
||||||
|
deps.update(self._find_factor_dependencies(expression.right))
|
||||||
|
elif isinstance(expression, UnaryOpNode):
|
||||||
|
deps.update(self._find_factor_dependencies(expression.operand))
|
||||||
|
elif isinstance(expression, FunctionNode):
|
||||||
|
for arg in expression.args:
|
||||||
|
deps.update(self._find_factor_dependencies(arg))
|
||||||
|
|
||||||
|
return deps
|
||||||
|
|
||||||
|
def _topological_sort(self, factor_names: List[str]) -> List[str]:
|
||||||
|
"""按依赖关系对因子进行拓扑排序。
|
||||||
|
|
||||||
|
确保依赖的因子先被计算。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factor_names: 因子名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
排序后的因子名称列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当检测到循环依赖时
|
||||||
|
"""
|
||||||
|
# 构建依赖图
|
||||||
|
graph: Dict[str, Set[str]] = {}
|
||||||
|
in_degree: Dict[str, int] = {}
|
||||||
|
|
||||||
|
for name in factor_names:
|
||||||
|
plan = self._plans[name]
|
||||||
|
# 只考虑在本次计算范围内的依赖
|
||||||
|
deps = plan.factor_dependencies & set(factor_names)
|
||||||
|
graph[name] = deps
|
||||||
|
in_degree[name] = len(deps)
|
||||||
|
|
||||||
|
# Kahn 算法
|
||||||
|
result = []
|
||||||
|
queue = [name for name, degree in in_degree.items() if degree == 0]
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
# 按原始顺序处理同级别的因子
|
||||||
|
queue.sort(key=lambda x: factor_names.index(x))
|
||||||
|
name = queue.pop(0)
|
||||||
|
result.append(name)
|
||||||
|
|
||||||
|
for other in factor_names:
|
||||||
|
if name in graph[other]:
|
||||||
|
in_degree[other] -= 1
|
||||||
|
if in_degree[other] == 0:
|
||||||
|
queue.append(other)
|
||||||
|
|
||||||
|
if len(result) != len(factor_names):
|
||||||
|
raise ValueError("检测到因子循环依赖")
|
||||||
|
|
||||||
|
return result
|
||||||
170
src/factors/engine/planner.py
Normal file
170
src/factors/engine/planner.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""执行计划生成器。
|
||||||
|
|
||||||
|
整合编译器和翻译器,生成完整的执行计划。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
from src.factors.dsl import (
|
||||||
|
Node,
|
||||||
|
Symbol,
|
||||||
|
FunctionNode,
|
||||||
|
BinaryOpNode,
|
||||||
|
UnaryOpNode,
|
||||||
|
Constant,
|
||||||
|
)
|
||||||
|
from src.factors.compiler import DependencyExtractor
|
||||||
|
from src.factors.translator import PolarsTranslator
|
||||||
|
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionPlanner:
|
||||||
|
"""执行计划生成器。
|
||||||
|
|
||||||
|
整合编译器和翻译器,生成完整的执行计划。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
compiler: 依赖提取器
|
||||||
|
translator: Polars 翻译器
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""初始化执行计划生成器。"""
|
||||||
|
self.compiler = DependencyExtractor()
|
||||||
|
self.translator = PolarsTranslator()
|
||||||
|
|
||||||
|
def create_plan(
|
||||||
|
self,
|
||||||
|
expression: Node,
|
||||||
|
output_name: str = "factor",
|
||||||
|
data_specs: Optional[List[DataSpec]] = None,
|
||||||
|
) -> ExecutionPlan:
|
||||||
|
"""从表达式创建执行计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: DSL 表达式节点
|
||||||
|
output_name: 输出因子名称
|
||||||
|
data_specs: 预定义的数据规格,None 时自动推导
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行计划对象
|
||||||
|
"""
|
||||||
|
# 1. 提取依赖
|
||||||
|
dependencies = self.compiler.extract_dependencies(expression)
|
||||||
|
|
||||||
|
# 2. 翻译为 Polars 表达式
|
||||||
|
polars_expr = self.translator.translate(expression)
|
||||||
|
|
||||||
|
# 3. 推导或验证数据规格
|
||||||
|
if data_specs is None:
|
||||||
|
data_specs = self._infer_data_specs(dependencies, expression)
|
||||||
|
|
||||||
|
return ExecutionPlan(
|
||||||
|
data_specs=data_specs,
|
||||||
|
polars_expr=polars_expr,
|
||||||
|
dependencies=dependencies,
|
||||||
|
output_name=output_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _infer_data_specs(
|
||||||
|
self,
|
||||||
|
dependencies: Set[str],
|
||||||
|
expression: Node,
|
||||||
|
) -> List[DataSpec]:
|
||||||
|
"""从依赖推导数据规格。
|
||||||
|
|
||||||
|
根据表达式中的函数类型推断回看天数需求。
|
||||||
|
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
||||||
|
默认从 pro_bar 表获取。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dependencies: 依赖的字段集合
|
||||||
|
expression: 表达式节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
数据规格列表
|
||||||
|
"""
|
||||||
|
# 计算最大回看窗口
|
||||||
|
max_window = self._extract_max_window(expression)
|
||||||
|
lookback_days = max(1, max_window)
|
||||||
|
|
||||||
|
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
||||||
|
pro_bar_fields = {
|
||||||
|
"open",
|
||||||
|
"high",
|
||||||
|
"low",
|
||||||
|
"close",
|
||||||
|
"vol",
|
||||||
|
"amount",
|
||||||
|
"pre_close",
|
||||||
|
"change",
|
||||||
|
"pct_chg",
|
||||||
|
"turnover_rate",
|
||||||
|
"volume_ratio",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将依赖分为 pro_bar 字段和其他字段
|
||||||
|
pro_bar_deps = dependencies & pro_bar_fields
|
||||||
|
other_deps = dependencies - pro_bar_fields
|
||||||
|
|
||||||
|
data_specs = []
|
||||||
|
|
||||||
|
# pro_bar 表的数据规格
|
||||||
|
if pro_bar_deps:
|
||||||
|
data_specs.append(
|
||||||
|
DataSpec(
|
||||||
|
table="pro_bar",
|
||||||
|
columns=sorted(pro_bar_deps),
|
||||||
|
lookback_days=lookback_days,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 其他字段从 daily 表获取
|
||||||
|
if other_deps:
|
||||||
|
data_specs.append(
|
||||||
|
DataSpec(
|
||||||
|
table="daily",
|
||||||
|
columns=sorted(other_deps),
|
||||||
|
lookback_days=lookback_days,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return data_specs
|
||||||
|
|
||||||
|
def _extract_max_window(self, node: Node) -> int:
|
||||||
|
"""从表达式中提取最大窗口大小。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最大窗口大小,无时序函数返回 1
|
||||||
|
"""
|
||||||
|
if isinstance(node, FunctionNode):
|
||||||
|
window = 1
|
||||||
|
# 检查函数参数中的窗口大小
|
||||||
|
for arg in node.args:
|
||||||
|
if (
|
||||||
|
isinstance(arg, Constant)
|
||||||
|
and isinstance(arg.value, int)
|
||||||
|
and arg.value > window
|
||||||
|
):
|
||||||
|
window = arg.value
|
||||||
|
|
||||||
|
# 递归检查子表达式
|
||||||
|
for arg in node.args:
|
||||||
|
if isinstance(arg, Node) and not isinstance(arg, Constant):
|
||||||
|
window = max(window, self._extract_max_window(arg))
|
||||||
|
|
||||||
|
return window
|
||||||
|
|
||||||
|
elif isinstance(node, BinaryOpNode):
|
||||||
|
return max(
|
||||||
|
self._extract_max_window(node.left),
|
||||||
|
self._extract_max_window(node.right),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(node, UnaryOpNode):
|
||||||
|
return self._extract_max_window(node.operand)
|
||||||
|
|
||||||
|
return 1
|
||||||
144
src/factors/exceptions.py
Normal file
144
src/factors/exceptions.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""公式解析异常定义。
|
||||||
|
|
||||||
|
提供清晰的错误信息,帮助用户快速定位公式解析问题。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class FormulaParseError(Exception):
|
||||||
|
"""公式解析错误基类。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
lineno: 错误所在行号(从1开始)
|
||||||
|
col_offset: 错误所在列号(从0开始)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
lineno: Optional[int] = None,
|
||||||
|
col_offset: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.expr = expr
|
||||||
|
self.lineno = lineno
|
||||||
|
self.col_offset = col_offset
|
||||||
|
|
||||||
|
# 构建详细错误信息
|
||||||
|
full_message = self._format_message(message)
|
||||||
|
super().__init__(full_message)
|
||||||
|
|
||||||
|
def _format_message(self, message: str) -> str:
|
||||||
|
"""格式化错误信息,包含位置指示器。"""
|
||||||
|
lines = [f"FormulaParseError: {message}"]
|
||||||
|
|
||||||
|
if self.expr:
|
||||||
|
lines.append(f" 公式: {self.expr}")
|
||||||
|
|
||||||
|
# 添加错误位置指示器
|
||||||
|
if self.col_offset is not None and self.lineno is not None:
|
||||||
|
# 计算错误行在表达式中的起始位置
|
||||||
|
expr_lines = self.expr.split("\n")
|
||||||
|
if 1 <= self.lineno <= len(expr_lines):
|
||||||
|
error_line = expr_lines[self.lineno - 1]
|
||||||
|
lines.append(f" {error_line}")
|
||||||
|
# 添加指向错误位置的箭头
|
||||||
|
pointer = " " * (self.col_offset + 7) + "^--- 此处出错"
|
||||||
|
lines.append(pointer)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownFunctionError(FormulaParseError):
|
||||||
|
"""未知函数错误。
|
||||||
|
|
||||||
|
当表达式中使用了未注册的函数时抛出。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
func_name: 未知的函数名
|
||||||
|
available: 可用函数列表
|
||||||
|
suggestions: 模糊匹配建议列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func_name: str,
|
||||||
|
available: List[str],
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
lineno: Optional[int] = None,
|
||||||
|
col_offset: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.func_name = func_name
|
||||||
|
self.available = available
|
||||||
|
|
||||||
|
# 使用 difflib 获取模糊匹配建议
|
||||||
|
self.suggestions = difflib.get_close_matches(
|
||||||
|
func_name, available, n=3, cutoff=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建错误信息
|
||||||
|
if self.suggestions:
|
||||||
|
suggestion_str = ", ".join(f"'{s}'" for s in self.suggestions)
|
||||||
|
hint_msg = f"你是不是想找: {suggestion_str}?"
|
||||||
|
else:
|
||||||
|
# 只显示前10个可用函数
|
||||||
|
available_preview = ", ".join(available[:10])
|
||||||
|
if len(available) > 10:
|
||||||
|
available_preview += f", ... 等共 {len(available)} 个函数"
|
||||||
|
hint_msg = f"可用函数预览: {available_preview}"
|
||||||
|
|
||||||
|
msg = f"未知函数 '{func_name}'。{hint_msg}"
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
message=msg,
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSyntaxError(FormulaParseError):
|
||||||
|
"""语法错误。
|
||||||
|
|
||||||
|
当表达式语法不正确或不支持时抛出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedOperatorError(InvalidSyntaxError):
|
||||||
|
"""不支持的运算符错误。
|
||||||
|
|
||||||
|
当使用了不支持的运算符时抛出(如位运算、矩阵运算等)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyExpressionError(FormulaParseError):
|
||||||
|
"""空表达式错误。"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("表达式不能为空或只包含空白字符")
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryError(Exception):
|
||||||
|
"""注册表错误基类。"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateFunctionError(RegistryError):
|
||||||
|
"""函数重复注册错误。
|
||||||
|
|
||||||
|
当尝试注册已存在的函数且未设置 force=True 时抛出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func_name: str):
|
||||||
|
self.func_name = func_name
|
||||||
|
super().__init__(
|
||||||
|
f"函数 '{func_name}' 已存在。使用 force=True 覆盖,或选择其他名称。"
|
||||||
|
)
|
||||||
411
src/factors/parser.py
Normal file
411
src/factors/parser.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
"""公式解析器 - 将字符串表达式转换为 DSL 节点树。
|
||||||
|
|
||||||
|
基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> from src.factors.parser import FormulaParser
|
||||||
|
>>> from src.factors.registry import FunctionRegistry
|
||||||
|
>>> parser = FormulaParser(FunctionRegistry())
|
||||||
|
>>> node = parser.parse("ts_mean(close, 20)")
|
||||||
|
>>> print(node)
|
||||||
|
ts_mean(close, 20)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode
|
||||||
|
from src.factors.exceptions import (
|
||||||
|
FormulaParseError,
|
||||||
|
UnknownFunctionError,
|
||||||
|
InvalidSyntaxError,
|
||||||
|
EmptyExpressionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# 运算符映射表
|
||||||
|
BIN_OP_MAP: Dict[type, str] = {
|
||||||
|
ast.Add: "+",
|
||||||
|
ast.Sub: "-",
|
||||||
|
ast.Mult: "*",
|
||||||
|
ast.Div: "/",
|
||||||
|
ast.Pow: "**",
|
||||||
|
ast.FloorDiv: "//",
|
||||||
|
ast.Mod: "%",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNARY_OP_MAP: Dict[type, str] = {
|
||||||
|
ast.UAdd: "+",
|
||||||
|
ast.USub: "-",
|
||||||
|
ast.Invert: "~", # 不支持,应报错
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPARE_OP_MAP: Dict[type, str] = {
|
||||||
|
ast.Eq: "==",
|
||||||
|
ast.NotEq: "!=",
|
||||||
|
ast.Lt: "<",
|
||||||
|
ast.LtE: "<=",
|
||||||
|
ast.Gt: ">",
|
||||||
|
ast.GtE: ">=",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FormulaParser:
|
||||||
|
"""基于 AST 的公式解析器。
|
||||||
|
|
||||||
|
将字符串表达式解析为 DSL 节点树,支持:
|
||||||
|
- 符号引用(如 close, open)
|
||||||
|
- 数值常量(如 20, 3.14)
|
||||||
|
- 二元运算(如 +, -, *, /)
|
||||||
|
- 一元运算(如 -x)
|
||||||
|
- 函数调用(如 ts_mean(close, 20))
|
||||||
|
- 比较运算(如 close > open)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
registry: 函数注册表,用于解析函数调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, registry: "FunctionRegistry") -> None:
|
||||||
|
"""初始化解析器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: 函数注册表,提供函数名到可调用对象的映射
|
||||||
|
"""
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
def parse(self, expr: str) -> Node:
|
||||||
|
"""解析字符串表达式为 Node 树。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr: 公式字符串,如 "ts_mean(close, 20)"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的 Node 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmptyExpressionError: 表达式为空时抛出
|
||||||
|
SyntaxError: Python 语法错误时抛出
|
||||||
|
FormulaParseError: 解析失败时抛出
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> parser.parse("close / open")
|
||||||
|
BinaryOpNode("/", Symbol("close"), Symbol("open"))
|
||||||
|
"""
|
||||||
|
# 检查空表达式
|
||||||
|
if not expr or not expr.strip():
|
||||||
|
raise EmptyExpressionError()
|
||||||
|
|
||||||
|
# 解析为 Python AST
|
||||||
|
try:
|
||||||
|
tree = ast.parse(expr, mode="eval")
|
||||||
|
except SyntaxError as e:
|
||||||
|
# 将 SyntaxError 包装为 InvalidSyntaxError,统一异常类型
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"表达式语法错误: {e.msg}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=e.lineno,
|
||||||
|
col_offset=e.offset,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# 递归访问 AST 节点
|
||||||
|
try:
|
||||||
|
return self._visit(tree.body, expr)
|
||||||
|
except FormulaParseError:
|
||||||
|
# 重新抛出 FormulaParseError(保留已有的位置信息)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 将其他异常包装为 FormulaParseError
|
||||||
|
if not isinstance(e, FormulaParseError):
|
||||||
|
raise FormulaParseError(
|
||||||
|
message=f"解析失败: {str(e)}",
|
||||||
|
expr=expr,
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _visit(self, node: ast.AST, expr: str) -> Node:
|
||||||
|
"""递归访问 AST 节点并转换为 DSL 节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Python AST 节点
|
||||||
|
expr: 原始表达式字符串(用于错误报告)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
对应的 DSL 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 遇到不支持的语法时抛出
|
||||||
|
"""
|
||||||
|
# 提取位置信息(如果节点有)
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
return self._visit_Name(node)
|
||||||
|
elif isinstance(node, ast.Constant):
|
||||||
|
return self._visit_Constant(node, expr)
|
||||||
|
elif isinstance(node, ast.BinOp):
|
||||||
|
return self._visit_BinOp(node, expr)
|
||||||
|
elif isinstance(node, ast.UnaryOp):
|
||||||
|
return self._visit_UnaryOp(node, expr)
|
||||||
|
elif isinstance(node, ast.Call):
|
||||||
|
return self._visit_Call(node, expr)
|
||||||
|
elif isinstance(node, ast.Compare):
|
||||||
|
return self._visit_Compare(node, expr)
|
||||||
|
else:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的语法: {type(node).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
except FormulaParseError:
|
||||||
|
# 重新抛出(保留已有的位置信息)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 包装为 FormulaParseError,添加位置信息
|
||||||
|
raise FormulaParseError(
|
||||||
|
message=f"解析节点失败: {str(e)}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _visit_Name(self, node: ast.Name) -> Symbol:
|
||||||
|
"""访问名称节点 - 永远转为 Symbol。
|
||||||
|
|
||||||
|
注意:利用 AST 语法自然区分变量和函数调用:
|
||||||
|
- log → Symbol("log")(数据列引用)
|
||||||
|
- log(close) → 在 _visit_Call 中处理(函数调用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 名称节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symbol 节点
|
||||||
|
"""
|
||||||
|
return Symbol(node.id)
|
||||||
|
|
||||||
|
def _visit_Constant(self, node: ast.Constant, expr: str) -> Node:
|
||||||
|
"""访问常量节点。
|
||||||
|
|
||||||
|
支持的类型:
|
||||||
|
- int/float → Constant 节点
|
||||||
|
- str → Symbol 节点(支持 ts_mean("close", 20) 语法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 常量节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Constant 或 Symbol 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的常量类型
|
||||||
|
"""
|
||||||
|
if isinstance(node.value, (int, float)):
|
||||||
|
return Constant(node.value)
|
||||||
|
elif isinstance(node.value, str):
|
||||||
|
# 字符串常量转为 Symbol,支持 "close" 写法
|
||||||
|
return Symbol(node.value)
|
||||||
|
else:
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的常量类型: {type(node.value).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode:
|
||||||
|
"""访问二元运算节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 二元运算节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BinaryOpNode 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的运算符
|
||||||
|
"""
|
||||||
|
left = self._visit(node.left, expr)
|
||||||
|
right = self._visit(node.right, expr)
|
||||||
|
|
||||||
|
op = BIN_OP_MAP.get(type(node.op))
|
||||||
|
if op is None:
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的运算符: {type(node.op).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BinaryOpNode(op, left, right)
|
||||||
|
|
||||||
|
def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node:
|
||||||
|
"""访问一元运算节点。
|
||||||
|
|
||||||
|
支持常量折叠优化:纯数值的一元运算直接计算结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 一元运算节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Constant(常量折叠)或 UnaryOpNode 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的运算符
|
||||||
|
"""
|
||||||
|
operand = self._visit(node.operand, expr)
|
||||||
|
op = UNARY_OP_MAP.get(type(node.op))
|
||||||
|
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
if op is None:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的一元运算符: {type(node.op).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if op == "~":
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message="位运算 '~' 不被支持",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 常量折叠优化:纯数值直接计算
|
||||||
|
if isinstance(operand, Constant) and isinstance(operand.value, (int, float)):
|
||||||
|
if op == "-":
|
||||||
|
return Constant(-operand.value)
|
||||||
|
elif op == "+":
|
||||||
|
return operand # +5 就是 5
|
||||||
|
|
||||||
|
# 非常量使用运算符重载
|
||||||
|
if op == "-":
|
||||||
|
return -operand
|
||||||
|
elif op == "+":
|
||||||
|
return +operand
|
||||||
|
|
||||||
|
# 不应该到达这里
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"无法处理的一元运算符: {op}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _visit_Call(self, node: ast.Call, expr: str) -> Node:
|
||||||
|
"""访问函数调用节点。
|
||||||
|
|
||||||
|
注意:只有在这里查注册表,处理函数调用。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 函数调用节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
函数返回的 Node 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的函数调用语法
|
||||||
|
UnknownFunctionError: 函数未注册
|
||||||
|
"""
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
# 只支持简单函数调用(如 func(a, b))
|
||||||
|
if not isinstance(node.func, ast.Name):
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message="只支持简单函数调用(如 func(a, b))",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
func_name = node.func.id
|
||||||
|
func = self.registry.get(func_name)
|
||||||
|
|
||||||
|
if func is None:
|
||||||
|
raise UnknownFunctionError(
|
||||||
|
func_name=func_name,
|
||||||
|
available=self.registry.available_functions(),
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析位置参数
|
||||||
|
args = [self._visit(arg, expr) for arg in node.args]
|
||||||
|
|
||||||
|
# 解析关键字参数(如果有)
|
||||||
|
kwargs = {}
|
||||||
|
for keyword in node.keywords:
|
||||||
|
kwargs[keyword.arg] = self._visit(keyword.value, expr)
|
||||||
|
|
||||||
|
# 应用函数
|
||||||
|
try:
|
||||||
|
if kwargs:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return func(*args)
|
||||||
|
except TypeError as e:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"函数 '{func_name}' 调用失败: {e}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode:
|
||||||
|
"""访问比较运算节点。
|
||||||
|
|
||||||
|
注意:只支持简单二元比较,不支持链式比较(如 a < b < c)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 比较节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BinaryOpNode 节点(使用比较运算符)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 链式比较或不支持的运算符
|
||||||
|
"""
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
# Python 支持链式比较 (a < b < c),这里简化为二元比较
|
||||||
|
if len(node.ops) != 1 or len(node.comparators) != 1:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message="只支持简单二元比较(如 a > b),不支持链式比较",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
left = self._visit(node.left, expr)
|
||||||
|
op = COMPARE_OP_MAP.get(type(node.ops[0]))
|
||||||
|
|
||||||
|
if op is None:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的比较运算符: {type(node.ops[0]).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
right = self._visit(node.comparators[0], expr)
|
||||||
|
return BinaryOpNode(op, left, right)
|
||||||
227
src/factors/registry.py
Normal file
227
src/factors/registry.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""函数注册表 - 管理字符串函数名到 Python 函数的映射。
|
||||||
|
|
||||||
|
支持自动发现和手动注册,与 FormulaParser 配合使用。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> from src.factors.registry import FunctionRegistry
|
||||||
|
>>> registry = FunctionRegistry(auto_scan=True) # 自动加载 api.py 函数
|
||||||
|
>>> registry.available_functions()[:5]
|
||||||
|
['abs', 'clip', 'cs_demean', 'cs_neutralize', 'cs_rank']
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from src.factors.dsl import Node, FunctionNode
|
||||||
|
from src.factors.exceptions import DuplicateFunctionError
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionRegistry:
|
||||||
|
"""函数注册表。
|
||||||
|
|
||||||
|
管理字符串函数名到可调用对象的映射。
|
||||||
|
自动从 api.py 加载标准函数,支持用户自定义函数注册。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_functions: 函数字典,name -> callable
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auto_scan: bool = True) -> None:
|
||||||
|
"""初始化注册表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_scan: 是否自动扫描 api.py 模块,默认 True
|
||||||
|
"""
|
||||||
|
self._functions: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
if auto_scan:
|
||||||
|
self._scan_api_module()
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self, name: str, func: Callable, force: bool = False
|
||||||
|
) -> "FunctionRegistry":
|
||||||
|
"""注册自定义函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 函数名称(字符串形式)
|
||||||
|
func: 可调用对象
|
||||||
|
force: 是否强制覆盖已存在的函数,默认 False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DuplicateFunctionError: 当函数名已存在且 force=False 时
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> registry = FunctionRegistry(auto_scan=False)
|
||||||
|
>>> registry.register("my_func", lambda x: x * 2)
|
||||||
|
>>> registry.get("my_func")(5)
|
||||||
|
10
|
||||||
|
"""
|
||||||
|
if name in self._functions and not force:
|
||||||
|
raise DuplicateFunctionError(name)
|
||||||
|
|
||||||
|
self._functions[name] = func
|
||||||
|
return self
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> "FunctionRegistry":
|
||||||
|
"""注销函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要注销的函数名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 函数不存在时
|
||||||
|
"""
|
||||||
|
if name not in self._functions:
|
||||||
|
raise KeyError(f"函数 '{name}' 不存在")
|
||||||
|
del self._functions[name]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Callable]:
|
||||||
|
"""获取函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 函数名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
函数对象,不存在返回 None
|
||||||
|
"""
|
||||||
|
return self._functions.get(name)
|
||||||
|
|
||||||
|
def has(self, name: str) -> bool:
|
||||||
|
"""检查函数是否存在。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 函数名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否存在
|
||||||
|
"""
|
||||||
|
return name in self._functions
|
||||||
|
|
||||||
|
def available_functions(self) -> List[str]:
|
||||||
|
"""返回所有可用函数名列表(按字母序)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
排序后的函数名列表
|
||||||
|
"""
|
||||||
|
return sorted(self._functions.keys())
|
||||||
|
|
||||||
|
def clear(self) -> "FunctionRegistry":
|
||||||
|
"""清空所有注册的函数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
"""
|
||||||
|
self._functions.clear()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def scan_module(
|
||||||
|
self, module: Any, prefix: str = "", force: bool = False
|
||||||
|
) -> "FunctionRegistry":
|
||||||
|
"""扫描指定模块,自动注册符合条件的函数。
|
||||||
|
|
||||||
|
扫描规则:
|
||||||
|
1. 模块级别的函数(排除私有函数 _*)
|
||||||
|
2. 返回类型注解为 Node 或 FunctionNode
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: 要扫描的模块对象
|
||||||
|
prefix: 函数名前缀,用于避免命名冲突
|
||||||
|
force: 是否强制覆盖已存在的函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import my_custom_module
|
||||||
|
>>> registry.scan_module(my_custom_module, prefix="custom_")
|
||||||
|
"""
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
# 只处理非私有函数
|
||||||
|
if not inspect.isfunction(obj) or name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否应该注册
|
||||||
|
if self._should_register(obj):
|
||||||
|
full_name = prefix + name
|
||||||
|
self.register(full_name, obj, force=force)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _scan_api_module(self) -> None:
|
||||||
|
"""自动扫描 api.py 模块,注册所有符合条件的函数。
|
||||||
|
|
||||||
|
这是默认的自动扫描行为,在 __init__ 中调用。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.factors import api
|
||||||
|
|
||||||
|
self.scan_module(api)
|
||||||
|
except ImportError:
|
||||||
|
# api 模块可能不存在,静默跳过
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _should_register(self, func: Callable) -> bool:
|
||||||
|
"""检查函数是否应该被注册。
|
||||||
|
|
||||||
|
基于类型提示检查函数返回类型,只注册返回 Node 或 FunctionNode 的函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: 要检查的函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否应该注册该函数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
hints = typing.get_type_hints(func)
|
||||||
|
return_type = hints.get("return")
|
||||||
|
|
||||||
|
if return_type is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 处理 Union 类型(如 Union[Node, FunctionNode])
|
||||||
|
origin = typing.get_origin(return_type)
|
||||||
|
args = typing.get_args(return_type)
|
||||||
|
|
||||||
|
if origin is typing.Union:
|
||||||
|
# Union 类型,检查任一参数
|
||||||
|
return any(self._is_node_type(arg) for arg in args)
|
||||||
|
else:
|
||||||
|
# 单一类型
|
||||||
|
return self._is_node_type(return_type)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_node_type(self, typ: Any) -> bool:
|
||||||
|
"""检查类型是否是 Node 或 FunctionNode 的子类。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
typ: 要检查的类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否是 Node 相关类型
|
||||||
|
"""
|
||||||
|
if not isinstance(typ, type):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return issubclass(typ, (Node, FunctionNode))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""返回已注册函数数量。"""
|
||||||
|
return len(self._functions)
|
||||||
|
|
||||||
|
def __contains__(self, name: str) -> bool:
|
||||||
|
"""检查是否包含某个函数名。"""
|
||||||
|
return name in self._functions
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""返回注册表字符串表示。"""
|
||||||
|
return f"FunctionRegistry({len(self._functions)} functions: {self.available_functions()[:5]}...)"
|
||||||
Reference in New Issue
Block a user