Files
ProStock/src/factors/parser.py
liaozhaorun 2034d60fbb fix(factors): 修复 AST 优化器并发命名冲突及逻辑运算支持
- 修复 ExpressionFlattener 跨实例临时名称冲突
- 添加 & 和 | 逻辑运算符的 DSL/Parser/Translator 支持
- 增加回归测试验证修复
2026-03-14 01:17:14 +08:00

414 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""公式解析器 - 将字符串表达式转换为 DSL 节点树。
基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。
示例:
>>> from src.factors.parser import FormulaParser
>>> from src.factors.registry import FunctionRegistry
>>> parser = FormulaParser(FunctionRegistry())
>>> node = parser.parse("ts_mean(close, 20)")
>>> print(node)
ts_mean(close, 20)
"""
import ast
from typing import Any, Dict, Optional, TYPE_CHECKING
from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode
from src.factors.exceptions import (
FormulaParseError,
UnknownFunctionError,
InvalidSyntaxError,
EmptyExpressionError,
)
if TYPE_CHECKING:
from src.factors.registry import FunctionRegistry
# 运算符映射表
BIN_OP_MAP: Dict[type, str] = {
ast.Add: "+",
ast.Sub: "-",
ast.Mult: "*",
ast.Div: "/",
ast.Pow: "**",
ast.FloorDiv: "//",
ast.Mod: "%",
ast.BitAnd: "&",
ast.BitOr: "|",
}
UNARY_OP_MAP: Dict[type, str] = {
ast.UAdd: "+",
ast.USub: "-",
ast.Invert: "~", # 不支持,应报错
}
COMPARE_OP_MAP: Dict[type, str] = {
ast.Eq: "==",
ast.NotEq: "!=",
ast.Lt: "<",
ast.LtE: "<=",
ast.Gt: ">",
ast.GtE: ">=",
}
class FormulaParser:
"""基于 AST 的公式解析器。
将字符串表达式解析为 DSL 节点树,支持:
- 符号引用(如 close, open
- 数值常量(如 20, 3.14
- 二元运算(如 +, -, *, /
- 一元运算(如 -x
- 函数调用(如 ts_mean(close, 20)
- 比较运算(如 close > open
Attributes:
registry: 函数注册表,用于解析函数调用
"""
def __init__(self, registry: "FunctionRegistry") -> None:
"""初始化解析器。
Args:
registry: 函数注册表,提供函数名到可调用对象的映射
"""
self.registry = registry
def parse(self, expr: str) -> Node:
"""解析字符串表达式为 Node 树。
Args:
expr: 公式字符串,如 "ts_mean(close, 20)"
Returns:
解析后的 Node 节点
Raises:
EmptyExpressionError: 表达式为空时抛出
SyntaxError: Python 语法错误时抛出
FormulaParseError: 解析失败时抛出
Example:
>>> parser.parse("close / open")
BinaryOpNode("/", Symbol("close"), Symbol("open"))
"""
# 检查空表达式
if not expr or not expr.strip():
raise EmptyExpressionError()
# 解析为 Python AST
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as e:
# 将 SyntaxError 包装为 InvalidSyntaxError统一异常类型
raise InvalidSyntaxError(
message=f"表达式语法错误: {e.msg}",
expr=expr,
lineno=e.lineno,
col_offset=e.offset,
) from e
# 递归访问 AST 节点
try:
return self._visit(tree.body, expr)
except FormulaParseError:
# 重新抛出 FormulaParseError保留已有的位置信息
raise
except Exception as e:
# 将其他异常包装为 FormulaParseError
if not isinstance(e, FormulaParseError):
raise FormulaParseError(
message=f"解析失败: {str(e)}",
expr=expr,
) from e
raise
def _visit(self, node: ast.AST, expr: str) -> Node:
"""递归访问 AST 节点并转换为 DSL 节点。
Args:
node: Python AST 节点
expr: 原始表达式字符串(用于错误报告)
Returns:
对应的 DSL 节点
Raises:
InvalidSyntaxError: 遇到不支持的语法时抛出
"""
# 提取位置信息(如果节点有)
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
try:
if isinstance(node, ast.Name):
return self._visit_Name(node)
elif isinstance(node, ast.Constant):
return self._visit_Constant(node, expr)
elif isinstance(node, ast.BinOp):
return self._visit_BinOp(node, expr)
elif isinstance(node, ast.UnaryOp):
return self._visit_UnaryOp(node, expr)
elif isinstance(node, ast.Call):
return self._visit_Call(node, expr)
elif isinstance(node, ast.Compare):
return self._visit_Compare(node, expr)
else:
raise InvalidSyntaxError(
message=f"不支持的语法: {type(node).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
except FormulaParseError:
# 重新抛出(保留已有的位置信息)
raise
except Exception as e:
# 包装为 FormulaParseError添加位置信息
raise FormulaParseError(
message=f"解析节点失败: {str(e)}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
) from e
def _visit_Name(self, node: ast.Name) -> Symbol:
"""访问名称节点 - 永远转为 Symbol。
注意:利用 AST 语法自然区分变量和函数调用:
- log → Symbol("log")(数据列引用)
- log(close) → 在 _visit_Call 中处理(函数调用)
Args:
node: AST 名称节点
Returns:
Symbol 节点
"""
return Symbol(node.id)
def _visit_Constant(self, node: ast.Constant, expr: str) -> Node:
"""访问常量节点。
支持的类型:
- int/float → Constant 节点
- str → Symbol 节点(支持 ts_mean("close", 20) 语法)
Args:
node: AST 常量节点
expr: 原始表达式字符串
Returns:
Constant 或 Symbol 节点
Raises:
InvalidSyntaxError: 不支持的常量类型
"""
if isinstance(node.value, (int, float)):
return Constant(node.value)
elif isinstance(node.value, str):
# 字符串常量转为 Symbol支持 "close" 写法
return Symbol(node.value)
else:
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
raise InvalidSyntaxError(
message=f"不支持的常量类型: {type(node.value).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode:
"""访问二元运算节点。
Args:
node: AST 二元运算节点
expr: 原始表达式字符串
Returns:
BinaryOpNode 节点
Raises:
InvalidSyntaxError: 不支持的运算符
"""
left = self._visit(node.left, expr)
right = self._visit(node.right, expr)
op = BIN_OP_MAP.get(type(node.op))
if op is None:
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
raise InvalidSyntaxError(
message=f"不支持的运算符: {type(node.op).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
return BinaryOpNode(op, left, right)
def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node:
"""访问一元运算节点。
支持常量折叠优化:纯数值的一元运算直接计算结果。
Args:
node: AST 一元运算节点
expr: 原始表达式字符串
Returns:
Constant常量折叠或 UnaryOpNode 节点
Raises:
InvalidSyntaxError: 不支持的运算符
"""
operand = self._visit(node.operand, expr)
op = UNARY_OP_MAP.get(type(node.op))
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
if op is None:
raise InvalidSyntaxError(
message=f"不支持的一元运算符: {type(node.op).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
if op == "~":
raise InvalidSyntaxError(
message="位运算 '~' 不被支持",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
# 常量折叠优化:纯数值直接计算
if isinstance(operand, Constant) and isinstance(operand.value, (int, float)):
if op == "-":
return Constant(-operand.value)
elif op == "+":
return operand # +5 就是 5
# 非常量使用运算符重载
if op == "-":
return -operand
elif op == "+":
return +operand
# 不应该到达这里
raise InvalidSyntaxError(
message=f"无法处理的一元运算符: {op}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
def _visit_Call(self, node: ast.Call, expr: str) -> Node:
"""访问函数调用节点。
注意:只有在这里查注册表,处理函数调用。
Args:
node: AST 函数调用节点
expr: 原始表达式字符串
Returns:
函数返回的 Node 节点
Raises:
InvalidSyntaxError: 不支持的函数调用语法
UnknownFunctionError: 函数未注册
"""
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
# 只支持简单函数调用(如 func(a, b)
if not isinstance(node.func, ast.Name):
raise InvalidSyntaxError(
message="只支持简单函数调用(如 func(a, b)",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
func_name = node.func.id
func = self.registry.get(func_name)
if func is None:
raise UnknownFunctionError(
func_name=func_name,
available=self.registry.available_functions(),
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
# 解析位置参数
args = [self._visit(arg, expr) for arg in node.args]
# 解析关键字参数(如果有)
kwargs = {}
for keyword in node.keywords:
kwargs[keyword.arg] = self._visit(keyword.value, expr)
# 应用函数
try:
if kwargs:
return func(*args, **kwargs)
return func(*args)
except TypeError as e:
raise InvalidSyntaxError(
message=f"函数 '{func_name}' 调用失败: {e}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
) from e
def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode:
"""访问比较运算节点。
注意:只支持简单二元比较,不支持链式比较(如 a < b < c
Args:
node: AST 比较节点
expr: 原始表达式字符串
Returns:
BinaryOpNode 节点(使用比较运算符)
Raises:
InvalidSyntaxError: 链式比较或不支持的运算符
"""
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
# Python 支持链式比较 (a < b < c),这里简化为二元比较
if len(node.ops) != 1 or len(node.comparators) != 1:
raise InvalidSyntaxError(
message="只支持简单二元比较(如 a > b不支持链式比较",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
left = self._visit(node.left, expr)
op = COMPARE_OP_MAP.get(type(node.ops[0]))
if op is None:
raise InvalidSyntaxError(
message=f"不支持的比较运算符: {type(node.ops[0]).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
right = self._visit(node.comparators[0], expr)
return BinaryOpNode(op, left, right)