diff --git a/src/factors/engine.py b/src/factors/engine.py index ce7ea21..827d380 100644 --- a/src/factors/engine.py +++ b/src/factors/engine.py @@ -55,12 +55,14 @@ class ExecutionPlan: polars_expr: Polars 表达式 dependencies: 依赖的原始字段 output_name: 输出因子名称 + factor_dependencies: 依赖的其他因子名称(用于分步执行) """ data_specs: List[DataSpec] polars_expr: pl.Expr dependencies: Set[str] output_name: str + factor_dependencies: Set[str] = field(default_factory=set) class DataRouter: @@ -706,6 +708,9 @@ class FactorEngine: >>> engine = FactorEngine() >>> engine.register("ma20", ts_mean(close, 20)) """ + # 检测因子依赖(在注册当前因子之前检查其他已注册因子) + factor_deps = self._find_factor_dependencies(expression) + self.registered_expressions[name] = expression # 预创建执行计划 @@ -714,6 +719,10 @@ class FactorEngine: output_name=name, data_specs=data_specs, ) + + # 添加因子依赖信息 + plan.factor_dependencies = factor_deps + self._plans[name] = plan return self @@ -772,11 +781,12 @@ class FactorEngine: if len(core_data) == 0: raise ValueError("未获取到任何数据,请检查日期范围和股票代码") - # 4. 执行计算 + # 4. 按依赖顺序执行计算 if len(plans) == 1: result = self.compute_engine.execute(plans[0], core_data) else: - result = self.compute_engine.execute_batch(plans, core_data) + # 使用依赖感知的方式执行 + result = self._execute_with_dependencies(factor_names, core_data) return result @@ -815,3 +825,239 @@ class FactorEngine: 执行计划,未注册时返回 None """ return self._plans.get(factor_name) + + def _execute_with_dependencies( + self, + factor_names: List[str], + core_data: pl.DataFrame, + ) -> pl.DataFrame: + """按依赖顺序执行因子计算。 + + 支持 cs_rank 等需要依赖列已存在的场景。 + + Args: + factor_names: 因子名称列表 + core_data: 核心宽表数据 + + Returns: + 包含所有因子结果的数据表 + """ + # 1. 拓扑排序 + sorted_names = self._topological_sort(factor_names) + + # 2. 按顺序执行 + result = core_data + for name in sorted_names: + plan = self._plans[name] + + # 创建新的执行计划,引用已计算的依赖列 + new_plan = self._create_optimized_plan(plan, result) + + # 执行计算 + result = self.compute_engine.execute(new_plan, result) + + return result + + def _create_optimized_plan( + self, + plan: ExecutionPlan, + current_data: pl.DataFrame, + ) -> ExecutionPlan: + """创建优化的执行计划。 + + 将表达式中已计算的依赖因子替换为列引用。 + + Args: + plan: 原始执行计划 + current_data: 当前数据(包含已计算的依赖列) + + Returns: + 新的执行计划 + """ + from src.factors.dsl import Symbol + + # 获取当前数据中已存在的列 + existing_cols = set(current_data.columns) + + # 检查依赖列是否已存在 + deps_available = plan.factor_dependencies & existing_cols + + if not deps_available: + # 没有可用的依赖列,直接返回原计划 + return plan + + # 获取原始表达式 + original_expr = self.registered_expressions[plan.output_name] + + # 创建新的表达式,用 Symbol 引用替换依赖因子 + def replace_with_symbol(node: Node) -> Node: + """递归替换表达式中的依赖因子为 Symbol 引用。""" + from typing import Any + + n: Any = node + + # 检查当前节点是否等于某个已计算依赖因子 + for dep_name in deps_available: + dep_expr = self.registered_expressions[dep_name] + if self._expressions_equal(node, dep_expr): + return Symbol(dep_name) + + # 递归处理子节点 + if isinstance(n, BinaryOpNode): + new_left = replace_with_symbol(n.left) + new_right = replace_with_symbol(n.right) + if new_left is not n.left or new_right is not n.right: + return BinaryOpNode(n.op, new_left, new_right) + elif isinstance(n, UnaryOpNode): + new_operand = replace_with_symbol(n.operand) + if new_operand is not n.operand: + return UnaryOpNode(n.op, new_operand) + elif isinstance(n, FunctionNode): + new_args = [replace_with_symbol(arg) for arg in n.args] + if any( + new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args) + ): + return FunctionNode(n.func_name, *new_args) + + return node + + # 替换表达式 + new_expr = replace_with_symbol(original_expr) + + # 重新翻译表达式 + translator = PolarsTranslator() + new_polars_expr = translator.translate(new_expr) + + # 更新依赖集合 + new_factor_deps = plan.factor_dependencies - deps_available + new_deps = plan.dependencies | deps_available + + return ExecutionPlan( + data_specs=plan.data_specs, + polars_expr=new_polars_expr, + dependencies=new_deps, + output_name=plan.output_name, + factor_dependencies=new_factor_deps, + ) + + def _expressions_equal(self, expr1: Node, expr2: Node) -> bool: + """比较两个表达式是否相等。 + + 用于检测因子间的依赖关系。 + + Args: + expr1: 第一个表达式 + expr2: 第二个表达式 + + Returns: + 是否相等 + """ + from typing import Any + + e1: Any = expr1 + e2: Any = expr2 + + if type(e1) != type(e2): + return False + + if isinstance(e1, Symbol): + return e1.name == e2.name + + if isinstance(e1, Constant): + return e1.value == e2.value + + if isinstance(e1, BinaryOpNode): + return ( + e1.op == e2.op + and self._expressions_equal(e1.left, e2.left) + and self._expressions_equal(e1.right, e2.right) + ) + + if isinstance(e1, UnaryOpNode): + return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand) + + if isinstance(e1, FunctionNode): + if e1.func_name != e2.func_name or len(e1.args) != len(e2.args): + return False + return all( + self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args) + ) + + return False + + def _find_factor_dependencies(self, expression: Node) -> Set[str]: + """查找表达式依赖的其他因子。 + + 遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。 + + Args: + expression: 待检查的表达式 + + Returns: + 依赖的因子名称集合 + """ + deps: Set[str] = set() + + # 检查表达式本身是否等于某个已注册因子 + for name, registered_expr in self.registered_expressions.items(): + if self._expressions_equal(expression, registered_expr): + deps.add(name) + break + + # 递归检查子节点 + if isinstance(expression, BinaryOpNode): + deps.update(self._find_factor_dependencies(expression.left)) + deps.update(self._find_factor_dependencies(expression.right)) + elif isinstance(expression, UnaryOpNode): + deps.update(self._find_factor_dependencies(expression.operand)) + elif isinstance(expression, FunctionNode): + for arg in expression.args: + deps.update(self._find_factor_dependencies(arg)) + + return deps + + def _topological_sort(self, factor_names: List[str]) -> List[str]: + """按依赖关系对因子进行拓扑排序。 + + 确保依赖的因子先被计算。 + + Args: + factor_names: 因子名称列表 + + Returns: + 排序后的因子名称列表 + + Raises: + ValueError: 当检测到循环依赖时 + """ + # 构建依赖图 + graph: Dict[str, Set[str]] = {} + in_degree: Dict[str, int] = {} + + for name in factor_names: + plan = self._plans[name] + # 只考虑在本次计算范围内的依赖 + deps = plan.factor_dependencies & set(factor_names) + graph[name] = deps + in_degree[name] = len(deps) + + # Kahn 算法 + result = [] + queue = [name for name, degree in in_degree.items() if degree == 0] + + while queue: + # 按原始顺序处理同级别的因子 + queue.sort(key=lambda x: factor_names.index(x)) + name = queue.pop(0) + result.append(name) + + for other in factor_names: + if name in graph[other]: + in_degree[other] -= 1 + if in_degree[other] == 0: + queue.append(other) + + if len(result) != len(factor_names): + raise ValueError("检测到因子循环依赖") + + return result