diff --git a/AGENTS.md b/AGENTS.md index 4a0d80c..f1c7550 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -937,6 +937,37 @@ LSP 报错:Syntax error on line 45 python tests/test_sync.py ``` +### 因子编写规范 + +**⚠️ 强制要求:编写因子时,优先使用字符串表达式而非 DSL 表达式。** + +1. **推荐方式(字符串表达式)** + ```python + from src.factors import FactorEngine + + engine = FactorEngine() + engine.add_factor("ma20", "ts_mean(close, 20)") + engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))") + ``` + +2. **不推荐方式(DSL 表达式)** + ```python + from src.factors.api import close, ts_mean, cs_rank + + engine.register("ma20", ts_mean(close, 20)) # 不推荐 + ``` + +3. **原因说明** + - 字符串表达式更易于序列化存储到因子元数据(`factors.jsonl`) + - 字符串表达式支持从元数据动态加载和复用 + - 字符串表达式便于在配置文件中定义和维护 + - 与 `src/scripts/register_factors.py` 批量注册脚本兼容 + +4. **使用场景** + - ✅ 在 `register_factors.py` 的 `FACTORS` 列表中定义因子 + - ✅ 动态添加因子到 FactorEngine + - ✅ 从因子元数据查询并注册因子 + ### Emoji 表情禁用规则 **⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。** diff --git a/src/factors/dsl.py b/src/factors/dsl.py index 4e1e2eb..2ac03b6 100644 --- a/src/factors/dsl.py +++ b/src/factors/dsl.py @@ -115,6 +115,24 @@ class Node(ABC): """大于等于: self >= other""" return BinaryOpNode(">=", self, _ensure_node(other)) + # ==================== 位运算符重载(用于逻辑运算) ==================== + + def __and__(self, other: Any) -> BinaryOpNode: + """逻辑与: self & other""" + return BinaryOpNode("&", self, _ensure_node(other)) + + def __rand__(self, other: Any) -> BinaryOpNode: + """右逻辑与: other & self""" + return BinaryOpNode("&", _ensure_node(other), self) + + def __or__(self, other: Any) -> BinaryOpNode: + """逻辑或: self | other""" + return BinaryOpNode("|", self, _ensure_node(other)) + + def __ror__(self, other: Any) -> BinaryOpNode: + """右逻辑或: other | self""" + return BinaryOpNode("|", _ensure_node(other), self) + # ==================== 抽象方法 ==================== @abstractmethod diff --git a/src/factors/engine/ast_optimizer.py b/src/factors/engine/ast_optimizer.py index f7bf737..5910c6b 100644 --- a/src/factors/engine/ast_optimizer.py +++ b/src/factors/engine/ast_optimizer.py @@ -14,6 +14,7 @@ - 主表达式: cs_rank(__tmp_0) """ +import threading from typing import Dict, Tuple from src.factors.dsl import ( @@ -26,30 +27,45 @@ from src.factors.dsl import ( ) +# 模块级全局计数器,用于生成唯一的临时因子名称 +_global_counter: int = 0 +_counter_lock = threading.Lock() + + +def _get_next_counter() -> int: + """获取下一个全局计数器值。 + + Returns: + 递增后的全局计数器值 + """ + global _global_counter + with _counter_lock: + _global_counter += 1 + return _global_counter + + class ExpressionFlattener: """表达式拍平器。 遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。 Attributes: - _counter: 临时因子名称计数器,用于生成唯一名称 _extracted_nodes: 存储已提取的临时因子字典 """ def __init__(self) -> None: """初始化拍平器。""" - self._counter: int = 0 self._extracted_nodes: Dict[str, Node] = {} def _generate_temp_name(self) -> str: """生成唯一的临时因子名称。 + 使用模块级全局计数器确保跨因子注册时的唯一性。 + Returns: - 格式为 "__tmp_X" 的临时名称,其中 X 是递增数字 + 格式为 "__tmp_X" 的临时名称,其中 X 是全局递增数字 """ - name = f"__tmp_{self._counter}" - self._counter += 1 - return name + return f"__tmp_{_get_next_counter()}" def _is_window_function(self, func_name: str) -> bool: """判断是否为窗口函数。 @@ -85,8 +101,7 @@ class ExpressionFlattener: >>> # flat_expr = cs_rank(__tmp_0) >>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)} """ - # 重置状态 - self._counter = 0 + # 重置状态(只重置提取的节点字典,计数器保持全局递增) self._extracted_nodes = {} # 从根节点开始遍历,初始状态为不在窗口函数内部 diff --git a/src/factors/parser.py b/src/factors/parser.py index 7d8f844..14393c5 100644 --- a/src/factors/parser.py +++ b/src/factors/parser.py @@ -35,6 +35,8 @@ BIN_OP_MAP: Dict[type, str] = { ast.Pow: "**", ast.FloorDiv: "//", ast.Mod: "%", + ast.BitAnd: "&", + ast.BitOr: "|", } UNARY_OP_MAP: Dict[type, str] = { diff --git a/src/factors/translator.py b/src/factors/translator.py index 43258bf..64686f8 100644 --- a/src/factors/translator.py +++ b/src/factors/translator.py @@ -178,6 +178,8 @@ class PolarsTranslator: "<=": lambda l, r: l.le(r), ">": lambda l, r: l.gt(r), ">=": lambda l, r: l.ge(r), + "&": lambda l, r: l & r, + "|": lambda l, r: l | r, } if node.op not in op_map: diff --git a/tests/test_bugfixes.py b/tests/test_bugfixes.py new file mode 100644 index 0000000..f94597e --- /dev/null +++ b/tests/test_bugfixes.py @@ -0,0 +1,144 @@ +"""测试 Bug 修复: +1. 临时因子命名冲突修复验证 +2. 逻辑运算符支持验证 +""" + +import sys + +sys.path.insert(0, "D:/PyProject/ProStock") + +from src.factors.dsl import Symbol, BinaryOpNode +from src.factors.engine.ast_optimizer import ExpressionFlattener, flatten_expression + + +def test_temp_name_uniqueness(): + """测试:临时因子名称全局唯一性。""" + print("测试 1: 临时因子命名冲突修复") + print("-" * 50) + + close = Symbol("close") + open_price = Symbol("open") + + # 创建两个表达式拍平器实例 + flattener1 = ExpressionFlattener() + flattener2 = ExpressionFlattener() + + # 模拟因子 A: cs_rank(ts_delay(close, 1)) + from src.factors.dsl import FunctionNode + + expr_a = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1)) + flat_a, temps_a = flattener1.flatten(expr_a) + + # 模拟因子 B: cs_mean(ts_delay(open, 2)) + expr_b = FunctionNode("cs_mean", FunctionNode("ts_delay", open_price, 2)) + flat_b, temps_b = flattener2.flatten(expr_b) + + # 验证临时名称不冲突 + temp_names_a = set(temps_a.keys()) + temp_names_b = set(temps_b.keys()) + + print(f"因子 A 临时名称: {temp_names_a}") + print(f"因子 B 临时名称: {temp_names_b}") + + # 检查是否有名称冲突 + common_names = temp_names_a & temp_names_b + if common_names: + print(f"[失败] 发现命名冲突: {common_names}") + return False + + print("[通过] 临时因子名称全局唯一,无冲突") + return True + + +def test_logical_operators(): + """测试:逻辑运算符支持。""" + print("\n测试 2: 逻辑运算符支持") + print("-" * 50) + + # 测试 DSL 层 + close = Symbol("close") + open_price = Symbol("open") + + # 测试 & 运算符(注意 Python 运算符优先级,需要用括号) + and_expr = (close > open_price) & (close > 0) + print(f"DSL 表达式 ((close > open) & (close > 0)): {and_expr}") + assert isinstance(and_expr, BinaryOpNode), "& 应生成 BinaryOpNode" + assert and_expr.op == "&", "运算符应为 &" + print("[通过] DSL 层支持 & 运算符") + + # 测试 | 运算符(注意 Python 运算符优先级,需要用括号) + or_expr = (close < open_price) | (close < 0) + print(f"DSL 表达式 ((close < open) | (close < 0)): {or_expr}") + assert isinstance(or_expr, BinaryOpNode), "| 应生成 BinaryOpNode" + assert or_expr.op == "|", "运算符应为 |" + print("[通过] DSL 层支持 | 运算符") + + # 测试字符串解析 + from src.factors.parser import FormulaParser + from src.factors.registry import FunctionRegistry + + parser = FormulaParser(FunctionRegistry()) + + # 解析包含 & 的表达式 + try: + parsed_and = parser.parse("(close > open) & (volume > 0)") + print(f"解析器支持 & 运算符: {parsed_and}") + print("[通过] Parser 支持 & 运算符") + except Exception as e: + print(f"[失败] Parser 解析 & 失败: {e}") + return False + + # 解析包含 | 的表达式 + try: + parsed_or = parser.parse("(close < open) | (volume < 0)") + print(f"解析器支持 | 运算符: {parsed_or}") + print("[通过] Parser 支持 | 运算符") + except Exception as e: + print(f"[失败] Parser 解析 | 失败: {e}") + return False + + # 测试翻译到 Polars + from src.factors.translator import PolarsTranslator + import polars as pl + + translator = PolarsTranslator() + + try: + polars_and = translator.translate(parsed_and) + print(f"Polars 表达式 (&): {polars_and}") + print("[通过] Translator 支持 & 运算符") + except Exception as e: + print(f"[失败] Translator 翻译 & 失败: {e}") + return False + + try: + polars_or = translator.translate(parsed_or) + print(f"Polars 表达式 (|): {polars_or}") + print("[通过] Translator 支持 | 运算符") + except Exception as e: + print(f"[失败] Translator 翻译 | 失败: {e}") + return False + + return True + + +if __name__ == "__main__": + print("=" * 60) + print("Bug 修复验证测试") + print("=" * 60) + + test1_passed = test_temp_name_uniqueness() + test2_passed = test_logical_operators() + + print("\n" + "=" * 60) + print("测试结果汇总") + print("=" * 60) + print(f"临时因子命名冲突修复: {'[通过]' if test1_passed else '[失败]'}") + print(f"逻辑运算符支持: {'[通过]' if test2_passed else '[失败]'}") + + if test1_passed and test2_passed: + print("\n所有测试通过!") + sys.exit(0) + else: + print("\n存在失败的测试!") + sys.exit(1)