"""测试新增的时间序列函数和智能分发逻辑。""" import pytest import polars as pl import numpy as np from src.factors.dsl import Symbol, FunctionNode from src.factors.translator import PolarsTranslator def test_ts_sma_translate(): """测试 ts_sma 翻译正确。""" close = Symbol("close") expr = FunctionNode("ts_sma", close, 10, 5) translator = PolarsTranslator() result = translator.translate(expr) assert isinstance(result, pl.Expr) def test_ts_wma_translate(): """测试 ts_wma 翻译正确。""" close = Symbol("close") expr = FunctionNode("ts_wma", close, 20) translator = PolarsTranslator() result = translator.translate(expr) assert isinstance(result, pl.Expr) def test_ts_sumac_translate(): """测试 ts_sumac 翻译正确。""" close = Symbol("close") expr = FunctionNode("ts_sumac", close) translator = PolarsTranslator() result = translator.translate(expr) assert isinstance(result, pl.Expr) def test_max_intelligent_dispatch(): """测试 max_ 智能分发: int -> ts_max,其他 -> element-wise max。""" from src.factors.api import max_, close # 正整数 -> ts_max result = max_(close, 20) assert result.func_name == "ts_max" # 零或负数 -> element-wise max result = max_(close, 0) assert result.func_name == "max" result = max_(close, -1) assert result.func_name == "max" # 浮点数 -> element-wise max result = max_(close, 10.5) assert result.func_name == "max" def test_min_intelligent_dispatch(): """测试 min_ 智能分发: int -> ts_min,其他 -> element-wise min。""" from src.factors.api import min_, close # 正整数 -> ts_min result = min_(close, 20) assert result.func_name == "ts_min" # 零或负数 -> element-wise min result = min_(close, 0) assert result.func_name == "min" def create_test_data() -> pl.DataFrame: """创建测试数据。""" np.random.seed(42) n = 100 return pl.DataFrame( { "ts_code": ["000001.SZ"] * n, "trade_date": list(range(20240101, 20240101 + n)), "close": np.random.randn(n).cumsum() + 100, } ) def test_ts_sma_computation(): """测试 ts_sma 计算与原生 Polars 一致。""" df = create_test_data() translator = PolarsTranslator() # 翻译因子 close = Symbol("close") expr_node = FunctionNode("ts_sma", close, 10, 5) expr = translator.translate(expr_node) # 使用翻译后的表达式计算 result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")]) # 原生 Polars 计算 native = df.with_columns( [pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")] ) # 对比结果 assert np.allclose( result["ts_sma_result"].to_numpy()[9:], native["native_sma"].to_numpy()[9:], equal_nan=True, ) def test_ts_sumac_computation(): """测试 ts_sumac 计算与原生 Polars 一致。""" df = create_test_data() translator = PolarsTranslator() close = Symbol("close") expr_node = FunctionNode("ts_sumac", close) expr = translator.translate(expr_node) result = df.select( ["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")] ) native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")]) assert np.allclose( result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy() ) if __name__ == "__main__": pytest.main([__file__, "-v"])