- 分析GTJA_alpha032等因子在不同LOOKBACK_DAYS下的差异来源 - 验证cs_rank嵌套和截面股票数量对结果的影响 - 测试ts_rank NaN处理和除法除零修复
147 lines
4.4 KiB
Python
147 lines
4.4 KiB
Python
"""
|
||
验证 ts_rank 和除法 Bug 修复的测试
|
||
"""
|
||
|
||
import polars as pl
|
||
import numpy as np
|
||
from src.factors import FactorEngine
|
||
|
||
|
||
def test_ts_rank_nan_handling():
|
||
"""测试 ts_rank 正确处理 NaN 值,不应该将 NaN 转换为 0.0"""
|
||
print("=" * 60)
|
||
print("测试 ts_rank NaN 处理")
|
||
print("=" * 60)
|
||
|
||
# 创建包含 NaN 的测试数据
|
||
df = pl.DataFrame(
|
||
{
|
||
"trade_date": ["20240101", "20240101", "20240101", "20240101", "20240101"]
|
||
* 2,
|
||
"ts_code": ["000001.SZ"] * 5 + ["000002.SZ"] * 5,
|
||
"close": [
|
||
1.0,
|
||
2.0,
|
||
np.nan,
|
||
4.0,
|
||
5.0, # 股票1:第3个是NaN
|
||
2.0,
|
||
3.0,
|
||
4.0,
|
||
5.0,
|
||
6.0,
|
||
], # 股票2:正常数据
|
||
}
|
||
)
|
||
|
||
print("输入数据:")
|
||
print(df)
|
||
|
||
# 使用引擎计算 ts_rank
|
||
engine = FactorEngine()
|
||
# 手动设置数据(这里用简单方式)
|
||
|
||
# 直接测试 translator
|
||
from src.factors.translator import PolarsTranslator
|
||
from src.factors.parser import FormulaParser
|
||
from src.factors.registry import FunctionRegistry
|
||
|
||
parser = FormulaParser(FunctionRegistry())
|
||
ast = parser.parse("ts_rank(close, 3)")
|
||
|
||
translator = PolarsTranslator()
|
||
expr = translator.translate(ast)
|
||
|
||
result = df.with_columns([expr.alias("ts_rank_result")])
|
||
|
||
print("\nts_rank 结果:")
|
||
print(result.select(["trade_date", "ts_code", "close", "ts_rank_result"]))
|
||
|
||
# 验证:NaN 输入应该产生 NaN 输出,而不是 0.0
|
||
stock1_data = result.filter(pl.col("ts_code") == "000001.SZ")
|
||
nan_input_row = stock1_data.filter(pl.col("close").is_nan())
|
||
|
||
if nan_input_row.height > 0:
|
||
rank_value = nan_input_row["ts_rank_result"][0]
|
||
if np.isnan(rank_value):
|
||
print("\n[PASS] NaN 输入正确产生了 NaN 输出")
|
||
return True
|
||
elif rank_value == 0.0:
|
||
print(f"\n[FAIL] NaN 输入被错误转换为 0.0!")
|
||
return False
|
||
else:
|
||
print(f"\n[FAIL] 意外的输出值: {rank_value}")
|
||
return False
|
||
else:
|
||
print("\n[SKIP] 没有找到 NaN 输入行")
|
||
return False
|
||
|
||
|
||
def test_division_by_zero():
|
||
"""测试除法处理除零情况,不应该产生 NaN/inf"""
|
||
print("\n" + "=" * 60)
|
||
print("测试除法除零处理")
|
||
print("=" * 60)
|
||
|
||
# 创建包含零的数据
|
||
df = pl.DataFrame(
|
||
{
|
||
"trade_date": ["20240101", "20240101", "20240101"],
|
||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||
"numerator": [100.0, 100.0, 100.0],
|
||
"denominator": [10.0, 0.0, 5.0], # 中间那个是0
|
||
}
|
||
)
|
||
|
||
print("输入数据:")
|
||
print(df)
|
||
|
||
# 直接测试 translator
|
||
from src.factors.translator import PolarsTranslator
|
||
from src.factors.parser import FormulaParser
|
||
from src.factors.registry import FunctionRegistry
|
||
|
||
parser = FormulaParser(FunctionRegistry())
|
||
ast = parser.parse("numerator / denominator")
|
||
|
||
translator = PolarsTranslator()
|
||
expr = translator.translate(ast)
|
||
|
||
result = df.with_columns([expr.alias("division_result")])
|
||
|
||
print("\n除法结果:")
|
||
print(result.select(["ts_code", "numerator", "denominator", "division_result"]))
|
||
|
||
# 验证:除零应该产生 Null,而不是 NaN/inf
|
||
zero_row = result.filter(pl.col("denominator") == 0.0)
|
||
if zero_row.height > 0:
|
||
div_value = zero_row["division_result"][0]
|
||
if div_value is None:
|
||
print("\n[PASS] 除零正确产生了 Null")
|
||
return True
|
||
elif np.isnan(div_value) or np.isinf(div_value):
|
||
print(f"\n[FAIL] 除零产生了 NaN/inf: {div_value}")
|
||
return False
|
||
else:
|
||
print(f"\n[?] 除零产生了: {div_value} (可能是其他有效值)")
|
||
return True # 也可能是预期的行为
|
||
else:
|
||
print("\n[SKIP] 没有找到除零行")
|
||
return False
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test1_pass = test_ts_rank_nan_handling()
|
||
test2_pass = test_division_by_zero()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试结果汇总")
|
||
print("=" * 60)
|
||
print(f"ts_rank NaN 处理: {'PASS' if test1_pass else 'FAIL'}")
|
||
print(f"除法除零处理: {'PASS' if test2_pass else 'FAIL'}")
|
||
|
||
if test1_pass and test2_pass:
|
||
print("\n所有测试通过!")
|
||
else:
|
||
print("\n部分测试失败,请检查实现")
|