fix(factors): 修复 AST 优化器并发命名冲突及逻辑运算支持

- 修复 ExpressionFlattener 跨实例临时名称冲突
- 添加 & 和 | 逻辑运算符的 DSL/Parser/Translator 支持
- 增加回归测试验证修复
This commit is contained in:
2026-03-14 01:17:14 +08:00
parent c8808d07eb
commit 2034d60fbb
6 changed files with 220 additions and 8 deletions

View File

@@ -115,6 +115,24 @@ class Node(ABC):
"""大于等于: self >= other"""
return BinaryOpNode(">=", self, _ensure_node(other))
# ==================== 位运算符重载(用于逻辑运算) ====================
def __and__(self, other: Any) -> BinaryOpNode:
"""逻辑与: self & other"""
return BinaryOpNode("&", self, _ensure_node(other))
def __rand__(self, other: Any) -> BinaryOpNode:
"""右逻辑与: other & self"""
return BinaryOpNode("&", _ensure_node(other), self)
def __or__(self, other: Any) -> BinaryOpNode:
"""逻辑或: self | other"""
return BinaryOpNode("|", self, _ensure_node(other))
def __ror__(self, other: Any) -> BinaryOpNode:
"""右逻辑或: other | self"""
return BinaryOpNode("|", _ensure_node(other), self)
# ==================== 抽象方法 ====================
@abstractmethod

View File

@@ -14,6 +14,7 @@
- 主表达式: cs_rank(__tmp_0)
"""
import threading
from typing import Dict, Tuple
from src.factors.dsl import (
@@ -26,30 +27,45 @@ from src.factors.dsl import (
)
# 模块级全局计数器,用于生成唯一的临时因子名称
_global_counter: int = 0
_counter_lock = threading.Lock()
def _get_next_counter() -> int:
"""获取下一个全局计数器值。
Returns:
递增后的全局计数器值
"""
global _global_counter
with _counter_lock:
_global_counter += 1
return _global_counter
class ExpressionFlattener:
"""表达式拍平器。
遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。
Attributes:
_counter: 临时因子名称计数器,用于生成唯一名称
_extracted_nodes: 存储已提取的临时因子字典
"""
def __init__(self) -> None:
"""初始化拍平器。"""
self._counter: int = 0
self._extracted_nodes: Dict[str, Node] = {}
def _generate_temp_name(self) -> str:
"""生成唯一的临时因子名称。
使用模块级全局计数器确保跨因子注册时的唯一性。
Returns:
格式为 "__tmp_X" 的临时名称,其中 X 是递增数字
格式为 "__tmp_X" 的临时名称,其中 X 是全局递增数字
"""
name = f"__tmp_{self._counter}"
self._counter += 1
return name
return f"__tmp_{_get_next_counter()}"
def _is_window_function(self, func_name: str) -> bool:
"""判断是否为窗口函数。
@@ -85,8 +101,7 @@ class ExpressionFlattener:
>>> # flat_expr = cs_rank(__tmp_0)
>>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)}
"""
# 重置状态
self._counter = 0
# 重置状态(只重置提取的节点字典,计数器保持全局递增)
self._extracted_nodes = {}
# 从根节点开始遍历,初始状态为不在窗口函数内部

View File

@@ -35,6 +35,8 @@ BIN_OP_MAP: Dict[type, str] = {
ast.Pow: "**",
ast.FloorDiv: "//",
ast.Mod: "%",
ast.BitAnd: "&",
ast.BitOr: "|",
}
UNARY_OP_MAP: Dict[type, str] = {

View File

@@ -178,6 +178,8 @@ class PolarsTranslator:
"<=": lambda l, r: l.le(r),
">": lambda l, r: l.gt(r),
">=": lambda l, r: l.ge(r),
"&": lambda l, r: l & r,
"|": lambda l, r: l | r,
}
if node.op not in op_map: