2026-03-03 00:04:48 +08:00
|
|
|
|
"""公式解析器 - 将字符串表达式转换为 DSL 节点树。
|
|
|
|
|
|
|
|
|
|
|
|
基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。
|
|
|
|
|
|
|
|
|
|
|
|
示例:
|
|
|
|
|
|
>>> from src.factors.parser import FormulaParser
|
|
|
|
|
|
>>> from src.factors.registry import FunctionRegistry
|
|
|
|
|
|
>>> parser = FormulaParser(FunctionRegistry())
|
|
|
|
|
|
>>> node = parser.parse("ts_mean(close, 20)")
|
|
|
|
|
|
>>> print(node)
|
|
|
|
|
|
ts_mean(close, 20)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import ast
|
|
|
|
|
|
from typing import Any, Dict, Optional, TYPE_CHECKING
|
|
|
|
|
|
|
|
|
|
|
|
from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode
|
|
|
|
|
|
from src.factors.exceptions import (
|
|
|
|
|
|
FormulaParseError,
|
|
|
|
|
|
UnknownFunctionError,
|
|
|
|
|
|
InvalidSyntaxError,
|
|
|
|
|
|
EmptyExpressionError,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
|
from src.factors.registry import FunctionRegistry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 运算符映射表
|
|
|
|
|
|
BIN_OP_MAP: Dict[type, str] = {
|
|
|
|
|
|
ast.Add: "+",
|
|
|
|
|
|
ast.Sub: "-",
|
|
|
|
|
|
ast.Mult: "*",
|
|
|
|
|
|
ast.Div: "/",
|
|
|
|
|
|
ast.Pow: "**",
|
|
|
|
|
|
ast.FloorDiv: "//",
|
|
|
|
|
|
ast.Mod: "%",
|
2026-03-14 01:17:14 +08:00
|
|
|
|
ast.BitAnd: "&",
|
|
|
|
|
|
ast.BitOr: "|",
|
2026-03-03 00:04:48 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
UNARY_OP_MAP: Dict[type, str] = {
|
|
|
|
|
|
ast.UAdd: "+",
|
|
|
|
|
|
ast.USub: "-",
|
|
|
|
|
|
ast.Invert: "~", # 不支持,应报错
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
COMPARE_OP_MAP: Dict[type, str] = {
|
|
|
|
|
|
ast.Eq: "==",
|
|
|
|
|
|
ast.NotEq: "!=",
|
|
|
|
|
|
ast.Lt: "<",
|
|
|
|
|
|
ast.LtE: "<=",
|
|
|
|
|
|
ast.Gt: ">",
|
|
|
|
|
|
ast.GtE: ">=",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FormulaParser:
|
|
|
|
|
|
"""基于 AST 的公式解析器。
|
|
|
|
|
|
|
|
|
|
|
|
将字符串表达式解析为 DSL 节点树,支持:
|
|
|
|
|
|
- 符号引用(如 close, open)
|
|
|
|
|
|
- 数值常量(如 20, 3.14)
|
|
|
|
|
|
- 二元运算(如 +, -, *, /)
|
|
|
|
|
|
- 一元运算(如 -x)
|
|
|
|
|
|
- 函数调用(如 ts_mean(close, 20))
|
|
|
|
|
|
- 比较运算(如 close > open)
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
registry: 函数注册表,用于解析函数调用
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, registry: "FunctionRegistry") -> None:
|
|
|
|
|
|
"""初始化解析器。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
registry: 函数注册表,提供函数名到可调用对象的映射
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.registry = registry
|
|
|
|
|
|
|
|
|
|
|
|
def parse(self, expr: str) -> Node:
|
|
|
|
|
|
"""解析字符串表达式为 Node 树。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
expr: 公式字符串,如 "ts_mean(close, 20)"
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
解析后的 Node 节点
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
EmptyExpressionError: 表达式为空时抛出
|
|
|
|
|
|
SyntaxError: Python 语法错误时抛出
|
|
|
|
|
|
FormulaParseError: 解析失败时抛出
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
>>> parser.parse("close / open")
|
|
|
|
|
|
BinaryOpNode("/", Symbol("close"), Symbol("open"))
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 检查空表达式
|
|
|
|
|
|
if not expr or not expr.strip():
|
|
|
|
|
|
raise EmptyExpressionError()
|
|
|
|
|
|
|
|
|
|
|
|
# 解析为 Python AST
|
|
|
|
|
|
try:
|
|
|
|
|
|
tree = ast.parse(expr, mode="eval")
|
|
|
|
|
|
except SyntaxError as e:
|
|
|
|
|
|
# 将 SyntaxError 包装为 InvalidSyntaxError,统一异常类型
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"表达式语法错误: {e.msg}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=e.lineno,
|
|
|
|
|
|
col_offset=e.offset,
|
|
|
|
|
|
) from e
|
|
|
|
|
|
|
|
|
|
|
|
# 递归访问 AST 节点
|
|
|
|
|
|
try:
|
|
|
|
|
|
return self._visit(tree.body, expr)
|
|
|
|
|
|
except FormulaParseError:
|
|
|
|
|
|
# 重新抛出 FormulaParseError(保留已有的位置信息)
|
|
|
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 将其他异常包装为 FormulaParseError
|
|
|
|
|
|
if not isinstance(e, FormulaParseError):
|
|
|
|
|
|
raise FormulaParseError(
|
|
|
|
|
|
message=f"解析失败: {str(e)}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
) from e
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def _visit(self, node: ast.AST, expr: str) -> Node:
|
|
|
|
|
|
"""递归访问 AST 节点并转换为 DSL 节点。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node: Python AST 节点
|
|
|
|
|
|
expr: 原始表达式字符串(用于错误报告)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
对应的 DSL 节点
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
InvalidSyntaxError: 遇到不支持的语法时抛出
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 提取位置信息(如果节点有)
|
|
|
|
|
|
lineno = getattr(node, "lineno", None)
|
|
|
|
|
|
col_offset = getattr(node, "col_offset", None)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
if isinstance(node, ast.Name):
|
|
|
|
|
|
return self._visit_Name(node)
|
|
|
|
|
|
elif isinstance(node, ast.Constant):
|
|
|
|
|
|
return self._visit_Constant(node, expr)
|
|
|
|
|
|
elif isinstance(node, ast.BinOp):
|
|
|
|
|
|
return self._visit_BinOp(node, expr)
|
|
|
|
|
|
elif isinstance(node, ast.UnaryOp):
|
|
|
|
|
|
return self._visit_UnaryOp(node, expr)
|
|
|
|
|
|
elif isinstance(node, ast.Call):
|
|
|
|
|
|
return self._visit_Call(node, expr)
|
|
|
|
|
|
elif isinstance(node, ast.Compare):
|
|
|
|
|
|
return self._visit_Compare(node, expr)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"不支持的语法: {type(node).__name__}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
except FormulaParseError:
|
|
|
|
|
|
# 重新抛出(保留已有的位置信息)
|
|
|
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 包装为 FormulaParseError,添加位置信息
|
|
|
|
|
|
raise FormulaParseError(
|
|
|
|
|
|
message=f"解析节点失败: {str(e)}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
) from e
|
|
|
|
|
|
|
|
|
|
|
|
def _visit_Name(self, node: ast.Name) -> Symbol:
|
|
|
|
|
|
"""访问名称节点 - 永远转为 Symbol。
|
|
|
|
|
|
|
|
|
|
|
|
注意:利用 AST 语法自然区分变量和函数调用:
|
|
|
|
|
|
- log → Symbol("log")(数据列引用)
|
|
|
|
|
|
- log(close) → 在 _visit_Call 中处理(函数调用)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node: AST 名称节点
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Symbol 节点
|
|
|
|
|
|
"""
|
|
|
|
|
|
return Symbol(node.id)
|
|
|
|
|
|
|
|
|
|
|
|
def _visit_Constant(self, node: ast.Constant, expr: str) -> Node:
|
|
|
|
|
|
"""访问常量节点。
|
|
|
|
|
|
|
|
|
|
|
|
支持的类型:
|
|
|
|
|
|
- int/float → Constant 节点
|
|
|
|
|
|
- str → Symbol 节点(支持 ts_mean("close", 20) 语法)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node: AST 常量节点
|
|
|
|
|
|
expr: 原始表达式字符串
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Constant 或 Symbol 节点
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
InvalidSyntaxError: 不支持的常量类型
|
|
|
|
|
|
"""
|
|
|
|
|
|
if isinstance(node.value, (int, float)):
|
|
|
|
|
|
return Constant(node.value)
|
|
|
|
|
|
elif isinstance(node.value, str):
|
|
|
|
|
|
# 字符串常量转为 Symbol,支持 "close" 写法
|
|
|
|
|
|
return Symbol(node.value)
|
|
|
|
|
|
else:
|
|
|
|
|
|
lineno = getattr(node, "lineno", None)
|
|
|
|
|
|
col_offset = getattr(node, "col_offset", None)
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"不支持的常量类型: {type(node.value).__name__}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode:
|
|
|
|
|
|
"""访问二元运算节点。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node: AST 二元运算节点
|
|
|
|
|
|
expr: 原始表达式字符串
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
BinaryOpNode 节点
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
InvalidSyntaxError: 不支持的运算符
|
|
|
|
|
|
"""
|
|
|
|
|
|
left = self._visit(node.left, expr)
|
|
|
|
|
|
right = self._visit(node.right, expr)
|
|
|
|
|
|
|
|
|
|
|
|
op = BIN_OP_MAP.get(type(node.op))
|
|
|
|
|
|
if op is None:
|
|
|
|
|
|
lineno = getattr(node, "lineno", None)
|
|
|
|
|
|
col_offset = getattr(node, "col_offset", None)
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"不支持的运算符: {type(node.op).__name__}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return BinaryOpNode(op, left, right)
|
|
|
|
|
|
|
|
|
|
|
|
def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node:
|
|
|
|
|
|
"""访问一元运算节点。
|
|
|
|
|
|
|
|
|
|
|
|
支持常量折叠优化:纯数值的一元运算直接计算结果。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node: AST 一元运算节点
|
|
|
|
|
|
expr: 原始表达式字符串
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Constant(常量折叠)或 UnaryOpNode 节点
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
InvalidSyntaxError: 不支持的运算符
|
|
|
|
|
|
"""
|
|
|
|
|
|
operand = self._visit(node.operand, expr)
|
|
|
|
|
|
op = UNARY_OP_MAP.get(type(node.op))
|
|
|
|
|
|
|
|
|
|
|
|
lineno = getattr(node, "lineno", None)
|
|
|
|
|
|
col_offset = getattr(node, "col_offset", None)
|
|
|
|
|
|
|
|
|
|
|
|
if op is None:
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"不支持的一元运算符: {type(node.op).__name__}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if op == "~":
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message="位运算 '~' 不被支持",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 常量折叠优化:纯数值直接计算
|
|
|
|
|
|
if isinstance(operand, Constant) and isinstance(operand.value, (int, float)):
|
|
|
|
|
|
if op == "-":
|
|
|
|
|
|
return Constant(-operand.value)
|
|
|
|
|
|
elif op == "+":
|
|
|
|
|
|
return operand # +5 就是 5
|
|
|
|
|
|
|
|
|
|
|
|
# 非常量使用运算符重载
|
|
|
|
|
|
if op == "-":
|
|
|
|
|
|
return -operand
|
|
|
|
|
|
elif op == "+":
|
|
|
|
|
|
return +operand
|
|
|
|
|
|
|
|
|
|
|
|
# 不应该到达这里
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"无法处理的一元运算符: {op}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _visit_Call(self, node: ast.Call, expr: str) -> Node:
|
|
|
|
|
|
"""访问函数调用节点。
|
|
|
|
|
|
|
|
|
|
|
|
注意:只有在这里查注册表,处理函数调用。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node: AST 函数调用节点
|
|
|
|
|
|
expr: 原始表达式字符串
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
函数返回的 Node 节点
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
InvalidSyntaxError: 不支持的函数调用语法
|
|
|
|
|
|
UnknownFunctionError: 函数未注册
|
|
|
|
|
|
"""
|
|
|
|
|
|
lineno = getattr(node, "lineno", None)
|
|
|
|
|
|
col_offset = getattr(node, "col_offset", None)
|
|
|
|
|
|
|
|
|
|
|
|
# 只支持简单函数调用(如 func(a, b))
|
|
|
|
|
|
if not isinstance(node.func, ast.Name):
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message="只支持简单函数调用(如 func(a, b))",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
func_name = node.func.id
|
|
|
|
|
|
func = self.registry.get(func_name)
|
|
|
|
|
|
|
|
|
|
|
|
if func is None:
|
|
|
|
|
|
raise UnknownFunctionError(
|
|
|
|
|
|
func_name=func_name,
|
|
|
|
|
|
available=self.registry.available_functions(),
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 解析位置参数
|
|
|
|
|
|
args = [self._visit(arg, expr) for arg in node.args]
|
|
|
|
|
|
|
|
|
|
|
|
# 解析关键字参数(如果有)
|
|
|
|
|
|
kwargs = {}
|
|
|
|
|
|
for keyword in node.keywords:
|
|
|
|
|
|
kwargs[keyword.arg] = self._visit(keyword.value, expr)
|
|
|
|
|
|
|
|
|
|
|
|
# 应用函数
|
|
|
|
|
|
try:
|
|
|
|
|
|
if kwargs:
|
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
return func(*args)
|
|
|
|
|
|
except TypeError as e:
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"函数 '{func_name}' 调用失败: {e}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
) from e
|
|
|
|
|
|
|
|
|
|
|
|
def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode:
|
|
|
|
|
|
"""访问比较运算节点。
|
|
|
|
|
|
|
|
|
|
|
|
注意:只支持简单二元比较,不支持链式比较(如 a < b < c)。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
node: AST 比较节点
|
|
|
|
|
|
expr: 原始表达式字符串
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
BinaryOpNode 节点(使用比较运算符)
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
InvalidSyntaxError: 链式比较或不支持的运算符
|
|
|
|
|
|
"""
|
|
|
|
|
|
lineno = getattr(node, "lineno", None)
|
|
|
|
|
|
col_offset = getattr(node, "col_offset", None)
|
|
|
|
|
|
|
|
|
|
|
|
# Python 支持链式比较 (a < b < c),这里简化为二元比较
|
|
|
|
|
|
if len(node.ops) != 1 or len(node.comparators) != 1:
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message="只支持简单二元比较(如 a > b),不支持链式比较",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
left = self._visit(node.left, expr)
|
|
|
|
|
|
op = COMPARE_OP_MAP.get(type(node.ops[0]))
|
|
|
|
|
|
|
|
|
|
|
|
if op is None:
|
|
|
|
|
|
raise InvalidSyntaxError(
|
|
|
|
|
|
message=f"不支持的比较运算符: {type(node.ops[0]).__name__}",
|
|
|
|
|
|
expr=expr,
|
|
|
|
|
|
lineno=lineno,
|
|
|
|
|
|
col_offset=col_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
right = self._visit(node.comparators[0], expr)
|
|
|
|
|
|
return BinaryOpNode(op, left, right)
|