"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。 本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。 支持时序因子(ts_*)和截面因子(cs_*)的防错分组翻译。 """ from typing import Any, Callable, Dict import polars as pl from src.factors.decorators import cross_section, element_wise, time_series from src.factors.dsl import ( BinaryOpNode, Constant, FunctionNode, Node, Symbol, UnaryOpNode, ) class PolarsTranslator: """Polars 表达式翻译器。 将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。 Attributes: handlers: 函数处理器注册表,映射 func_name 到处理函数 Example: >>> from src.factors.dsl import Symbol, FunctionNode >>> close = Symbol("close") >>> expr = FunctionNode("ts_mean", close, 20) >>> translator = PolarsTranslator() >>> polars_expr = translator.translate(expr) >>> # 结果: pl.col("close").rolling_mean(20).over("asset") """ def __init__(self) -> None: """初始化翻译器并注册内置函数处理器。""" self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {} self._register_builtin_handlers() def _register_builtin_handlers(self) -> None: """注册内置的函数处理器。""" # 时序因子处理器 (ts_*) self.register_handler("ts_mean", self._handle_ts_mean) self.register_handler("ts_sum", self._handle_ts_sum) self.register_handler("ts_std", self._handle_ts_std) self.register_handler("ts_max", self._handle_ts_max) self.register_handler("ts_min", self._handle_ts_min) self.register_handler("ts_delay", self._handle_ts_delay) self.register_handler("ts_delta", self._handle_ts_delta) self.register_handler("ts_corr", self._handle_ts_corr) self.register_handler("ts_cov", self._handle_ts_cov) # 截面因子处理器 (cs_*) self.register_handler("cs_rank", self._handle_cs_rank) self.register_handler("cs_zscore", self._handle_cs_zscore) self.register_handler("cs_neutral", self._handle_cs_neutral) # 元素级数学函数 (element_wise) self.register_handler("log", self._handle_log) self.register_handler("exp", self._handle_exp) self.register_handler("sqrt", self._handle_sqrt) self.register_handler("sign", self._handle_sign) self.register_handler("cos", self._handle_cos) self.register_handler("sin", self._handle_sin) def register_handler( self, func_name: str, handler: Callable[[FunctionNode], pl.Expr] ) -> None: """注册自定义函数处理器。 Args: func_name: 函数名称 handler: 处理函数,接收 FunctionNode 返回 pl.Expr Example: >>> def handle_custom(node: FunctionNode) -> pl.Expr: ... arg = self.translate(node.args[0]) ... return arg * 2 >>> translator.register_handler("custom", handle_custom) """ self.handlers[func_name] = handler def translate(self, node: Node) -> pl.Expr: """递归翻译 AST 节点为 Polars 表达式。 Args: node: AST 节点(Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode) Returns: Polars 表达式对象 Raises: TypeError: 当遇到未知的节点类型时 """ if isinstance(node, Symbol): return self._translate_symbol(node) elif isinstance(node, Constant): return self._translate_constant(node) elif isinstance(node, BinaryOpNode): return self._translate_binary_op(node) elif isinstance(node, UnaryOpNode): return self._translate_unary_op(node) elif isinstance(node, FunctionNode): return self._translate_function(node) else: raise TypeError(f"未知的节点类型: {type(node).__name__}") def _translate_symbol(self, node: Symbol) -> pl.Expr: """翻译 Symbol 节点为 pl.col() 表达式。 Args: node: 符号节点 Returns: pl.col(node.name) 表达式 """ return pl.col(node.name) def _translate_constant(self, node: Constant) -> pl.Expr: """翻译 Constant 节点为 Polars 字面量。 Args: node: 常量节点 Returns: pl.lit(node.value) 表达式 """ return pl.lit(node.value) def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr: """翻译 BinaryOpNode 为 Polars 二元运算。 Args: node: 二元运算节点 Returns: Polars 二元运算表达式 """ left = self.translate(node.left) right = self.translate(node.right) op_map = { "+": lambda l, r: l + r, "-": lambda l, r: l - r, "*": lambda l, r: l * r, "/": lambda l, r: l / r, "**": lambda l, r: l.pow(r), "//": lambda l, r: l.floor_div(r), "%": lambda l, r: l % r, "==": lambda l, r: l.eq(r), "!=": lambda l, r: l.ne(r), "<": lambda l, r: l.lt(r), "<=": lambda l, r: l.le(r), ">": lambda l, r: l.gt(r), ">=": lambda l, r: l.ge(r), } if node.op not in op_map: raise ValueError(f"不支持的二元运算符: {node.op}") return op_map[node.op](left, right) def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr: """翻译 UnaryOpNode 为 Polars 一元运算。 Args: node: 一元运算节点 Returns: Polars 一元运算表达式 """ operand = self.translate(node.operand) op_map = { "+": lambda x: x, "-": lambda x: -x, "abs": lambda x: x.abs(), } if node.op not in op_map: raise ValueError(f"不支持的一元运算符: {node.op}") return op_map[node.op](operand) def _translate_function(self, node: FunctionNode) -> pl.Expr: """翻译 FunctionNode 为 Polars 函数调用。 优先从 handlers 注册表中查找处理器,未找到则抛出错误。 Args: node: 函数调用节点 Returns: Polars 函数表达式 Raises: ValueError: 当函数名称未注册处理器时 """ func_name = node.func_name if func_name in self.handlers: return self.handlers[func_name](node) else: raise ValueError( f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。" ) # ==================== 时序因子处理器 (ts_*) ==================== # 所有时序因子使用 @time_series 装饰器自动注入 over("ts_code") 防串表 @time_series def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr: """处理 ts_mean(close, window) -> rolling_mean(window)。""" if len(node.args) != 2: raise ValueError("ts_mean 需要 2 个参数: (expr, window)") expr = self.translate(node.args[0]) window = self._extract_window(node.args[1]) return expr.rolling_mean(window_size=window) @time_series def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr: """处理 ts_sum(close, window) -> rolling_sum(window)。""" if len(node.args) != 2: raise ValueError("ts_sum 需要 2 个参数: (expr, window)") expr = self.translate(node.args[0]) window = self._extract_window(node.args[1]) return expr.rolling_sum(window_size=window) @time_series def _handle_ts_std(self, node: FunctionNode) -> pl.Expr: """处理 ts_std(close, window) -> rolling_std(window)。""" if len(node.args) != 2: raise ValueError("ts_std 需要 2 个参数: (expr, window)") expr = self.translate(node.args[0]) window = self._extract_window(node.args[1]) return expr.rolling_std(window_size=window) @time_series def _handle_ts_max(self, node: FunctionNode) -> pl.Expr: """处理 ts_max(close, window) -> rolling_max(window)。""" if len(node.args) != 2: raise ValueError("ts_max 需要 2 个参数: (expr, window)") expr = self.translate(node.args[0]) window = self._extract_window(node.args[1]) return expr.rolling_max(window_size=window) @time_series def _handle_ts_min(self, node: FunctionNode) -> pl.Expr: """处理 ts_min(close, window) -> rolling_min(window)。""" if len(node.args) != 2: raise ValueError("ts_min 需要 2 个参数: (expr, window)") expr = self.translate(node.args[0]) window = self._extract_window(node.args[1]) return expr.rolling_min(window_size=window) @time_series def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr: """处理 ts_delay(close, n) -> shift(n)。""" if len(node.args) != 2: raise ValueError("ts_delay 需要 2 个参数: (expr, n)") expr = self.translate(node.args[0]) n = self._extract_window(node.args[1]) return expr.shift(n) @time_series def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr: """处理 ts_delta(close, n) -> (expr - shift(n))。""" if len(node.args) != 2: raise ValueError("ts_delta 需要 2 个参数: (expr, n)") expr = self.translate(node.args[0]) n = self._extract_window(node.args[1]) return expr - expr.shift(n) @time_series def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr: """处理 ts_corr(x, y, window) -> rolling_corr(y, window)。""" if len(node.args) != 3: raise ValueError("ts_corr 需要 3 个参数: (x, y, window)") x = self.translate(node.args[0]) y = self.translate(node.args[1]) window = self._extract_window(node.args[2]) return x.rolling_corr(y, window_size=window) @time_series def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr: """处理 ts_cov(x, y, window) -> rolling_cov(y, window)。""" if len(node.args) != 3: raise ValueError("ts_cov 需要 3 个参数: (x, y, window)") x = self.translate(node.args[0]) y = self.translate(node.args[1]) window = self._extract_window(node.args[2]) return x.rolling_cov(y, window_size=window) # ==================== 截面因子处理器 (cs_*) ==================== # 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表 @cross_section def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr: """处理 cs_rank(expr) -> rank()/count()。 将排名归一化到 [0, 1] 区间。 """ if len(node.args) != 1: raise ValueError("cs_rank 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return expr.rank() / expr.count() @cross_section def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr: """处理 cs_zscore(expr) -> (expr - mean())/std()。""" if len(node.args) != 1: raise ValueError("cs_zscore 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return (expr - expr.mean()) / expr.std() @cross_section def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr: """处理 cs_neutral(expr, group) -> 分组中性化。""" if len(node.args) not in [1, 2]: raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])") expr = self.translate(node.args[0]) # 简单实现:减去截面均值(可在未来扩展为分组中性化) return expr - expr.mean() # ==================== 元素级数学函数 (element_wise) ==================== # 这些函数对每个元素独立计算,不添加 over @element_wise def _handle_log(self, node: FunctionNode) -> pl.Expr: """处理 log(expr) -> 自然对数。""" if len(node.args) != 1: raise ValueError("log 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return expr.log() @element_wise def _handle_exp(self, node: FunctionNode) -> pl.Expr: """处理 exp(expr) -> 指数函数。""" if len(node.args) != 1: raise ValueError("exp 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return expr.exp() @element_wise def _handle_sqrt(self, node: FunctionNode) -> pl.Expr: """处理 sqrt(expr) -> 平方根。""" if len(node.args) != 1: raise ValueError("sqrt 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return expr.sqrt() @element_wise def _handle_sign(self, node: FunctionNode) -> pl.Expr: """处理 sign(expr) -> 符号函数。""" if len(node.args) != 1: raise ValueError("sign 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return expr.sign() @element_wise def _handle_cos(self, node: FunctionNode) -> pl.Expr: """处理 cos(expr) -> 余弦函数。""" if len(node.args) != 1: raise ValueError("cos 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return expr.cos() @element_wise def _handle_sin(self, node: FunctionNode) -> pl.Expr: """处理 sin(expr) -> 正弦函数。""" if len(node.args) != 1: raise ValueError("sin 需要 1 个参数: (expr)") expr = self.translate(node.args[0]) return expr.sin() # ==================== 辅助方法 ==================== def _extract_window(self, node: Node) -> int: """从节点中提取窗口大小参数。 Args: node: 应该是 Constant 节点 Returns: 整数值 Raises: ValueError: 当节点不是 Constant 或值不是整数时 """ if isinstance(node, Constant): if not isinstance(node.value, int): raise ValueError( f"窗口参数必须是整数,得到: {type(node.value).__name__}" ) return node.value raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}") def translate_to_polars(node: Node) -> pl.Expr: """便捷函数 - 将 AST 节点翻译为 Polars 表达式。 Args: node: 表达式树的根节点 Returns: Polars 表达式对象 Example: >>> from src.factors.dsl import Symbol, FunctionNode >>> close = Symbol("close") >>> expr = FunctionNode("ts_mean", close, 20) >>> polars_expr = translate_to_polars(expr) """ translator = PolarsTranslator() return translator.translate(node) if __name__ == "__main__": # 测试用例 from src.factors.dsl import Symbol, FunctionNode # 创建符号 close = Symbol("close") volume = Symbol("volume") # 测试 1: 简单符号 print("测试 1: Symbol") translator = PolarsTranslator() expr1 = translator.translate(close) print(f" close -> {expr1}") assert str(expr1) == 'col("close")' # 测试 2: 二元运算 print("\n测试 2: BinaryOp") expr2 = translator.translate(close + 10) print(f" close + 10 -> {expr2}") # 测试 3: ts_mean print("\n测试 3: ts_mean") expr3 = translator.translate(FunctionNode("ts_mean", close, 20)) print(f" ts_mean(close, 20) -> {expr3}") # 测试 4: cs_rank print("\n测试 4: cs_rank") expr4 = translator.translate(FunctionNode("cs_rank", close / volume)) print(f" cs_rank(close / volume) -> {expr4}") # 测试 5: 复杂表达式 print("\n测试 5: 复杂表达式") ma20 = FunctionNode("ts_mean", close, 20) ma60 = FunctionNode("ts_mean", close, 60) expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60)) print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}") print("\n✅ 所有测试通过!")