"""公式解析器 - 将字符串表达式转换为 DSL 节点树。 基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。 示例: >>> from src.factors.parser import FormulaParser >>> from src.factors.registry import FunctionRegistry >>> parser = FormulaParser(FunctionRegistry()) >>> node = parser.parse("ts_mean(close, 20)") >>> print(node) ts_mean(close, 20) """ import ast from typing import Any, Dict, Optional, TYPE_CHECKING from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode from src.factors.exceptions import ( FormulaParseError, UnknownFunctionError, InvalidSyntaxError, EmptyExpressionError, ) if TYPE_CHECKING: from src.factors.registry import FunctionRegistry # 运算符映射表 BIN_OP_MAP: Dict[type, str] = { ast.Add: "+", ast.Sub: "-", ast.Mult: "*", ast.Div: "/", ast.Pow: "**", ast.FloorDiv: "//", ast.Mod: "%", } UNARY_OP_MAP: Dict[type, str] = { ast.UAdd: "+", ast.USub: "-", ast.Invert: "~", # 不支持,应报错 } COMPARE_OP_MAP: Dict[type, str] = { ast.Eq: "==", ast.NotEq: "!=", ast.Lt: "<", ast.LtE: "<=", ast.Gt: ">", ast.GtE: ">=", } class FormulaParser: """基于 AST 的公式解析器。 将字符串表达式解析为 DSL 节点树,支持: - 符号引用(如 close, open) - 数值常量(如 20, 3.14) - 二元运算(如 +, -, *, /) - 一元运算(如 -x) - 函数调用(如 ts_mean(close, 20)) - 比较运算(如 close > open) Attributes: registry: 函数注册表,用于解析函数调用 """ def __init__(self, registry: "FunctionRegistry") -> None: """初始化解析器。 Args: registry: 函数注册表,提供函数名到可调用对象的映射 """ self.registry = registry def parse(self, expr: str) -> Node: """解析字符串表达式为 Node 树。 Args: expr: 公式字符串,如 "ts_mean(close, 20)" Returns: 解析后的 Node 节点 Raises: EmptyExpressionError: 表达式为空时抛出 SyntaxError: Python 语法错误时抛出 FormulaParseError: 解析失败时抛出 Example: >>> parser.parse("close / open") BinaryOpNode("/", Symbol("close"), Symbol("open")) """ # 检查空表达式 if not expr or not expr.strip(): raise EmptyExpressionError() # 解析为 Python AST try: tree = ast.parse(expr, mode="eval") except SyntaxError as e: # 将 SyntaxError 包装为 InvalidSyntaxError,统一异常类型 raise InvalidSyntaxError( message=f"表达式语法错误: {e.msg}", expr=expr, lineno=e.lineno, col_offset=e.offset, ) from e # 递归访问 AST 节点 try: return self._visit(tree.body, expr) except FormulaParseError: # 重新抛出 FormulaParseError(保留已有的位置信息) raise except Exception as e: # 将其他异常包装为 FormulaParseError if not isinstance(e, FormulaParseError): raise FormulaParseError( message=f"解析失败: {str(e)}", expr=expr, ) from e raise def _visit(self, node: ast.AST, expr: str) -> Node: """递归访问 AST 节点并转换为 DSL 节点。 Args: node: Python AST 节点 expr: 原始表达式字符串(用于错误报告) Returns: 对应的 DSL 节点 Raises: InvalidSyntaxError: 遇到不支持的语法时抛出 """ # 提取位置信息(如果节点有) lineno = getattr(node, "lineno", None) col_offset = getattr(node, "col_offset", None) try: if isinstance(node, ast.Name): return self._visit_Name(node) elif isinstance(node, ast.Constant): return self._visit_Constant(node, expr) elif isinstance(node, ast.BinOp): return self._visit_BinOp(node, expr) elif isinstance(node, ast.UnaryOp): return self._visit_UnaryOp(node, expr) elif isinstance(node, ast.Call): return self._visit_Call(node, expr) elif isinstance(node, ast.Compare): return self._visit_Compare(node, expr) else: raise InvalidSyntaxError( message=f"不支持的语法: {type(node).__name__}", expr=expr, lineno=lineno, col_offset=col_offset, ) except FormulaParseError: # 重新抛出(保留已有的位置信息) raise except Exception as e: # 包装为 FormulaParseError,添加位置信息 raise FormulaParseError( message=f"解析节点失败: {str(e)}", expr=expr, lineno=lineno, col_offset=col_offset, ) from e def _visit_Name(self, node: ast.Name) -> Symbol: """访问名称节点 - 永远转为 Symbol。 注意:利用 AST 语法自然区分变量和函数调用: - log → Symbol("log")(数据列引用) - log(close) → 在 _visit_Call 中处理(函数调用) Args: node: AST 名称节点 Returns: Symbol 节点 """ return Symbol(node.id) def _visit_Constant(self, node: ast.Constant, expr: str) -> Node: """访问常量节点。 支持的类型: - int/float → Constant 节点 - str → Symbol 节点(支持 ts_mean("close", 20) 语法) Args: node: AST 常量节点 expr: 原始表达式字符串 Returns: Constant 或 Symbol 节点 Raises: InvalidSyntaxError: 不支持的常量类型 """ if isinstance(node.value, (int, float)): return Constant(node.value) elif isinstance(node.value, str): # 字符串常量转为 Symbol,支持 "close" 写法 return Symbol(node.value) else: lineno = getattr(node, "lineno", None) col_offset = getattr(node, "col_offset", None) raise InvalidSyntaxError( message=f"不支持的常量类型: {type(node.value).__name__}", expr=expr, lineno=lineno, col_offset=col_offset, ) def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode: """访问二元运算节点。 Args: node: AST 二元运算节点 expr: 原始表达式字符串 Returns: BinaryOpNode 节点 Raises: InvalidSyntaxError: 不支持的运算符 """ left = self._visit(node.left, expr) right = self._visit(node.right, expr) op = BIN_OP_MAP.get(type(node.op)) if op is None: lineno = getattr(node, "lineno", None) col_offset = getattr(node, "col_offset", None) raise InvalidSyntaxError( message=f"不支持的运算符: {type(node.op).__name__}", expr=expr, lineno=lineno, col_offset=col_offset, ) return BinaryOpNode(op, left, right) def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node: """访问一元运算节点。 支持常量折叠优化:纯数值的一元运算直接计算结果。 Args: node: AST 一元运算节点 expr: 原始表达式字符串 Returns: Constant(常量折叠)或 UnaryOpNode 节点 Raises: InvalidSyntaxError: 不支持的运算符 """ operand = self._visit(node.operand, expr) op = UNARY_OP_MAP.get(type(node.op)) lineno = getattr(node, "lineno", None) col_offset = getattr(node, "col_offset", None) if op is None: raise InvalidSyntaxError( message=f"不支持的一元运算符: {type(node.op).__name__}", expr=expr, lineno=lineno, col_offset=col_offset, ) if op == "~": raise InvalidSyntaxError( message="位运算 '~' 不被支持", expr=expr, lineno=lineno, col_offset=col_offset, ) # 常量折叠优化:纯数值直接计算 if isinstance(operand, Constant) and isinstance(operand.value, (int, float)): if op == "-": return Constant(-operand.value) elif op == "+": return operand # +5 就是 5 # 非常量使用运算符重载 if op == "-": return -operand elif op == "+": return +operand # 不应该到达这里 raise InvalidSyntaxError( message=f"无法处理的一元运算符: {op}", expr=expr, lineno=lineno, col_offset=col_offset, ) def _visit_Call(self, node: ast.Call, expr: str) -> Node: """访问函数调用节点。 注意:只有在这里查注册表,处理函数调用。 Args: node: AST 函数调用节点 expr: 原始表达式字符串 Returns: 函数返回的 Node 节点 Raises: InvalidSyntaxError: 不支持的函数调用语法 UnknownFunctionError: 函数未注册 """ lineno = getattr(node, "lineno", None) col_offset = getattr(node, "col_offset", None) # 只支持简单函数调用(如 func(a, b)) if not isinstance(node.func, ast.Name): raise InvalidSyntaxError( message="只支持简单函数调用(如 func(a, b))", expr=expr, lineno=lineno, col_offset=col_offset, ) func_name = node.func.id func = self.registry.get(func_name) if func is None: raise UnknownFunctionError( func_name=func_name, available=self.registry.available_functions(), expr=expr, lineno=lineno, col_offset=col_offset, ) # 解析位置参数 args = [self._visit(arg, expr) for arg in node.args] # 解析关键字参数(如果有) kwargs = {} for keyword in node.keywords: kwargs[keyword.arg] = self._visit(keyword.value, expr) # 应用函数 try: if kwargs: return func(*args, **kwargs) return func(*args) except TypeError as e: raise InvalidSyntaxError( message=f"函数 '{func_name}' 调用失败: {e}", expr=expr, lineno=lineno, col_offset=col_offset, ) from e def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode: """访问比较运算节点。 注意:只支持简单二元比较,不支持链式比较(如 a < b < c)。 Args: node: AST 比较节点 expr: 原始表达式字符串 Returns: BinaryOpNode 节点(使用比较运算符) Raises: InvalidSyntaxError: 链式比较或不支持的运算符 """ lineno = getattr(node, "lineno", None) col_offset = getattr(node, "col_offset", None) # Python 支持链式比较 (a < b < c),这里简化为二元比较 if len(node.ops) != 1 or len(node.comparators) != 1: raise InvalidSyntaxError( message="只支持简单二元比较(如 a > b),不支持链式比较", expr=expr, lineno=lineno, col_offset=col_offset, ) left = self._visit(node.left, expr) op = COMPARE_OP_MAP.get(type(node.ops[0])) if op is None: raise InvalidSyntaxError( message=f"不支持的比较运算符: {type(node.ops[0]).__name__}", expr=expr, lineno=lineno, col_offset=col_offset, ) right = self._visit(node.comparators[0], expr) return BinaryOpNode(op, left, right)