""" 验证 ts_rank 和除法 Bug 修复的测试 """ import polars as pl import numpy as np from src.factors import FactorEngine def test_ts_rank_nan_handling(): """测试 ts_rank 正确处理 NaN 值,不应该将 NaN 转换为 0.0""" print("=" * 60) print("测试 ts_rank NaN 处理") print("=" * 60) # 创建包含 NaN 的测试数据 df = pl.DataFrame( { "trade_date": ["20240101", "20240101", "20240101", "20240101", "20240101"] * 2, "ts_code": ["000001.SZ"] * 5 + ["000002.SZ"] * 5, "close": [ 1.0, 2.0, np.nan, 4.0, 5.0, # 股票1:第3个是NaN 2.0, 3.0, 4.0, 5.0, 6.0, ], # 股票2:正常数据 } ) print("输入数据:") print(df) # 使用引擎计算 ts_rank engine = FactorEngine() # 手动设置数据(这里用简单方式) # 直接测试 translator from src.factors.translator import PolarsTranslator from src.factors.parser import FormulaParser from src.factors.registry import FunctionRegistry parser = FormulaParser(FunctionRegistry()) ast = parser.parse("ts_rank(close, 3)") translator = PolarsTranslator() expr = translator.translate(ast) result = df.with_columns([expr.alias("ts_rank_result")]) print("\nts_rank 结果:") print(result.select(["trade_date", "ts_code", "close", "ts_rank_result"])) # 验证:NaN 输入应该产生 NaN 输出,而不是 0.0 stock1_data = result.filter(pl.col("ts_code") == "000001.SZ") nan_input_row = stock1_data.filter(pl.col("close").is_nan()) if nan_input_row.height > 0: rank_value = nan_input_row["ts_rank_result"][0] if np.isnan(rank_value): print("\n[PASS] NaN 输入正确产生了 NaN 输出") return True elif rank_value == 0.0: print(f"\n[FAIL] NaN 输入被错误转换为 0.0!") return False else: print(f"\n[FAIL] 意外的输出值: {rank_value}") return False else: print("\n[SKIP] 没有找到 NaN 输入行") return False def test_division_by_zero(): """测试除法处理除零情况,不应该产生 NaN/inf""" print("\n" + "=" * 60) print("测试除法除零处理") print("=" * 60) # 创建包含零的数据 df = pl.DataFrame( { "trade_date": ["20240101", "20240101", "20240101"], "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "numerator": [100.0, 100.0, 100.0], "denominator": [10.0, 0.0, 5.0], # 中间那个是0 } ) print("输入数据:") print(df) # 直接测试 translator from src.factors.translator import PolarsTranslator from src.factors.parser import FormulaParser from src.factors.registry import FunctionRegistry parser = FormulaParser(FunctionRegistry()) ast = parser.parse("numerator / denominator") translator = PolarsTranslator() expr = translator.translate(ast) result = df.with_columns([expr.alias("division_result")]) print("\n除法结果:") print(result.select(["ts_code", "numerator", "denominator", "division_result"])) # 验证:除零应该产生 Null,而不是 NaN/inf zero_row = result.filter(pl.col("denominator") == 0.0) if zero_row.height > 0: div_value = zero_row["division_result"][0] if div_value is None: print("\n[PASS] 除零正确产生了 Null") return True elif np.isnan(div_value) or np.isinf(div_value): print(f"\n[FAIL] 除零产生了 NaN/inf: {div_value}") return False else: print(f"\n[?] 除零产生了: {div_value} (可能是其他有效值)") return True # 也可能是预期的行为 else: print("\n[SKIP] 没有找到除零行") return False if __name__ == "__main__": test1_pass = test_ts_rank_nan_handling() test2_pass = test_division_by_zero() print("\n" + "=" * 60) print("测试结果汇总") print("=" * 60) print(f"ts_rank NaN 处理: {'PASS' if test1_pass else 'FAIL'}") print(f"除法除零处理: {'PASS' if test2_pass else 'FAIL'}") if test1_pass and test2_pass: print("\n所有测试通过!") else: print("\n部分测试失败,请检查实现")