Files
ProStock/tests/test_ast_optimizer.py
liaozhaorun c8808d07eb feat(factors): 实现 AST 拍平优化支持嵌套窗口函数
- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
2026-03-14 01:06:17 +08:00

368 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。
测试因子: cs_rank(ts_delay(close, 1))
这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine
from src.factors.api import close, ts_delay, cs_rank
from src.factors.dsl import FunctionNode
from src.factors.engine.ast_optimizer import ExpressionFlattener
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
}
)
return pl.DataFrame(rows)
class TestASTOptimizer:
"""AST 优化器测试类。"""
def test_flattener_basic(self):
"""测试拍平器基本功能。"""
from src.factors.api import close
flattener = ExpressionFlattener()
# 创建嵌套表达式: cs_rank(ts_delay(close, 1))
expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证临时因子被提取
assert len(tmp_factors) == 1
assert "__tmp_0" in tmp_factors
# 验证主表达式使用了 Symbol 引用
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
# 验证第一个参数是临时因子引用(通过 name 属性检查)
assert hasattr(flat_expr.args[0], "name")
assert flat_expr.args[0].name == "__tmp_0"
# 验证临时因子内容
tmp_node = tmp_factors["__tmp_0"]
assert isinstance(tmp_node, FunctionNode)
assert tmp_node.func_name == "ts_delay"
print("[PASS] 拍平器基本功能测试")
def test_flattener_no_nested(self):
"""测试非嵌套表达式不会被拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 非嵌套表达式: ts_mean(close, 20)
expr = FunctionNode("ts_mean", close, 20)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证没有临时因子被提取
assert len(tmp_factors) == 0
# 验证表达式保持不变
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "ts_mean"
print("[PASS] 非嵌套表达式测试")
def test_flattener_deeply_nested(self):
"""测试多层嵌套表达式拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5))
expr = FunctionNode(
"cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5)
)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证提取了两个临时因子(修复后正确行为)
# ts_delay(close, 1) 被提取为 __tmp_0
# ts_mean(__tmp_0, 5) 被提取为 __tmp_1
assert len(tmp_factors) == 2
assert "__tmp_0" in tmp_factors
assert "__tmp_1" in tmp_factors
# 验证 __tmp_0 内容是 ts_delay(close, 1)
tmp0_node = tmp_factors["__tmp_0"]
assert isinstance(tmp0_node, FunctionNode)
assert tmp0_node.func_name == "ts_delay"
# 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5)
tmp1_node = tmp_factors["__tmp_1"]
assert isinstance(tmp1_node, FunctionNode)
assert tmp1_node.func_name == "ts_mean"
from src.factors.dsl import Symbol
assert isinstance(tmp1_node.args[0], Symbol)
assert tmp1_node.args[0].name == "__tmp_0"
# 验证主表达式引用 __tmp_1
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
assert isinstance(flat_expr.args[0], Symbol)
assert flat_expr.args[0].name == "__tmp_1"
print("[PASS] 多层嵌套表达式拍平测试")
def test_nested_window_function_engine(self):
"""测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。"""
print("\n" + "=" * 60)
print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 初始化引擎
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("引擎初始化完成")
# 3. 使用字符串表达式注册嵌套窗口函数
print("\n注册因子: cs_rank(ts_delay(close, 1))")
engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))")
# 4. 检查临时因子是否被创建
registered_factors = engine.list_registered()
print(f"已注册因子: {registered_factors}")
# 验证有临时因子被创建
tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")]
assert len(tmp_factors) >= 1, "应该有临时因子被创建"
print(f"临时因子: {tmp_factors}")
# 5. 执行计算
print("\n执行计算...")
result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 6. 验证结果
assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列"
# 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null
non_null_values = result["delayed_rank"].drop_nulls()
if len(non_null_values) > 0:
assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间"
assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间"
# 检查没有过多空值(考虑到开头的滞后期)
null_count = result["delayed_rank"].is_null().sum()
print(f"空值数量: {null_count}")
# 展示部分结果
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("嵌套窗口函数测试通过!")
print("=" * 60)
def test_multiple_nested_factors(self):
"""测试同时注册多个嵌套因子。"""
print("\n" + "=" * 60)
print("测试多个嵌套因子")
print("=" * 60)
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 注册多个嵌套因子(使用字符串表达式)
print("\n注册因子1: cs_rank(ts_delay(close, 1))")
engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))")
print("注册因子2: ts_mean(cs_rank(close), 5)")
engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)")
# 检查已注册因子
factors = engine.list_registered()
print(f"\n已注册因子: {factors}")
# 计算所有因子
result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131")
assert "rank1" in result.columns
assert "rank_mean" in result.columns
print(f"\n结果行数: {len(result)}")
print(f"rank1 空值数: {result['rank1'].is_null().sum()}")
print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}")
print("\n" + "=" * 60)
print("多个嵌套因子测试通过!")
print("=" * 60)
def test_nested_vs_native_polars(self):
"""对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。"""
print("\n" + "=" * 60)
print("对比测试cs_rank(ts_delay(close, 1)) vs 原生 Polars")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 使用 FactorEngine 计算嵌套因子
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...")
engine.register("delayed_rank", cs_rank(ts_delay(close, 1)))
engine_result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"FactorEngine 结果: {len(engine_result)}")
# 3. 使用原生 Polars 计算(手动分步)
print("\n使用原生 Polars 手动计算...")
# 先计算 ts_delay(close, 1)
native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns(
[pl.col("close").shift(1).over("ts_code").alias("delayed_close")]
)
# 再计算 cs_rank
native_result = native_result.with_columns(
[
(pl.col("delayed_close").rank() / pl.col("delayed_close").count())
.over("trade_date")
.alias("native_delayed_rank")
]
)
print(f"原生 Polars 结果: {len(native_result)}")
# 4. 合并结果进行对比
comparison = engine_result.join(
native_result.select(["ts_code", "trade_date", "native_delayed_rank"]),
on=["ts_code", "trade_date"],
how="inner",
)
# 5. 验证数值一致性(允许微小浮点误差)
diff = comparison.with_columns(
[
(pl.col("delayed_rank") - pl.col("native_delayed_rank"))
.abs()
.alias("diff")
]
)
max_diff = diff["diff"].max()
print(f"\n最大差异: {max_diff}")
# 过滤掉空值后比较(开头的滞后期会有空值)
non_null_diff = diff.filter(pl.col("diff").is_not_null())
assert non_null_diff["diff"].max() < 1e-10, (
f"数值差异过大: {non_null_diff['diff'].max()}"
)
print("\n" + "=" * 60)
print("数值一致性验证通过!")
print("=" * 60)
def test_factor_reference_factor(self):
"""测试因子引用另一个因子fac2 = cs_rank(fac1)。"""
print("\n" + "=" * 60)
print("测试因子引用其他因子: fac2 = cs_rank(fac1)")
print("=" * 60)
# 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 1. 注册基础因子 fac1
print("\n注册基础因子 fac1 = ts_mean(close, 5)")
from src.factors.api import ts_mean
engine.register("fac1", ts_mean(close, 5))
# 2. 注册引用因子 fac2引用 fac1
print("注册引用因子 fac2 = cs_rank(fac1)")
engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子
# 3. 验证依赖关系
registered = engine.list_registered()
print(f"\n已注册因子: {registered}")
assert "fac1" in registered
assert "fac2" in registered
# 4. 执行计算
print("\n执行计算...")
result = engine.compute(["fac1", "fac2"], "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 5. 验证结果
assert "fac1" in result.columns, "结果中应有 fac1 列"
assert "fac2" in result.columns, "结果中应有 fac2 列"
# fac2 是排名,应在 [0, 1] 之间
assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间"
assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间"
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("因子引用功能测试通过!")
print("=" * 60)
if __name__ == "__main__":
test = TestASTOptimizer()
test.test_flattener_basic()
test.test_flattener_no_nested()
test.test_flattener_deeply_nested()
test.test_nested_window_function_engine()
test.test_multiple_nested_factors()
test.test_nested_vs_native_polars()
test.test_factor_reference_factor()
print("\n所有测试通过!")