feat(factors): 实现 AST 拍平优化支持嵌套窗口函数

- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
This commit is contained in:
2026-03-14 01:06:17 +08:00
parent 282fe1fef5
commit c8808d07eb
5 changed files with 742 additions and 43 deletions

View File

@@ -3,7 +3,7 @@
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
"""
from typing import Set
from typing import Set, Optional
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
@@ -24,9 +24,14 @@ class DependencyExtractor:
{'close', 'pe_ratio'}
"""
def __init__(self) -> None:
"""初始化依赖提取器。"""
def __init__(self, ignore_symbols: Optional[Set[str]] = None) -> None:
"""初始化依赖提取器。
Args:
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
"""
self.dependencies: Set[str] = set()
self.ignore_symbols: Set[str] = ignore_symbols or set()
def visit(self, node: Node) -> None:
"""访问节点,根据节点类型分发到具体处理方法。
@@ -47,10 +52,14 @@ class DependencyExtractor:
def _visit_symbol(self, node: Symbol) -> None:
"""访问 Symbol 节点,提取符号名称。
排除临时因子(以 __tmp_ 开头的符号)和已在免疫名单中的因子。
Args:
node: 符号节点
"""
self.dependencies.add(node.name)
# 排除临时因子引用 和 已在免疫名单中的因子
if not node.name.startswith("__tmp_") and node.name not in self.ignore_symbols:
self.dependencies.add(node.name)
def _visit_binary_op(self, node: BinaryOpNode) -> None:
"""访问 BinaryOpNode 节点,递归遍历左右子节点。
@@ -92,13 +101,16 @@ class DependencyExtractor:
return self.dependencies.copy()
@classmethod
def extract_dependencies(cls, node: Node) -> Set[str]:
def extract_dependencies(
cls, node: Node, ignore_symbols: Optional[Set[str]] = None
) -> Set[str]:
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。
这是一个便捷方法,无需手动实例化 DependencyExtractor。
Args:
node: 表达式树的根节点
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
Returns:
依赖的符号名称集合
@@ -112,17 +124,20 @@ class DependencyExtractor:
>>> print(deps)
{'close', 'open'}
"""
extractor = cls()
extractor = cls(ignore_symbols=ignore_symbols)
return extractor.extract(node)
def extract_dependencies(node: Node) -> Set[str]:
def extract_dependencies(
node: Node, ignore_symbols: Optional[Set[str]] = None
) -> Set[str]:
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
Args:
node: 表达式树的根节点
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
Returns:
依赖的符号名称集合
@@ -136,7 +151,7 @@ def extract_dependencies(node: Node) -> Set[str]:
>>> print(deps)
{'close', 'pe_ratio'}
"""
return DependencyExtractor.extract_dependencies(node)
return DependencyExtractor.extract_dependencies(node, ignore_symbols=ignore_symbols)
if __name__ == "__main__":

View File

@@ -0,0 +1,223 @@
"""AST 优化器 - 表达式拍平。
本模块实现将嵌套的窗口函数表达式自动提取为中间临时因子,
解决多维窗口函数over嵌套导致计算为空的问题。
核心思想:
通过 AST 变换,将嵌套在窗口函数内的窗口函数表达式提取出来,
作为独立的临时因子先行计算,然后主表达式引用这些临时因子。
示例:
原始表达式: cs_rank(ts_delay(close, 1))
拍平后:
- 临时因子: __tmp_0 = ts_delay(close, 1)
- 主表达式: cs_rank(__tmp_0)
"""
from typing import Dict, Tuple
from src.factors.dsl import (
BinaryOpNode,
Constant,
FunctionNode,
Node,
Symbol,
UnaryOpNode,
)
class ExpressionFlattener:
"""表达式拍平器。
遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。
Attributes:
_counter: 临时因子名称计数器,用于生成唯一名称
_extracted_nodes: 存储已提取的临时因子字典
"""
def __init__(self) -> None:
"""初始化拍平器。"""
self._counter: int = 0
self._extracted_nodes: Dict[str, Node] = {}
def _generate_temp_name(self) -> str:
"""生成唯一的临时因子名称。
Returns:
格式为 "__tmp_X" 的临时名称,其中 X 是递增数字
"""
name = f"__tmp_{self._counter}"
self._counter += 1
return name
def _is_window_function(self, func_name: str) -> bool:
"""判断是否为窗口函数。
窗口函数以 "ts_"(时序)或 "cs_"(截面)开头。
Args:
func_name: 函数名称
Returns:
是否是窗口函数
"""
return func_name.startswith("ts_") or func_name.startswith("cs_")
def flatten(self, node: Node) -> Tuple[Node, Dict[str, Node]]:
"""拍平表达式。
遍历 AST将嵌套的窗口函数提取为临时因子。
Args:
node: 原始表达式根节点
Returns:
Tuple[拍平后的主表达式节点, 临时因子字典]
临时因子字典: {临时名称 -> 被提取的节点}
Example:
>>> flattener = ExpressionFlattener()
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
>>> flat_expr, tmp_factors = flattener.flatten(expr)
>>> # flat_expr = cs_rank(__tmp_0)
>>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)}
"""
# 重置状态
self._counter = 0
self._extracted_nodes = {}
# 从根节点开始遍历,初始状态为不在窗口函数内部
new_node = self._flatten_recursive(node, inside_window=False)
return new_node, self._extracted_nodes.copy()
def _flatten_recursive(self, node: Node, inside_window: bool) -> Node:
"""递归拍平节点。
Args:
node: 当前处理的节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点(可能是原节点或替换为 Symbol
"""
# Symbol 和 Constant 是叶子节点,直接返回
if isinstance(node, Symbol):
return node
if isinstance(node, Constant):
return node
# 处理二元运算节点
if isinstance(node, BinaryOpNode):
return self._flatten_binary_op(node, inside_window)
# 处理一元运算节点
if isinstance(node, UnaryOpNode):
return self._flatten_unary_op(node, inside_window)
# 处理函数调用节点
if isinstance(node, FunctionNode):
return self._flatten_function(node, inside_window)
# 未知节点类型,直接返回
return node
def _flatten_binary_op(self, node: BinaryOpNode, inside_window: bool) -> Node:
"""拍平二元运算节点。
Args:
node: 二元运算节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
# 递归处理左右子节点
new_left = self._flatten_recursive(node.left, inside_window)
new_right = self._flatten_recursive(node.right, inside_window)
# 如果子节点没有变化,返回原节点
if new_left is node.left and new_right is node.right:
return node
# 创建新的二元运算节点
return BinaryOpNode(node.op, new_left, new_right)
def _flatten_unary_op(self, node: UnaryOpNode, inside_window: bool) -> Node:
"""拍平一元运算节点。
Args:
node: 一元运算节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
# 递归处理操作数
new_operand = self._flatten_recursive(node.operand, inside_window)
# 如果操作数没有变化,返回原节点
if new_operand is node.operand:
return node
# 创建新的一元运算节点
return UnaryOpNode(node.op, new_operand)
def _flatten_function(self, node: FunctionNode, inside_window: bool) -> Node:
"""拍平函数调用节点。
修正为后序遍历Bottom-Up先递归拍平参数再决定是否提取当前节点。
确保深层嵌套(如 3层以上也能被彻底逐层拆解。
Args:
node: 函数调用节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
is_window = self._is_window_function(node.func_name)
next_inside_window = inside_window or is_window
# 1. 优先递归处理所有参数
new_args = []
has_change = False
for arg in node.args:
new_arg = self._flatten_recursive(arg, next_inside_window)
new_args.append(new_arg)
if new_arg is not arg:
has_change = True
# 2. 只有当参数发生变化时,才创建新的当前节点
current_node = FunctionNode(node.func_name, *new_args) if has_change else node
# 3. 判断是否需要提取(此时子节点肯定已经被彻底拍平了)
if inside_window and is_window:
temp_name = self._generate_temp_name()
self._extracted_nodes[temp_name] = current_node
return Symbol(temp_name)
return current_node
def flatten_expression(node: Node) -> Tuple[Node, Dict[str, Node]]:
"""便捷函数 - 拍平表达式。
Args:
node: 表达式树的根节点
Returns:
Tuple[拍平后的主表达式节点, 临时因子字典]
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
>>> flat_expr, tmp_factors = flatten_expression(expr)
"""
flattener = ExpressionFlattener()
return flattener.flatten(node)

View File

@@ -30,6 +30,7 @@ from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.ast_optimizer import ExpressionFlattener
class FactorEngine:
@@ -92,13 +93,68 @@ class FactorEngine:
self._metadata = FactorManager()
def _register_internal(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""内部注册方法,直接注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
known_factors = set(self.registered_expressions.keys())
self.registered_expressions[name] = expression
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
ignore_dependencies=known_factors,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
if not plan.data_specs and factor_deps:
merged_specs: List[DataSpec] = []
for dep_name in factor_deps:
if dep_name in self._plans:
merged_specs.extend(self._plans[dep_name].data_specs)
# 去重(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in merged_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
plan.data_specs = unique_specs
self._plans[name] = plan
return self
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子表达式。
"""注册因子表达式(自动处理嵌套窗口函数)
Args:
name: 因子名称
@@ -113,22 +169,16 @@ class FactorEngine:
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
# 使用 AST 优化器拍平嵌套窗口函数
flattener = ExpressionFlattener()
flat_expression, tmp_factors = flattener.flatten(expression)
self.registered_expressions[name] = expression
# 先注册所有临时因子(自动推导数据规格)
for tmp_name, tmp_node in tmp_factors.items():
self._register_internal(tmp_name, tmp_node, data_specs=None)
# 预创建执行计划
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
self._plans[name] = plan
# 最后注册主因子
self._register_internal(name, flat_expression, data_specs)
return self
@@ -174,7 +224,7 @@ class FactorEngine:
# 解析表达式为 Node
node = self._parser.parse(dsl_expr)
# 委托给 register 方法
# 委托给 register 方法register 会处理嵌套窗口函数拍平)
return self.register(name, node, data_specs)
def add_factor(
@@ -272,21 +322,32 @@ class FactorEngine:
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 获取执行计划
# 1. 收集所有需要的因子(包括临时因子依赖)
all_factor_names = self._collect_all_dependencies(factor_names)
# 2. 获取执行计划
plans = []
for name in factor_names:
for name in all_factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 2. 合并数据规格并获取数据
# 3. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表
# 去重数据规格(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in all_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
# 4. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=all_specs,
data_specs=unique_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
@@ -295,14 +356,14 @@ class FactorEngine:
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 按依赖顺序执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
# 使用依赖感知的方式执行
result = self._execute_with_dependencies(factor_names, core_data)
# 5. 按依赖顺序执行计算(包含临时因子)
result = self._execute_with_dependencies(all_factor_names, core_data)
return result
# 6. 清理内存宽表过滤掉临时因子列__tmp_X
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
return result.select(cols_to_keep)
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
@@ -501,10 +562,32 @@ class FactorEngine:
return False
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子。
def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
"""收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
collected: Set[str] = set()
result: List[str] = []
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
def collect_recursive(name: str):
if name in collected:
return
collected.add(name)
# 获取执行计划并递归收集强依赖
plan = self._plans.get(name)
if plan:
for dep_name in plan.factor_dependencies:
collect_recursive(dep_name)
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
result.append(name)
for name in factor_names:
collect_recursive(name)
return result
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
Args:
expression: 待检查的表达式
@@ -514,13 +597,20 @@ class FactorEngine:
"""
deps: Set[str] = set()
# 检查表达式本身是否等于某个已注册因子
# 1. 【新增】如果直接引用了已注册因子名称(包含 __tmp_X 或用户因子)
if (
isinstance(expression, Symbol)
and expression.name in self.registered_expressions
):
deps.add(expression.name)
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 递归检查子节点
# 3. 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))

View File

@@ -39,6 +39,7 @@ class ExecutionPlanner:
expression: Node,
output_name: str = "factor",
data_specs: Optional[List[DataSpec]] = None,
ignore_dependencies: Optional[Set[str]] = None,
) -> ExecutionPlan:
"""从表达式创建执行计划。
@@ -46,12 +47,15 @@ class ExecutionPlanner:
expression: DSL 表达式节点
output_name: 输出因子名称
data_specs: 预定义的数据规格None 时自动推导
ignore_dependencies: 需要忽略的依赖符号集合(如已注册因子名)
Returns:
执行计划对象
"""
# 1. 提取依赖
dependencies = self.compiler.extract_dependencies(expression)
# 1. 提取依赖时传入要忽略的符号
dependencies = self.compiler.extract_dependencies(
expression, ignore_symbols=ignore_dependencies
)
# 2. 翻译为 Polars 表达式
polars_expr = self.translator.translate(expression)