fix(factors): 修复 AST 优化器并发命名冲突及逻辑运算支持
- 修复 ExpressionFlattener 跨实例临时名称冲突 - 添加 & 和 | 逻辑运算符的 DSL/Parser/Translator 支持 - 增加回归测试验证修复
This commit is contained in:
31
AGENTS.md
31
AGENTS.md
@@ -937,6 +937,37 @@ LSP 报错:Syntax error on line 45
|
||||
python tests/test_sync.py
|
||||
```
|
||||
|
||||
### 因子编写规范
|
||||
|
||||
**⚠️ 强制要求:编写因子时,优先使用字符串表达式而非 DSL 表达式。**
|
||||
|
||||
1. **推荐方式(字符串表达式)**
|
||||
```python
|
||||
from src.factors import FactorEngine
|
||||
|
||||
engine = FactorEngine()
|
||||
engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
|
||||
```
|
||||
|
||||
2. **不推荐方式(DSL 表达式)**
|
||||
```python
|
||||
from src.factors.api import close, ts_mean, cs_rank
|
||||
|
||||
engine.register("ma20", ts_mean(close, 20)) # 不推荐
|
||||
```
|
||||
|
||||
3. **原因说明**
|
||||
- 字符串表达式更易于序列化存储到因子元数据(`factors.jsonl`)
|
||||
- 字符串表达式支持从元数据动态加载和复用
|
||||
- 字符串表达式便于在配置文件中定义和维护
|
||||
- 与 `src/scripts/register_factors.py` 批量注册脚本兼容
|
||||
|
||||
4. **使用场景**
|
||||
- ✅ 在 `register_factors.py` 的 `FACTORS` 列表中定义因子
|
||||
- ✅ 动态添加因子到 FactorEngine
|
||||
- ✅ 从因子元数据查询并注册因子
|
||||
|
||||
### Emoji 表情禁用规则
|
||||
|
||||
**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。**
|
||||
|
||||
@@ -115,6 +115,24 @@ class Node(ABC):
|
||||
"""大于等于: self >= other"""
|
||||
return BinaryOpNode(">=", self, _ensure_node(other))
|
||||
|
||||
# ==================== 位运算符重载(用于逻辑运算) ====================
|
||||
|
||||
def __and__(self, other: Any) -> BinaryOpNode:
|
||||
"""逻辑与: self & other"""
|
||||
return BinaryOpNode("&", self, _ensure_node(other))
|
||||
|
||||
def __rand__(self, other: Any) -> BinaryOpNode:
|
||||
"""右逻辑与: other & self"""
|
||||
return BinaryOpNode("&", _ensure_node(other), self)
|
||||
|
||||
def __or__(self, other: Any) -> BinaryOpNode:
|
||||
"""逻辑或: self | other"""
|
||||
return BinaryOpNode("|", self, _ensure_node(other))
|
||||
|
||||
def __ror__(self, other: Any) -> BinaryOpNode:
|
||||
"""右逻辑或: other | self"""
|
||||
return BinaryOpNode("|", _ensure_node(other), self)
|
||||
|
||||
# ==================== 抽象方法 ====================
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
- 主表达式: cs_rank(__tmp_0)
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from src.factors.dsl import (
|
||||
@@ -26,30 +27,45 @@ from src.factors.dsl import (
|
||||
)
|
||||
|
||||
|
||||
# 模块级全局计数器,用于生成唯一的临时因子名称
|
||||
_global_counter: int = 0
|
||||
_counter_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_next_counter() -> int:
|
||||
"""获取下一个全局计数器值。
|
||||
|
||||
Returns:
|
||||
递增后的全局计数器值
|
||||
"""
|
||||
global _global_counter
|
||||
with _counter_lock:
|
||||
_global_counter += 1
|
||||
return _global_counter
|
||||
|
||||
|
||||
class ExpressionFlattener:
|
||||
"""表达式拍平器。
|
||||
|
||||
遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。
|
||||
|
||||
Attributes:
|
||||
_counter: 临时因子名称计数器,用于生成唯一名称
|
||||
_extracted_nodes: 存储已提取的临时因子字典
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化拍平器。"""
|
||||
self._counter: int = 0
|
||||
self._extracted_nodes: Dict[str, Node] = {}
|
||||
|
||||
def _generate_temp_name(self) -> str:
|
||||
"""生成唯一的临时因子名称。
|
||||
|
||||
使用模块级全局计数器确保跨因子注册时的唯一性。
|
||||
|
||||
Returns:
|
||||
格式为 "__tmp_X" 的临时名称,其中 X 是递增数字
|
||||
格式为 "__tmp_X" 的临时名称,其中 X 是全局递增数字
|
||||
"""
|
||||
name = f"__tmp_{self._counter}"
|
||||
self._counter += 1
|
||||
return name
|
||||
return f"__tmp_{_get_next_counter()}"
|
||||
|
||||
def _is_window_function(self, func_name: str) -> bool:
|
||||
"""判断是否为窗口函数。
|
||||
@@ -85,8 +101,7 @@ class ExpressionFlattener:
|
||||
>>> # flat_expr = cs_rank(__tmp_0)
|
||||
>>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)}
|
||||
"""
|
||||
# 重置状态
|
||||
self._counter = 0
|
||||
# 重置状态(只重置提取的节点字典,计数器保持全局递增)
|
||||
self._extracted_nodes = {}
|
||||
|
||||
# 从根节点开始遍历,初始状态为不在窗口函数内部
|
||||
|
||||
@@ -35,6 +35,8 @@ BIN_OP_MAP: Dict[type, str] = {
|
||||
ast.Pow: "**",
|
||||
ast.FloorDiv: "//",
|
||||
ast.Mod: "%",
|
||||
ast.BitAnd: "&",
|
||||
ast.BitOr: "|",
|
||||
}
|
||||
|
||||
UNARY_OP_MAP: Dict[type, str] = {
|
||||
|
||||
@@ -178,6 +178,8 @@ class PolarsTranslator:
|
||||
"<=": lambda l, r: l.le(r),
|
||||
">": lambda l, r: l.gt(r),
|
||||
">=": lambda l, r: l.ge(r),
|
||||
"&": lambda l, r: l & r,
|
||||
"|": lambda l, r: l | r,
|
||||
}
|
||||
|
||||
if node.op not in op_map:
|
||||
|
||||
144
tests/test_bugfixes.py
Normal file
144
tests/test_bugfixes.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""测试 Bug 修复:
|
||||
1. 临时因子命名冲突修复验证
|
||||
2. 逻辑运算符支持验证
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "D:/PyProject/ProStock")
|
||||
|
||||
from src.factors.dsl import Symbol, BinaryOpNode
|
||||
from src.factors.engine.ast_optimizer import ExpressionFlattener, flatten_expression
|
||||
|
||||
|
||||
def test_temp_name_uniqueness():
|
||||
"""测试:临时因子名称全局唯一性。"""
|
||||
print("测试 1: 临时因子命名冲突修复")
|
||||
print("-" * 50)
|
||||
|
||||
close = Symbol("close")
|
||||
open_price = Symbol("open")
|
||||
|
||||
# 创建两个表达式拍平器实例
|
||||
flattener1 = ExpressionFlattener()
|
||||
flattener2 = ExpressionFlattener()
|
||||
|
||||
# 模拟因子 A: cs_rank(ts_delay(close, 1))
|
||||
from src.factors.dsl import FunctionNode
|
||||
|
||||
expr_a = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
|
||||
flat_a, temps_a = flattener1.flatten(expr_a)
|
||||
|
||||
# 模拟因子 B: cs_mean(ts_delay(open, 2))
|
||||
expr_b = FunctionNode("cs_mean", FunctionNode("ts_delay", open_price, 2))
|
||||
flat_b, temps_b = flattener2.flatten(expr_b)
|
||||
|
||||
# 验证临时名称不冲突
|
||||
temp_names_a = set(temps_a.keys())
|
||||
temp_names_b = set(temps_b.keys())
|
||||
|
||||
print(f"因子 A 临时名称: {temp_names_a}")
|
||||
print(f"因子 B 临时名称: {temp_names_b}")
|
||||
|
||||
# 检查是否有名称冲突
|
||||
common_names = temp_names_a & temp_names_b
|
||||
if common_names:
|
||||
print(f"[失败] 发现命名冲突: {common_names}")
|
||||
return False
|
||||
|
||||
print("[通过] 临时因子名称全局唯一,无冲突")
|
||||
return True
|
||||
|
||||
|
||||
def test_logical_operators():
|
||||
"""测试:逻辑运算符支持。"""
|
||||
print("\n测试 2: 逻辑运算符支持")
|
||||
print("-" * 50)
|
||||
|
||||
# 测试 DSL 层
|
||||
close = Symbol("close")
|
||||
open_price = Symbol("open")
|
||||
|
||||
# 测试 & 运算符(注意 Python 运算符优先级,需要用括号)
|
||||
and_expr = (close > open_price) & (close > 0)
|
||||
print(f"DSL 表达式 ((close > open) & (close > 0)): {and_expr}")
|
||||
assert isinstance(and_expr, BinaryOpNode), "& 应生成 BinaryOpNode"
|
||||
assert and_expr.op == "&", "运算符应为 &"
|
||||
print("[通过] DSL 层支持 & 运算符")
|
||||
|
||||
# 测试 | 运算符(注意 Python 运算符优先级,需要用括号)
|
||||
or_expr = (close < open_price) | (close < 0)
|
||||
print(f"DSL 表达式 ((close < open) | (close < 0)): {or_expr}")
|
||||
assert isinstance(or_expr, BinaryOpNode), "| 应生成 BinaryOpNode"
|
||||
assert or_expr.op == "|", "运算符应为 |"
|
||||
print("[通过] DSL 层支持 | 运算符")
|
||||
|
||||
# 测试字符串解析
|
||||
from src.factors.parser import FormulaParser
|
||||
from src.factors.registry import FunctionRegistry
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 解析包含 & 的表达式
|
||||
try:
|
||||
parsed_and = parser.parse("(close > open) & (volume > 0)")
|
||||
print(f"解析器支持 & 运算符: {parsed_and}")
|
||||
print("[通过] Parser 支持 & 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Parser 解析 & 失败: {e}")
|
||||
return False
|
||||
|
||||
# 解析包含 | 的表达式
|
||||
try:
|
||||
parsed_or = parser.parse("(close < open) | (volume < 0)")
|
||||
print(f"解析器支持 | 运算符: {parsed_or}")
|
||||
print("[通过] Parser 支持 | 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Parser 解析 | 失败: {e}")
|
||||
return False
|
||||
|
||||
# 测试翻译到 Polars
|
||||
from src.factors.translator import PolarsTranslator
|
||||
import polars as pl
|
||||
|
||||
translator = PolarsTranslator()
|
||||
|
||||
try:
|
||||
polars_and = translator.translate(parsed_and)
|
||||
print(f"Polars 表达式 (&): {polars_and}")
|
||||
print("[通过] Translator 支持 & 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Translator 翻译 & 失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
polars_or = translator.translate(parsed_or)
|
||||
print(f"Polars 表达式 (|): {polars_or}")
|
||||
print("[通过] Translator 支持 | 运算符")
|
||||
except Exception as e:
|
||||
print(f"[失败] Translator 翻译 | 失败: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Bug 修复验证测试")
|
||||
print("=" * 60)
|
||||
|
||||
test1_passed = test_temp_name_uniqueness()
|
||||
test2_passed = test_logical_operators()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试结果汇总")
|
||||
print("=" * 60)
|
||||
print(f"临时因子命名冲突修复: {'[通过]' if test1_passed else '[失败]'}")
|
||||
print(f"逻辑运算符支持: {'[通过]' if test2_passed else '[失败]'}")
|
||||
|
||||
if test1_passed and test2_passed:
|
||||
print("\n所有测试通过!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\n存在失败的测试!")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user