"""Phase 1-2 因子函数集成测试。 测试所有新实现的函数,使用字符串因子表达式形式计算因子, 并与原始 Polars 计算结果进行对比。 测试范围: 1. 数学函数:atan, log1p 2. 统计函数:ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema 3. TA-Lib 函数:ts_atr, ts_rsi, ts_obv """ import numpy as np import polars as pl import pytest from src.factors import FormulaParser, FunctionRegistry from src.factors.translator import PolarsTranslator, HAS_TALIB from src.factors.engine import FactorEngine from src.data.catalog import DatabaseCatalog # ============== 测试数据准备 ============== def create_test_data() -> pl.DataFrame: """创建测试用的模拟数据。 创建一个包含多只股票、多个交易日的 DataFrame, 用于测试因子函数的计算。 """ np.random.seed(42) dates = pl.date_range( start=pl.date(2024, 1, 1), end=pl.date(2024, 1, 31), interval="1d", eager=True, ) stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"] data = [] for stock in stocks: base_price = 100 + np.random.randn() * 10 for i, date in enumerate(dates): price = base_price + np.random.randn() * 5 + i * 0.1 data.append( { "ts_code": stock, "trade_date": date, "close": price, "open": price * (1 + np.random.randn() * 0.01), "high": price * (1 + abs(np.random.randn()) * 0.02), "low": price * (1 - abs(np.random.randn()) * 0.02), "vol": int(1000000 + np.random.randn() * 500000), } ) return pl.DataFrame(data) # ============== 数学函数测试 ============== def test_atan_function(): """测试 atan 函数:计算反正切值。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 df = pl.DataFrame( { "ts_code": ["A"] * 5, "trade_date": pl.date_range( pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True ), "value": [0.0, 1.0, -1.0, 0.5, -0.5], } ) # DSL 计算 expr = parser.parse("atan(value)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 原始 Polars 计算 result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[ "pl_result" ] # 对比结果 np.testing.assert_array_almost_equal( result_dsl.values, result_pl.values, decimal=10 ) def test_log1p_function(): """测试 log1p 函数:计算 log(1+x)。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 df = pl.DataFrame( { "ts_code": ["A"] * 5, "trade_date": pl.date_range( pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True ), "value": [0.0, 0.1, -0.1, 1.0, -0.5], } ) # DSL 计算 expr = parser.parse("log1p(value)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 原始 Polars 计算 result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[ "pl_result" ] # 对比结果 np.testing.assert_array_almost_equal( result_dsl.values, result_pl.values, decimal=10 ) # ============== 统计函数测试 ============== def test_ts_var_function(): """测试 ts_var 函数:滚动方差。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 df = pl.DataFrame( { "ts_code": ["A"] * 10 + ["B"] * 10, "trade_date": pl.date_range( pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True ).append( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True) ), "close": list(range(1, 11)) + list(range(10, 20)), } ) # DSL 计算 expr = parser.parse("ts_var(close, 5)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = ( df.with_columns(dsl_result=polars_expr) .to_pandas() .groupby("ts_code")["dsl_result"] .apply(list) ) # 原始 Polars 计算 result_pl = ( df.with_columns( pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code") ) .to_pandas() .groupby("ts_code")["pl_result"] .apply(list) ) # 对比结果 for stock in ["A", "B"]: np.testing.assert_array_almost_equal( result_dsl[stock], result_pl[stock], decimal=10 ) def test_ts_skew_function(): """测试 ts_skew 函数:滚动偏度。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 np.random.seed(42) df = pl.DataFrame( { "ts_code": ["A"] * 20 + ["B"] * 20, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) ) * 2, "close": np.random.randn(40), } ) # DSL 计算 expr = parser.parse("ts_skew(close, 10)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 原始 Polars 计算 result_pl = df.with_columns( pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code") ).to_pandas()["pl_result"] # 对比结果 np.testing.assert_array_almost_equal( result_dsl.values, result_pl.values, decimal=10 ) def test_ts_kurt_function(): """测试 ts_kurt 函数:滚动峰度。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 np.random.seed(42) df = pl.DataFrame( { "ts_code": ["A"] * 20 + ["B"] * 20, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) ) * 2, "close": np.random.randn(40), } ) # DSL 计算 expr = parser.parse("ts_kurt(close, 10)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 原始 Polars 计算 result_pl = df.with_columns( pl_result=pl.col("close") .rolling_map( lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"), window_size=10, ) .over("ts_code") ).to_pandas()["pl_result"] # 对比结果 np.testing.assert_array_almost_equal( result_dsl.values, result_pl.values, decimal=10 ) def test_ts_pct_change_function(): """测试 ts_pct_change 函数:百分比变化。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 df = pl.DataFrame( { "ts_code": ["A"] * 5 + ["B"] * 5, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True) ) * 2, "close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60], } ) # DSL 计算 expr = parser.parse("ts_pct_change(close, 1)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 原始 Polars 计算 result_pl = df.with_columns( pl_result=(pl.col("close") - pl.col("close").shift(1)) / pl.col("close").shift(1).over("ts_code") ).to_pandas()["pl_result"] # 对比结果 np.testing.assert_array_almost_equal( result_dsl.values, result_pl.values, decimal=10 ) def test_ts_ema_function(): """测试 ts_ema 函数:指数移动平均。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 df = pl.DataFrame( { "ts_code": ["A"] * 10 + ["B"] * 10, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True) ) * 2, "close": list(range(1, 11)) + list(range(10, 20)), } ) # DSL 计算 expr = parser.parse("ts_ema(close, 5)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 原始 Polars 计算 result_pl = df.with_columns( pl_result=pl.col("close").ewm_mean(span=5).over("ts_code") ).to_pandas()["pl_result"] # 对比结果 np.testing.assert_array_almost_equal( result_dsl.values, result_pl.values, decimal=10 ) # ============== TA-Lib 函数测试 ============== @pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") def test_ts_atr_function(): """测试 ts_atr 函数:平均真实波幅。""" import talib parser = FormulaParser(FunctionRegistry()) # 创建测试数据 np.random.seed(42) df = pl.DataFrame( { "ts_code": ["A"] * 20 + ["B"] * 20, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) ) * 2, "high": 100 + np.random.randn(40) * 2, "low": 98 + np.random.randn(40) * 2, "close": 99 + np.random.randn(40) * 2, } ) # DSL 计算 expr = parser.parse("ts_atr(high, low, close, 14)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 使用 talib 手动计算(分组计算) result_expected = [] for stock in ["A", "B"]: stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() atr = talib.ATR( stock_df["high"].values, stock_df["low"].values, stock_df["close"].values, timeperiod=14, ) result_expected.extend(atr) # 对比结果(允许小误差) np.testing.assert_array_almost_equal( result_dsl.values, np.array(result_expected), decimal=5 ) @pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") def test_ts_rsi_function(): """测试 ts_rsi 函数:相对强弱指数。""" import talib parser = FormulaParser(FunctionRegistry()) # 创建测试数据 np.random.seed(42) df = pl.DataFrame( { "ts_code": ["A"] * 30 + ["B"] * 30, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True) ) * 2, "close": 100 + np.cumsum(np.random.randn(60)), } ) # DSL 计算 expr = parser.parse("ts_rsi(close, 14)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 使用 talib 手动计算(分组计算) result_expected = [] for stock in ["A", "B"]: stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() rsi = talib.RSI(stock_df["close"].values, timeperiod=14) result_expected.extend(rsi) # 对比结果(允许小误差) np.testing.assert_array_almost_equal( result_dsl.values, np.array(result_expected), decimal=5 ) @pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") def test_ts_obv_function(): """测试 ts_obv 函数:能量潮指标。""" import talib parser = FormulaParser(FunctionRegistry()) # 创建测试数据 np.random.seed(42) df = pl.DataFrame( { "ts_code": ["A"] * 20 + ["B"] * 20, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) ) * 2, "close": 100 + np.cumsum(np.random.randn(40)), "vol": np.random.randint(100000, 1000000, 40).astype(float), } ) # DSL 计算 expr = parser.parse("ts_obv(close, vol)") translator = PolarsTranslator() polars_expr = translator.translate(expr) result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] # 使用 talib 手动计算(分组计算) result_expected = [] for stock in ["A", "B"]: stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() obv = talib.OBV( stock_df["close"].values, stock_df["vol"].values, ) result_expected.extend(obv) # 对比结果(允许小误差) np.testing.assert_array_almost_equal( result_dsl.values, np.array(result_expected), decimal=5 ) # ============== 综合测试 ============== def test_complex_factor_expressions(): """测试复杂因子表达式的计算。""" parser = FormulaParser(FunctionRegistry()) # 创建测试数据 np.random.seed(42) df = pl.DataFrame( { "ts_code": ["A"] * 30 + ["B"] * 30, "trade_date": list( pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True) ) * 2, "close": 100 + np.cumsum(np.random.randn(60)), } ) # 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50 expr = parser.parse( "atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50" ) translator = PolarsTranslator() polars_expr = translator.translate(expr) result = df.with_columns(factor=polars_expr) # 验证结果不为空 assert len(result) == 60 assert "factor" in result.columns print("复杂因子表达式测试通过") # ============== 主函数 ============== if __name__ == "__main__": print("运行 Phase 1-2 因子函数测试...") print("=" * 80) # 运行数学函数测试 print("\n[数学函数测试]") test_atan_function() print(" ✅ atan 测试通过") test_log1p_function() print(" ✅ log1p 测试通过") # 运行统计函数测试 print("\n[统计函数测试]") test_ts_var_function() print(" ✅ ts_var 测试通过") test_ts_skew_function() print(" ✅ ts_skew 测试通过") test_ts_kurt_function() print(" ✅ ts_kurt 测试通过") test_ts_pct_change_function() print(" ✅ ts_pct_change 测试通过") test_ts_ema_function() print(" ✅ ts_ema 测试通过") # 运行 TA-Lib 函数测试 print("\n[TA-Lib 函数测试]") try: import talib HAS_TALIB = True except ImportError: HAS_TALIB = False print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试") if HAS_TALIB: test_ts_atr_function() print(" ✅ ts_atr 测试通过") test_ts_rsi_function() print(" ✅ ts_rsi 测试通过") test_ts_obv_function() print(" ✅ ts_obv 测试通过") # 运行综合测试 print("\n[综合测试]") test_complex_factor_expressions() print(" ✅ 复杂因子表达式测试通过") print("\n" + "=" * 80) print("所有测试通过!")