Compare commits
3 Commits
9b826c1845
...
05d0c90312
| Author | SHA1 | Date | |
|---|---|---|---|
| 05d0c90312 | |||
| 77e4e94e05 | |||
| 1c0c4a0de1 |
@@ -52,6 +52,22 @@ from src.factors.engine import (
|
||||
ComputeEngine,
|
||||
)
|
||||
|
||||
from src.factors.parser import FormulaParser
|
||||
|
||||
from src.factors.registry import FunctionRegistry
|
||||
|
||||
from src.factors.exceptions import (
|
||||
FormulaParseError,
|
||||
UnknownFunctionError,
|
||||
InvalidSyntaxError,
|
||||
EmptyExpressionError,
|
||||
RegistryError,
|
||||
DuplicateFunctionError,
|
||||
)
|
||||
|
||||
# 保持向后兼容:factor_engine.py 中的类也可以通过 src.factors.engine 访问
|
||||
# 例如:from src.factors.engine import FactorEngine
|
||||
|
||||
__all__ = [
|
||||
# DSL 层
|
||||
"Node",
|
||||
@@ -73,4 +89,15 @@ __all__ = [
|
||||
"DataRouter",
|
||||
"ExecutionPlanner",
|
||||
"ComputeEngine",
|
||||
# 解析器 (Phase 1 新增)
|
||||
"FormulaParser",
|
||||
# 注册表 (Phase 1 新增)
|
||||
"FunctionRegistry",
|
||||
# 异常类 (Phase 1 新增)
|
||||
"FormulaParseError",
|
||||
"UnknownFunctionError",
|
||||
"InvalidSyntaxError",
|
||||
"EmptyExpressionError",
|
||||
"RegistryError",
|
||||
"DuplicateFunctionError",
|
||||
]
|
||||
|
||||
@@ -1,817 +0,0 @@
|
||||
"""FactorEngine - 因子计算引擎统一入口。
|
||||
|
||||
提供从表达式注册到结果输出的完整执行链路:
|
||||
接收研究员的表达式 -> 调用编译器解析依赖 -> 调用路由器连接数据库拉取并组装核心宽表
|
||||
-> 调用翻译器生成物理执行计划 -> 将计划提交给计算引擎执行并行运算。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.dsl import (
|
||||
Node,
|
||||
Symbol,
|
||||
FunctionNode,
|
||||
BinaryOpNode,
|
||||
UnaryOpNode,
|
||||
Constant,
|
||||
)
|
||||
from src.factors.compiler import DependencyExtractor
|
||||
from src.factors.translator import PolarsTranslator
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSpec:
|
||||
"""数据规格定义。
|
||||
|
||||
描述因子计算所需的数据表和字段。
|
||||
|
||||
Attributes:
|
||||
table: 数据表名称
|
||||
columns: 需要的字段列表
|
||||
lookback_days: 回看天数(用于时序计算)
|
||||
"""
|
||||
|
||||
table: str
|
||||
columns: List[str]
|
||||
lookback_days: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionPlan:
|
||||
"""执行计划。
|
||||
|
||||
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
|
||||
|
||||
Attributes:
|
||||
data_specs: 数据规格列表
|
||||
polars_expr: Polars 表达式
|
||||
dependencies: 依赖的原始字段
|
||||
output_name: 输出因子名称
|
||||
"""
|
||||
|
||||
data_specs: List[DataSpec]
|
||||
polars_expr: pl.Expr
|
||||
dependencies: Set[str]
|
||||
output_name: str
|
||||
|
||||
|
||||
class DataRouter:
|
||||
"""数据路由器 - 按需取数、组装核心宽表。
|
||||
|
||||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||||
支持内存数据源(用于测试)和真实数据库连接。
|
||||
|
||||
Attributes:
|
||||
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
|
||||
is_memory_mode: 是否为内存模式
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
|
||||
"""初始化数据路由器。
|
||||
|
||||
Args:
|
||||
data_source: 内存数据源,字典格式 {表名: DataFrame}
|
||||
为 None 时自动连接 DuckDB 数据库
|
||||
"""
|
||||
self.data_source = data_source or {}
|
||||
self.is_memory_mode = data_source is not None
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 数据库模式下初始化 Storage
|
||||
if not self.is_memory_mode:
|
||||
self._storage = Storage()
|
||||
else:
|
||||
self._storage = None
|
||||
|
||||
def fetch_data(
|
||||
self,
|
||||
data_specs: List[DataSpec],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""根据数据规格获取并组装核心宽表。
|
||||
|
||||
Args:
|
||||
data_specs: 数据规格列表
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
stock_codes: 股票代码列表,None 表示全市场
|
||||
|
||||
Returns:
|
||||
组装好的核心宽表 DataFrame
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据源中缺少必要的表或字段时
|
||||
"""
|
||||
if not data_specs:
|
||||
raise ValueError("数据规格不能为空")
|
||||
|
||||
# 收集所有需要的表和字段
|
||||
required_tables: Dict[str, Set[str]] = {}
|
||||
max_lookback = 0
|
||||
|
||||
for spec in data_specs:
|
||||
if spec.table not in required_tables:
|
||||
required_tables[spec.table] = set()
|
||||
required_tables[spec.table].update(spec.columns)
|
||||
max_lookback = max(max_lookback, spec.lookback_days)
|
||||
|
||||
# 调整日期范围以包含回看期
|
||||
adjusted_start = self._adjust_start_date(start_date, max_lookback)
|
||||
|
||||
# 从数据源获取各表数据
|
||||
table_data = {}
|
||||
for table_name, columns in required_tables.items():
|
||||
df = self._load_table(
|
||||
table_name=table_name,
|
||||
columns=list(columns),
|
||||
start_date=adjusted_start,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
table_data[table_name] = df
|
||||
|
||||
# 组装核心宽表
|
||||
core_table = self._assemble_wide_table(table_data, required_tables)
|
||||
|
||||
# 过滤到实际请求日期范围
|
||||
core_table = core_table.filter(
|
||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
return core_table
|
||||
|
||||
def _load_table(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""加载单个表的数据。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
columns: 需要的字段
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
stock_codes: 股票代码过滤
|
||||
|
||||
Returns:
|
||||
过滤后的 DataFrame
|
||||
"""
|
||||
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
if self.is_memory_mode:
|
||||
df = self._load_from_memory(
|
||||
table_name, columns, start_date, end_date, stock_codes
|
||||
)
|
||||
else:
|
||||
df = self._load_from_database(
|
||||
table_name, columns, start_date, end_date, stock_codes
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._cache[cache_key] = df
|
||||
|
||||
return df
|
||||
|
||||
def _load_from_memory(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""从内存数据源加载数据。"""
|
||||
if table_name not in self.data_source:
|
||||
raise ValueError(f"内存数据源中缺少表: {table_name}")
|
||||
|
||||
df = self.data_source[table_name]
|
||||
|
||||
# 确保必需字段存在
|
||||
for col in columns:
|
||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||
|
||||
# 过滤日期和股票
|
||||
df = df.filter(
|
||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
if stock_codes is not None:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
# 选择需要的列
|
||||
select_cols = ["ts_code", "trade_date"] + [
|
||||
c for c in columns if c in df.columns
|
||||
]
|
||||
return df.select(select_cols)
|
||||
|
||||
def _load_from_database(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""从 DuckDB 数据库加载数据。
|
||||
|
||||
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
|
||||
"""
|
||||
if self._storage is None:
|
||||
raise RuntimeError("Storage 未初始化")
|
||||
|
||||
# 检查表是否存在
|
||||
if not self._storage.exists(table_name):
|
||||
raise ValueError(f"数据库中不存在表: {table_name}")
|
||||
|
||||
# 构建查询参数
|
||||
# Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况
|
||||
if stock_codes is not None and len(stock_codes) == 1:
|
||||
ts_code_filter = stock_codes[0]
|
||||
else:
|
||||
ts_code_filter = None
|
||||
|
||||
try:
|
||||
# 从数据库加载原始数据
|
||||
df = self._storage.load_polars(
|
||||
name=table_name,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
ts_code=ts_code_filter,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
|
||||
|
||||
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||||
if stock_codes is not None and len(stock_codes) > 1:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
# 检查必需字段
|
||||
for col in columns:
|
||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||
|
||||
# 选择需要的列
|
||||
select_cols = ["ts_code", "trade_date"] + [
|
||||
c for c in columns if c in df.columns
|
||||
]
|
||||
|
||||
return df.select(select_cols)
|
||||
|
||||
def _assemble_wide_table(
|
||||
self,
|
||||
table_data: Dict[str, pl.DataFrame],
|
||||
required_tables: Dict[str, Set[str]],
|
||||
) -> pl.DataFrame:
|
||||
"""组装多表数据为核心宽表。
|
||||
|
||||
使用 left join 合并各表数据,以第一个表为基准。
|
||||
|
||||
Args:
|
||||
table_data: 表名到 DataFrame 的映射
|
||||
required_tables: 表名到字段集合的映射
|
||||
|
||||
Returns:
|
||||
组装后的宽表
|
||||
"""
|
||||
if not table_data:
|
||||
raise ValueError("没有数据可组装")
|
||||
|
||||
# 以第一个表为基准
|
||||
base_table_name = list(table_data.keys())[0]
|
||||
result = table_data[base_table_name]
|
||||
|
||||
# 与其他表 join
|
||||
for table_name, df in table_data.items():
|
||||
if table_name == base_table_name:
|
||||
continue
|
||||
|
||||
# 使用 ts_code 和 trade_date 作为 join 键
|
||||
result = result.join(
|
||||
df,
|
||||
on=["ts_code", "trade_date"],
|
||||
how="left",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
|
||||
"""根据回看天数调整开始日期。
|
||||
|
||||
Args:
|
||||
start_date: 原始开始日期 (YYYYMMDD)
|
||||
lookback_days: 需要回看的交易日数
|
||||
|
||||
Returns:
|
||||
调整后的开始日期
|
||||
"""
|
||||
# 简化的日期调整:假设每月30天,向前推移
|
||||
# 实际应用中应该使用交易日历
|
||||
year = int(start_date[:4])
|
||||
month = int(start_date[4:6])
|
||||
day = int(start_date[6:8])
|
||||
|
||||
total_days = lookback_days + 30 # 额外缓冲
|
||||
|
||||
day -= total_days
|
||||
while day <= 0:
|
||||
month -= 1
|
||||
if month <= 0:
|
||||
month = 12
|
||||
year -= 1
|
||||
day += 30
|
||||
|
||||
return f"{year:04d}{month:02d}{day:02d}"
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除数据缓存。"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
# 数据库模式下清理 Storage 连接(可选)
|
||||
if not self.is_memory_mode and self._storage is not None:
|
||||
# Storage 使用单例模式,不需要关闭连接
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionPlanner:
|
||||
"""执行计划生成器。
|
||||
|
||||
整合编译器和翻译器,生成完整的执行计划。
|
||||
|
||||
Attributes:
|
||||
compiler: 依赖提取器
|
||||
translator: Polars 翻译器
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化执行计划生成器。"""
|
||||
self.compiler = DependencyExtractor()
|
||||
self.translator = PolarsTranslator()
|
||||
|
||||
def create_plan(
|
||||
self,
|
||||
expression: Node,
|
||||
output_name: str = "factor",
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""从表达式创建执行计划。
|
||||
|
||||
Args:
|
||||
expression: DSL 表达式节点
|
||||
output_name: 输出因子名称
|
||||
data_specs: 预定义的数据规格,None 时自动推导
|
||||
|
||||
Returns:
|
||||
执行计划对象
|
||||
"""
|
||||
# 1. 提取依赖
|
||||
dependencies = self.compiler.extract_dependencies(expression)
|
||||
|
||||
# 2. 翻译为 Polars 表达式
|
||||
polars_expr = self.translator.translate(expression)
|
||||
|
||||
# 3. 推导或验证数据规格
|
||||
if data_specs is None:
|
||||
data_specs = self._infer_data_specs(dependencies, expression)
|
||||
|
||||
return ExecutionPlan(
|
||||
data_specs=data_specs,
|
||||
polars_expr=polars_expr,
|
||||
dependencies=dependencies,
|
||||
output_name=output_name,
|
||||
)
|
||||
|
||||
def _infer_data_specs(
|
||||
self,
|
||||
dependencies: Set[str],
|
||||
expression: Node,
|
||||
) -> List[DataSpec]:
|
||||
"""从依赖推导数据规格。
|
||||
|
||||
根据表达式中的函数类型推断回看天数需求。
|
||||
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
||||
默认从 pro_bar 表获取。
|
||||
|
||||
Args:
|
||||
dependencies: 依赖的字段集合
|
||||
expression: 表达式节点
|
||||
|
||||
Returns:
|
||||
数据规格列表
|
||||
"""
|
||||
# 计算最大回看窗口
|
||||
max_window = self._extract_max_window(expression)
|
||||
lookback_days = max(1, max_window)
|
||||
|
||||
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
||||
pro_bar_fields = {
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
"amount",
|
||||
"pre_close",
|
||||
"change",
|
||||
"pct_chg",
|
||||
"turnover_rate",
|
||||
"volume_ratio",
|
||||
}
|
||||
|
||||
# 将依赖分为 pro_bar 字段和其他字段
|
||||
pro_bar_deps = dependencies & pro_bar_fields
|
||||
other_deps = dependencies - pro_bar_fields
|
||||
|
||||
data_specs = []
|
||||
|
||||
# pro_bar 表的数据规格
|
||||
if pro_bar_deps:
|
||||
data_specs.append(
|
||||
DataSpec(
|
||||
table="pro_bar",
|
||||
columns=sorted(pro_bar_deps),
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
)
|
||||
|
||||
# 其他字段从 daily 表获取
|
||||
if other_deps:
|
||||
data_specs.append(
|
||||
DataSpec(
|
||||
table="daily",
|
||||
columns=sorted(other_deps),
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
)
|
||||
|
||||
return data_specs
|
||||
|
||||
def _extract_max_window(self, node: Node) -> int:
|
||||
"""从表达式中提取最大窗口大小。
|
||||
|
||||
Args:
|
||||
node: AST 节点
|
||||
|
||||
Returns:
|
||||
最大窗口大小,无时序函数返回 1
|
||||
"""
|
||||
if isinstance(node, FunctionNode):
|
||||
window = 1
|
||||
# 检查函数参数中的窗口大小
|
||||
for arg in node.args:
|
||||
if (
|
||||
isinstance(arg, Constant)
|
||||
and isinstance(arg.value, int)
|
||||
and arg.value > window
|
||||
):
|
||||
window = arg.value
|
||||
|
||||
# 递归检查子表达式
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node) and not isinstance(arg, Constant):
|
||||
window = max(window, self._extract_max_window(arg))
|
||||
|
||||
return window
|
||||
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
return max(
|
||||
self._extract_max_window(node.left),
|
||||
self._extract_max_window(node.right),
|
||||
)
|
||||
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
return self._extract_max_window(node.operand)
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
class ComputeEngine:
|
||||
"""计算引擎 - 执行并行运算。
|
||||
|
||||
负责将执行计划应用到数据上,支持并行计算。
|
||||
|
||||
Attributes:
|
||||
max_workers: 最大并行工作线程数
|
||||
use_processes: 是否使用进程池(CPU 密集型任务)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int = 4,
|
||||
use_processes: bool = False,
|
||||
) -> None:
|
||||
"""初始化计算引擎。
|
||||
|
||||
Args:
|
||||
max_workers: 最大并行工作线程数
|
||||
use_processes: 是否使用进程池代替线程池
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.use_processes = use_processes
|
||||
|
||||
def execute(
|
||||
self,
|
||||
plan: ExecutionPlan,
|
||||
data: pl.DataFrame,
|
||||
) -> pl.DataFrame:
|
||||
"""执行计算计划。
|
||||
|
||||
Args:
|
||||
plan: 执行计划
|
||||
data: 输入数据(核心宽表)
|
||||
|
||||
Returns:
|
||||
包含因子结果的 DataFrame
|
||||
"""
|
||||
# 检查依赖字段是否存在
|
||||
missing_cols = plan.dependencies - set(data.columns)
|
||||
if missing_cols:
|
||||
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
|
||||
|
||||
# 执行计算
|
||||
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
|
||||
|
||||
return result
|
||||
|
||||
def execute_batch(
|
||||
self,
|
||||
plans: List[ExecutionPlan],
|
||||
data: pl.DataFrame,
|
||||
) -> pl.DataFrame:
|
||||
"""批量执行多个计算计划。
|
||||
|
||||
Args:
|
||||
plans: 执行计划列表
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含所有因子结果的 DataFrame
|
||||
"""
|
||||
result = data
|
||||
|
||||
for plan in plans:
|
||||
result = self.execute(plan, result)
|
||||
|
||||
return result
|
||||
|
||||
def execute_parallel(
|
||||
self,
|
||||
plans: List[ExecutionPlan],
|
||||
data: pl.DataFrame,
|
||||
) -> pl.DataFrame:
|
||||
"""并行执行多个计算计划。
|
||||
|
||||
Args:
|
||||
plans: 执行计划列表
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含所有因子结果的 DataFrame
|
||||
"""
|
||||
# 检查计划间依赖
|
||||
independent_plans = []
|
||||
dependent_plans = []
|
||||
available_cols = set(data.columns)
|
||||
|
||||
for plan in plans:
|
||||
if plan.dependencies <= available_cols:
|
||||
independent_plans.append(plan)
|
||||
available_cols.add(plan.output_name)
|
||||
else:
|
||||
dependent_plans.append(plan)
|
||||
|
||||
# 并行执行独立计划
|
||||
if independent_plans:
|
||||
ExecutorClass = (
|
||||
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
||||
)
|
||||
|
||||
with ExecutorClass(max_workers=self.max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(self._execute_single, plan, data): plan
|
||||
for plan in independent_plans
|
||||
}
|
||||
|
||||
results = []
|
||||
for future in futures:
|
||||
plan = futures[future]
|
||||
try:
|
||||
result_col = future.result()
|
||||
results.append((plan.output_name, result_col))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
|
||||
|
||||
# 合并结果
|
||||
for name, series in results:
|
||||
data = data.with_columns([series.alias(name)])
|
||||
|
||||
# 顺序执行依赖计划
|
||||
for plan in dependent_plans:
|
||||
data = self.execute(plan, data)
|
||||
|
||||
return data
|
||||
|
||||
def _execute_single(
|
||||
self,
|
||||
plan: ExecutionPlan,
|
||||
data: pl.DataFrame,
|
||||
) -> pl.Series:
|
||||
"""执行单个计划并返回结果列。
|
||||
|
||||
Args:
|
||||
plan: 执行计划
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
计算结果序列
|
||||
"""
|
||||
result = self.execute(plan, data)
|
||||
return result[plan.output_name]
|
||||
|
||||
|
||||
class FactorEngine:
|
||||
"""因子计算引擎 - 系统统一入口。
|
||||
|
||||
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
||||
|
||||
执行流程:
|
||||
1. 注册表达式 -> 调用编译器解析依赖
|
||||
2. 调用路由器连接数据库拉取并组装核心宽表
|
||||
3. 调用翻译器生成物理执行计划
|
||||
4. 将计划提交给计算引擎执行并行运算
|
||||
5. 返回包含因子结果的数据表
|
||||
|
||||
Attributes:
|
||||
router: 数据路由器
|
||||
planner: 执行计划生成器
|
||||
compute_engine: 计算引擎
|
||||
registered_expressions: 注册的表达式字典
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
||||
max_workers: int = 4,
|
||||
) -> None:
|
||||
"""初始化因子引擎。
|
||||
|
||||
Args:
|
||||
data_source: 内存数据源,为 None 时使用数据库连接
|
||||
max_workers: 并行计算的最大工作线程数
|
||||
"""
|
||||
self.router = DataRouter(data_source)
|
||||
self.planner = ExecutionPlanner()
|
||||
self.compute_engine = ComputeEngine(max_workers=max_workers)
|
||||
self.registered_expressions: Dict[str, Node] = {}
|
||||
self._plans: Dict[str, ExecutionPlan] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
expression: Node,
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> FactorEngine:
|
||||
"""注册因子表达式。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
expression: DSL 表达式
|
||||
data_specs: 数据规格,None 时自动推导
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
|
||||
Example:
|
||||
>>> from src.factors.api import close, ts_mean
|
||||
>>> engine = FactorEngine()
|
||||
>>> engine.register("ma20", ts_mean(close, 20))
|
||||
"""
|
||||
self.registered_expressions[name] = expression
|
||||
|
||||
# 预创建执行计划
|
||||
plan = self.planner.create_plan(
|
||||
expression=expression,
|
||||
output_name=name,
|
||||
data_specs=data_specs,
|
||||
)
|
||||
self._plans[name] = plan
|
||||
|
||||
return self
|
||||
|
||||
def compute(
|
||||
self,
|
||||
factor_names: Union[str, List[str]],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""计算指定因子的值。
|
||||
|
||||
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
|
||||
|
||||
Args:
|
||||
factor_names: 因子名称或名称列表
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
stock_codes: 股票代码列表,None 表示全市场
|
||||
|
||||
Returns:
|
||||
包含因子结果的数据表
|
||||
|
||||
Raises:
|
||||
ValueError: 当因子未注册或数据不足时
|
||||
|
||||
Example:
|
||||
>>> result = engine.compute("ma20", "20240101", "20240131")
|
||||
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
|
||||
"""
|
||||
# 标准化因子名称
|
||||
if isinstance(factor_names, str):
|
||||
factor_names = [factor_names]
|
||||
|
||||
# 1. 获取执行计划
|
||||
plans = []
|
||||
for name in factor_names:
|
||||
if name not in self._plans:
|
||||
raise ValueError(f"因子未注册: {name}")
|
||||
plans.append(self._plans[name])
|
||||
|
||||
# 2. 合并数据规格并获取数据
|
||||
all_specs = []
|
||||
for plan in plans:
|
||||
all_specs.extend(plan.data_specs)
|
||||
|
||||
# 3. 从路由器获取核心宽表
|
||||
core_data = self.router.fetch_data(
|
||||
data_specs=all_specs,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
|
||||
if len(core_data) == 0:
|
||||
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
||||
|
||||
# 4. 执行计算
|
||||
if len(plans) == 1:
|
||||
result = self.compute_engine.execute(plans[0], core_data)
|
||||
else:
|
||||
result = self.compute_engine.execute_batch(plans, core_data)
|
||||
|
||||
return result
|
||||
|
||||
def list_registered(self) -> List[str]:
|
||||
"""获取已注册的因子列表。
|
||||
|
||||
Returns:
|
||||
因子名称列表
|
||||
"""
|
||||
return list(self.registered_expressions.keys())
|
||||
|
||||
def get_expression(self, name: str) -> Optional[Node]:
|
||||
"""获取已注册的表达式。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
|
||||
Returns:
|
||||
表达式节点,未注册时返回 None
|
||||
"""
|
||||
return self.registered_expressions.get(name)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有注册的表达式和缓存。"""
|
||||
self.registered_expressions.clear()
|
||||
self._plans.clear()
|
||||
self.router.clear_cache()
|
||||
|
||||
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
|
||||
"""预览因子的执行计划。
|
||||
|
||||
Args:
|
||||
factor_name: 因子名称
|
||||
|
||||
Returns:
|
||||
执行计划,未注册时返回 None
|
||||
"""
|
||||
return self._plans.get(factor_name)
|
||||
28
src/factors/engine/__init__.py
Normal file
28
src/factors/engine/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""因子计算引擎模块。
|
||||
|
||||
提供完整的因子计算引擎组件:
|
||||
- DataSpec: 数据规格定义
|
||||
- ExecutionPlan: 执行计划
|
||||
- DataRouter: 数据路由器
|
||||
- ExecutionPlanner: 执行计划生成器
|
||||
- ComputeEngine: 计算引擎
|
||||
- FactorEngine: 因子计算引擎(统一入口)
|
||||
"""
|
||||
|
||||
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||
from src.factors.engine.data_router import DataRouter
|
||||
from src.factors.engine.planner import ExecutionPlanner
|
||||
from src.factors.engine.compute_engine import ComputeEngine
|
||||
from src.factors.engine.factor_engine import FactorEngine
|
||||
|
||||
__all__ = [
|
||||
"DataSpec",
|
||||
"ExecutionPlan",
|
||||
"DataRouter",
|
||||
"ExecutionPlanner",
|
||||
"ComputeEngine",
|
||||
"FactorEngine",
|
||||
]
|
||||
|
||||
# 类型导出(用于类型注解)
|
||||
# FunctionRegistry 从 src.factors.registry 导入
|
||||
155
src/factors/engine/compute_engine.py
Normal file
155
src/factors/engine/compute_engine.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""计算引擎。
|
||||
|
||||
执行并行运算,负责将执行计划应用到数据上。
|
||||
"""
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.engine.data_spec import ExecutionPlan
|
||||
|
||||
|
||||
class ComputeEngine:
|
||||
"""计算引擎 - 执行并行运算。
|
||||
|
||||
负责将执行计划应用到数据上,支持并行计算。
|
||||
|
||||
Attributes:
|
||||
max_workers: 最大并行工作线程数
|
||||
use_processes: 是否使用进程池(CPU 密集型任务)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int = 4,
|
||||
use_processes: bool = False,
|
||||
) -> None:
|
||||
"""初始化计算引擎。
|
||||
|
||||
Args:
|
||||
max_workers: 最大并行工作线程数
|
||||
use_processes: 是否使用进程池代替线程池
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.use_processes = use_processes
|
||||
|
||||
def execute(
|
||||
self,
|
||||
plan: ExecutionPlan,
|
||||
data: pl.DataFrame,
|
||||
) -> pl.DataFrame:
|
||||
"""执行计算计划。
|
||||
|
||||
Args:
|
||||
plan: 执行计划
|
||||
data: 输入数据(核心宽表)
|
||||
|
||||
Returns:
|
||||
包含因子结果的 DataFrame
|
||||
"""
|
||||
# 检查依赖字段是否存在
|
||||
missing_cols = plan.dependencies - set(data.columns)
|
||||
if missing_cols:
|
||||
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
|
||||
|
||||
# 执行计算
|
||||
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
|
||||
|
||||
return result
|
||||
|
||||
def execute_batch(
|
||||
self,
|
||||
plans: List[ExecutionPlan],
|
||||
data: pl.DataFrame,
|
||||
) -> pl.DataFrame:
|
||||
"""批量执行多个计算计划。
|
||||
|
||||
Args:
|
||||
plans: 执行计划列表
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含所有因子结果的 DataFrame
|
||||
"""
|
||||
result = data
|
||||
|
||||
for plan in plans:
|
||||
result = self.execute(plan, result)
|
||||
|
||||
return result
|
||||
|
||||
def execute_parallel(
|
||||
self,
|
||||
plans: List[ExecutionPlan],
|
||||
data: pl.DataFrame,
|
||||
) -> pl.DataFrame:
|
||||
"""并行执行多个计算计划。
|
||||
|
||||
Args:
|
||||
plans: 执行计划列表
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含所有因子结果的 DataFrame
|
||||
"""
|
||||
# 检查计划间依赖
|
||||
independent_plans = []
|
||||
dependent_plans = []
|
||||
available_cols = set(data.columns)
|
||||
|
||||
for plan in plans:
|
||||
if plan.dependencies <= available_cols:
|
||||
independent_plans.append(plan)
|
||||
available_cols.add(plan.output_name)
|
||||
else:
|
||||
dependent_plans.append(plan)
|
||||
|
||||
# 并行执行独立计划
|
||||
if independent_plans:
|
||||
ExecutorClass = (
|
||||
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
|
||||
)
|
||||
|
||||
with ExecutorClass(max_workers=self.max_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(self._execute_single, plan, data): plan
|
||||
for plan in independent_plans
|
||||
}
|
||||
|
||||
results = []
|
||||
for future in futures:
|
||||
plan = futures[future]
|
||||
try:
|
||||
result_col = future.result()
|
||||
results.append((plan.output_name, result_col))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
|
||||
|
||||
# 合并结果
|
||||
for name, series in results:
|
||||
data = data.with_columns([series.alias(name)])
|
||||
|
||||
# 顺序执行依赖计划
|
||||
for plan in dependent_plans:
|
||||
data = self.execute(plan, data)
|
||||
|
||||
return data
|
||||
|
||||
def _execute_single(
|
||||
self,
|
||||
plan: ExecutionPlan,
|
||||
data: pl.DataFrame,
|
||||
) -> pl.Series:
|
||||
"""执行单个计划并返回结果列。
|
||||
|
||||
Args:
|
||||
plan: 执行计划
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
计算结果序列
|
||||
"""
|
||||
result = self.execute(plan, data)
|
||||
return result[plan.output_name]
|
||||
304
src/factors/engine/data_router.py
Normal file
304
src/factors/engine/data_router.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""数据路由器。
|
||||
|
||||
按需取数、组装核心宽表。
|
||||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||||
支持内存数据源(用于测试)和真实数据库连接。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
import threading
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.engine.data_spec import DataSpec
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
class DataRouter:
|
||||
"""数据路由器 - 按需取数、组装核心宽表。
|
||||
|
||||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||||
支持内存数据源(用于测试)和真实数据库连接。
|
||||
|
||||
Attributes:
|
||||
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
|
||||
is_memory_mode: 是否为内存模式
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
|
||||
"""初始化数据路由器。
|
||||
|
||||
Args:
|
||||
data_source: 内存数据源,字典格式 {表名: DataFrame}
|
||||
为 None 时自动连接 DuckDB 数据库
|
||||
"""
|
||||
self.data_source = data_source or {}
|
||||
self.is_memory_mode = data_source is not None
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 数据库模式下初始化 Storage
|
||||
if not self.is_memory_mode:
|
||||
self._storage = Storage()
|
||||
else:
|
||||
self._storage = None
|
||||
|
||||
def fetch_data(
|
||||
self,
|
||||
data_specs: List[DataSpec],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""根据数据规格获取并组装核心宽表。
|
||||
|
||||
Args:
|
||||
data_specs: 数据规格列表
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
stock_codes: 股票代码列表,None 表示全市场
|
||||
|
||||
Returns:
|
||||
组装好的核心宽表 DataFrame
|
||||
|
||||
Raises:
|
||||
ValueError: 当数据源中缺少必要的表或字段时
|
||||
"""
|
||||
if not data_specs:
|
||||
raise ValueError("数据规格不能为空")
|
||||
|
||||
# 收集所有需要的表和字段
|
||||
required_tables: Dict[str, Set[str]] = {}
|
||||
max_lookback = 0
|
||||
|
||||
for spec in data_specs:
|
||||
if spec.table not in required_tables:
|
||||
required_tables[spec.table] = set()
|
||||
required_tables[spec.table].update(spec.columns)
|
||||
max_lookback = max(max_lookback, spec.lookback_days)
|
||||
|
||||
# 调整日期范围以包含回看期
|
||||
adjusted_start = self._adjust_start_date(start_date, max_lookback)
|
||||
|
||||
# 从数据源获取各表数据
|
||||
table_data = {}
|
||||
for table_name, columns in required_tables.items():
|
||||
df = self._load_table(
|
||||
table_name=table_name,
|
||||
columns=list(columns),
|
||||
start_date=adjusted_start,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
table_data[table_name] = df
|
||||
|
||||
# 组装核心宽表
|
||||
core_table = self._assemble_wide_table(table_data, required_tables)
|
||||
|
||||
# 过滤到实际请求日期范围
|
||||
core_table = core_table.filter(
|
||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
return core_table
|
||||
|
||||
def _load_table(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""加载单个表的数据。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
columns: 需要的字段
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
stock_codes: 股票代码过滤
|
||||
|
||||
Returns:
|
||||
过滤后的 DataFrame
|
||||
"""
|
||||
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
if self.is_memory_mode:
|
||||
df = self._load_from_memory(
|
||||
table_name, columns, start_date, end_date, stock_codes
|
||||
)
|
||||
else:
|
||||
df = self._load_from_database(
|
||||
table_name, columns, start_date, end_date, stock_codes
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._cache[cache_key] = df
|
||||
|
||||
return df
|
||||
|
||||
def _load_from_memory(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""从内存数据源加载数据。"""
|
||||
if table_name not in self.data_source:
|
||||
raise ValueError(f"内存数据源中缺少表: {table_name}")
|
||||
|
||||
df = self.data_source[table_name]
|
||||
|
||||
# 确保必需字段存在
|
||||
for col in columns:
|
||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||
|
||||
# 过滤日期和股票
|
||||
df = df.filter(
|
||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
if stock_codes is not None:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
# 选择需要的列
|
||||
select_cols = ["ts_code", "trade_date"] + [
|
||||
c for c in columns if c in df.columns
|
||||
]
|
||||
return df.select(select_cols)
|
||||
|
||||
def _load_from_database(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""从 DuckDB 数据库加载数据。
|
||||
|
||||
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
|
||||
"""
|
||||
if self._storage is None:
|
||||
raise RuntimeError("Storage 未初始化")
|
||||
|
||||
# 检查表是否存在
|
||||
if not self._storage.exists(table_name):
|
||||
raise ValueError(f"数据库中不存在表: {table_name}")
|
||||
|
||||
# 构建查询参数
|
||||
# Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况
|
||||
if stock_codes is not None and len(stock_codes) == 1:
|
||||
ts_code_filter = stock_codes[0]
|
||||
else:
|
||||
ts_code_filter = None
|
||||
|
||||
try:
|
||||
# 从数据库加载原始数据
|
||||
df = self._storage.load_polars(
|
||||
name=table_name,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
ts_code=ts_code_filter,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
|
||||
|
||||
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||||
if stock_codes is not None and len(stock_codes) > 1:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
# 检查必需字段
|
||||
for col in columns:
|
||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||
|
||||
# 选择需要的列
|
||||
select_cols = ["ts_code", "trade_date"] + [
|
||||
c for c in columns if c in df.columns
|
||||
]
|
||||
|
||||
return df.select(select_cols)
|
||||
|
||||
def _assemble_wide_table(
|
||||
self,
|
||||
table_data: Dict[str, pl.DataFrame],
|
||||
required_tables: Dict[str, Set[str]],
|
||||
) -> pl.DataFrame:
|
||||
"""组装多表数据为核心宽表。
|
||||
|
||||
使用 left join 合并各表数据,以第一个表为基准。
|
||||
|
||||
Args:
|
||||
table_data: 表名到 DataFrame 的映射
|
||||
required_tables: 表名到字段集合的映射
|
||||
|
||||
Returns:
|
||||
组装后的宽表
|
||||
"""
|
||||
if not table_data:
|
||||
raise ValueError("没有数据可组装")
|
||||
|
||||
# 以第一个表为基准
|
||||
base_table_name = list(table_data.keys())[0]
|
||||
result = table_data[base_table_name]
|
||||
|
||||
# 与其他表 join
|
||||
for table_name, df in table_data.items():
|
||||
if table_name == base_table_name:
|
||||
continue
|
||||
|
||||
# 使用 ts_code 和 trade_date 作为 join 键
|
||||
result = result.join(
|
||||
df,
|
||||
on=["ts_code", "trade_date"],
|
||||
how="left",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
|
||||
"""根据回看天数调整开始日期。
|
||||
|
||||
Args:
|
||||
start_date: 原始开始日期 (YYYYMMDD)
|
||||
lookback_days: 需要回看的交易日数
|
||||
|
||||
Returns:
|
||||
调整后的开始日期
|
||||
"""
|
||||
# 简化的日期调整:假设每月30天,向前推移
|
||||
# 实际应用中应该使用交易日历
|
||||
year = int(start_date[:4])
|
||||
month = int(start_date[4:6])
|
||||
day = int(start_date[6:8])
|
||||
|
||||
total_days = lookback_days + 30 # 额外缓冲
|
||||
|
||||
day -= total_days
|
||||
while day <= 0:
|
||||
month -= 1
|
||||
if month <= 0:
|
||||
month = 12
|
||||
year -= 1
|
||||
day += 30
|
||||
|
||||
return f"{year:04d}{month:02d}{day:02d}"
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除数据缓存。"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
# 数据库模式下清理 Storage 连接(可选)
|
||||
if not self.is_memory_mode and self._storage is not None:
|
||||
# Storage 使用单例模式,不需要关闭连接
|
||||
pass
|
||||
47
src/factors/engine/data_spec.py
Normal file
47
src/factors/engine/data_spec.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""数据规格和执行计划定义。
|
||||
|
||||
定义因子计算所需的数据规格和执行计划结构。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import polars as pl
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSpec:
|
||||
"""数据规格定义。
|
||||
|
||||
描述因子计算所需的数据表和字段。
|
||||
|
||||
Attributes:
|
||||
table: 数据表名称
|
||||
columns: 需要的字段列表
|
||||
lookback_days: 回看天数(用于时序计算)
|
||||
"""
|
||||
|
||||
table: str
|
||||
columns: List[str]
|
||||
lookback_days: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionPlan:
|
||||
"""执行计划。
|
||||
|
||||
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
|
||||
|
||||
Attributes:
|
||||
data_specs: 数据规格列表
|
||||
polars_expr: Polars 表达式
|
||||
dependencies: 依赖的原始字段
|
||||
output_name: 输出因子名称
|
||||
factor_dependencies: 依赖的其他因子名称(用于分步执行)
|
||||
"""
|
||||
|
||||
data_specs: List[DataSpec]
|
||||
polars_expr: pl.Expr
|
||||
dependencies: Set[str]
|
||||
output_name: str
|
||||
factor_dependencies: Set[str] = field(default_factory=set)
|
||||
513
src/factors/engine/factor_engine.py
Normal file
513
src/factors/engine/factor_engine.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""因子计算引擎 - 系统统一入口。
|
||||
|
||||
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
||||
|
||||
执行流程:
|
||||
1. 注册表达式 -> 调用编译器解析依赖
|
||||
2. 调用路由器连接数据库拉取并组装核心宽表
|
||||
3. 调用翻译器生成物理执行计划
|
||||
4. 将计划提交给计算引擎执行并行运算
|
||||
5. 返回包含因子结果的数据表
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
|
||||
|
||||
import polars as pl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.factors.registry import FunctionRegistry
|
||||
|
||||
from src.factors.dsl import (
|
||||
Node,
|
||||
Symbol,
|
||||
BinaryOpNode,
|
||||
UnaryOpNode,
|
||||
FunctionNode,
|
||||
)
|
||||
from src.factors.translator import PolarsTranslator
|
||||
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||
from src.factors.engine.data_router import DataRouter
|
||||
from src.factors.engine.planner import ExecutionPlanner
|
||||
from src.factors.engine.compute_engine import ComputeEngine
|
||||
|
||||
|
||||
class FactorEngine:
|
||||
"""因子计算引擎 - 系统统一入口。
|
||||
|
||||
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
||||
|
||||
执行流程:
|
||||
1. 注册表达式 -> 调用编译器解析依赖
|
||||
2. 调用路由器连接数据库拉取并组装核心宽表
|
||||
3. 调用翻译器生成物理执行计划
|
||||
4. 将计划提交给计算引擎执行并行运算
|
||||
5. 返回包含因子结果的数据表
|
||||
|
||||
Attributes:
|
||||
router: 数据路由器
|
||||
planner: 执行计划生成器
|
||||
compute_engine: 计算引擎
|
||||
registered_expressions: 注册的表达式字典
|
||||
_registry: 函数注册表
|
||||
_parser: 公式解析器
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
||||
max_workers: int = 4,
|
||||
registry: Optional["FunctionRegistry"] = None,
|
||||
) -> None:
|
||||
"""初始化因子引擎。
|
||||
|
||||
Args:
|
||||
data_source: 内存数据源,为 None 时使用数据库连接
|
||||
max_workers: 并行计算的最大工作线程数
|
||||
registry: 函数注册表,None 时创建独立实例
|
||||
"""
|
||||
from src.factors.registry import FunctionRegistry
|
||||
from src.factors.parser import FormulaParser
|
||||
|
||||
self.router = DataRouter(data_source)
|
||||
self.planner = ExecutionPlanner()
|
||||
self.compute_engine = ComputeEngine(max_workers=max_workers)
|
||||
self.registered_expressions: Dict[str, Node] = {}
|
||||
self._plans: Dict[str, ExecutionPlan] = {}
|
||||
|
||||
# 初始化注册表和解析器(支持注入外部注册表实现共享)
|
||||
self._registry = registry if registry is not None else FunctionRegistry()
|
||||
self._parser = FormulaParser(self._registry)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
expression: Node,
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> "FactorEngine":
|
||||
"""注册因子表达式。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
expression: DSL 表达式
|
||||
data_specs: 数据规格,None 时自动推导
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
|
||||
Example:
|
||||
>>> from src.factors.api import close, ts_mean
|
||||
>>> engine = FactorEngine()
|
||||
>>> engine.register("ma20", ts_mean(close, 20))
|
||||
"""
|
||||
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
|
||||
factor_deps = self._find_factor_dependencies(expression)
|
||||
|
||||
self.registered_expressions[name] = expression
|
||||
|
||||
# 预创建执行计划
|
||||
plan = self.planner.create_plan(
|
||||
expression=expression,
|
||||
output_name=name,
|
||||
data_specs=data_specs,
|
||||
)
|
||||
|
||||
# 添加因子依赖信息
|
||||
plan.factor_dependencies = factor_deps
|
||||
|
||||
self._plans[name] = plan
|
||||
|
||||
return self
|
||||
|
||||
def add_factor(
|
||||
self,
|
||||
name: str,
|
||||
expression: Union[str, Node],
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> "FactorEngine":
|
||||
"""注册因子(支持字符串或 Node 表达式)。
|
||||
|
||||
这是 register 方法的增强版,支持字符串表达式解析。
|
||||
向后兼容:register 方法保持不变,继续只接受 Node 类型。
|
||||
|
||||
遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
expression: 字符串表达式或 Node 对象
|
||||
data_specs: 可选的数据规格
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
|
||||
Raises:
|
||||
TypeError: 当 expression 类型不支持时
|
||||
FormulaParseError: 当字符串解析失败时(立即报错)
|
||||
|
||||
Example:
|
||||
>>> engine = FactorEngine()
|
||||
>>>
|
||||
>>> # 字符串方式(新功能)
|
||||
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||
>>>
|
||||
>>> # Node 方式(与 register 相同)
|
||||
>>> from src.factors.api import close, ts_mean
|
||||
>>> engine.add_factor("ma20", ts_mean(close, 20))
|
||||
>>>
|
||||
>>> # 复杂表达式
|
||||
>>> engine.add_factor("alpha1", "cs_rank(close / open)")
|
||||
>>>
|
||||
>>> # 链式调用
|
||||
>>> (engine
|
||||
... .add_factor("ma5", "ts_mean(close, 5)")
|
||||
... .add_factor("ma10", "ts_mean(close, 10)")
|
||||
... .add_factor("golden_cross", "ma5 > ma10"))
|
||||
"""
|
||||
if isinstance(expression, str):
|
||||
# Fail-Fast:立即解析,失败立即报错
|
||||
node = self._parser.parse(expression)
|
||||
elif isinstance(expression, Node):
|
||||
node = expression
|
||||
else:
|
||||
raise TypeError(
|
||||
f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}"
|
||||
)
|
||||
|
||||
# 委托给现有的 register 方法
|
||||
return self.register(name, node, data_specs)
|
||||
|
||||
def compute(
|
||||
self,
|
||||
factor_names: Union[str, List[str]],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""计算指定因子的值。
|
||||
|
||||
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
|
||||
|
||||
Args:
|
||||
factor_names: 因子名称或名称列表
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
stock_codes: 股票代码列表,None 表示全市场
|
||||
|
||||
Returns:
|
||||
包含因子结果的数据表
|
||||
|
||||
Raises:
|
||||
ValueError: 当因子未注册或数据不足时
|
||||
|
||||
Example:
|
||||
>>> result = engine.compute("ma20", "20240101", "20240131")
|
||||
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
|
||||
"""
|
||||
# 标准化因子名称
|
||||
if isinstance(factor_names, str):
|
||||
factor_names = [factor_names]
|
||||
|
||||
# 1. 获取执行计划
|
||||
plans = []
|
||||
for name in factor_names:
|
||||
if name not in self._plans:
|
||||
raise ValueError(f"因子未注册: {name}")
|
||||
plans.append(self._plans[name])
|
||||
|
||||
# 2. 合并数据规格并获取数据
|
||||
all_specs = []
|
||||
for plan in plans:
|
||||
all_specs.extend(plan.data_specs)
|
||||
|
||||
# 3. 从路由器获取核心宽表
|
||||
core_data = self.router.fetch_data(
|
||||
data_specs=all_specs,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
|
||||
if len(core_data) == 0:
|
||||
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
||||
|
||||
# 4. 按依赖顺序执行计算
|
||||
if len(plans) == 1:
|
||||
result = self.compute_engine.execute(plans[0], core_data)
|
||||
else:
|
||||
# 使用依赖感知的方式执行
|
||||
result = self._execute_with_dependencies(factor_names, core_data)
|
||||
|
||||
return result
|
||||
|
||||
def list_registered(self) -> List[str]:
|
||||
"""获取已注册的因子列表。
|
||||
|
||||
Returns:
|
||||
因子名称列表
|
||||
"""
|
||||
return list(self.registered_expressions.keys())
|
||||
|
||||
def get_expression(self, name: str) -> Optional[Node]:
|
||||
"""获取已注册的表达式。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
|
||||
Returns:
|
||||
表达式节点,未注册时返回 None
|
||||
"""
|
||||
return self.registered_expressions.get(name)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有注册的表达式和缓存。"""
|
||||
self.registered_expressions.clear()
|
||||
self._plans.clear()
|
||||
self.router.clear_cache()
|
||||
|
||||
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
|
||||
"""预览因子的执行计划。
|
||||
|
||||
Args:
|
||||
factor_name: 因子名称
|
||||
|
||||
Returns:
|
||||
执行计划,未注册时返回 None
|
||||
"""
|
||||
return self._plans.get(factor_name)
|
||||
|
||||
def _execute_with_dependencies(
|
||||
self,
|
||||
factor_names: List[str],
|
||||
core_data: pl.DataFrame,
|
||||
) -> pl.DataFrame:
|
||||
"""按依赖顺序执行因子计算。
|
||||
|
||||
支持 cs_rank 等需要依赖列已存在的场景。
|
||||
|
||||
Args:
|
||||
factor_names: 因子名称列表
|
||||
core_data: 核心宽表数据
|
||||
|
||||
Returns:
|
||||
包含所有因子结果的数据表
|
||||
"""
|
||||
# 1. 拓扑排序
|
||||
sorted_names = self._topological_sort(factor_names)
|
||||
|
||||
# 2. 按顺序执行
|
||||
result = core_data
|
||||
for name in sorted_names:
|
||||
plan = self._plans[name]
|
||||
|
||||
# 创建新的执行计划,引用已计算的依赖列
|
||||
new_plan = self._create_optimized_plan(plan, result)
|
||||
|
||||
# 执行计算
|
||||
result = self.compute_engine.execute(new_plan, result)
|
||||
|
||||
return result
|
||||
|
||||
def _create_optimized_plan(
|
||||
self,
|
||||
plan: ExecutionPlan,
|
||||
current_data: pl.DataFrame,
|
||||
) -> ExecutionPlan:
|
||||
"""创建优化的执行计划。
|
||||
|
||||
将表达式中已计算的依赖因子替换为列引用。
|
||||
|
||||
Args:
|
||||
plan: 原始执行计划
|
||||
current_data: 当前数据(包含已计算的依赖列)
|
||||
|
||||
Returns:
|
||||
新的执行计划
|
||||
"""
|
||||
from src.factors.dsl import Symbol
|
||||
|
||||
# 获取当前数据中已存在的列
|
||||
existing_cols = set(current_data.columns)
|
||||
|
||||
# 检查依赖列是否已存在
|
||||
deps_available = plan.factor_dependencies & existing_cols
|
||||
|
||||
if not deps_available:
|
||||
# 没有可用的依赖列,直接返回原计划
|
||||
return plan
|
||||
|
||||
# 获取原始表达式
|
||||
original_expr = self.registered_expressions[plan.output_name]
|
||||
|
||||
# 创建新的表达式,用 Symbol 引用替换依赖因子
|
||||
def replace_with_symbol(node: Node) -> Node:
|
||||
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
|
||||
from typing import Any
|
||||
|
||||
n: Any = node
|
||||
|
||||
# 检查当前节点是否等于某个已计算依赖因子
|
||||
for dep_name in deps_available:
|
||||
dep_expr = self.registered_expressions[dep_name]
|
||||
if self._expressions_equal(node, dep_expr):
|
||||
return Symbol(dep_name)
|
||||
|
||||
# 递归处理子节点
|
||||
if isinstance(n, BinaryOpNode):
|
||||
new_left = replace_with_symbol(n.left)
|
||||
new_right = replace_with_symbol(n.right)
|
||||
if new_left is not n.left or new_right is not n.right:
|
||||
return BinaryOpNode(n.op, new_left, new_right)
|
||||
elif isinstance(n, UnaryOpNode):
|
||||
new_operand = replace_with_symbol(n.operand)
|
||||
if new_operand is not n.operand:
|
||||
return UnaryOpNode(n.op, new_operand)
|
||||
elif isinstance(n, FunctionNode):
|
||||
new_args = [replace_with_symbol(arg) for arg in n.args]
|
||||
if any(
|
||||
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
|
||||
):
|
||||
return FunctionNode(n.func_name, *new_args)
|
||||
|
||||
return node
|
||||
|
||||
# 替换表达式
|
||||
new_expr = replace_with_symbol(original_expr)
|
||||
|
||||
# 重新翻译表达式
|
||||
translator = PolarsTranslator()
|
||||
new_polars_expr = translator.translate(new_expr)
|
||||
|
||||
# 更新依赖集合
|
||||
new_factor_deps = plan.factor_dependencies - deps_available
|
||||
new_deps = plan.dependencies | deps_available
|
||||
|
||||
return ExecutionPlan(
|
||||
data_specs=plan.data_specs,
|
||||
polars_expr=new_polars_expr,
|
||||
dependencies=new_deps,
|
||||
output_name=plan.output_name,
|
||||
factor_dependencies=new_factor_deps,
|
||||
)
|
||||
|
||||
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
|
||||
"""比较两个表达式是否相等。
|
||||
|
||||
用于检测因子间的依赖关系。
|
||||
|
||||
Args:
|
||||
expr1: 第一个表达式
|
||||
expr2: 第二个表达式
|
||||
|
||||
Returns:
|
||||
是否相等
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
e1: Any = expr1
|
||||
e2: Any = expr2
|
||||
|
||||
if type(e1) != type(e2):
|
||||
return False
|
||||
|
||||
if isinstance(e1, Symbol):
|
||||
return e1.name == e2.name
|
||||
|
||||
from src.factors.dsl import Constant
|
||||
|
||||
if isinstance(e1, Constant):
|
||||
return e1.value == e2.value
|
||||
|
||||
if isinstance(e1, BinaryOpNode):
|
||||
return (
|
||||
e1.op == e2.op
|
||||
and self._expressions_equal(e1.left, e2.left)
|
||||
and self._expressions_equal(e1.right, e2.right)
|
||||
)
|
||||
|
||||
if isinstance(e1, UnaryOpNode):
|
||||
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
|
||||
|
||||
if isinstance(e1, FunctionNode):
|
||||
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
|
||||
return False
|
||||
return all(
|
||||
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
|
||||
"""查找表达式依赖的其他因子。
|
||||
|
||||
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
|
||||
|
||||
Args:
|
||||
expression: 待检查的表达式
|
||||
|
||||
Returns:
|
||||
依赖的因子名称集合
|
||||
"""
|
||||
deps: Set[str] = set()
|
||||
|
||||
# 检查表达式本身是否等于某个已注册因子
|
||||
for name, registered_expr in self.registered_expressions.items():
|
||||
if self._expressions_equal(expression, registered_expr):
|
||||
deps.add(name)
|
||||
break
|
||||
|
||||
# 递归检查子节点
|
||||
if isinstance(expression, BinaryOpNode):
|
||||
deps.update(self._find_factor_dependencies(expression.left))
|
||||
deps.update(self._find_factor_dependencies(expression.right))
|
||||
elif isinstance(expression, UnaryOpNode):
|
||||
deps.update(self._find_factor_dependencies(expression.operand))
|
||||
elif isinstance(expression, FunctionNode):
|
||||
for arg in expression.args:
|
||||
deps.update(self._find_factor_dependencies(arg))
|
||||
|
||||
return deps
|
||||
|
||||
def _topological_sort(self, factor_names: List[str]) -> List[str]:
|
||||
"""按依赖关系对因子进行拓扑排序。
|
||||
|
||||
确保依赖的因子先被计算。
|
||||
|
||||
Args:
|
||||
factor_names: 因子名称列表
|
||||
|
||||
Returns:
|
||||
排序后的因子名称列表
|
||||
|
||||
Raises:
|
||||
ValueError: 当检测到循环依赖时
|
||||
"""
|
||||
# 构建依赖图
|
||||
graph: Dict[str, Set[str]] = {}
|
||||
in_degree: Dict[str, int] = {}
|
||||
|
||||
for name in factor_names:
|
||||
plan = self._plans[name]
|
||||
# 只考虑在本次计算范围内的依赖
|
||||
deps = plan.factor_dependencies & set(factor_names)
|
||||
graph[name] = deps
|
||||
in_degree[name] = len(deps)
|
||||
|
||||
# Kahn 算法
|
||||
result = []
|
||||
queue = [name for name, degree in in_degree.items() if degree == 0]
|
||||
|
||||
while queue:
|
||||
# 按原始顺序处理同级别的因子
|
||||
queue.sort(key=lambda x: factor_names.index(x))
|
||||
name = queue.pop(0)
|
||||
result.append(name)
|
||||
|
||||
for other in factor_names:
|
||||
if name in graph[other]:
|
||||
in_degree[other] -= 1
|
||||
if in_degree[other] == 0:
|
||||
queue.append(other)
|
||||
|
||||
if len(result) != len(factor_names):
|
||||
raise ValueError("检测到因子循环依赖")
|
||||
|
||||
return result
|
||||
170
src/factors/engine/planner.py
Normal file
170
src/factors/engine/planner.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""执行计划生成器。
|
||||
|
||||
整合编译器和翻译器,生成完整的执行计划。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from src.factors.dsl import (
|
||||
Node,
|
||||
Symbol,
|
||||
FunctionNode,
|
||||
BinaryOpNode,
|
||||
UnaryOpNode,
|
||||
Constant,
|
||||
)
|
||||
from src.factors.compiler import DependencyExtractor
|
||||
from src.factors.translator import PolarsTranslator
|
||||
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||
|
||||
|
||||
class ExecutionPlanner:
|
||||
"""执行计划生成器。
|
||||
|
||||
整合编译器和翻译器,生成完整的执行计划。
|
||||
|
||||
Attributes:
|
||||
compiler: 依赖提取器
|
||||
translator: Polars 翻译器
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化执行计划生成器。"""
|
||||
self.compiler = DependencyExtractor()
|
||||
self.translator = PolarsTranslator()
|
||||
|
||||
def create_plan(
|
||||
self,
|
||||
expression: Node,
|
||||
output_name: str = "factor",
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""从表达式创建执行计划。
|
||||
|
||||
Args:
|
||||
expression: DSL 表达式节点
|
||||
output_name: 输出因子名称
|
||||
data_specs: 预定义的数据规格,None 时自动推导
|
||||
|
||||
Returns:
|
||||
执行计划对象
|
||||
"""
|
||||
# 1. 提取依赖
|
||||
dependencies = self.compiler.extract_dependencies(expression)
|
||||
|
||||
# 2. 翻译为 Polars 表达式
|
||||
polars_expr = self.translator.translate(expression)
|
||||
|
||||
# 3. 推导或验证数据规格
|
||||
if data_specs is None:
|
||||
data_specs = self._infer_data_specs(dependencies, expression)
|
||||
|
||||
return ExecutionPlan(
|
||||
data_specs=data_specs,
|
||||
polars_expr=polars_expr,
|
||||
dependencies=dependencies,
|
||||
output_name=output_name,
|
||||
)
|
||||
|
||||
def _infer_data_specs(
|
||||
self,
|
||||
dependencies: Set[str],
|
||||
expression: Node,
|
||||
) -> List[DataSpec]:
|
||||
"""从依赖推导数据规格。
|
||||
|
||||
根据表达式中的函数类型推断回看天数需求。
|
||||
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
||||
默认从 pro_bar 表获取。
|
||||
|
||||
Args:
|
||||
dependencies: 依赖的字段集合
|
||||
expression: 表达式节点
|
||||
|
||||
Returns:
|
||||
数据规格列表
|
||||
"""
|
||||
# 计算最大回看窗口
|
||||
max_window = self._extract_max_window(expression)
|
||||
lookback_days = max(1, max_window)
|
||||
|
||||
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
||||
pro_bar_fields = {
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
"amount",
|
||||
"pre_close",
|
||||
"change",
|
||||
"pct_chg",
|
||||
"turnover_rate",
|
||||
"volume_ratio",
|
||||
}
|
||||
|
||||
# 将依赖分为 pro_bar 字段和其他字段
|
||||
pro_bar_deps = dependencies & pro_bar_fields
|
||||
other_deps = dependencies - pro_bar_fields
|
||||
|
||||
data_specs = []
|
||||
|
||||
# pro_bar 表的数据规格
|
||||
if pro_bar_deps:
|
||||
data_specs.append(
|
||||
DataSpec(
|
||||
table="pro_bar",
|
||||
columns=sorted(pro_bar_deps),
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
)
|
||||
|
||||
# 其他字段从 daily 表获取
|
||||
if other_deps:
|
||||
data_specs.append(
|
||||
DataSpec(
|
||||
table="daily",
|
||||
columns=sorted(other_deps),
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
)
|
||||
|
||||
return data_specs
|
||||
|
||||
def _extract_max_window(self, node: Node) -> int:
|
||||
"""从表达式中提取最大窗口大小。
|
||||
|
||||
Args:
|
||||
node: AST 节点
|
||||
|
||||
Returns:
|
||||
最大窗口大小,无时序函数返回 1
|
||||
"""
|
||||
if isinstance(node, FunctionNode):
|
||||
window = 1
|
||||
# 检查函数参数中的窗口大小
|
||||
for arg in node.args:
|
||||
if (
|
||||
isinstance(arg, Constant)
|
||||
and isinstance(arg.value, int)
|
||||
and arg.value > window
|
||||
):
|
||||
window = arg.value
|
||||
|
||||
# 递归检查子表达式
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node) and not isinstance(arg, Constant):
|
||||
window = max(window, self._extract_max_window(arg))
|
||||
|
||||
return window
|
||||
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
return max(
|
||||
self._extract_max_window(node.left),
|
||||
self._extract_max_window(node.right),
|
||||
)
|
||||
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
return self._extract_max_window(node.operand)
|
||||
|
||||
return 1
|
||||
144
src/factors/exceptions.py
Normal file
144
src/factors/exceptions.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""公式解析异常定义。
|
||||
|
||||
提供清晰的错误信息,帮助用户快速定位公式解析问题。
|
||||
"""
|
||||
|
||||
import difflib
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class FormulaParseError(Exception):
|
||||
"""公式解析错误基类。
|
||||
|
||||
Attributes:
|
||||
expr: 原始表达式字符串
|
||||
lineno: 错误所在行号(从1开始)
|
||||
col_offset: 错误所在列号(从0开始)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
expr: Optional[str] = None,
|
||||
lineno: Optional[int] = None,
|
||||
col_offset: Optional[int] = None,
|
||||
):
|
||||
self.expr = expr
|
||||
self.lineno = lineno
|
||||
self.col_offset = col_offset
|
||||
|
||||
# 构建详细错误信息
|
||||
full_message = self._format_message(message)
|
||||
super().__init__(full_message)
|
||||
|
||||
def _format_message(self, message: str) -> str:
|
||||
"""格式化错误信息,包含位置指示器。"""
|
||||
lines = [f"FormulaParseError: {message}"]
|
||||
|
||||
if self.expr:
|
||||
lines.append(f" 公式: {self.expr}")
|
||||
|
||||
# 添加错误位置指示器
|
||||
if self.col_offset is not None and self.lineno is not None:
|
||||
# 计算错误行在表达式中的起始位置
|
||||
expr_lines = self.expr.split("\n")
|
||||
if 1 <= self.lineno <= len(expr_lines):
|
||||
error_line = expr_lines[self.lineno - 1]
|
||||
lines.append(f" {error_line}")
|
||||
# 添加指向错误位置的箭头
|
||||
pointer = " " * (self.col_offset + 7) + "^--- 此处出错"
|
||||
lines.append(pointer)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class UnknownFunctionError(FormulaParseError):
|
||||
"""未知函数错误。
|
||||
|
||||
当表达式中使用了未注册的函数时抛出。
|
||||
|
||||
Attributes:
|
||||
func_name: 未知的函数名
|
||||
available: 可用函数列表
|
||||
suggestions: 模糊匹配建议列表
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func_name: str,
|
||||
available: List[str],
|
||||
expr: Optional[str] = None,
|
||||
lineno: Optional[int] = None,
|
||||
col_offset: Optional[int] = None,
|
||||
):
|
||||
self.func_name = func_name
|
||||
self.available = available
|
||||
|
||||
# 使用 difflib 获取模糊匹配建议
|
||||
self.suggestions = difflib.get_close_matches(
|
||||
func_name, available, n=3, cutoff=0.5
|
||||
)
|
||||
|
||||
# 构建错误信息
|
||||
if self.suggestions:
|
||||
suggestion_str = ", ".join(f"'{s}'" for s in self.suggestions)
|
||||
hint_msg = f"你是不是想找: {suggestion_str}?"
|
||||
else:
|
||||
# 只显示前10个可用函数
|
||||
available_preview = ", ".join(available[:10])
|
||||
if len(available) > 10:
|
||||
available_preview += f", ... 等共 {len(available)} 个函数"
|
||||
hint_msg = f"可用函数预览: {available_preview}"
|
||||
|
||||
msg = f"未知函数 '{func_name}'。{hint_msg}"
|
||||
|
||||
super().__init__(
|
||||
message=msg,
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
|
||||
class InvalidSyntaxError(FormulaParseError):
|
||||
"""语法错误。
|
||||
|
||||
当表达式语法不正确或不支持时抛出。
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedOperatorError(InvalidSyntaxError):
|
||||
"""不支持的运算符错误。
|
||||
|
||||
当使用了不支持的运算符时抛出(如位运算、矩阵运算等)。
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EmptyExpressionError(FormulaParseError):
|
||||
"""空表达式错误。"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("表达式不能为空或只包含空白字符")
|
||||
|
||||
|
||||
class RegistryError(Exception):
|
||||
"""注册表错误基类。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateFunctionError(RegistryError):
|
||||
"""函数重复注册错误。
|
||||
|
||||
当尝试注册已存在的函数且未设置 force=True 时抛出。
|
||||
"""
|
||||
|
||||
def __init__(self, func_name: str):
|
||||
self.func_name = func_name
|
||||
super().__init__(
|
||||
f"函数 '{func_name}' 已存在。使用 force=True 覆盖,或选择其他名称。"
|
||||
)
|
||||
411
src/factors/parser.py
Normal file
411
src/factors/parser.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""公式解析器 - 将字符串表达式转换为 DSL 节点树。
|
||||
|
||||
基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。
|
||||
|
||||
示例:
|
||||
>>> from src.factors.parser import FormulaParser
|
||||
>>> from src.factors.registry import FunctionRegistry
|
||||
>>> parser = FormulaParser(FunctionRegistry())
|
||||
>>> node = parser.parse("ts_mean(close, 20)")
|
||||
>>> print(node)
|
||||
ts_mean(close, 20)
|
||||
"""
|
||||
|
||||
import ast
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode
|
||||
from src.factors.exceptions import (
|
||||
FormulaParseError,
|
||||
UnknownFunctionError,
|
||||
InvalidSyntaxError,
|
||||
EmptyExpressionError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.factors.registry import FunctionRegistry
|
||||
|
||||
|
||||
# 运算符映射表
|
||||
BIN_OP_MAP: Dict[type, str] = {
|
||||
ast.Add: "+",
|
||||
ast.Sub: "-",
|
||||
ast.Mult: "*",
|
||||
ast.Div: "/",
|
||||
ast.Pow: "**",
|
||||
ast.FloorDiv: "//",
|
||||
ast.Mod: "%",
|
||||
}
|
||||
|
||||
UNARY_OP_MAP: Dict[type, str] = {
|
||||
ast.UAdd: "+",
|
||||
ast.USub: "-",
|
||||
ast.Invert: "~", # 不支持,应报错
|
||||
}
|
||||
|
||||
COMPARE_OP_MAP: Dict[type, str] = {
|
||||
ast.Eq: "==",
|
||||
ast.NotEq: "!=",
|
||||
ast.Lt: "<",
|
||||
ast.LtE: "<=",
|
||||
ast.Gt: ">",
|
||||
ast.GtE: ">=",
|
||||
}
|
||||
|
||||
|
||||
class FormulaParser:
|
||||
"""基于 AST 的公式解析器。
|
||||
|
||||
将字符串表达式解析为 DSL 节点树,支持:
|
||||
- 符号引用(如 close, open)
|
||||
- 数值常量(如 20, 3.14)
|
||||
- 二元运算(如 +, -, *, /)
|
||||
- 一元运算(如 -x)
|
||||
- 函数调用(如 ts_mean(close, 20))
|
||||
- 比较运算(如 close > open)
|
||||
|
||||
Attributes:
|
||||
registry: 函数注册表,用于解析函数调用
|
||||
"""
|
||||
|
||||
def __init__(self, registry: "FunctionRegistry") -> None:
|
||||
"""初始化解析器。
|
||||
|
||||
Args:
|
||||
registry: 函数注册表,提供函数名到可调用对象的映射
|
||||
"""
|
||||
self.registry = registry
|
||||
|
||||
def parse(self, expr: str) -> Node:
|
||||
"""解析字符串表达式为 Node 树。
|
||||
|
||||
Args:
|
||||
expr: 公式字符串,如 "ts_mean(close, 20)"
|
||||
|
||||
Returns:
|
||||
解析后的 Node 节点
|
||||
|
||||
Raises:
|
||||
EmptyExpressionError: 表达式为空时抛出
|
||||
SyntaxError: Python 语法错误时抛出
|
||||
FormulaParseError: 解析失败时抛出
|
||||
|
||||
Example:
|
||||
>>> parser.parse("close / open")
|
||||
BinaryOpNode("/", Symbol("close"), Symbol("open"))
|
||||
"""
|
||||
# 检查空表达式
|
||||
if not expr or not expr.strip():
|
||||
raise EmptyExpressionError()
|
||||
|
||||
# 解析为 Python AST
|
||||
try:
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
except SyntaxError as e:
|
||||
# 将 SyntaxError 包装为 InvalidSyntaxError,统一异常类型
|
||||
raise InvalidSyntaxError(
|
||||
message=f"表达式语法错误: {e.msg}",
|
||||
expr=expr,
|
||||
lineno=e.lineno,
|
||||
col_offset=e.offset,
|
||||
) from e
|
||||
|
||||
# 递归访问 AST 节点
|
||||
try:
|
||||
return self._visit(tree.body, expr)
|
||||
except FormulaParseError:
|
||||
# 重新抛出 FormulaParseError(保留已有的位置信息)
|
||||
raise
|
||||
except Exception as e:
|
||||
# 将其他异常包装为 FormulaParseError
|
||||
if not isinstance(e, FormulaParseError):
|
||||
raise FormulaParseError(
|
||||
message=f"解析失败: {str(e)}",
|
||||
expr=expr,
|
||||
) from e
|
||||
raise
|
||||
|
||||
def _visit(self, node: ast.AST, expr: str) -> Node:
|
||||
"""递归访问 AST 节点并转换为 DSL 节点。
|
||||
|
||||
Args:
|
||||
node: Python AST 节点
|
||||
expr: 原始表达式字符串(用于错误报告)
|
||||
|
||||
Returns:
|
||||
对应的 DSL 节点
|
||||
|
||||
Raises:
|
||||
InvalidSyntaxError: 遇到不支持的语法时抛出
|
||||
"""
|
||||
# 提取位置信息(如果节点有)
|
||||
lineno = getattr(node, "lineno", None)
|
||||
col_offset = getattr(node, "col_offset", None)
|
||||
|
||||
try:
|
||||
if isinstance(node, ast.Name):
|
||||
return self._visit_Name(node)
|
||||
elif isinstance(node, ast.Constant):
|
||||
return self._visit_Constant(node, expr)
|
||||
elif isinstance(node, ast.BinOp):
|
||||
return self._visit_BinOp(node, expr)
|
||||
elif isinstance(node, ast.UnaryOp):
|
||||
return self._visit_UnaryOp(node, expr)
|
||||
elif isinstance(node, ast.Call):
|
||||
return self._visit_Call(node, expr)
|
||||
elif isinstance(node, ast.Compare):
|
||||
return self._visit_Compare(node, expr)
|
||||
else:
|
||||
raise InvalidSyntaxError(
|
||||
message=f"不支持的语法: {type(node).__name__}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
except FormulaParseError:
|
||||
# 重新抛出(保留已有的位置信息)
|
||||
raise
|
||||
except Exception as e:
|
||||
# 包装为 FormulaParseError,添加位置信息
|
||||
raise FormulaParseError(
|
||||
message=f"解析节点失败: {str(e)}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
) from e
|
||||
|
||||
def _visit_Name(self, node: ast.Name) -> Symbol:
|
||||
"""访问名称节点 - 永远转为 Symbol。
|
||||
|
||||
注意:利用 AST 语法自然区分变量和函数调用:
|
||||
- log → Symbol("log")(数据列引用)
|
||||
- log(close) → 在 _visit_Call 中处理(函数调用)
|
||||
|
||||
Args:
|
||||
node: AST 名称节点
|
||||
|
||||
Returns:
|
||||
Symbol 节点
|
||||
"""
|
||||
return Symbol(node.id)
|
||||
|
||||
def _visit_Constant(self, node: ast.Constant, expr: str) -> Node:
|
||||
"""访问常量节点。
|
||||
|
||||
支持的类型:
|
||||
- int/float → Constant 节点
|
||||
- str → Symbol 节点(支持 ts_mean("close", 20) 语法)
|
||||
|
||||
Args:
|
||||
node: AST 常量节点
|
||||
expr: 原始表达式字符串
|
||||
|
||||
Returns:
|
||||
Constant 或 Symbol 节点
|
||||
|
||||
Raises:
|
||||
InvalidSyntaxError: 不支持的常量类型
|
||||
"""
|
||||
if isinstance(node.value, (int, float)):
|
||||
return Constant(node.value)
|
||||
elif isinstance(node.value, str):
|
||||
# 字符串常量转为 Symbol,支持 "close" 写法
|
||||
return Symbol(node.value)
|
||||
else:
|
||||
lineno = getattr(node, "lineno", None)
|
||||
col_offset = getattr(node, "col_offset", None)
|
||||
raise InvalidSyntaxError(
|
||||
message=f"不支持的常量类型: {type(node.value).__name__}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode:
|
||||
"""访问二元运算节点。
|
||||
|
||||
Args:
|
||||
node: AST 二元运算节点
|
||||
expr: 原始表达式字符串
|
||||
|
||||
Returns:
|
||||
BinaryOpNode 节点
|
||||
|
||||
Raises:
|
||||
InvalidSyntaxError: 不支持的运算符
|
||||
"""
|
||||
left = self._visit(node.left, expr)
|
||||
right = self._visit(node.right, expr)
|
||||
|
||||
op = BIN_OP_MAP.get(type(node.op))
|
||||
if op is None:
|
||||
lineno = getattr(node, "lineno", None)
|
||||
col_offset = getattr(node, "col_offset", None)
|
||||
raise InvalidSyntaxError(
|
||||
message=f"不支持的运算符: {type(node.op).__name__}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
return BinaryOpNode(op, left, right)
|
||||
|
||||
def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node:
|
||||
"""访问一元运算节点。
|
||||
|
||||
支持常量折叠优化:纯数值的一元运算直接计算结果。
|
||||
|
||||
Args:
|
||||
node: AST 一元运算节点
|
||||
expr: 原始表达式字符串
|
||||
|
||||
Returns:
|
||||
Constant(常量折叠)或 UnaryOpNode 节点
|
||||
|
||||
Raises:
|
||||
InvalidSyntaxError: 不支持的运算符
|
||||
"""
|
||||
operand = self._visit(node.operand, expr)
|
||||
op = UNARY_OP_MAP.get(type(node.op))
|
||||
|
||||
lineno = getattr(node, "lineno", None)
|
||||
col_offset = getattr(node, "col_offset", None)
|
||||
|
||||
if op is None:
|
||||
raise InvalidSyntaxError(
|
||||
message=f"不支持的一元运算符: {type(node.op).__name__}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
if op == "~":
|
||||
raise InvalidSyntaxError(
|
||||
message="位运算 '~' 不被支持",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
# 常量折叠优化:纯数值直接计算
|
||||
if isinstance(operand, Constant) and isinstance(operand.value, (int, float)):
|
||||
if op == "-":
|
||||
return Constant(-operand.value)
|
||||
elif op == "+":
|
||||
return operand # +5 就是 5
|
||||
|
||||
# 非常量使用运算符重载
|
||||
if op == "-":
|
||||
return -operand
|
||||
elif op == "+":
|
||||
return +operand
|
||||
|
||||
# 不应该到达这里
|
||||
raise InvalidSyntaxError(
|
||||
message=f"无法处理的一元运算符: {op}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
def _visit_Call(self, node: ast.Call, expr: str) -> Node:
|
||||
"""访问函数调用节点。
|
||||
|
||||
注意:只有在这里查注册表,处理函数调用。
|
||||
|
||||
Args:
|
||||
node: AST 函数调用节点
|
||||
expr: 原始表达式字符串
|
||||
|
||||
Returns:
|
||||
函数返回的 Node 节点
|
||||
|
||||
Raises:
|
||||
InvalidSyntaxError: 不支持的函数调用语法
|
||||
UnknownFunctionError: 函数未注册
|
||||
"""
|
||||
lineno = getattr(node, "lineno", None)
|
||||
col_offset = getattr(node, "col_offset", None)
|
||||
|
||||
# 只支持简单函数调用(如 func(a, b))
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise InvalidSyntaxError(
|
||||
message="只支持简单函数调用(如 func(a, b))",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
func_name = node.func.id
|
||||
func = self.registry.get(func_name)
|
||||
|
||||
if func is None:
|
||||
raise UnknownFunctionError(
|
||||
func_name=func_name,
|
||||
available=self.registry.available_functions(),
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
# 解析位置参数
|
||||
args = [self._visit(arg, expr) for arg in node.args]
|
||||
|
||||
# 解析关键字参数(如果有)
|
||||
kwargs = {}
|
||||
for keyword in node.keywords:
|
||||
kwargs[keyword.arg] = self._visit(keyword.value, expr)
|
||||
|
||||
# 应用函数
|
||||
try:
|
||||
if kwargs:
|
||||
return func(*args, **kwargs)
|
||||
return func(*args)
|
||||
except TypeError as e:
|
||||
raise InvalidSyntaxError(
|
||||
message=f"函数 '{func_name}' 调用失败: {e}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
) from e
|
||||
|
||||
def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode:
|
||||
"""访问比较运算节点。
|
||||
|
||||
注意:只支持简单二元比较,不支持链式比较(如 a < b < c)。
|
||||
|
||||
Args:
|
||||
node: AST 比较节点
|
||||
expr: 原始表达式字符串
|
||||
|
||||
Returns:
|
||||
BinaryOpNode 节点(使用比较运算符)
|
||||
|
||||
Raises:
|
||||
InvalidSyntaxError: 链式比较或不支持的运算符
|
||||
"""
|
||||
lineno = getattr(node, "lineno", None)
|
||||
col_offset = getattr(node, "col_offset", None)
|
||||
|
||||
# Python 支持链式比较 (a < b < c),这里简化为二元比较
|
||||
if len(node.ops) != 1 or len(node.comparators) != 1:
|
||||
raise InvalidSyntaxError(
|
||||
message="只支持简单二元比较(如 a > b),不支持链式比较",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
left = self._visit(node.left, expr)
|
||||
op = COMPARE_OP_MAP.get(type(node.ops[0]))
|
||||
|
||||
if op is None:
|
||||
raise InvalidSyntaxError(
|
||||
message=f"不支持的比较运算符: {type(node.ops[0]).__name__}",
|
||||
expr=expr,
|
||||
lineno=lineno,
|
||||
col_offset=col_offset,
|
||||
)
|
||||
|
||||
right = self._visit(node.comparators[0], expr)
|
||||
return BinaryOpNode(op, left, right)
|
||||
227
src/factors/registry.py
Normal file
227
src/factors/registry.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""函数注册表 - 管理字符串函数名到 Python 函数的映射。
|
||||
|
||||
支持自动发现和手动注册,与 FormulaParser 配合使用。
|
||||
|
||||
示例:
|
||||
>>> from src.factors.registry import FunctionRegistry
|
||||
>>> registry = FunctionRegistry(auto_scan=True) # 自动加载 api.py 函数
|
||||
>>> registry.available_functions()[:5]
|
||||
['abs', 'clip', 'cs_demean', 'cs_neutralize', 'cs_rank']
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from src.factors.dsl import Node, FunctionNode
|
||||
from src.factors.exceptions import DuplicateFunctionError
|
||||
|
||||
|
||||
class FunctionRegistry:
|
||||
"""函数注册表。
|
||||
|
||||
管理字符串函数名到可调用对象的映射。
|
||||
自动从 api.py 加载标准函数,支持用户自定义函数注册。
|
||||
|
||||
Attributes:
|
||||
_functions: 函数字典,name -> callable
|
||||
"""
|
||||
|
||||
def __init__(self, auto_scan: bool = True) -> None:
|
||||
"""初始化注册表。
|
||||
|
||||
Args:
|
||||
auto_scan: 是否自动扫描 api.py 模块,默认 True
|
||||
"""
|
||||
self._functions: Dict[str, Callable] = {}
|
||||
|
||||
if auto_scan:
|
||||
self._scan_api_module()
|
||||
|
||||
def register(
|
||||
self, name: str, func: Callable, force: bool = False
|
||||
) -> "FunctionRegistry":
|
||||
"""注册自定义函数。
|
||||
|
||||
Args:
|
||||
name: 函数名称(字符串形式)
|
||||
func: 可调用对象
|
||||
force: 是否强制覆盖已存在的函数,默认 False
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Raises:
|
||||
DuplicateFunctionError: 当函数名已存在且 force=False 时
|
||||
|
||||
Example:
|
||||
>>> registry = FunctionRegistry(auto_scan=False)
|
||||
>>> registry.register("my_func", lambda x: x * 2)
|
||||
>>> registry.get("my_func")(5)
|
||||
10
|
||||
"""
|
||||
if name in self._functions and not force:
|
||||
raise DuplicateFunctionError(name)
|
||||
|
||||
self._functions[name] = func
|
||||
return self
|
||||
|
||||
def unregister(self, name: str) -> "FunctionRegistry":
|
||||
"""注销函数。
|
||||
|
||||
Args:
|
||||
name: 要注销的函数名
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Raises:
|
||||
KeyError: 函数不存在时
|
||||
"""
|
||||
if name not in self._functions:
|
||||
raise KeyError(f"函数 '{name}' 不存在")
|
||||
del self._functions[name]
|
||||
return self
|
||||
|
||||
def get(self, name: str) -> Optional[Callable]:
|
||||
"""获取函数。
|
||||
|
||||
Args:
|
||||
name: 函数名称
|
||||
|
||||
Returns:
|
||||
函数对象,不存在返回 None
|
||||
"""
|
||||
return self._functions.get(name)
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""检查函数是否存在。
|
||||
|
||||
Args:
|
||||
name: 函数名称
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
return name in self._functions
|
||||
|
||||
def available_functions(self) -> List[str]:
|
||||
"""返回所有可用函数名列表(按字母序)。
|
||||
|
||||
Returns:
|
||||
排序后的函数名列表
|
||||
"""
|
||||
return sorted(self._functions.keys())
|
||||
|
||||
def clear(self) -> "FunctionRegistry":
|
||||
"""清空所有注册的函数。
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
"""
|
||||
self._functions.clear()
|
||||
return self
|
||||
|
||||
def scan_module(
|
||||
self, module: Any, prefix: str = "", force: bool = False
|
||||
) -> "FunctionRegistry":
|
||||
"""扫描指定模块,自动注册符合条件的函数。
|
||||
|
||||
扫描规则:
|
||||
1. 模块级别的函数(排除私有函数 _*)
|
||||
2. 返回类型注解为 Node 或 FunctionNode
|
||||
|
||||
Args:
|
||||
module: 要扫描的模块对象
|
||||
prefix: 函数名前缀,用于避免命名冲突
|
||||
force: 是否强制覆盖已存在的函数
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Example:
|
||||
>>> import my_custom_module
|
||||
>>> registry.scan_module(my_custom_module, prefix="custom_")
|
||||
"""
|
||||
for name, obj in inspect.getmembers(module):
|
||||
# 只处理非私有函数
|
||||
if not inspect.isfunction(obj) or name.startswith("_"):
|
||||
continue
|
||||
|
||||
# 检查是否应该注册
|
||||
if self._should_register(obj):
|
||||
full_name = prefix + name
|
||||
self.register(full_name, obj, force=force)
|
||||
|
||||
return self
|
||||
|
||||
def _scan_api_module(self) -> None:
|
||||
"""自动扫描 api.py 模块,注册所有符合条件的函数。
|
||||
|
||||
这是默认的自动扫描行为,在 __init__ 中调用。
|
||||
"""
|
||||
try:
|
||||
from src.factors import api
|
||||
|
||||
self.scan_module(api)
|
||||
except ImportError:
|
||||
# api 模块可能不存在,静默跳过
|
||||
pass
|
||||
|
||||
def _should_register(self, func: Callable) -> bool:
|
||||
"""检查函数是否应该被注册。
|
||||
|
||||
基于类型提示检查函数返回类型,只注册返回 Node 或 FunctionNode 的函数。
|
||||
|
||||
Args:
|
||||
func: 要检查的函数
|
||||
|
||||
Returns:
|
||||
是否应该注册该函数
|
||||
"""
|
||||
try:
|
||||
hints = typing.get_type_hints(func)
|
||||
return_type = hints.get("return")
|
||||
|
||||
if return_type is None:
|
||||
return False
|
||||
|
||||
# 处理 Union 类型(如 Union[Node, FunctionNode])
|
||||
origin = typing.get_origin(return_type)
|
||||
args = typing.get_args(return_type)
|
||||
|
||||
if origin is typing.Union:
|
||||
# Union 类型,检查任一参数
|
||||
return any(self._is_node_type(arg) for arg in args)
|
||||
else:
|
||||
# 单一类型
|
||||
return self._is_node_type(return_type)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_node_type(self, typ: Any) -> bool:
|
||||
"""检查类型是否是 Node 或 FunctionNode 的子类。
|
||||
|
||||
Args:
|
||||
typ: 要检查的类型
|
||||
|
||||
Returns:
|
||||
是否是 Node 相关类型
|
||||
"""
|
||||
if not isinstance(typ, type):
|
||||
return False
|
||||
|
||||
return issubclass(typ, (Node, FunctionNode))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""返回已注册函数数量。"""
|
||||
return len(self._functions)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
"""检查是否包含某个函数名。"""
|
||||
return name in self._functions
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回注册表字符串表示。"""
|
||||
return f"FunctionRegistry({len(self._functions)} functions: {self.available_functions()[:5]}...)"
|
||||
Reference in New Issue
Block a user