feat(factors): 实现 AST 拍平优化支持嵌套窗口函数

- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
This commit is contained in:
2026-03-14 01:06:17 +08:00
parent 282fe1fef5
commit c8808d07eb
5 changed files with 742 additions and 43 deletions

367
tests/test_ast_optimizer.py Normal file
View File

@@ -0,0 +1,367 @@
"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。
测试因子: cs_rank(ts_delay(close, 1))
这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine
from src.factors.api import close, ts_delay, cs_rank
from src.factors.dsl import FunctionNode
from src.factors.engine.ast_optimizer import ExpressionFlattener
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
}
)
return pl.DataFrame(rows)
class TestASTOptimizer:
"""AST 优化器测试类。"""
def test_flattener_basic(self):
"""测试拍平器基本功能。"""
from src.factors.api import close
flattener = ExpressionFlattener()
# 创建嵌套表达式: cs_rank(ts_delay(close, 1))
expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证临时因子被提取
assert len(tmp_factors) == 1
assert "__tmp_0" in tmp_factors
# 验证主表达式使用了 Symbol 引用
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
# 验证第一个参数是临时因子引用(通过 name 属性检查)
assert hasattr(flat_expr.args[0], "name")
assert flat_expr.args[0].name == "__tmp_0"
# 验证临时因子内容
tmp_node = tmp_factors["__tmp_0"]
assert isinstance(tmp_node, FunctionNode)
assert tmp_node.func_name == "ts_delay"
print("[PASS] 拍平器基本功能测试")
def test_flattener_no_nested(self):
"""测试非嵌套表达式不会被拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 非嵌套表达式: ts_mean(close, 20)
expr = FunctionNode("ts_mean", close, 20)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证没有临时因子被提取
assert len(tmp_factors) == 0
# 验证表达式保持不变
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "ts_mean"
print("[PASS] 非嵌套表达式测试")
def test_flattener_deeply_nested(self):
"""测试多层嵌套表达式拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5))
expr = FunctionNode(
"cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5)
)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证提取了两个临时因子(修复后正确行为)
# ts_delay(close, 1) 被提取为 __tmp_0
# ts_mean(__tmp_0, 5) 被提取为 __tmp_1
assert len(tmp_factors) == 2
assert "__tmp_0" in tmp_factors
assert "__tmp_1" in tmp_factors
# 验证 __tmp_0 内容是 ts_delay(close, 1)
tmp0_node = tmp_factors["__tmp_0"]
assert isinstance(tmp0_node, FunctionNode)
assert tmp0_node.func_name == "ts_delay"
# 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5)
tmp1_node = tmp_factors["__tmp_1"]
assert isinstance(tmp1_node, FunctionNode)
assert tmp1_node.func_name == "ts_mean"
from src.factors.dsl import Symbol
assert isinstance(tmp1_node.args[0], Symbol)
assert tmp1_node.args[0].name == "__tmp_0"
# 验证主表达式引用 __tmp_1
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
assert isinstance(flat_expr.args[0], Symbol)
assert flat_expr.args[0].name == "__tmp_1"
print("[PASS] 多层嵌套表达式拍平测试")
def test_nested_window_function_engine(self):
"""测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。"""
print("\n" + "=" * 60)
print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 初始化引擎
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("引擎初始化完成")
# 3. 使用字符串表达式注册嵌套窗口函数
print("\n注册因子: cs_rank(ts_delay(close, 1))")
engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))")
# 4. 检查临时因子是否被创建
registered_factors = engine.list_registered()
print(f"已注册因子: {registered_factors}")
# 验证有临时因子被创建
tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")]
assert len(tmp_factors) >= 1, "应该有临时因子被创建"
print(f"临时因子: {tmp_factors}")
# 5. 执行计算
print("\n执行计算...")
result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 6. 验证结果
assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列"
# 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null
non_null_values = result["delayed_rank"].drop_nulls()
if len(non_null_values) > 0:
assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间"
assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间"
# 检查没有过多空值(考虑到开头的滞后期)
null_count = result["delayed_rank"].is_null().sum()
print(f"空值数量: {null_count}")
# 展示部分结果
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("嵌套窗口函数测试通过!")
print("=" * 60)
def test_multiple_nested_factors(self):
"""测试同时注册多个嵌套因子。"""
print("\n" + "=" * 60)
print("测试多个嵌套因子")
print("=" * 60)
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 注册多个嵌套因子(使用字符串表达式)
print("\n注册因子1: cs_rank(ts_delay(close, 1))")
engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))")
print("注册因子2: ts_mean(cs_rank(close), 5)")
engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)")
# 检查已注册因子
factors = engine.list_registered()
print(f"\n已注册因子: {factors}")
# 计算所有因子
result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131")
assert "rank1" in result.columns
assert "rank_mean" in result.columns
print(f"\n结果行数: {len(result)}")
print(f"rank1 空值数: {result['rank1'].is_null().sum()}")
print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}")
print("\n" + "=" * 60)
print("多个嵌套因子测试通过!")
print("=" * 60)
def test_nested_vs_native_polars(self):
"""对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。"""
print("\n" + "=" * 60)
print("对比测试cs_rank(ts_delay(close, 1)) vs 原生 Polars")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 使用 FactorEngine 计算嵌套因子
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...")
engine.register("delayed_rank", cs_rank(ts_delay(close, 1)))
engine_result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"FactorEngine 结果: {len(engine_result)}")
# 3. 使用原生 Polars 计算(手动分步)
print("\n使用原生 Polars 手动计算...")
# 先计算 ts_delay(close, 1)
native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns(
[pl.col("close").shift(1).over("ts_code").alias("delayed_close")]
)
# 再计算 cs_rank
native_result = native_result.with_columns(
[
(pl.col("delayed_close").rank() / pl.col("delayed_close").count())
.over("trade_date")
.alias("native_delayed_rank")
]
)
print(f"原生 Polars 结果: {len(native_result)}")
# 4. 合并结果进行对比
comparison = engine_result.join(
native_result.select(["ts_code", "trade_date", "native_delayed_rank"]),
on=["ts_code", "trade_date"],
how="inner",
)
# 5. 验证数值一致性(允许微小浮点误差)
diff = comparison.with_columns(
[
(pl.col("delayed_rank") - pl.col("native_delayed_rank"))
.abs()
.alias("diff")
]
)
max_diff = diff["diff"].max()
print(f"\n最大差异: {max_diff}")
# 过滤掉空值后比较(开头的滞后期会有空值)
non_null_diff = diff.filter(pl.col("diff").is_not_null())
assert non_null_diff["diff"].max() < 1e-10, (
f"数值差异过大: {non_null_diff['diff'].max()}"
)
print("\n" + "=" * 60)
print("数值一致性验证通过!")
print("=" * 60)
def test_factor_reference_factor(self):
"""测试因子引用另一个因子fac2 = cs_rank(fac1)。"""
print("\n" + "=" * 60)
print("测试因子引用其他因子: fac2 = cs_rank(fac1)")
print("=" * 60)
# 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 1. 注册基础因子 fac1
print("\n注册基础因子 fac1 = ts_mean(close, 5)")
from src.factors.api import ts_mean
engine.register("fac1", ts_mean(close, 5))
# 2. 注册引用因子 fac2引用 fac1
print("注册引用因子 fac2 = cs_rank(fac1)")
engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子
# 3. 验证依赖关系
registered = engine.list_registered()
print(f"\n已注册因子: {registered}")
assert "fac1" in registered
assert "fac2" in registered
# 4. 执行计算
print("\n执行计算...")
result = engine.compute(["fac1", "fac2"], "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 5. 验证结果
assert "fac1" in result.columns, "结果中应有 fac1 列"
assert "fac2" in result.columns, "结果中应有 fac2 列"
# fac2 是排名,应在 [0, 1] 之间
assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间"
assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间"
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("因子引用功能测试通过!")
print("=" * 60)
if __name__ == "__main__":
test = TestASTOptimizer()
test.test_flattener_basic()
test.test_flattener_no_nested()
test.test_flattener_deeply_nested()
test.test_nested_window_function_engine()
test.test_multiple_nested_factors()
test.test_nested_vs_native_polars()
test.test_factor_reference_factor()
print("\n所有测试通过!")