test(debug): 添加因子回测一致性问题的调试测试套件
- 分析GTJA_alpha032等因子在不同LOOKBACK_DAYS下的差异来源 - 验证cs_rank嵌套和截面股票数量对结果的影响 - 测试ts_rank NaN处理和除法除零修复
This commit is contained in:
146
tests/debug/test_bug_fixes.py
Normal file
146
tests/debug/test_bug_fixes.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
验证 ts_rank 和除法 Bug 修复的测试
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from src.factors import FactorEngine
|
||||
|
||||
|
||||
def test_ts_rank_nan_handling():
|
||||
"""测试 ts_rank 正确处理 NaN 值,不应该将 NaN 转换为 0.0"""
|
||||
print("=" * 60)
|
||||
print("测试 ts_rank NaN 处理")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建包含 NaN 的测试数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101", "20240101", "20240101", "20240101", "20240101"]
|
||||
* 2,
|
||||
"ts_code": ["000001.SZ"] * 5 + ["000002.SZ"] * 5,
|
||||
"close": [
|
||||
1.0,
|
||||
2.0,
|
||||
np.nan,
|
||||
4.0,
|
||||
5.0, # 股票1:第3个是NaN
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
5.0,
|
||||
6.0,
|
||||
], # 股票2:正常数据
|
||||
}
|
||||
)
|
||||
|
||||
print("输入数据:")
|
||||
print(df)
|
||||
|
||||
# 使用引擎计算 ts_rank
|
||||
engine = FactorEngine()
|
||||
# 手动设置数据(这里用简单方式)
|
||||
|
||||
# 直接测试 translator
|
||||
from src.factors.translator import PolarsTranslator
|
||||
from src.factors.parser import FormulaParser
|
||||
from src.factors.registry import FunctionRegistry
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
ast = parser.parse("ts_rank(close, 3)")
|
||||
|
||||
translator = PolarsTranslator()
|
||||
expr = translator.translate(ast)
|
||||
|
||||
result = df.with_columns([expr.alias("ts_rank_result")])
|
||||
|
||||
print("\nts_rank 结果:")
|
||||
print(result.select(["trade_date", "ts_code", "close", "ts_rank_result"]))
|
||||
|
||||
# 验证:NaN 输入应该产生 NaN 输出,而不是 0.0
|
||||
stock1_data = result.filter(pl.col("ts_code") == "000001.SZ")
|
||||
nan_input_row = stock1_data.filter(pl.col("close").is_nan())
|
||||
|
||||
if nan_input_row.height > 0:
|
||||
rank_value = nan_input_row["ts_rank_result"][0]
|
||||
if np.isnan(rank_value):
|
||||
print("\n[PASS] NaN 输入正确产生了 NaN 输出")
|
||||
return True
|
||||
elif rank_value == 0.0:
|
||||
print(f"\n[FAIL] NaN 输入被错误转换为 0.0!")
|
||||
return False
|
||||
else:
|
||||
print(f"\n[FAIL] 意外的输出值: {rank_value}")
|
||||
return False
|
||||
else:
|
||||
print("\n[SKIP] 没有找到 NaN 输入行")
|
||||
return False
|
||||
|
||||
|
||||
def test_division_by_zero():
|
||||
"""测试除法处理除零情况,不应该产生 NaN/inf"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试除法除零处理")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建包含零的数据
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101", "20240101", "20240101"],
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"numerator": [100.0, 100.0, 100.0],
|
||||
"denominator": [10.0, 0.0, 5.0], # 中间那个是0
|
||||
}
|
||||
)
|
||||
|
||||
print("输入数据:")
|
||||
print(df)
|
||||
|
||||
# 直接测试 translator
|
||||
from src.factors.translator import PolarsTranslator
|
||||
from src.factors.parser import FormulaParser
|
||||
from src.factors.registry import FunctionRegistry
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
ast = parser.parse("numerator / denominator")
|
||||
|
||||
translator = PolarsTranslator()
|
||||
expr = translator.translate(ast)
|
||||
|
||||
result = df.with_columns([expr.alias("division_result")])
|
||||
|
||||
print("\n除法结果:")
|
||||
print(result.select(["ts_code", "numerator", "denominator", "division_result"]))
|
||||
|
||||
# 验证:除零应该产生 Null,而不是 NaN/inf
|
||||
zero_row = result.filter(pl.col("denominator") == 0.0)
|
||||
if zero_row.height > 0:
|
||||
div_value = zero_row["division_result"][0]
|
||||
if div_value is None:
|
||||
print("\n[PASS] 除零正确产生了 Null")
|
||||
return True
|
||||
elif np.isnan(div_value) or np.isinf(div_value):
|
||||
print(f"\n[FAIL] 除零产生了 NaN/inf: {div_value}")
|
||||
return False
|
||||
else:
|
||||
print(f"\n[?] 除零产生了: {div_value} (可能是其他有效值)")
|
||||
return True # 也可能是预期的行为
|
||||
else:
|
||||
print("\n[SKIP] 没有找到除零行")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test1_pass = test_ts_rank_nan_handling()
|
||||
test2_pass = test_division_by_zero()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试结果汇总")
|
||||
print("=" * 60)
|
||||
print(f"ts_rank NaN 处理: {'PASS' if test1_pass else 'FAIL'}")
|
||||
print(f"除法除零处理: {'PASS' if test2_pass else 'FAIL'}")
|
||||
|
||||
if test1_pass and test2_pass:
|
||||
print("\n所有测试通过!")
|
||||
else:
|
||||
print("\n部分测试失败,请检查实现")
|
||||
Reference in New Issue
Block a user