fix(factors): 修复 cs_rank 等截面函数在依赖表达式时输出全 null 的问题

This commit is contained in:
2026-03-02 22:21:43 +08:00
parent 9b826c1845
commit 1c0c4a0de1

View File

@@ -55,12 +55,14 @@ class ExecutionPlan:
polars_expr: Polars 表达式
dependencies: 依赖的原始字段
output_name: 输出因子名称
factor_dependencies: 依赖的其他因子名称(用于分步执行)
"""
data_specs: List[DataSpec]
polars_expr: pl.Expr
dependencies: Set[str]
output_name: str
factor_dependencies: Set[str] = field(default_factory=set)
class DataRouter:
@@ -706,6 +708,9 @@ class FactorEngine:
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
self.registered_expressions[name] = expression
# 预创建执行计划
@@ -714,6 +719,10 @@ class FactorEngine:
output_name=name,
data_specs=data_specs,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
self._plans[name] = plan
return self
@@ -772,11 +781,12 @@ class FactorEngine:
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 执行计算
# 4. 按依赖顺序执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
result = self.compute_engine.execute_batch(plans, core_data)
# 使用依赖感知的方式执行
result = self._execute_with_dependencies(factor_names, core_data)
return result
@@ -815,3 +825,239 @@ class FactorEngine:
执行计划,未注册时返回 None
"""
return self._plans.get(factor_name)
def _execute_with_dependencies(
self,
factor_names: List[str],
core_data: pl.DataFrame,
) -> pl.DataFrame:
"""按依赖顺序执行因子计算。
支持 cs_rank 等需要依赖列已存在的场景。
Args:
factor_names: 因子名称列表
core_data: 核心宽表数据
Returns:
包含所有因子结果的数据表
"""
# 1. 拓扑排序
sorted_names = self._topological_sort(factor_names)
# 2. 按顺序执行
result = core_data
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 执行计算
result = self.compute_engine.execute(new_plan, result)
return result
def _create_optimized_plan(
self,
plan: ExecutionPlan,
current_data: pl.DataFrame,
) -> ExecutionPlan:
"""创建优化的执行计划。
将表达式中已计算的依赖因子替换为列引用。
Args:
plan: 原始执行计划
current_data: 当前数据(包含已计算的依赖列)
Returns:
新的执行计划
"""
from src.factors.dsl import Symbol
# 获取当前数据中已存在的列
existing_cols = set(current_data.columns)
# 检查依赖列是否已存在
deps_available = plan.factor_dependencies & existing_cols
if not deps_available:
# 没有可用的依赖列,直接返回原计划
return plan
# 获取原始表达式
original_expr = self.registered_expressions[plan.output_name]
# 创建新的表达式,用 Symbol 引用替换依赖因子
def replace_with_symbol(node: Node) -> Node:
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
from typing import Any
n: Any = node
# 检查当前节点是否等于某个已计算依赖因子
for dep_name in deps_available:
dep_expr = self.registered_expressions[dep_name]
if self._expressions_equal(node, dep_expr):
return Symbol(dep_name)
# 递归处理子节点
if isinstance(n, BinaryOpNode):
new_left = replace_with_symbol(n.left)
new_right = replace_with_symbol(n.right)
if new_left is not n.left or new_right is not n.right:
return BinaryOpNode(n.op, new_left, new_right)
elif isinstance(n, UnaryOpNode):
new_operand = replace_with_symbol(n.operand)
if new_operand is not n.operand:
return UnaryOpNode(n.op, new_operand)
elif isinstance(n, FunctionNode):
new_args = [replace_with_symbol(arg) for arg in n.args]
if any(
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
):
return FunctionNode(n.func_name, *new_args)
return node
# 替换表达式
new_expr = replace_with_symbol(original_expr)
# 重新翻译表达式
translator = PolarsTranslator()
new_polars_expr = translator.translate(new_expr)
# 更新依赖集合
new_factor_deps = plan.factor_dependencies - deps_available
new_deps = plan.dependencies | deps_available
return ExecutionPlan(
data_specs=plan.data_specs,
polars_expr=new_polars_expr,
dependencies=new_deps,
output_name=plan.output_name,
factor_dependencies=new_factor_deps,
)
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
"""比较两个表达式是否相等。
用于检测因子间的依赖关系。
Args:
expr1: 第一个表达式
expr2: 第二个表达式
Returns:
是否相等
"""
from typing import Any
e1: Any = expr1
e2: Any = expr2
if type(e1) != type(e2):
return False
if isinstance(e1, Symbol):
return e1.name == e2.name
if isinstance(e1, Constant):
return e1.value == e2.value
if isinstance(e1, BinaryOpNode):
return (
e1.op == e2.op
and self._expressions_equal(e1.left, e2.left)
and self._expressions_equal(e1.right, e2.right)
)
if isinstance(e1, UnaryOpNode):
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
if isinstance(e1, FunctionNode):
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
return False
return all(
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
)
return False
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子。
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
Args:
expression: 待检查的表达式
Returns:
依赖的因子名称集合
"""
deps: Set[str] = set()
# 检查表达式本身是否等于某个已注册因子
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))
elif isinstance(expression, UnaryOpNode):
deps.update(self._find_factor_dependencies(expression.operand))
elif isinstance(expression, FunctionNode):
for arg in expression.args:
deps.update(self._find_factor_dependencies(arg))
return deps
def _topological_sort(self, factor_names: List[str]) -> List[str]:
"""按依赖关系对因子进行拓扑排序。
确保依赖的因子先被计算。
Args:
factor_names: 因子名称列表
Returns:
排序后的因子名称列表
Raises:
ValueError: 当检测到循环依赖时
"""
# 构建依赖图
graph: Dict[str, Set[str]] = {}
in_degree: Dict[str, int] = {}
for name in factor_names:
plan = self._plans[name]
# 只考虑在本次计算范围内的依赖
deps = plan.factor_dependencies & set(factor_names)
graph[name] = deps
in_degree[name] = len(deps)
# Kahn 算法
result = []
queue = [name for name, degree in in_degree.items() if degree == 0]
while queue:
# 按原始顺序处理同级别的因子
queue.sort(key=lambda x: factor_names.index(x))
name = queue.pop(0)
result.append(name)
for other in factor_names:
if name in graph[other]:
in_degree[other] -= 1
if in_degree[other] == 0:
queue.append(other)
if len(result) != len(factor_names):
raise ValueError("检测到因子循环依赖")
return result