feat(factors): 实现 AST 拍平优化支持嵌套窗口函数

- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
This commit is contained in:
2026-03-14 01:06:17 +08:00
parent 282fe1fef5
commit c8808d07eb
5 changed files with 742 additions and 43 deletions

View File

@@ -3,7 +3,7 @@
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。 本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
""" """
from typing import Set from typing import Set, Optional
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
@@ -24,9 +24,14 @@ class DependencyExtractor:
{'close', 'pe_ratio'} {'close', 'pe_ratio'}
""" """
def __init__(self) -> None: def __init__(self, ignore_symbols: Optional[Set[str]] = None) -> None:
"""初始化依赖提取器。""" """初始化依赖提取器。
Args:
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
"""
self.dependencies: Set[str] = set() self.dependencies: Set[str] = set()
self.ignore_symbols: Set[str] = ignore_symbols or set()
def visit(self, node: Node) -> None: def visit(self, node: Node) -> None:
"""访问节点,根据节点类型分发到具体处理方法。 """访问节点,根据节点类型分发到具体处理方法。
@@ -47,10 +52,14 @@ class DependencyExtractor:
def _visit_symbol(self, node: Symbol) -> None: def _visit_symbol(self, node: Symbol) -> None:
"""访问 Symbol 节点,提取符号名称。 """访问 Symbol 节点,提取符号名称。
排除临时因子(以 __tmp_ 开头的符号)和已在免疫名单中的因子。
Args: Args:
node: 符号节点 node: 符号节点
""" """
self.dependencies.add(node.name) # 排除临时因子引用 和 已在免疫名单中的因子
if not node.name.startswith("__tmp_") and node.name not in self.ignore_symbols:
self.dependencies.add(node.name)
def _visit_binary_op(self, node: BinaryOpNode) -> None: def _visit_binary_op(self, node: BinaryOpNode) -> None:
"""访问 BinaryOpNode 节点,递归遍历左右子节点。 """访问 BinaryOpNode 节点,递归遍历左右子节点。
@@ -92,13 +101,16 @@ class DependencyExtractor:
return self.dependencies.copy() return self.dependencies.copy()
@classmethod @classmethod
def extract_dependencies(cls, node: Node) -> Set[str]: def extract_dependencies(
cls, node: Node, ignore_symbols: Optional[Set[str]] = None
) -> Set[str]:
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。 """类方法 - 从 AST 节点中提取所有依赖的符号名称。
这是一个便捷方法,无需手动实例化 DependencyExtractor。 这是一个便捷方法,无需手动实例化 DependencyExtractor。
Args: Args:
node: 表达式树的根节点 node: 表达式树的根节点
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
Returns: Returns:
依赖的符号名称集合 依赖的符号名称集合
@@ -112,17 +124,20 @@ class DependencyExtractor:
>>> print(deps) >>> print(deps)
{'close', 'open'} {'close', 'open'}
""" """
extractor = cls() extractor = cls(ignore_symbols=ignore_symbols)
return extractor.extract(node) return extractor.extract(node)
def extract_dependencies(node: Node) -> Set[str]: def extract_dependencies(
node: Node, ignore_symbols: Optional[Set[str]] = None
) -> Set[str]:
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。 """单例方法 - 从 AST 节点中提取所有依赖的符号名称。
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。 这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
Args: Args:
node: 表达式树的根节点 node: 表达式树的根节点
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
Returns: Returns:
依赖的符号名称集合 依赖的符号名称集合
@@ -136,7 +151,7 @@ def extract_dependencies(node: Node) -> Set[str]:
>>> print(deps) >>> print(deps)
{'close', 'pe_ratio'} {'close', 'pe_ratio'}
""" """
return DependencyExtractor.extract_dependencies(node) return DependencyExtractor.extract_dependencies(node, ignore_symbols=ignore_symbols)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,223 @@
"""AST 优化器 - 表达式拍平。
本模块实现将嵌套的窗口函数表达式自动提取为中间临时因子,
解决多维窗口函数over嵌套导致计算为空的问题。
核心思想:
通过 AST 变换,将嵌套在窗口函数内的窗口函数表达式提取出来,
作为独立的临时因子先行计算,然后主表达式引用这些临时因子。
示例:
原始表达式: cs_rank(ts_delay(close, 1))
拍平后:
- 临时因子: __tmp_0 = ts_delay(close, 1)
- 主表达式: cs_rank(__tmp_0)
"""
from typing import Dict, Tuple
from src.factors.dsl import (
BinaryOpNode,
Constant,
FunctionNode,
Node,
Symbol,
UnaryOpNode,
)
class ExpressionFlattener:
"""表达式拍平器。
遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。
Attributes:
_counter: 临时因子名称计数器,用于生成唯一名称
_extracted_nodes: 存储已提取的临时因子字典
"""
def __init__(self) -> None:
"""初始化拍平器。"""
self._counter: int = 0
self._extracted_nodes: Dict[str, Node] = {}
def _generate_temp_name(self) -> str:
"""生成唯一的临时因子名称。
Returns:
格式为 "__tmp_X" 的临时名称,其中 X 是递增数字
"""
name = f"__tmp_{self._counter}"
self._counter += 1
return name
def _is_window_function(self, func_name: str) -> bool:
"""判断是否为窗口函数。
窗口函数以 "ts_"(时序)或 "cs_"(截面)开头。
Args:
func_name: 函数名称
Returns:
是否是窗口函数
"""
return func_name.startswith("ts_") or func_name.startswith("cs_")
def flatten(self, node: Node) -> Tuple[Node, Dict[str, Node]]:
"""拍平表达式。
遍历 AST将嵌套的窗口函数提取为临时因子。
Args:
node: 原始表达式根节点
Returns:
Tuple[拍平后的主表达式节点, 临时因子字典]
临时因子字典: {临时名称 -> 被提取的节点}
Example:
>>> flattener = ExpressionFlattener()
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
>>> flat_expr, tmp_factors = flattener.flatten(expr)
>>> # flat_expr = cs_rank(__tmp_0)
>>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)}
"""
# 重置状态
self._counter = 0
self._extracted_nodes = {}
# 从根节点开始遍历,初始状态为不在窗口函数内部
new_node = self._flatten_recursive(node, inside_window=False)
return new_node, self._extracted_nodes.copy()
def _flatten_recursive(self, node: Node, inside_window: bool) -> Node:
"""递归拍平节点。
Args:
node: 当前处理的节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点(可能是原节点或替换为 Symbol
"""
# Symbol 和 Constant 是叶子节点,直接返回
if isinstance(node, Symbol):
return node
if isinstance(node, Constant):
return node
# 处理二元运算节点
if isinstance(node, BinaryOpNode):
return self._flatten_binary_op(node, inside_window)
# 处理一元运算节点
if isinstance(node, UnaryOpNode):
return self._flatten_unary_op(node, inside_window)
# 处理函数调用节点
if isinstance(node, FunctionNode):
return self._flatten_function(node, inside_window)
# 未知节点类型,直接返回
return node
def _flatten_binary_op(self, node: BinaryOpNode, inside_window: bool) -> Node:
"""拍平二元运算节点。
Args:
node: 二元运算节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
# 递归处理左右子节点
new_left = self._flatten_recursive(node.left, inside_window)
new_right = self._flatten_recursive(node.right, inside_window)
# 如果子节点没有变化,返回原节点
if new_left is node.left and new_right is node.right:
return node
# 创建新的二元运算节点
return BinaryOpNode(node.op, new_left, new_right)
def _flatten_unary_op(self, node: UnaryOpNode, inside_window: bool) -> Node:
"""拍平一元运算节点。
Args:
node: 一元运算节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
# 递归处理操作数
new_operand = self._flatten_recursive(node.operand, inside_window)
# 如果操作数没有变化,返回原节点
if new_operand is node.operand:
return node
# 创建新的一元运算节点
return UnaryOpNode(node.op, new_operand)
def _flatten_function(self, node: FunctionNode, inside_window: bool) -> Node:
"""拍平函数调用节点。
修正为后序遍历Bottom-Up先递归拍平参数再决定是否提取当前节点。
确保深层嵌套(如 3层以上也能被彻底逐层拆解。
Args:
node: 函数调用节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
is_window = self._is_window_function(node.func_name)
next_inside_window = inside_window or is_window
# 1. 优先递归处理所有参数
new_args = []
has_change = False
for arg in node.args:
new_arg = self._flatten_recursive(arg, next_inside_window)
new_args.append(new_arg)
if new_arg is not arg:
has_change = True
# 2. 只有当参数发生变化时,才创建新的当前节点
current_node = FunctionNode(node.func_name, *new_args) if has_change else node
# 3. 判断是否需要提取(此时子节点肯定已经被彻底拍平了)
if inside_window and is_window:
temp_name = self._generate_temp_name()
self._extracted_nodes[temp_name] = current_node
return Symbol(temp_name)
return current_node
def flatten_expression(node: Node) -> Tuple[Node, Dict[str, Node]]:
"""便捷函数 - 拍平表达式。
Args:
node: 表达式树的根节点
Returns:
Tuple[拍平后的主表达式节点, 临时因子字典]
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
>>> flat_expr, tmp_factors = flatten_expression(expr)
"""
flattener = ExpressionFlattener()
return flattener.flatten(node)

View File

@@ -30,6 +30,7 @@ from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.ast_optimizer import ExpressionFlattener
class FactorEngine: class FactorEngine:
@@ -92,13 +93,68 @@ class FactorEngine:
self._metadata = FactorManager() self._metadata = FactorManager()
def _register_internal(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""内部注册方法,直接注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
known_factors = set(self.registered_expressions.keys())
self.registered_expressions[name] = expression
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
ignore_dependencies=known_factors,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
if not plan.data_specs and factor_deps:
merged_specs: List[DataSpec] = []
for dep_name in factor_deps:
if dep_name in self._plans:
merged_specs.extend(self._plans[dep_name].data_specs)
# 去重(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in merged_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
plan.data_specs = unique_specs
self._plans[name] = plan
return self
def register( def register(
self, self,
name: str, name: str,
expression: Node, expression: Node,
data_specs: Optional[List[DataSpec]] = None, data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine": ) -> "FactorEngine":
"""注册因子表达式。 """注册因子表达式(自动处理嵌套窗口函数)
Args: Args:
name: 因子名称 name: 因子名称
@@ -113,22 +169,16 @@ class FactorEngine:
>>> engine = FactorEngine() >>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20)) >>> engine.register("ma20", ts_mean(close, 20))
""" """
# 检测因子依赖(在注册当前因子之前检查其他已注册因子) # 使用 AST 优化器拍平嵌套窗口函数
factor_deps = self._find_factor_dependencies(expression) flattener = ExpressionFlattener()
flat_expression, tmp_factors = flattener.flatten(expression)
self.registered_expressions[name] = expression # 先注册所有临时因子(自动推导数据规格)
for tmp_name, tmp_node in tmp_factors.items():
self._register_internal(tmp_name, tmp_node, data_specs=None)
# 预创建执行计划 # 最后注册主因子
plan = self.planner.create_plan( self._register_internal(name, flat_expression, data_specs)
expression=expression,
output_name=name,
data_specs=data_specs,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
self._plans[name] = plan
return self return self
@@ -174,7 +224,7 @@ class FactorEngine:
# 解析表达式为 Node # 解析表达式为 Node
node = self._parser.parse(dsl_expr) node = self._parser.parse(dsl_expr)
# 委托给 register 方法 # 委托给 register 方法register 会处理嵌套窗口函数拍平)
return self.register(name, node, data_specs) return self.register(name, node, data_specs)
def add_factor( def add_factor(
@@ -272,21 +322,32 @@ class FactorEngine:
if isinstance(factor_names, str): if isinstance(factor_names, str):
factor_names = [factor_names] factor_names = [factor_names]
# 1. 获取执行计划 # 1. 收集所有需要的因子(包括临时因子依赖)
all_factor_names = self._collect_all_dependencies(factor_names)
# 2. 获取执行计划
plans = [] plans = []
for name in factor_names: for name in all_factor_names:
if name not in self._plans: if name not in self._plans:
raise ValueError(f"因子未注册: {name}") raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name]) plans.append(self._plans[name])
# 2. 合并数据规格并获取数据 # 3. 合并数据规格并获取数据
all_specs = [] all_specs = []
for plan in plans: for plan in plans:
all_specs.extend(plan.data_specs) all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表 # 去重数据规格(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in all_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
# 4. 从路由器获取核心宽表
core_data = self.router.fetch_data( core_data = self.router.fetch_data(
data_specs=all_specs, data_specs=unique_specs,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
stock_codes=stock_codes, stock_codes=stock_codes,
@@ -295,14 +356,14 @@ class FactorEngine:
if len(core_data) == 0: if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码") raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 按依赖顺序执行计算 # 5. 按依赖顺序执行计算(包含临时因子)
if len(plans) == 1: result = self._execute_with_dependencies(all_factor_names, core_data)
result = self.compute_engine.execute(plans[0], core_data)
else:
# 使用依赖感知的方式执行
result = self._execute_with_dependencies(factor_names, core_data)
return result # 6. 清理内存宽表过滤掉临时因子列__tmp_X
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
return result.select(cols_to_keep)
def list_registered(self) -> List[str]: def list_registered(self) -> List[str]:
"""获取已注册的因子列表。 """获取已注册的因子列表。
@@ -501,10 +562,32 @@ class FactorEngine:
return False return False
def _find_factor_dependencies(self, expression: Node) -> Set[str]: def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
"""查找表达式依赖的其他因子。 """收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
collected: Set[str] = set()
result: List[str] = []
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。 def collect_recursive(name: str):
if name in collected:
return
collected.add(name)
# 获取执行计划并递归收集强依赖
plan = self._plans.get(name)
if plan:
for dep_name in plan.factor_dependencies:
collect_recursive(dep_name)
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
result.append(name)
for name in factor_names:
collect_recursive(name)
return result
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
Args: Args:
expression: 待检查的表达式 expression: 待检查的表达式
@@ -514,13 +597,20 @@ class FactorEngine:
""" """
deps: Set[str] = set() deps: Set[str] = set()
# 检查表达式本身是否等于某个已注册因子 # 1. 【新增】如果直接引用了已注册因子名称(包含 __tmp_X 或用户因子)
if (
isinstance(expression, Symbol)
and expression.name in self.registered_expressions
):
deps.add(expression.name)
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
for name, registered_expr in self.registered_expressions.items(): for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr): if self._expressions_equal(expression, registered_expr):
deps.add(name) deps.add(name)
break break
# 递归检查子节点 # 3. 递归检查子节点
if isinstance(expression, BinaryOpNode): if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left)) deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right)) deps.update(self._find_factor_dependencies(expression.right))

View File

@@ -39,6 +39,7 @@ class ExecutionPlanner:
expression: Node, expression: Node,
output_name: str = "factor", output_name: str = "factor",
data_specs: Optional[List[DataSpec]] = None, data_specs: Optional[List[DataSpec]] = None,
ignore_dependencies: Optional[Set[str]] = None,
) -> ExecutionPlan: ) -> ExecutionPlan:
"""从表达式创建执行计划。 """从表达式创建执行计划。
@@ -46,12 +47,15 @@ class ExecutionPlanner:
expression: DSL 表达式节点 expression: DSL 表达式节点
output_name: 输出因子名称 output_name: 输出因子名称
data_specs: 预定义的数据规格None 时自动推导 data_specs: 预定义的数据规格None 时自动推导
ignore_dependencies: 需要忽略的依赖符号集合(如已注册因子名)
Returns: Returns:
执行计划对象 执行计划对象
""" """
# 1. 提取依赖 # 1. 提取依赖时传入要忽略的符号
dependencies = self.compiler.extract_dependencies(expression) dependencies = self.compiler.extract_dependencies(
expression, ignore_symbols=ignore_dependencies
)
# 2. 翻译为 Polars 表达式 # 2. 翻译为 Polars 表达式
polars_expr = self.translator.translate(expression) polars_expr = self.translator.translate(expression)

367
tests/test_ast_optimizer.py Normal file
View File

@@ -0,0 +1,367 @@
"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。
测试因子: cs_rank(ts_delay(close, 1))
这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine
from src.factors.api import close, ts_delay, cs_rank
from src.factors.dsl import FunctionNode
from src.factors.engine.ast_optimizer import ExpressionFlattener
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
}
)
return pl.DataFrame(rows)
class TestASTOptimizer:
"""AST 优化器测试类。"""
def test_flattener_basic(self):
"""测试拍平器基本功能。"""
from src.factors.api import close
flattener = ExpressionFlattener()
# 创建嵌套表达式: cs_rank(ts_delay(close, 1))
expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证临时因子被提取
assert len(tmp_factors) == 1
assert "__tmp_0" in tmp_factors
# 验证主表达式使用了 Symbol 引用
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
# 验证第一个参数是临时因子引用(通过 name 属性检查)
assert hasattr(flat_expr.args[0], "name")
assert flat_expr.args[0].name == "__tmp_0"
# 验证临时因子内容
tmp_node = tmp_factors["__tmp_0"]
assert isinstance(tmp_node, FunctionNode)
assert tmp_node.func_name == "ts_delay"
print("[PASS] 拍平器基本功能测试")
def test_flattener_no_nested(self):
"""测试非嵌套表达式不会被拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 非嵌套表达式: ts_mean(close, 20)
expr = FunctionNode("ts_mean", close, 20)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证没有临时因子被提取
assert len(tmp_factors) == 0
# 验证表达式保持不变
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "ts_mean"
print("[PASS] 非嵌套表达式测试")
def test_flattener_deeply_nested(self):
"""测试多层嵌套表达式拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5))
expr = FunctionNode(
"cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5)
)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证提取了两个临时因子(修复后正确行为)
# ts_delay(close, 1) 被提取为 __tmp_0
# ts_mean(__tmp_0, 5) 被提取为 __tmp_1
assert len(tmp_factors) == 2
assert "__tmp_0" in tmp_factors
assert "__tmp_1" in tmp_factors
# 验证 __tmp_0 内容是 ts_delay(close, 1)
tmp0_node = tmp_factors["__tmp_0"]
assert isinstance(tmp0_node, FunctionNode)
assert tmp0_node.func_name == "ts_delay"
# 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5)
tmp1_node = tmp_factors["__tmp_1"]
assert isinstance(tmp1_node, FunctionNode)
assert tmp1_node.func_name == "ts_mean"
from src.factors.dsl import Symbol
assert isinstance(tmp1_node.args[0], Symbol)
assert tmp1_node.args[0].name == "__tmp_0"
# 验证主表达式引用 __tmp_1
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
assert isinstance(flat_expr.args[0], Symbol)
assert flat_expr.args[0].name == "__tmp_1"
print("[PASS] 多层嵌套表达式拍平测试")
def test_nested_window_function_engine(self):
"""测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。"""
print("\n" + "=" * 60)
print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 初始化引擎
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("引擎初始化完成")
# 3. 使用字符串表达式注册嵌套窗口函数
print("\n注册因子: cs_rank(ts_delay(close, 1))")
engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))")
# 4. 检查临时因子是否被创建
registered_factors = engine.list_registered()
print(f"已注册因子: {registered_factors}")
# 验证有临时因子被创建
tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")]
assert len(tmp_factors) >= 1, "应该有临时因子被创建"
print(f"临时因子: {tmp_factors}")
# 5. 执行计算
print("\n执行计算...")
result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 6. 验证结果
assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列"
# 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null
non_null_values = result["delayed_rank"].drop_nulls()
if len(non_null_values) > 0:
assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间"
assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间"
# 检查没有过多空值(考虑到开头的滞后期)
null_count = result["delayed_rank"].is_null().sum()
print(f"空值数量: {null_count}")
# 展示部分结果
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("嵌套窗口函数测试通过!")
print("=" * 60)
def test_multiple_nested_factors(self):
"""测试同时注册多个嵌套因子。"""
print("\n" + "=" * 60)
print("测试多个嵌套因子")
print("=" * 60)
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 注册多个嵌套因子(使用字符串表达式)
print("\n注册因子1: cs_rank(ts_delay(close, 1))")
engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))")
print("注册因子2: ts_mean(cs_rank(close), 5)")
engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)")
# 检查已注册因子
factors = engine.list_registered()
print(f"\n已注册因子: {factors}")
# 计算所有因子
result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131")
assert "rank1" in result.columns
assert "rank_mean" in result.columns
print(f"\n结果行数: {len(result)}")
print(f"rank1 空值数: {result['rank1'].is_null().sum()}")
print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}")
print("\n" + "=" * 60)
print("多个嵌套因子测试通过!")
print("=" * 60)
def test_nested_vs_native_polars(self):
"""对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。"""
print("\n" + "=" * 60)
print("对比测试cs_rank(ts_delay(close, 1)) vs 原生 Polars")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 使用 FactorEngine 计算嵌套因子
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...")
engine.register("delayed_rank", cs_rank(ts_delay(close, 1)))
engine_result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"FactorEngine 结果: {len(engine_result)}")
# 3. 使用原生 Polars 计算(手动分步)
print("\n使用原生 Polars 手动计算...")
# 先计算 ts_delay(close, 1)
native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns(
[pl.col("close").shift(1).over("ts_code").alias("delayed_close")]
)
# 再计算 cs_rank
native_result = native_result.with_columns(
[
(pl.col("delayed_close").rank() / pl.col("delayed_close").count())
.over("trade_date")
.alias("native_delayed_rank")
]
)
print(f"原生 Polars 结果: {len(native_result)}")
# 4. 合并结果进行对比
comparison = engine_result.join(
native_result.select(["ts_code", "trade_date", "native_delayed_rank"]),
on=["ts_code", "trade_date"],
how="inner",
)
# 5. 验证数值一致性(允许微小浮点误差)
diff = comparison.with_columns(
[
(pl.col("delayed_rank") - pl.col("native_delayed_rank"))
.abs()
.alias("diff")
]
)
max_diff = diff["diff"].max()
print(f"\n最大差异: {max_diff}")
# 过滤掉空值后比较(开头的滞后期会有空值)
non_null_diff = diff.filter(pl.col("diff").is_not_null())
assert non_null_diff["diff"].max() < 1e-10, (
f"数值差异过大: {non_null_diff['diff'].max()}"
)
print("\n" + "=" * 60)
print("数值一致性验证通过!")
print("=" * 60)
def test_factor_reference_factor(self):
"""测试因子引用另一个因子fac2 = cs_rank(fac1)。"""
print("\n" + "=" * 60)
print("测试因子引用其他因子: fac2 = cs_rank(fac1)")
print("=" * 60)
# 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 1. 注册基础因子 fac1
print("\n注册基础因子 fac1 = ts_mean(close, 5)")
from src.factors.api import ts_mean
engine.register("fac1", ts_mean(close, 5))
# 2. 注册引用因子 fac2引用 fac1
print("注册引用因子 fac2 = cs_rank(fac1)")
engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子
# 3. 验证依赖关系
registered = engine.list_registered()
print(f"\n已注册因子: {registered}")
assert "fac1" in registered
assert "fac2" in registered
# 4. 执行计算
print("\n执行计算...")
result = engine.compute(["fac1", "fac2"], "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 5. 验证结果
assert "fac1" in result.columns, "结果中应有 fac1 列"
assert "fac2" in result.columns, "结果中应有 fac2 列"
# fac2 是排名,应在 [0, 1] 之间
assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间"
assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间"
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("因子引用功能测试通过!")
print("=" * 60)
if __name__ == "__main__":
test = TestASTOptimizer()
test.test_flattener_basic()
test.test_flattener_no_nested()
test.test_flattener_deeply_nested()
test.test_nested_window_function_engine()
test.test_multiple_nested_factors()
test.test_nested_vs_native_polars()
test.test_factor_reference_factor()
print("\n所有测试通过!")