Files
ProStock/tests/debug/test_bug_fixes.py

147 lines
4.4 KiB
Python
Raw Normal View History

"""
验证 ts_rank 和除法 Bug 修复的测试
"""
import polars as pl
import numpy as np
from src.factors import FactorEngine
def test_ts_rank_nan_handling():
"""测试 ts_rank 正确处理 NaN 值,不应该将 NaN 转换为 0.0"""
print("=" * 60)
print("测试 ts_rank NaN 处理")
print("=" * 60)
# 创建包含 NaN 的测试数据
df = pl.DataFrame(
{
"trade_date": ["20240101", "20240101", "20240101", "20240101", "20240101"]
* 2,
"ts_code": ["000001.SZ"] * 5 + ["000002.SZ"] * 5,
"close": [
1.0,
2.0,
np.nan,
4.0,
5.0, # 股票1第3个是NaN
2.0,
3.0,
4.0,
5.0,
6.0,
], # 股票2正常数据
}
)
print("输入数据:")
print(df)
# 使用引擎计算 ts_rank
engine = FactorEngine()
# 手动设置数据(这里用简单方式)
# 直接测试 translator
from src.factors.translator import PolarsTranslator
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
parser = FormulaParser(FunctionRegistry())
ast = parser.parse("ts_rank(close, 3)")
translator = PolarsTranslator()
expr = translator.translate(ast)
result = df.with_columns([expr.alias("ts_rank_result")])
print("\nts_rank 结果:")
print(result.select(["trade_date", "ts_code", "close", "ts_rank_result"]))
# 验证NaN 输入应该产生 NaN 输出,而不是 0.0
stock1_data = result.filter(pl.col("ts_code") == "000001.SZ")
nan_input_row = stock1_data.filter(pl.col("close").is_nan())
if nan_input_row.height > 0:
rank_value = nan_input_row["ts_rank_result"][0]
if np.isnan(rank_value):
print("\n[PASS] NaN 输入正确产生了 NaN 输出")
return True
elif rank_value == 0.0:
print(f"\n[FAIL] NaN 输入被错误转换为 0.0")
return False
else:
print(f"\n[FAIL] 意外的输出值: {rank_value}")
return False
else:
print("\n[SKIP] 没有找到 NaN 输入行")
return False
def test_division_by_zero():
"""测试除法处理除零情况,不应该产生 NaN/inf"""
print("\n" + "=" * 60)
print("测试除法除零处理")
print("=" * 60)
# 创建包含零的数据
df = pl.DataFrame(
{
"trade_date": ["20240101", "20240101", "20240101"],
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"numerator": [100.0, 100.0, 100.0],
"denominator": [10.0, 0.0, 5.0], # 中间那个是0
}
)
print("输入数据:")
print(df)
# 直接测试 translator
from src.factors.translator import PolarsTranslator
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
parser = FormulaParser(FunctionRegistry())
ast = parser.parse("numerator / denominator")
translator = PolarsTranslator()
expr = translator.translate(ast)
result = df.with_columns([expr.alias("division_result")])
print("\n除法结果:")
print(result.select(["ts_code", "numerator", "denominator", "division_result"]))
# 验证:除零应该产生 Null而不是 NaN/inf
zero_row = result.filter(pl.col("denominator") == 0.0)
if zero_row.height > 0:
div_value = zero_row["division_result"][0]
if div_value is None:
print("\n[PASS] 除零正确产生了 Null")
return True
elif np.isnan(div_value) or np.isinf(div_value):
print(f"\n[FAIL] 除零产生了 NaN/inf: {div_value}")
return False
else:
print(f"\n[?] 除零产生了: {div_value} (可能是其他有效值)")
return True # 也可能是预期的行为
else:
print("\n[SKIP] 没有找到除零行")
return False
if __name__ == "__main__":
test1_pass = test_ts_rank_nan_handling()
test2_pass = test_division_by_zero()
print("\n" + "=" * 60)
print("测试结果汇总")
print("=" * 60)
print(f"ts_rank NaN 处理: {'PASS' if test1_pass else 'FAIL'}")
print(f"除法除零处理: {'PASS' if test2_pass else 'FAIL'}")
if test1_pass and test2_pass:
print("\n所有测试通过!")
else:
print("\n部分测试失败,请检查实现")