diff --git a/src/factors/compiler.py b/src/factors/compiler.py index 89f494a..0692a21 100644 --- a/src/factors/compiler.py +++ b/src/factors/compiler.py @@ -3,7 +3,7 @@ 本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。 """ -from typing import Set +from typing import Set, Optional from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode @@ -24,9 +24,14 @@ class DependencyExtractor: {'close', 'pe_ratio'} """ - def __init__(self) -> None: - """初始化依赖提取器。""" + def __init__(self, ignore_symbols: Optional[Set[str]] = None) -> None: + """初始化依赖提取器。 + + Args: + ignore_symbols: 需要忽略的符号集合(如已注册的因子名) + """ self.dependencies: Set[str] = set() + self.ignore_symbols: Set[str] = ignore_symbols or set() def visit(self, node: Node) -> None: """访问节点,根据节点类型分发到具体处理方法。 @@ -47,10 +52,14 @@ class DependencyExtractor: def _visit_symbol(self, node: Symbol) -> None: """访问 Symbol 节点,提取符号名称。 + 排除临时因子(以 __tmp_ 开头的符号)和已在免疫名单中的因子。 + Args: node: 符号节点 """ - self.dependencies.add(node.name) + # 排除临时因子引用 和 已在免疫名单中的因子 + if not node.name.startswith("__tmp_") and node.name not in self.ignore_symbols: + self.dependencies.add(node.name) def _visit_binary_op(self, node: BinaryOpNode) -> None: """访问 BinaryOpNode 节点,递归遍历左右子节点。 @@ -92,13 +101,16 @@ class DependencyExtractor: return self.dependencies.copy() @classmethod - def extract_dependencies(cls, node: Node) -> Set[str]: + def extract_dependencies( + cls, node: Node, ignore_symbols: Optional[Set[str]] = None + ) -> Set[str]: """类方法 - 从 AST 节点中提取所有依赖的符号名称。 这是一个便捷方法,无需手动实例化 DependencyExtractor。 Args: node: 表达式树的根节点 + ignore_symbols: 需要忽略的符号集合(如已注册的因子名) Returns: 依赖的符号名称集合 @@ -112,17 +124,20 @@ class DependencyExtractor: >>> print(deps) {'close', 'open'} """ - extractor = cls() + extractor = cls(ignore_symbols=ignore_symbols) return extractor.extract(node) -def extract_dependencies(node: Node) -> Set[str]: +def extract_dependencies( + node: Node, ignore_symbols: Optional[Set[str]] = None +) -> Set[str]: """单例方法 - 从 AST 节点中提取所有依赖的符号名称。 这是 DependencyExtractor.extract_dependencies 的便捷包装函数。 Args: node: 表达式树的根节点 + ignore_symbols: 需要忽略的符号集合(如已注册的因子名) Returns: 依赖的符号名称集合 @@ -136,7 +151,7 @@ def extract_dependencies(node: Node) -> Set[str]: >>> print(deps) {'close', 'pe_ratio'} """ - return DependencyExtractor.extract_dependencies(node) + return DependencyExtractor.extract_dependencies(node, ignore_symbols=ignore_symbols) if __name__ == "__main__": diff --git a/src/factors/engine/ast_optimizer.py b/src/factors/engine/ast_optimizer.py new file mode 100644 index 0000000..f7bf737 --- /dev/null +++ b/src/factors/engine/ast_optimizer.py @@ -0,0 +1,223 @@ +"""AST 优化器 - 表达式拍平。 + +本模块实现将嵌套的窗口函数表达式自动提取为中间临时因子, +解决多维窗口函数(over)嵌套导致计算为空的问题。 + +核心思想: + 通过 AST 变换,将嵌套在窗口函数内的窗口函数表达式提取出来, + 作为独立的临时因子先行计算,然后主表达式引用这些临时因子。 + +示例: + 原始表达式: cs_rank(ts_delay(close, 1)) + 拍平后: + - 临时因子: __tmp_0 = ts_delay(close, 1) + - 主表达式: cs_rank(__tmp_0) +""" + +from typing import Dict, Tuple + +from src.factors.dsl import ( + BinaryOpNode, + Constant, + FunctionNode, + Node, + Symbol, + UnaryOpNode, +) + + +class ExpressionFlattener: + """表达式拍平器。 + + 遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。 + + Attributes: + _counter: 临时因子名称计数器,用于生成唯一名称 + _extracted_nodes: 存储已提取的临时因子字典 + """ + + def __init__(self) -> None: + """初始化拍平器。""" + self._counter: int = 0 + self._extracted_nodes: Dict[str, Node] = {} + + def _generate_temp_name(self) -> str: + """生成唯一的临时因子名称。 + + Returns: + 格式为 "__tmp_X" 的临时名称,其中 X 是递增数字 + """ + name = f"__tmp_{self._counter}" + self._counter += 1 + return name + + def _is_window_function(self, func_name: str) -> bool: + """判断是否为窗口函数。 + + 窗口函数以 "ts_"(时序)或 "cs_"(截面)开头。 + + Args: + func_name: 函数名称 + + Returns: + 是否是窗口函数 + """ + return func_name.startswith("ts_") or func_name.startswith("cs_") + + def flatten(self, node: Node) -> Tuple[Node, Dict[str, Node]]: + """拍平表达式。 + + 遍历 AST,将嵌套的窗口函数提取为临时因子。 + + Args: + node: 原始表达式根节点 + + Returns: + Tuple[拍平后的主表达式节点, 临时因子字典] + 临时因子字典: {临时名称 -> 被提取的节点} + + Example: + >>> flattener = ExpressionFlattener() + >>> from src.factors.dsl import Symbol, FunctionNode + >>> close = Symbol("close") + >>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1)) + >>> flat_expr, tmp_factors = flattener.flatten(expr) + >>> # flat_expr = cs_rank(__tmp_0) + >>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)} + """ + # 重置状态 + self._counter = 0 + self._extracted_nodes = {} + + # 从根节点开始遍历,初始状态为不在窗口函数内部 + new_node = self._flatten_recursive(node, inside_window=False) + + return new_node, self._extracted_nodes.copy() + + def _flatten_recursive(self, node: Node, inside_window: bool) -> Node: + """递归拍平节点。 + + Args: + node: 当前处理的节点 + inside_window: 当前是否处于窗口函数内部 + + Returns: + 处理后的节点(可能是原节点或替换为 Symbol) + """ + # Symbol 和 Constant 是叶子节点,直接返回 + if isinstance(node, Symbol): + return node + + if isinstance(node, Constant): + return node + + # 处理二元运算节点 + if isinstance(node, BinaryOpNode): + return self._flatten_binary_op(node, inside_window) + + # 处理一元运算节点 + if isinstance(node, UnaryOpNode): + return self._flatten_unary_op(node, inside_window) + + # 处理函数调用节点 + if isinstance(node, FunctionNode): + return self._flatten_function(node, inside_window) + + # 未知节点类型,直接返回 + return node + + def _flatten_binary_op(self, node: BinaryOpNode, inside_window: bool) -> Node: + """拍平二元运算节点。 + + Args: + node: 二元运算节点 + inside_window: 当前是否处于窗口函数内部 + + Returns: + 处理后的节点 + """ + # 递归处理左右子节点 + new_left = self._flatten_recursive(node.left, inside_window) + new_right = self._flatten_recursive(node.right, inside_window) + + # 如果子节点没有变化,返回原节点 + if new_left is node.left and new_right is node.right: + return node + + # 创建新的二元运算节点 + return BinaryOpNode(node.op, new_left, new_right) + + def _flatten_unary_op(self, node: UnaryOpNode, inside_window: bool) -> Node: + """拍平一元运算节点。 + + Args: + node: 一元运算节点 + inside_window: 当前是否处于窗口函数内部 + + Returns: + 处理后的节点 + """ + # 递归处理操作数 + new_operand = self._flatten_recursive(node.operand, inside_window) + + # 如果操作数没有变化,返回原节点 + if new_operand is node.operand: + return node + + # 创建新的一元运算节点 + return UnaryOpNode(node.op, new_operand) + + def _flatten_function(self, node: FunctionNode, inside_window: bool) -> Node: + """拍平函数调用节点。 + + 修正为后序遍历(Bottom-Up):先递归拍平参数,再决定是否提取当前节点。 + 确保深层嵌套(如 3层以上)也能被彻底逐层拆解。 + + Args: + node: 函数调用节点 + inside_window: 当前是否处于窗口函数内部 + + Returns: + 处理后的节点 + """ + is_window = self._is_window_function(node.func_name) + next_inside_window = inside_window or is_window + + # 1. 优先递归处理所有参数 + new_args = [] + has_change = False + for arg in node.args: + new_arg = self._flatten_recursive(arg, next_inside_window) + new_args.append(new_arg) + if new_arg is not arg: + has_change = True + + # 2. 只有当参数发生变化时,才创建新的当前节点 + current_node = FunctionNode(node.func_name, *new_args) if has_change else node + + # 3. 判断是否需要提取(此时子节点肯定已经被彻底拍平了) + if inside_window and is_window: + temp_name = self._generate_temp_name() + self._extracted_nodes[temp_name] = current_node + return Symbol(temp_name) + + return current_node + + +def flatten_expression(node: Node) -> Tuple[Node, Dict[str, Node]]: + """便捷函数 - 拍平表达式。 + + Args: + node: 表达式树的根节点 + + Returns: + Tuple[拍平后的主表达式节点, 临时因子字典] + + Example: + >>> from src.factors.dsl import Symbol, FunctionNode + >>> close = Symbol("close") + >>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1)) + >>> flat_expr, tmp_factors = flatten_expression(expr) + """ + flattener = ExpressionFlattener() + return flattener.flatten(node) diff --git a/src/factors/engine/factor_engine.py b/src/factors/engine/factor_engine.py index f7452e4..184ac46 100644 --- a/src/factors/engine/factor_engine.py +++ b/src/factors/engine/factor_engine.py @@ -30,6 +30,7 @@ from src.factors.engine.data_spec import DataSpec, ExecutionPlan from src.factors.engine.data_router import DataRouter from src.factors.engine.planner import ExecutionPlanner from src.factors.engine.compute_engine import ComputeEngine +from src.factors.engine.ast_optimizer import ExpressionFlattener class FactorEngine: @@ -92,13 +93,68 @@ class FactorEngine: self._metadata = FactorManager() + def _register_internal( + self, + name: str, + expression: Node, + data_specs: Optional[List[DataSpec]] = None, + ) -> "FactorEngine": + """内部注册方法,直接注册因子表达式。 + + Args: + name: 因子名称 + expression: DSL 表达式 + data_specs: 数据规格,None 时自动推导 + + Returns: + self,支持链式调用 + """ + # 检测因子依赖(在注册当前因子之前检查其他已注册因子) + factor_deps = self._find_factor_dependencies(expression) + + # 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段) + known_factors = set(self.registered_expressions.keys()) + + self.registered_expressions[name] = expression + + # 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段 + plan = self.planner.create_plan( + expression=expression, + output_name=name, + data_specs=data_specs, + ignore_dependencies=known_factors, + ) + + # 添加因子依赖信息 + plan.factor_dependencies = factor_deps + + # 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格 + if not plan.data_specs and factor_deps: + merged_specs: List[DataSpec] = [] + for dep_name in factor_deps: + if dep_name in self._plans: + merged_specs.extend(self._plans[dep_name].data_specs) + + # 去重(基于表名) + seen_tables: set = set() + unique_specs: List[DataSpec] = [] + for spec in merged_specs: + if spec.table not in seen_tables: + seen_tables.add(spec.table) + unique_specs.append(spec) + plan.data_specs = unique_specs + + self._plans[name] = plan + + return self + def register( self, name: str, expression: Node, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": - """注册因子表达式。 + """注册因子表达式(自动处理嵌套窗口函数)。 Args: name: 因子名称 @@ -113,22 +169,16 @@ class FactorEngine: >>> engine = FactorEngine() >>> engine.register("ma20", ts_mean(close, 20)) """ - # 检测因子依赖(在注册当前因子之前检查其他已注册因子) - factor_deps = self._find_factor_dependencies(expression) + # 使用 AST 优化器拍平嵌套窗口函数 + flattener = ExpressionFlattener() + flat_expression, tmp_factors = flattener.flatten(expression) - self.registered_expressions[name] = expression + # 先注册所有临时因子(自动推导数据规格) + for tmp_name, tmp_node in tmp_factors.items(): + self._register_internal(tmp_name, tmp_node, data_specs=None) - # 预创建执行计划 - plan = self.planner.create_plan( - expression=expression, - output_name=name, - data_specs=data_specs, - ) - - # 添加因子依赖信息 - plan.factor_dependencies = factor_deps - - self._plans[name] = plan + # 最后注册主因子 + self._register_internal(name, flat_expression, data_specs) return self @@ -174,7 +224,7 @@ class FactorEngine: # 解析表达式为 Node node = self._parser.parse(dsl_expr) - # 委托给 register 方法 + # 委托给 register 方法(register 会处理嵌套窗口函数拍平) return self.register(name, node, data_specs) def add_factor( @@ -272,21 +322,32 @@ class FactorEngine: if isinstance(factor_names, str): factor_names = [factor_names] - # 1. 获取执行计划 + # 1. 收集所有需要的因子(包括临时因子依赖) + all_factor_names = self._collect_all_dependencies(factor_names) + + # 2. 获取执行计划 plans = [] - for name in factor_names: + for name in all_factor_names: if name not in self._plans: raise ValueError(f"因子未注册: {name}") plans.append(self._plans[name]) - # 2. 合并数据规格并获取数据 + # 3. 合并数据规格并获取数据 all_specs = [] for plan in plans: all_specs.extend(plan.data_specs) - # 3. 从路由器获取核心宽表 + # 去重数据规格(基于表名) + seen_tables: set = set() + unique_specs: List[DataSpec] = [] + for spec in all_specs: + if spec.table not in seen_tables: + seen_tables.add(spec.table) + unique_specs.append(spec) + + # 4. 从路由器获取核心宽表 core_data = self.router.fetch_data( - data_specs=all_specs, + data_specs=unique_specs, start_date=start_date, end_date=end_date, stock_codes=stock_codes, @@ -295,14 +356,14 @@ class FactorEngine: if len(core_data) == 0: raise ValueError("未获取到任何数据,请检查日期范围和股票代码") - # 4. 按依赖顺序执行计算 - if len(plans) == 1: - result = self.compute_engine.execute(plans[0], core_data) - else: - # 使用依赖感知的方式执行 - result = self._execute_with_dependencies(factor_names, core_data) + # 5. 按依赖顺序执行计算(包含临时因子) + result = self._execute_with_dependencies(all_factor_names, core_data) - return result + # 6. 清理内存宽表,过滤掉临时因子列(__tmp_X) + # 保留所有非临时因子列(包括原始数据列和用户请求的因子列) + cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")] + + return result.select(cols_to_keep) def list_registered(self) -> List[str]: """获取已注册的因子列表。 @@ -501,10 +562,32 @@ class FactorEngine: return False - def _find_factor_dependencies(self, expression: Node) -> Set[str]: - """查找表达式依赖的其他因子。 + def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]: + """收集所有因子及其依赖(包括用户定义的因子和临时因子)。""" + collected: Set[str] = set() + result: List[str] = [] - 遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。 + def collect_recursive(name: str): + if name in collected: + return + collected.add(name) + + # 获取执行计划并递归收集强依赖 + plan = self._plans.get(name) + if plan: + for dep_name in plan.factor_dependencies: + collect_recursive(dep_name) + + # 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序) + result.append(name) + + for name in factor_names: + collect_recursive(name) + + return result + + def _find_factor_dependencies(self, expression: Node) -> Set[str]: + """查找表达式依赖的其他因子(包括临时因子和用户因子引用)。 Args: expression: 待检查的表达式 @@ -514,13 +597,20 @@ class FactorEngine: """ deps: Set[str] = set() - # 检查表达式本身是否等于某个已注册因子 + # 1. 【新增】如果直接引用了已注册的因子名称(包含 __tmp_X 或用户因子) + if ( + isinstance(expression, Symbol) + and expression.name in self.registered_expressions + ): + deps.add(expression.name) + + # 2. 检查表达式本身是否等于某个已注册因子的完整 AST for name, registered_expr in self.registered_expressions.items(): if self._expressions_equal(expression, registered_expr): deps.add(name) break - # 递归检查子节点 + # 3. 递归检查子节点 if isinstance(expression, BinaryOpNode): deps.update(self._find_factor_dependencies(expression.left)) deps.update(self._find_factor_dependencies(expression.right)) diff --git a/src/factors/engine/planner.py b/src/factors/engine/planner.py index 2139f62..5b422e4 100644 --- a/src/factors/engine/planner.py +++ b/src/factors/engine/planner.py @@ -39,6 +39,7 @@ class ExecutionPlanner: expression: Node, output_name: str = "factor", data_specs: Optional[List[DataSpec]] = None, + ignore_dependencies: Optional[Set[str]] = None, ) -> ExecutionPlan: """从表达式创建执行计划。 @@ -46,12 +47,15 @@ class ExecutionPlanner: expression: DSL 表达式节点 output_name: 输出因子名称 data_specs: 预定义的数据规格,None 时自动推导 + ignore_dependencies: 需要忽略的依赖符号集合(如已注册因子名) Returns: 执行计划对象 """ - # 1. 提取依赖 - dependencies = self.compiler.extract_dependencies(expression) + # 1. 提取依赖时传入要忽略的符号 + dependencies = self.compiler.extract_dependencies( + expression, ignore_symbols=ignore_dependencies + ) # 2. 翻译为 Polars 表达式 polars_expr = self.translator.translate(expression) diff --git a/tests/test_ast_optimizer.py b/tests/test_ast_optimizer.py new file mode 100644 index 0000000..fe54169 --- /dev/null +++ b/tests/test_ast_optimizer.py @@ -0,0 +1,367 @@ +"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。 + +测试因子: cs_rank(ts_delay(close, 1)) +这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。 +""" + +import pytest +import polars as pl +import numpy as np +from datetime import datetime, timedelta + +from src.factors.engine import FactorEngine +from src.factors.api import close, ts_delay, cs_rank +from src.factors.dsl import FunctionNode +from src.factors.engine.ast_optimizer import ExpressionFlattener + + +def create_mock_data( + start_date: str = "20240101", + end_date: str = "20240131", + n_stocks: int = 5, +) -> pl.DataFrame: + """创建模拟的日线数据。""" + start = datetime.strptime(start_date, "%Y%m%d") + end = datetime.strptime(end_date, "%Y%m%d") + + dates = [] + current = start + while current <= end: + if current.weekday() < 5: # 周一到周五 + dates.append(current.strftime("%Y%m%d")) + current += timedelta(days=1) + + stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)] + np.random.seed(42) + + rows = [] + for date in dates: + for stock in stocks: + base_price = 10 + np.random.randn() * 5 + close_val = base_price + np.random.randn() * 0.5 + open_val = close_val + np.random.randn() * 0.2 + high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3 + low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3 + vol = int(1000000 + np.random.exponential(500000)) + + rows.append( + { + "ts_code": stock, + "trade_date": date, + "open": round(open_val, 2), + "high": round(high_val, 2), + "low": round(low_val, 2), + "close": round(close_val, 2), + "volume": vol, + } + ) + + return pl.DataFrame(rows) + + +class TestASTOptimizer: + """AST 优化器测试类。""" + + def test_flattener_basic(self): + """测试拍平器基本功能。""" + from src.factors.api import close + + flattener = ExpressionFlattener() + + # 创建嵌套表达式: cs_rank(ts_delay(close, 1)) + expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1)) + + flat_expr, tmp_factors = flattener.flatten(expr) + + # 验证临时因子被提取 + assert len(tmp_factors) == 1 + assert "__tmp_0" in tmp_factors + + # 验证主表达式使用了 Symbol 引用 + assert isinstance(flat_expr, FunctionNode) + assert flat_expr.func_name == "cs_rank" + # 验证第一个参数是临时因子引用(通过 name 属性检查) + assert hasattr(flat_expr.args[0], "name") + assert flat_expr.args[0].name == "__tmp_0" + + # 验证临时因子内容 + tmp_node = tmp_factors["__tmp_0"] + assert isinstance(tmp_node, FunctionNode) + assert tmp_node.func_name == "ts_delay" + + print("[PASS] 拍平器基本功能测试") + + def test_flattener_no_nested(self): + """测试非嵌套表达式不会被拍平。""" + from src.factors.api import close, ts_mean + + flattener = ExpressionFlattener() + + # 非嵌套表达式: ts_mean(close, 20) + expr = FunctionNode("ts_mean", close, 20) + + flat_expr, tmp_factors = flattener.flatten(expr) + + # 验证没有临时因子被提取 + assert len(tmp_factors) == 0 + + # 验证表达式保持不变 + assert isinstance(flat_expr, FunctionNode) + assert flat_expr.func_name == "ts_mean" + + print("[PASS] 非嵌套表达式测试") + + def test_flattener_deeply_nested(self): + """测试多层嵌套表达式拍平。""" + from src.factors.api import close, ts_mean + + flattener = ExpressionFlattener() + + # 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5)) + expr = FunctionNode( + "cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5) + ) + + flat_expr, tmp_factors = flattener.flatten(expr) + + # 验证提取了两个临时因子(修复后正确行为) + # ts_delay(close, 1) 被提取为 __tmp_0 + # ts_mean(__tmp_0, 5) 被提取为 __tmp_1 + assert len(tmp_factors) == 2 + assert "__tmp_0" in tmp_factors + assert "__tmp_1" in tmp_factors + + # 验证 __tmp_0 内容是 ts_delay(close, 1) + tmp0_node = tmp_factors["__tmp_0"] + assert isinstance(tmp0_node, FunctionNode) + assert tmp0_node.func_name == "ts_delay" + + # 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5) + tmp1_node = tmp_factors["__tmp_1"] + assert isinstance(tmp1_node, FunctionNode) + assert tmp1_node.func_name == "ts_mean" + from src.factors.dsl import Symbol + + assert isinstance(tmp1_node.args[0], Symbol) + assert tmp1_node.args[0].name == "__tmp_0" + + # 验证主表达式引用 __tmp_1 + assert isinstance(flat_expr, FunctionNode) + assert flat_expr.func_name == "cs_rank" + assert isinstance(flat_expr.args[0], Symbol) + assert flat_expr.args[0].name == "__tmp_1" + + print("[PASS] 多层嵌套表达式拍平测试") + + def test_nested_window_function_engine(self): + """测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。""" + print("\n" + "=" * 60) + print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))") + print("=" * 60) + + # 1. 准备数据 + mock_data = create_mock_data("20240101", "20240131", n_stocks=5) + print(f"\n生成模拟数据: {len(mock_data)} 行") + + # 2. 初始化引擎 + engine = FactorEngine(data_source={"pro_bar": mock_data}) + print("引擎初始化完成") + + # 3. 使用字符串表达式注册嵌套窗口函数 + print("\n注册因子: cs_rank(ts_delay(close, 1))") + engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))") + + # 4. 检查临时因子是否被创建 + registered_factors = engine.list_registered() + print(f"已注册因子: {registered_factors}") + + # 验证有临时因子被创建 + tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")] + assert len(tmp_factors) >= 1, "应该有临时因子被创建" + print(f"临时因子: {tmp_factors}") + + # 5. 执行计算 + print("\n执行计算...") + result = engine.compute("delayed_rank", "20240115", "20240131") + print(f"计算完成: {len(result)} 行") + + # 6. 验证结果 + assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列" + + # 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null) + non_null_values = result["delayed_rank"].drop_nulls() + if len(non_null_values) > 0: + assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间" + assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间" + + # 检查没有过多空值(考虑到开头的滞后期) + null_count = result["delayed_rank"].is_null().sum() + print(f"空值数量: {null_count}") + + # 展示部分结果 + print("\n前 10 行结果:") + sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head( + 10 + ) + print(sample.to_pandas().to_string(index=False)) + + print("\n" + "=" * 60) + print("嵌套窗口函数测试通过!") + print("=" * 60) + + def test_multiple_nested_factors(self): + """测试同时注册多个嵌套因子。""" + print("\n" + "=" * 60) + print("测试多个嵌套因子") + print("=" * 60) + + mock_data = create_mock_data("20240101", "20240131", n_stocks=5) + engine = FactorEngine(data_source={"pro_bar": mock_data}) + + # 注册多个嵌套因子(使用字符串表达式) + print("\n注册因子1: cs_rank(ts_delay(close, 1))") + engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))") + + print("注册因子2: ts_mean(cs_rank(close), 5)") + engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)") + + # 检查已注册因子 + factors = engine.list_registered() + print(f"\n已注册因子: {factors}") + + # 计算所有因子 + result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131") + + assert "rank1" in result.columns + assert "rank_mean" in result.columns + + print(f"\n结果行数: {len(result)}") + print(f"rank1 空值数: {result['rank1'].is_null().sum()}") + print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}") + + print("\n" + "=" * 60) + print("多个嵌套因子测试通过!") + print("=" * 60) + + def test_nested_vs_native_polars(self): + """对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。""" + print("\n" + "=" * 60) + print("对比测试:cs_rank(ts_delay(close, 1)) vs 原生 Polars") + print("=" * 60) + + # 1. 准备数据 + mock_data = create_mock_data("20240101", "20240131", n_stocks=5) + print(f"\n生成模拟数据: {len(mock_data)} 行") + + # 2. 使用 FactorEngine 计算嵌套因子 + engine = FactorEngine(data_source={"pro_bar": mock_data}) + print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...") + engine.register("delayed_rank", cs_rank(ts_delay(close, 1))) + engine_result = engine.compute("delayed_rank", "20240115", "20240131") + print(f"FactorEngine 结果: {len(engine_result)} 行") + + # 3. 使用原生 Polars 计算(手动分步) + print("\n使用原生 Polars 手动计算...") + # 先计算 ts_delay(close, 1) + native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns( + [pl.col("close").shift(1).over("ts_code").alias("delayed_close")] + ) + # 再计算 cs_rank + native_result = native_result.with_columns( + [ + (pl.col("delayed_close").rank() / pl.col("delayed_close").count()) + .over("trade_date") + .alias("native_delayed_rank") + ] + ) + print(f"原生 Polars 结果: {len(native_result)} 行") + + # 4. 合并结果进行对比 + comparison = engine_result.join( + native_result.select(["ts_code", "trade_date", "native_delayed_rank"]), + on=["ts_code", "trade_date"], + how="inner", + ) + + # 5. 验证数值一致性(允许微小浮点误差) + diff = comparison.with_columns( + [ + (pl.col("delayed_rank") - pl.col("native_delayed_rank")) + .abs() + .alias("diff") + ] + ) + + max_diff = diff["diff"].max() + print(f"\n最大差异: {max_diff}") + + # 过滤掉空值后比较(开头的滞后期会有空值) + non_null_diff = diff.filter(pl.col("diff").is_not_null()) + assert non_null_diff["diff"].max() < 1e-10, ( + f"数值差异过大: {non_null_diff['diff'].max()}" + ) + + print("\n" + "=" * 60) + print("数值一致性验证通过!") + print("=" * 60) + + def test_factor_reference_factor(self): + """测试因子引用另一个因子:fac2 = cs_rank(fac1)。""" + print("\n" + "=" * 60) + print("测试因子引用其他因子: fac2 = cs_rank(fac1)") + print("=" * 60) + + # 准备数据 + mock_data = create_mock_data("20240101", "20240131", n_stocks=5) + engine = FactorEngine(data_source={"pro_bar": mock_data}) + + # 1. 注册基础因子 fac1 + print("\n注册基础因子 fac1 = ts_mean(close, 5)") + from src.factors.api import ts_mean + + engine.register("fac1", ts_mean(close, 5)) + + # 2. 注册引用因子 fac2,引用 fac1 + print("注册引用因子 fac2 = cs_rank(fac1)") + engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子 + + # 3. 验证依赖关系 + registered = engine.list_registered() + print(f"\n已注册因子: {registered}") + assert "fac1" in registered + assert "fac2" in registered + + # 4. 执行计算 + print("\n执行计算...") + result = engine.compute(["fac1", "fac2"], "20240115", "20240131") + print(f"计算完成: {len(result)} 行") + + # 5. 验证结果 + assert "fac1" in result.columns, "结果中应有 fac1 列" + assert "fac2" in result.columns, "结果中应有 fac2 列" + + # fac2 是排名,应在 [0, 1] 之间 + assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间" + assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间" + + print("\n前 10 行结果:") + sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head( + 10 + ) + print(sample.to_pandas().to_string(index=False)) + + print("\n" + "=" * 60) + print("因子引用功能测试通过!") + print("=" * 60) + + +if __name__ == "__main__": + test = TestASTOptimizer() + test.test_flattener_basic() + test.test_flattener_no_nested() + test.test_flattener_deeply_nested() + test.test_nested_window_function_engine() + test.test_multiple_nested_factors() + test.test_nested_vs_native_polars() + test.test_factor_reference_factor() + print("\n所有测试通过!")