"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。 测试因子: cs_rank(ts_delay(close, 1)) 这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。 """ import pytest import polars as pl import numpy as np from datetime import datetime, timedelta from src.factors.engine import FactorEngine from src.factors.api import close, ts_delay, cs_rank from src.factors.dsl import FunctionNode from src.factors.engine.ast_optimizer import ExpressionFlattener def create_mock_data( start_date: str = "20240101", end_date: str = "20240131", n_stocks: int = 5, ) -> pl.DataFrame: """创建模拟的日线数据。""" start = datetime.strptime(start_date, "%Y%m%d") end = datetime.strptime(end_date, "%Y%m%d") dates = [] current = start while current <= end: if current.weekday() < 5: # 周一到周五 dates.append(current.strftime("%Y%m%d")) current += timedelta(days=1) stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)] np.random.seed(42) rows = [] for date in dates: for stock in stocks: base_price = 10 + np.random.randn() * 5 close_val = base_price + np.random.randn() * 0.5 open_val = close_val + np.random.randn() * 0.2 high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3 low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3 vol = int(1000000 + np.random.exponential(500000)) rows.append( { "ts_code": stock, "trade_date": date, "open": round(open_val, 2), "high": round(high_val, 2), "low": round(low_val, 2), "close": round(close_val, 2), "volume": vol, } ) return pl.DataFrame(rows) class TestASTOptimizer: """AST 优化器测试类。""" def test_flattener_basic(self): """测试拍平器基本功能。""" from src.factors.api import close flattener = ExpressionFlattener() # 创建嵌套表达式: cs_rank(ts_delay(close, 1)) expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1)) flat_expr, tmp_factors = flattener.flatten(expr) # 验证临时因子被提取 assert len(tmp_factors) == 1 assert "__tmp_0" in tmp_factors # 验证主表达式使用了 Symbol 引用 assert isinstance(flat_expr, FunctionNode) assert flat_expr.func_name == "cs_rank" # 验证第一个参数是临时因子引用(通过 name 属性检查) assert hasattr(flat_expr.args[0], "name") assert flat_expr.args[0].name == "__tmp_0" # 验证临时因子内容 tmp_node = tmp_factors["__tmp_0"] assert isinstance(tmp_node, FunctionNode) assert tmp_node.func_name == "ts_delay" print("[PASS] 拍平器基本功能测试") def test_flattener_no_nested(self): """测试非嵌套表达式不会被拍平。""" from src.factors.api import close, ts_mean flattener = ExpressionFlattener() # 非嵌套表达式: ts_mean(close, 20) expr = FunctionNode("ts_mean", close, 20) flat_expr, tmp_factors = flattener.flatten(expr) # 验证没有临时因子被提取 assert len(tmp_factors) == 0 # 验证表达式保持不变 assert isinstance(flat_expr, FunctionNode) assert flat_expr.func_name == "ts_mean" print("[PASS] 非嵌套表达式测试") def test_flattener_deeply_nested(self): """测试多层嵌套表达式拍平。""" from src.factors.api import close, ts_mean flattener = ExpressionFlattener() # 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5)) expr = FunctionNode( "cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5) ) flat_expr, tmp_factors = flattener.flatten(expr) # 验证提取了两个临时因子(修复后正确行为) # ts_delay(close, 1) 被提取为 __tmp_0 # ts_mean(__tmp_0, 5) 被提取为 __tmp_1 assert len(tmp_factors) == 2 assert "__tmp_0" in tmp_factors assert "__tmp_1" in tmp_factors # 验证 __tmp_0 内容是 ts_delay(close, 1) tmp0_node = tmp_factors["__tmp_0"] assert isinstance(tmp0_node, FunctionNode) assert tmp0_node.func_name == "ts_delay" # 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5) tmp1_node = tmp_factors["__tmp_1"] assert isinstance(tmp1_node, FunctionNode) assert tmp1_node.func_name == "ts_mean" from src.factors.dsl import Symbol assert isinstance(tmp1_node.args[0], Symbol) assert tmp1_node.args[0].name == "__tmp_0" # 验证主表达式引用 __tmp_1 assert isinstance(flat_expr, FunctionNode) assert flat_expr.func_name == "cs_rank" assert isinstance(flat_expr.args[0], Symbol) assert flat_expr.args[0].name == "__tmp_1" print("[PASS] 多层嵌套表达式拍平测试") def test_nested_window_function_engine(self): """测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。""" print("\n" + "=" * 60) print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))") print("=" * 60) # 1. 准备数据 mock_data = create_mock_data("20240101", "20240131", n_stocks=5) print(f"\n生成模拟数据: {len(mock_data)} 行") # 2. 初始化引擎 engine = FactorEngine(data_source={"pro_bar": mock_data}) print("引擎初始化完成") # 3. 使用字符串表达式注册嵌套窗口函数 print("\n注册因子: cs_rank(ts_delay(close, 1))") engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))") # 4. 检查临时因子是否被创建 registered_factors = engine.list_registered() print(f"已注册因子: {registered_factors}") # 验证有临时因子被创建 tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")] assert len(tmp_factors) >= 1, "应该有临时因子被创建" print(f"临时因子: {tmp_factors}") # 5. 执行计算 print("\n执行计算...") result = engine.compute("delayed_rank", "20240115", "20240131") print(f"计算完成: {len(result)} 行") # 6. 验证结果 assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列" # 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null) non_null_values = result["delayed_rank"].drop_nulls() if len(non_null_values) > 0: assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间" assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间" # 检查没有过多空值(考虑到开头的滞后期) null_count = result["delayed_rank"].is_null().sum() print(f"空值数量: {null_count}") # 展示部分结果 print("\n前 10 行结果:") sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head( 10 ) print(sample.to_pandas().to_string(index=False)) print("\n" + "=" * 60) print("嵌套窗口函数测试通过!") print("=" * 60) def test_multiple_nested_factors(self): """测试同时注册多个嵌套因子。""" print("\n" + "=" * 60) print("测试多个嵌套因子") print("=" * 60) mock_data = create_mock_data("20240101", "20240131", n_stocks=5) engine = FactorEngine(data_source={"pro_bar": mock_data}) # 注册多个嵌套因子(使用字符串表达式) print("\n注册因子1: cs_rank(ts_delay(close, 1))") engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))") print("注册因子2: ts_mean(cs_rank(close), 5)") engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)") # 检查已注册因子 factors = engine.list_registered() print(f"\n已注册因子: {factors}") # 计算所有因子 result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131") assert "rank1" in result.columns assert "rank_mean" in result.columns print(f"\n结果行数: {len(result)}") print(f"rank1 空值数: {result['rank1'].is_null().sum()}") print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}") print("\n" + "=" * 60) print("多个嵌套因子测试通过!") print("=" * 60) def test_nested_vs_native_polars(self): """对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。""" print("\n" + "=" * 60) print("对比测试:cs_rank(ts_delay(close, 1)) vs 原生 Polars") print("=" * 60) # 1. 准备数据 mock_data = create_mock_data("20240101", "20240131", n_stocks=5) print(f"\n生成模拟数据: {len(mock_data)} 行") # 2. 使用 FactorEngine 计算嵌套因子 engine = FactorEngine(data_source={"pro_bar": mock_data}) print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...") engine.register("delayed_rank", cs_rank(ts_delay(close, 1))) engine_result = engine.compute("delayed_rank", "20240115", "20240131") print(f"FactorEngine 结果: {len(engine_result)} 行") # 3. 使用原生 Polars 计算(手动分步) print("\n使用原生 Polars 手动计算...") # 先计算 ts_delay(close, 1) native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns( [pl.col("close").shift(1).over("ts_code").alias("delayed_close")] ) # 再计算 cs_rank native_result = native_result.with_columns( [ (pl.col("delayed_close").rank() / pl.col("delayed_close").count()) .over("trade_date") .alias("native_delayed_rank") ] ) print(f"原生 Polars 结果: {len(native_result)} 行") # 4. 合并结果进行对比 comparison = engine_result.join( native_result.select(["ts_code", "trade_date", "native_delayed_rank"]), on=["ts_code", "trade_date"], how="inner", ) # 5. 验证数值一致性(允许微小浮点误差) diff = comparison.with_columns( [ (pl.col("delayed_rank") - pl.col("native_delayed_rank")) .abs() .alias("diff") ] ) max_diff = diff["diff"].max() print(f"\n最大差异: {max_diff}") # 过滤掉空值后比较(开头的滞后期会有空值) non_null_diff = diff.filter(pl.col("diff").is_not_null()) assert non_null_diff["diff"].max() < 1e-10, ( f"数值差异过大: {non_null_diff['diff'].max()}" ) print("\n" + "=" * 60) print("数值一致性验证通过!") print("=" * 60) def test_factor_reference_factor(self): """测试因子引用另一个因子:fac2 = cs_rank(fac1)。""" print("\n" + "=" * 60) print("测试因子引用其他因子: fac2 = cs_rank(fac1)") print("=" * 60) # 准备数据 mock_data = create_mock_data("20240101", "20240131", n_stocks=5) engine = FactorEngine(data_source={"pro_bar": mock_data}) # 1. 注册基础因子 fac1 print("\n注册基础因子 fac1 = ts_mean(close, 5)") from src.factors.api import ts_mean engine.register("fac1", ts_mean(close, 5)) # 2. 注册引用因子 fac2,引用 fac1 print("注册引用因子 fac2 = cs_rank(fac1)") engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子 # 3. 验证依赖关系 registered = engine.list_registered() print(f"\n已注册因子: {registered}") assert "fac1" in registered assert "fac2" in registered # 4. 执行计算 print("\n执行计算...") result = engine.compute(["fac1", "fac2"], "20240115", "20240131") print(f"计算完成: {len(result)} 行") # 5. 验证结果 assert "fac1" in result.columns, "结果中应有 fac1 列" assert "fac2" in result.columns, "结果中应有 fac2 列" # fac2 是排名,应在 [0, 1] 之间 assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间" assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间" print("\n前 10 行结果:") sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head( 10 ) print(sample.to_pandas().to_string(index=False)) print("\n" + "=" * 60) print("因子引用功能测试通过!") print("=" * 60) if __name__ == "__main__": test = TestASTOptimizer() test.test_flattener_basic() test.test_flattener_no_nested() test.test_flattener_deeply_nested() test.test_nested_window_function_engine() test.test_multiple_nested_factors() test.test_nested_vs_native_polars() test.test_factor_reference_factor() print("\n所有测试通过!")