Files
ProStock/tests/debug/test_bug_fixes.py
liaozhaorun 31b25074c3 test(debug): 添加因子回测一致性问题的调试测试套件
- 分析GTJA_alpha032等因子在不同LOOKBACK_DAYS下的差异来源
- 验证cs_rank嵌套和截面股票数量对结果的影响
- 测试ts_rank NaN处理和除法除零修复
2026-03-22 02:43:23 +08:00

147 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
验证 ts_rank 和除法 Bug 修复的测试
"""
import polars as pl
import numpy as np
from src.factors import FactorEngine
def test_ts_rank_nan_handling():
"""测试 ts_rank 正确处理 NaN 值,不应该将 NaN 转换为 0.0"""
print("=" * 60)
print("测试 ts_rank NaN 处理")
print("=" * 60)
# 创建包含 NaN 的测试数据
df = pl.DataFrame(
{
"trade_date": ["20240101", "20240101", "20240101", "20240101", "20240101"]
* 2,
"ts_code": ["000001.SZ"] * 5 + ["000002.SZ"] * 5,
"close": [
1.0,
2.0,
np.nan,
4.0,
5.0, # 股票1第3个是NaN
2.0,
3.0,
4.0,
5.0,
6.0,
], # 股票2正常数据
}
)
print("输入数据:")
print(df)
# 使用引擎计算 ts_rank
engine = FactorEngine()
# 手动设置数据(这里用简单方式)
# 直接测试 translator
from src.factors.translator import PolarsTranslator
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
parser = FormulaParser(FunctionRegistry())
ast = parser.parse("ts_rank(close, 3)")
translator = PolarsTranslator()
expr = translator.translate(ast)
result = df.with_columns([expr.alias("ts_rank_result")])
print("\nts_rank 结果:")
print(result.select(["trade_date", "ts_code", "close", "ts_rank_result"]))
# 验证NaN 输入应该产生 NaN 输出,而不是 0.0
stock1_data = result.filter(pl.col("ts_code") == "000001.SZ")
nan_input_row = stock1_data.filter(pl.col("close").is_nan())
if nan_input_row.height > 0:
rank_value = nan_input_row["ts_rank_result"][0]
if np.isnan(rank_value):
print("\n[PASS] NaN 输入正确产生了 NaN 输出")
return True
elif rank_value == 0.0:
print(f"\n[FAIL] NaN 输入被错误转换为 0.0")
return False
else:
print(f"\n[FAIL] 意外的输出值: {rank_value}")
return False
else:
print("\n[SKIP] 没有找到 NaN 输入行")
return False
def test_division_by_zero():
"""测试除法处理除零情况,不应该产生 NaN/inf"""
print("\n" + "=" * 60)
print("测试除法除零处理")
print("=" * 60)
# 创建包含零的数据
df = pl.DataFrame(
{
"trade_date": ["20240101", "20240101", "20240101"],
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"numerator": [100.0, 100.0, 100.0],
"denominator": [10.0, 0.0, 5.0], # 中间那个是0
}
)
print("输入数据:")
print(df)
# 直接测试 translator
from src.factors.translator import PolarsTranslator
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
parser = FormulaParser(FunctionRegistry())
ast = parser.parse("numerator / denominator")
translator = PolarsTranslator()
expr = translator.translate(ast)
result = df.with_columns([expr.alias("division_result")])
print("\n除法结果:")
print(result.select(["ts_code", "numerator", "denominator", "division_result"]))
# 验证:除零应该产生 Null而不是 NaN/inf
zero_row = result.filter(pl.col("denominator") == 0.0)
if zero_row.height > 0:
div_value = zero_row["division_result"][0]
if div_value is None:
print("\n[PASS] 除零正确产生了 Null")
return True
elif np.isnan(div_value) or np.isinf(div_value):
print(f"\n[FAIL] 除零产生了 NaN/inf: {div_value}")
return False
else:
print(f"\n[?] 除零产生了: {div_value} (可能是其他有效值)")
return True # 也可能是预期的行为
else:
print("\n[SKIP] 没有找到除零行")
return False
if __name__ == "__main__":
test1_pass = test_ts_rank_nan_handling()
test2_pass = test_division_by_zero()
print("\n" + "=" * 60)
print("测试结果汇总")
print("=" * 60)
print(f"ts_rank NaN 处理: {'PASS' if test1_pass else 'FAIL'}")
print(f"除法除零处理: {'PASS' if test2_pass else 'FAIL'}")
if test1_pass and test2_pass:
print("\n所有测试通过!")
else:
print("\n部分测试失败,请检查实现")