feat(factors): 新增公式解析基础组件
新增公式解析相关模块,支持将字符串表达式解析为 DSL 节点树: - exceptions.py: 定义公式解析异常体系 - FormulaParseError 基类,提供位置指示的错误信息 - UnknownFunctionError 支持模糊匹配建议 - InvalidSyntaxError、EmptyExpressionError 等具体异常 - parser.py: 基于 Python ast 的公式解析器 - 支持符号引用、数值常量、二元/一元运算 - 支持函数调用和比较运算 - 常量折叠优化 - registry.py: 函数注册表 - 支持动态注册和查询公式函数 - 提供可用函数列表和重复注册检查
This commit is contained in:
@@ -52,6 +52,19 @@ from src.factors.engine import (
|
|||||||
ComputeEngine,
|
ComputeEngine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from src.factors.parser import FormulaParser
|
||||||
|
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
|
||||||
|
from src.factors.exceptions import (
|
||||||
|
FormulaParseError,
|
||||||
|
UnknownFunctionError,
|
||||||
|
InvalidSyntaxError,
|
||||||
|
EmptyExpressionError,
|
||||||
|
RegistryError,
|
||||||
|
DuplicateFunctionError,
|
||||||
|
)
|
||||||
|
|
||||||
# 保持向后兼容:factor_engine.py 中的类也可以通过 src.factors.engine 访问
|
# 保持向后兼容:factor_engine.py 中的类也可以通过 src.factors.engine 访问
|
||||||
# 例如:from src.factors.engine import FactorEngine
|
# 例如:from src.factors.engine import FactorEngine
|
||||||
|
|
||||||
@@ -76,4 +89,15 @@ __all__ = [
|
|||||||
"DataRouter",
|
"DataRouter",
|
||||||
"ExecutionPlanner",
|
"ExecutionPlanner",
|
||||||
"ComputeEngine",
|
"ComputeEngine",
|
||||||
|
# 解析器 (Phase 1 新增)
|
||||||
|
"FormulaParser",
|
||||||
|
# 注册表 (Phase 1 新增)
|
||||||
|
"FunctionRegistry",
|
||||||
|
# 异常类 (Phase 1 新增)
|
||||||
|
"FormulaParseError",
|
||||||
|
"UnknownFunctionError",
|
||||||
|
"InvalidSyntaxError",
|
||||||
|
"EmptyExpressionError",
|
||||||
|
"RegistryError",
|
||||||
|
"DuplicateFunctionError",
|
||||||
]
|
]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -23,3 +23,6 @@ __all__ = [
|
|||||||
"ComputeEngine",
|
"ComputeEngine",
|
||||||
"FactorEngine",
|
"FactorEngine",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 类型导出(用于类型注解)
|
||||||
|
# FunctionRegistry 从 src.factors.registry 导入
|
||||||
|
|||||||
@@ -10,10 +10,13 @@
|
|||||||
5. 返回包含因子结果的数据表
|
5. 返回包含因子结果的数据表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
|
||||||
from src.factors.dsl import (
|
from src.factors.dsl import (
|
||||||
Node,
|
Node,
|
||||||
Symbol,
|
Symbol,
|
||||||
@@ -45,25 +48,36 @@ class FactorEngine:
|
|||||||
planner: 执行计划生成器
|
planner: 执行计划生成器
|
||||||
compute_engine: 计算引擎
|
compute_engine: 计算引擎
|
||||||
registered_expressions: 注册的表达式字典
|
registered_expressions: 注册的表达式字典
|
||||||
|
_registry: 函数注册表
|
||||||
|
_parser: 公式解析器
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
||||||
max_workers: int = 4,
|
max_workers: int = 4,
|
||||||
|
registry: Optional["FunctionRegistry"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化因子引擎。
|
"""初始化因子引擎。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_source: 内存数据源,为 None 时使用数据库连接
|
data_source: 内存数据源,为 None 时使用数据库连接
|
||||||
max_workers: 并行计算的最大工作线程数
|
max_workers: 并行计算的最大工作线程数
|
||||||
|
registry: 函数注册表,None 时创建独立实例
|
||||||
"""
|
"""
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
from src.factors.parser import FormulaParser
|
||||||
|
|
||||||
self.router = DataRouter(data_source)
|
self.router = DataRouter(data_source)
|
||||||
self.planner = ExecutionPlanner()
|
self.planner = ExecutionPlanner()
|
||||||
self.compute_engine = ComputeEngine(max_workers=max_workers)
|
self.compute_engine = ComputeEngine(max_workers=max_workers)
|
||||||
self.registered_expressions: Dict[str, Node] = {}
|
self.registered_expressions: Dict[str, Node] = {}
|
||||||
self._plans: Dict[str, ExecutionPlan] = {}
|
self._plans: Dict[str, ExecutionPlan] = {}
|
||||||
|
|
||||||
|
# 初始化注册表和解析器(支持注入外部注册表实现共享)
|
||||||
|
self._registry = registry if registry is not None else FunctionRegistry()
|
||||||
|
self._parser = FormulaParser(self._registry)
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@@ -104,6 +118,63 @@ class FactorEngine:
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def add_factor(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
expression: Union[str, Node],
|
||||||
|
data_specs: Optional[List[DataSpec]] = None,
|
||||||
|
) -> "FactorEngine":
|
||||||
|
"""注册因子(支持字符串或 Node 表达式)。
|
||||||
|
|
||||||
|
这是 register 方法的增强版,支持字符串表达式解析。
|
||||||
|
向后兼容:register 方法保持不变,继续只接受 Node 类型。
|
||||||
|
|
||||||
|
遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
expression: 字符串表达式或 Node 对象
|
||||||
|
data_specs: 可选的数据规格
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 当 expression 类型不支持时
|
||||||
|
FormulaParseError: 当字符串解析失败时(立即报错)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> engine = FactorEngine()
|
||||||
|
>>>
|
||||||
|
>>> # 字符串方式(新功能)
|
||||||
|
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||||
|
>>>
|
||||||
|
>>> # Node 方式(与 register 相同)
|
||||||
|
>>> from src.factors.api import close, ts_mean
|
||||||
|
>>> engine.add_factor("ma20", ts_mean(close, 20))
|
||||||
|
>>>
|
||||||
|
>>> # 复杂表达式
|
||||||
|
>>> engine.add_factor("alpha1", "cs_rank(close / open)")
|
||||||
|
>>>
|
||||||
|
>>> # 链式调用
|
||||||
|
>>> (engine
|
||||||
|
... .add_factor("ma5", "ts_mean(close, 5)")
|
||||||
|
... .add_factor("ma10", "ts_mean(close, 10)")
|
||||||
|
... .add_factor("golden_cross", "ma5 > ma10"))
|
||||||
|
"""
|
||||||
|
if isinstance(expression, str):
|
||||||
|
# Fail-Fast:立即解析,失败立即报错
|
||||||
|
node = self._parser.parse(expression)
|
||||||
|
elif isinstance(expression, Node):
|
||||||
|
node = expression
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 委托给现有的 register 方法
|
||||||
|
return self.register(name, node, data_specs)
|
||||||
|
|
||||||
def compute(
|
def compute(
|
||||||
self,
|
self,
|
||||||
factor_names: Union[str, List[str]],
|
factor_names: Union[str, List[str]],
|
||||||
|
|||||||
144
src/factors/exceptions.py
Normal file
144
src/factors/exceptions.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""公式解析异常定义。
|
||||||
|
|
||||||
|
提供清晰的错误信息,帮助用户快速定位公式解析问题。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class FormulaParseError(Exception):
|
||||||
|
"""公式解析错误基类。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
lineno: 错误所在行号(从1开始)
|
||||||
|
col_offset: 错误所在列号(从0开始)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
lineno: Optional[int] = None,
|
||||||
|
col_offset: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.expr = expr
|
||||||
|
self.lineno = lineno
|
||||||
|
self.col_offset = col_offset
|
||||||
|
|
||||||
|
# 构建详细错误信息
|
||||||
|
full_message = self._format_message(message)
|
||||||
|
super().__init__(full_message)
|
||||||
|
|
||||||
|
def _format_message(self, message: str) -> str:
|
||||||
|
"""格式化错误信息,包含位置指示器。"""
|
||||||
|
lines = [f"FormulaParseError: {message}"]
|
||||||
|
|
||||||
|
if self.expr:
|
||||||
|
lines.append(f" 公式: {self.expr}")
|
||||||
|
|
||||||
|
# 添加错误位置指示器
|
||||||
|
if self.col_offset is not None and self.lineno is not None:
|
||||||
|
# 计算错误行在表达式中的起始位置
|
||||||
|
expr_lines = self.expr.split("\n")
|
||||||
|
if 1 <= self.lineno <= len(expr_lines):
|
||||||
|
error_line = expr_lines[self.lineno - 1]
|
||||||
|
lines.append(f" {error_line}")
|
||||||
|
# 添加指向错误位置的箭头
|
||||||
|
pointer = " " * (self.col_offset + 7) + "^--- 此处出错"
|
||||||
|
lines.append(pointer)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownFunctionError(FormulaParseError):
|
||||||
|
"""未知函数错误。
|
||||||
|
|
||||||
|
当表达式中使用了未注册的函数时抛出。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
func_name: 未知的函数名
|
||||||
|
available: 可用函数列表
|
||||||
|
suggestions: 模糊匹配建议列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func_name: str,
|
||||||
|
available: List[str],
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
lineno: Optional[int] = None,
|
||||||
|
col_offset: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.func_name = func_name
|
||||||
|
self.available = available
|
||||||
|
|
||||||
|
# 使用 difflib 获取模糊匹配建议
|
||||||
|
self.suggestions = difflib.get_close_matches(
|
||||||
|
func_name, available, n=3, cutoff=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建错误信息
|
||||||
|
if self.suggestions:
|
||||||
|
suggestion_str = ", ".join(f"'{s}'" for s in self.suggestions)
|
||||||
|
hint_msg = f"你是不是想找: {suggestion_str}?"
|
||||||
|
else:
|
||||||
|
# 只显示前10个可用函数
|
||||||
|
available_preview = ", ".join(available[:10])
|
||||||
|
if len(available) > 10:
|
||||||
|
available_preview += f", ... 等共 {len(available)} 个函数"
|
||||||
|
hint_msg = f"可用函数预览: {available_preview}"
|
||||||
|
|
||||||
|
msg = f"未知函数 '{func_name}'。{hint_msg}"
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
message=msg,
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSyntaxError(FormulaParseError):
|
||||||
|
"""语法错误。
|
||||||
|
|
||||||
|
当表达式语法不正确或不支持时抛出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedOperatorError(InvalidSyntaxError):
|
||||||
|
"""不支持的运算符错误。
|
||||||
|
|
||||||
|
当使用了不支持的运算符时抛出(如位运算、矩阵运算等)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyExpressionError(FormulaParseError):
|
||||||
|
"""空表达式错误。"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("表达式不能为空或只包含空白字符")
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryError(Exception):
|
||||||
|
"""注册表错误基类。"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateFunctionError(RegistryError):
|
||||||
|
"""函数重复注册错误。
|
||||||
|
|
||||||
|
当尝试注册已存在的函数且未设置 force=True 时抛出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func_name: str):
|
||||||
|
self.func_name = func_name
|
||||||
|
super().__init__(
|
||||||
|
f"函数 '{func_name}' 已存在。使用 force=True 覆盖,或选择其他名称。"
|
||||||
|
)
|
||||||
411
src/factors/parser.py
Normal file
411
src/factors/parser.py
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
"""公式解析器 - 将字符串表达式转换为 DSL 节点树。
|
||||||
|
|
||||||
|
基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> from src.factors.parser import FormulaParser
|
||||||
|
>>> from src.factors.registry import FunctionRegistry
|
||||||
|
>>> parser = FormulaParser(FunctionRegistry())
|
||||||
|
>>> node = parser.parse("ts_mean(close, 20)")
|
||||||
|
>>> print(node)
|
||||||
|
ts_mean(close, 20)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode
|
||||||
|
from src.factors.exceptions import (
|
||||||
|
FormulaParseError,
|
||||||
|
UnknownFunctionError,
|
||||||
|
InvalidSyntaxError,
|
||||||
|
EmptyExpressionError,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factors.registry import FunctionRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# 运算符映射表
|
||||||
|
BIN_OP_MAP: Dict[type, str] = {
|
||||||
|
ast.Add: "+",
|
||||||
|
ast.Sub: "-",
|
||||||
|
ast.Mult: "*",
|
||||||
|
ast.Div: "/",
|
||||||
|
ast.Pow: "**",
|
||||||
|
ast.FloorDiv: "//",
|
||||||
|
ast.Mod: "%",
|
||||||
|
}
|
||||||
|
|
||||||
|
UNARY_OP_MAP: Dict[type, str] = {
|
||||||
|
ast.UAdd: "+",
|
||||||
|
ast.USub: "-",
|
||||||
|
ast.Invert: "~", # 不支持,应报错
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPARE_OP_MAP: Dict[type, str] = {
|
||||||
|
ast.Eq: "==",
|
||||||
|
ast.NotEq: "!=",
|
||||||
|
ast.Lt: "<",
|
||||||
|
ast.LtE: "<=",
|
||||||
|
ast.Gt: ">",
|
||||||
|
ast.GtE: ">=",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FormulaParser:
|
||||||
|
"""基于 AST 的公式解析器。
|
||||||
|
|
||||||
|
将字符串表达式解析为 DSL 节点树,支持:
|
||||||
|
- 符号引用(如 close, open)
|
||||||
|
- 数值常量(如 20, 3.14)
|
||||||
|
- 二元运算(如 +, -, *, /)
|
||||||
|
- 一元运算(如 -x)
|
||||||
|
- 函数调用(如 ts_mean(close, 20))
|
||||||
|
- 比较运算(如 close > open)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
registry: 函数注册表,用于解析函数调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, registry: "FunctionRegistry") -> None:
|
||||||
|
"""初始化解析器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: 函数注册表,提供函数名到可调用对象的映射
|
||||||
|
"""
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
def parse(self, expr: str) -> Node:
|
||||||
|
"""解析字符串表达式为 Node 树。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expr: 公式字符串,如 "ts_mean(close, 20)"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的 Node 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmptyExpressionError: 表达式为空时抛出
|
||||||
|
SyntaxError: Python 语法错误时抛出
|
||||||
|
FormulaParseError: 解析失败时抛出
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> parser.parse("close / open")
|
||||||
|
BinaryOpNode("/", Symbol("close"), Symbol("open"))
|
||||||
|
"""
|
||||||
|
# 检查空表达式
|
||||||
|
if not expr or not expr.strip():
|
||||||
|
raise EmptyExpressionError()
|
||||||
|
|
||||||
|
# 解析为 Python AST
|
||||||
|
try:
|
||||||
|
tree = ast.parse(expr, mode="eval")
|
||||||
|
except SyntaxError as e:
|
||||||
|
# 将 SyntaxError 包装为 InvalidSyntaxError,统一异常类型
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"表达式语法错误: {e.msg}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=e.lineno,
|
||||||
|
col_offset=e.offset,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# 递归访问 AST 节点
|
||||||
|
try:
|
||||||
|
return self._visit(tree.body, expr)
|
||||||
|
except FormulaParseError:
|
||||||
|
# 重新抛出 FormulaParseError(保留已有的位置信息)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 将其他异常包装为 FormulaParseError
|
||||||
|
if not isinstance(e, FormulaParseError):
|
||||||
|
raise FormulaParseError(
|
||||||
|
message=f"解析失败: {str(e)}",
|
||||||
|
expr=expr,
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _visit(self, node: ast.AST, expr: str) -> Node:
|
||||||
|
"""递归访问 AST 节点并转换为 DSL 节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Python AST 节点
|
||||||
|
expr: 原始表达式字符串(用于错误报告)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
对应的 DSL 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 遇到不支持的语法时抛出
|
||||||
|
"""
|
||||||
|
# 提取位置信息(如果节点有)
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
return self._visit_Name(node)
|
||||||
|
elif isinstance(node, ast.Constant):
|
||||||
|
return self._visit_Constant(node, expr)
|
||||||
|
elif isinstance(node, ast.BinOp):
|
||||||
|
return self._visit_BinOp(node, expr)
|
||||||
|
elif isinstance(node, ast.UnaryOp):
|
||||||
|
return self._visit_UnaryOp(node, expr)
|
||||||
|
elif isinstance(node, ast.Call):
|
||||||
|
return self._visit_Call(node, expr)
|
||||||
|
elif isinstance(node, ast.Compare):
|
||||||
|
return self._visit_Compare(node, expr)
|
||||||
|
else:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的语法: {type(node).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
except FormulaParseError:
|
||||||
|
# 重新抛出(保留已有的位置信息)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 包装为 FormulaParseError,添加位置信息
|
||||||
|
raise FormulaParseError(
|
||||||
|
message=f"解析节点失败: {str(e)}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _visit_Name(self, node: ast.Name) -> Symbol:
|
||||||
|
"""访问名称节点 - 永远转为 Symbol。
|
||||||
|
|
||||||
|
注意:利用 AST 语法自然区分变量和函数调用:
|
||||||
|
- log → Symbol("log")(数据列引用)
|
||||||
|
- log(close) → 在 _visit_Call 中处理(函数调用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 名称节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symbol 节点
|
||||||
|
"""
|
||||||
|
return Symbol(node.id)
|
||||||
|
|
||||||
|
def _visit_Constant(self, node: ast.Constant, expr: str) -> Node:
|
||||||
|
"""访问常量节点。
|
||||||
|
|
||||||
|
支持的类型:
|
||||||
|
- int/float → Constant 节点
|
||||||
|
- str → Symbol 节点(支持 ts_mean("close", 20) 语法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 常量节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Constant 或 Symbol 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的常量类型
|
||||||
|
"""
|
||||||
|
if isinstance(node.value, (int, float)):
|
||||||
|
return Constant(node.value)
|
||||||
|
elif isinstance(node.value, str):
|
||||||
|
# 字符串常量转为 Symbol,支持 "close" 写法
|
||||||
|
return Symbol(node.value)
|
||||||
|
else:
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的常量类型: {type(node.value).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode:
|
||||||
|
"""访问二元运算节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 二元运算节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BinaryOpNode 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的运算符
|
||||||
|
"""
|
||||||
|
left = self._visit(node.left, expr)
|
||||||
|
right = self._visit(node.right, expr)
|
||||||
|
|
||||||
|
op = BIN_OP_MAP.get(type(node.op))
|
||||||
|
if op is None:
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的运算符: {type(node.op).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BinaryOpNode(op, left, right)
|
||||||
|
|
||||||
|
def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node:
|
||||||
|
"""访问一元运算节点。
|
||||||
|
|
||||||
|
支持常量折叠优化:纯数值的一元运算直接计算结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 一元运算节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Constant(常量折叠)或 UnaryOpNode 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的运算符
|
||||||
|
"""
|
||||||
|
operand = self._visit(node.operand, expr)
|
||||||
|
op = UNARY_OP_MAP.get(type(node.op))
|
||||||
|
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
if op is None:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的一元运算符: {type(node.op).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
if op == "~":
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message="位运算 '~' 不被支持",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 常量折叠优化:纯数值直接计算
|
||||||
|
if isinstance(operand, Constant) and isinstance(operand.value, (int, float)):
|
||||||
|
if op == "-":
|
||||||
|
return Constant(-operand.value)
|
||||||
|
elif op == "+":
|
||||||
|
return operand # +5 就是 5
|
||||||
|
|
||||||
|
# 非常量使用运算符重载
|
||||||
|
if op == "-":
|
||||||
|
return -operand
|
||||||
|
elif op == "+":
|
||||||
|
return +operand
|
||||||
|
|
||||||
|
# 不应该到达这里
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"无法处理的一元运算符: {op}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _visit_Call(self, node: ast.Call, expr: str) -> Node:
|
||||||
|
"""访问函数调用节点。
|
||||||
|
|
||||||
|
注意:只有在这里查注册表,处理函数调用。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 函数调用节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
函数返回的 Node 节点
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 不支持的函数调用语法
|
||||||
|
UnknownFunctionError: 函数未注册
|
||||||
|
"""
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
# 只支持简单函数调用(如 func(a, b))
|
||||||
|
if not isinstance(node.func, ast.Name):
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message="只支持简单函数调用(如 func(a, b))",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
func_name = node.func.id
|
||||||
|
func = self.registry.get(func_name)
|
||||||
|
|
||||||
|
if func is None:
|
||||||
|
raise UnknownFunctionError(
|
||||||
|
func_name=func_name,
|
||||||
|
available=self.registry.available_functions(),
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析位置参数
|
||||||
|
args = [self._visit(arg, expr) for arg in node.args]
|
||||||
|
|
||||||
|
# 解析关键字参数(如果有)
|
||||||
|
kwargs = {}
|
||||||
|
for keyword in node.keywords:
|
||||||
|
kwargs[keyword.arg] = self._visit(keyword.value, expr)
|
||||||
|
|
||||||
|
# 应用函数
|
||||||
|
try:
|
||||||
|
if kwargs:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return func(*args)
|
||||||
|
except TypeError as e:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"函数 '{func_name}' 调用失败: {e}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode:
|
||||||
|
"""访问比较运算节点。
|
||||||
|
|
||||||
|
注意:只支持简单二元比较,不支持链式比较(如 a < b < c)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: AST 比较节点
|
||||||
|
expr: 原始表达式字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BinaryOpNode 节点(使用比较运算符)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidSyntaxError: 链式比较或不支持的运算符
|
||||||
|
"""
|
||||||
|
lineno = getattr(node, "lineno", None)
|
||||||
|
col_offset = getattr(node, "col_offset", None)
|
||||||
|
|
||||||
|
# Python 支持链式比较 (a < b < c),这里简化为二元比较
|
||||||
|
if len(node.ops) != 1 or len(node.comparators) != 1:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message="只支持简单二元比较(如 a > b),不支持链式比较",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
left = self._visit(node.left, expr)
|
||||||
|
op = COMPARE_OP_MAP.get(type(node.ops[0]))
|
||||||
|
|
||||||
|
if op is None:
|
||||||
|
raise InvalidSyntaxError(
|
||||||
|
message=f"不支持的比较运算符: {type(node.ops[0]).__name__}",
|
||||||
|
expr=expr,
|
||||||
|
lineno=lineno,
|
||||||
|
col_offset=col_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
right = self._visit(node.comparators[0], expr)
|
||||||
|
return BinaryOpNode(op, left, right)
|
||||||
227
src/factors/registry.py
Normal file
227
src/factors/registry.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""函数注册表 - 管理字符串函数名到 Python 函数的映射。
|
||||||
|
|
||||||
|
支持自动发现和手动注册,与 FormulaParser 配合使用。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
>>> from src.factors.registry import FunctionRegistry
|
||||||
|
>>> registry = FunctionRegistry(auto_scan=True) # 自动加载 api.py 函数
|
||||||
|
>>> registry.available_functions()[:5]
|
||||||
|
['abs', 'clip', 'cs_demean', 'cs_neutralize', 'cs_rank']
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from src.factors.dsl import Node, FunctionNode
|
||||||
|
from src.factors.exceptions import DuplicateFunctionError
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionRegistry:
|
||||||
|
"""函数注册表。
|
||||||
|
|
||||||
|
管理字符串函数名到可调用对象的映射。
|
||||||
|
自动从 api.py 加载标准函数,支持用户自定义函数注册。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_functions: 函数字典,name -> callable
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auto_scan: bool = True) -> None:
|
||||||
|
"""初始化注册表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_scan: 是否自动扫描 api.py 模块,默认 True
|
||||||
|
"""
|
||||||
|
self._functions: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
if auto_scan:
|
||||||
|
self._scan_api_module()
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self, name: str, func: Callable, force: bool = False
|
||||||
|
) -> "FunctionRegistry":
|
||||||
|
"""注册自定义函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 函数名称(字符串形式)
|
||||||
|
func: 可调用对象
|
||||||
|
force: 是否强制覆盖已存在的函数,默认 False
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DuplicateFunctionError: 当函数名已存在且 force=False 时
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> registry = FunctionRegistry(auto_scan=False)
|
||||||
|
>>> registry.register("my_func", lambda x: x * 2)
|
||||||
|
>>> registry.get("my_func")(5)
|
||||||
|
10
|
||||||
|
"""
|
||||||
|
if name in self._functions and not force:
|
||||||
|
raise DuplicateFunctionError(name)
|
||||||
|
|
||||||
|
self._functions[name] = func
|
||||||
|
return self
|
||||||
|
|
||||||
|
def unregister(self, name: str) -> "FunctionRegistry":
|
||||||
|
"""注销函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 要注销的函数名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 函数不存在时
|
||||||
|
"""
|
||||||
|
if name not in self._functions:
|
||||||
|
raise KeyError(f"函数 '{name}' 不存在")
|
||||||
|
del self._functions[name]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Callable]:
|
||||||
|
"""获取函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 函数名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
函数对象,不存在返回 None
|
||||||
|
"""
|
||||||
|
return self._functions.get(name)
|
||||||
|
|
||||||
|
def has(self, name: str) -> bool:
|
||||||
|
"""检查函数是否存在。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 函数名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否存在
|
||||||
|
"""
|
||||||
|
return name in self._functions
|
||||||
|
|
||||||
|
def available_functions(self) -> List[str]:
|
||||||
|
"""返回所有可用函数名列表(按字母序)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
排序后的函数名列表
|
||||||
|
"""
|
||||||
|
return sorted(self._functions.keys())
|
||||||
|
|
||||||
|
def clear(self) -> "FunctionRegistry":
|
||||||
|
"""清空所有注册的函数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
"""
|
||||||
|
self._functions.clear()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def scan_module(
|
||||||
|
self, module: Any, prefix: str = "", force: bool = False
|
||||||
|
) -> "FunctionRegistry":
|
||||||
|
"""扫描指定模块,自动注册符合条件的函数。
|
||||||
|
|
||||||
|
扫描规则:
|
||||||
|
1. 模块级别的函数(排除私有函数 _*)
|
||||||
|
2. 返回类型注解为 Node 或 FunctionNode
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: 要扫描的模块对象
|
||||||
|
prefix: 函数名前缀,用于避免命名冲突
|
||||||
|
force: 是否强制覆盖已存在的函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self(支持链式调用)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import my_custom_module
|
||||||
|
>>> registry.scan_module(my_custom_module, prefix="custom_")
|
||||||
|
"""
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
# 只处理非私有函数
|
||||||
|
if not inspect.isfunction(obj) or name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否应该注册
|
||||||
|
if self._should_register(obj):
|
||||||
|
full_name = prefix + name
|
||||||
|
self.register(full_name, obj, force=force)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _scan_api_module(self) -> None:
|
||||||
|
"""自动扫描 api.py 模块,注册所有符合条件的函数。
|
||||||
|
|
||||||
|
这是默认的自动扫描行为,在 __init__ 中调用。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.factors import api
|
||||||
|
|
||||||
|
self.scan_module(api)
|
||||||
|
except ImportError:
|
||||||
|
# api 模块可能不存在,静默跳过
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _should_register(self, func: Callable) -> bool:
|
||||||
|
"""检查函数是否应该被注册。
|
||||||
|
|
||||||
|
基于类型提示检查函数返回类型,只注册返回 Node 或 FunctionNode 的函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: 要检查的函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否应该注册该函数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
hints = typing.get_type_hints(func)
|
||||||
|
return_type = hints.get("return")
|
||||||
|
|
||||||
|
if return_type is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 处理 Union 类型(如 Union[Node, FunctionNode])
|
||||||
|
origin = typing.get_origin(return_type)
|
||||||
|
args = typing.get_args(return_type)
|
||||||
|
|
||||||
|
if origin is typing.Union:
|
||||||
|
# Union 类型,检查任一参数
|
||||||
|
return any(self._is_node_type(arg) for arg in args)
|
||||||
|
else:
|
||||||
|
# 单一类型
|
||||||
|
return self._is_node_type(return_type)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_node_type(self, typ: Any) -> bool:
|
||||||
|
"""检查类型是否是 Node 或 FunctionNode 的子类。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
typ: 要检查的类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否是 Node 相关类型
|
||||||
|
"""
|
||||||
|
if not isinstance(typ, type):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return issubclass(typ, (Node, FunctionNode))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""返回已注册函数数量。"""
|
||||||
|
return len(self._functions)
|
||||||
|
|
||||||
|
def __contains__(self, name: str) -> bool:
|
||||||
|
"""检查是否包含某个函数名。"""
|
||||||
|
return name in self._functions
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""返回注册表字符串表示。"""
|
||||||
|
return f"FunctionRegistry({len(self._functions)} functions: {self.available_functions()[:5]}...)"
|
||||||
Reference in New Issue
Block a user