147 lines
4.4 KiB
Python
147 lines
4.4 KiB
Python
|
|
"""
|
|||
|
|
验证 ts_rank 和除法 Bug 修复的测试
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import polars as pl
|
|||
|
|
import numpy as np
|
|||
|
|
from src.factors import FactorEngine
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ts_rank_nan_handling():
|
|||
|
|
"""测试 ts_rank 正确处理 NaN 值,不应该将 NaN 转换为 0.0"""
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("测试 ts_rank NaN 处理")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
# 创建包含 NaN 的测试数据
|
|||
|
|
df = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"trade_date": ["20240101", "20240101", "20240101", "20240101", "20240101"]
|
|||
|
|
* 2,
|
|||
|
|
"ts_code": ["000001.SZ"] * 5 + ["000002.SZ"] * 5,
|
|||
|
|
"close": [
|
|||
|
|
1.0,
|
|||
|
|
2.0,
|
|||
|
|
np.nan,
|
|||
|
|
4.0,
|
|||
|
|
5.0, # 股票1:第3个是NaN
|
|||
|
|
2.0,
|
|||
|
|
3.0,
|
|||
|
|
4.0,
|
|||
|
|
5.0,
|
|||
|
|
6.0,
|
|||
|
|
], # 股票2:正常数据
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("输入数据:")
|
|||
|
|
print(df)
|
|||
|
|
|
|||
|
|
# 使用引擎计算 ts_rank
|
|||
|
|
engine = FactorEngine()
|
|||
|
|
# 手动设置数据(这里用简单方式)
|
|||
|
|
|
|||
|
|
# 直接测试 translator
|
|||
|
|
from src.factors.translator import PolarsTranslator
|
|||
|
|
from src.factors.parser import FormulaParser
|
|||
|
|
from src.factors.registry import FunctionRegistry
|
|||
|
|
|
|||
|
|
parser = FormulaParser(FunctionRegistry())
|
|||
|
|
ast = parser.parse("ts_rank(close, 3)")
|
|||
|
|
|
|||
|
|
translator = PolarsTranslator()
|
|||
|
|
expr = translator.translate(ast)
|
|||
|
|
|
|||
|
|
result = df.with_columns([expr.alias("ts_rank_result")])
|
|||
|
|
|
|||
|
|
print("\nts_rank 结果:")
|
|||
|
|
print(result.select(["trade_date", "ts_code", "close", "ts_rank_result"]))
|
|||
|
|
|
|||
|
|
# 验证:NaN 输入应该产生 NaN 输出,而不是 0.0
|
|||
|
|
stock1_data = result.filter(pl.col("ts_code") == "000001.SZ")
|
|||
|
|
nan_input_row = stock1_data.filter(pl.col("close").is_nan())
|
|||
|
|
|
|||
|
|
if nan_input_row.height > 0:
|
|||
|
|
rank_value = nan_input_row["ts_rank_result"][0]
|
|||
|
|
if np.isnan(rank_value):
|
|||
|
|
print("\n[PASS] NaN 输入正确产生了 NaN 输出")
|
|||
|
|
return True
|
|||
|
|
elif rank_value == 0.0:
|
|||
|
|
print(f"\n[FAIL] NaN 输入被错误转换为 0.0!")
|
|||
|
|
return False
|
|||
|
|
else:
|
|||
|
|
print(f"\n[FAIL] 意外的输出值: {rank_value}")
|
|||
|
|
return False
|
|||
|
|
else:
|
|||
|
|
print("\n[SKIP] 没有找到 NaN 输入行")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_division_by_zero():
|
|||
|
|
"""测试除法处理除零情况,不应该产生 NaN/inf"""
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试除法除零处理")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
# 创建包含零的数据
|
|||
|
|
df = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"trade_date": ["20240101", "20240101", "20240101"],
|
|||
|
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
|||
|
|
"numerator": [100.0, 100.0, 100.0],
|
|||
|
|
"denominator": [10.0, 0.0, 5.0], # 中间那个是0
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("输入数据:")
|
|||
|
|
print(df)
|
|||
|
|
|
|||
|
|
# 直接测试 translator
|
|||
|
|
from src.factors.translator import PolarsTranslator
|
|||
|
|
from src.factors.parser import FormulaParser
|
|||
|
|
from src.factors.registry import FunctionRegistry
|
|||
|
|
|
|||
|
|
parser = FormulaParser(FunctionRegistry())
|
|||
|
|
ast = parser.parse("numerator / denominator")
|
|||
|
|
|
|||
|
|
translator = PolarsTranslator()
|
|||
|
|
expr = translator.translate(ast)
|
|||
|
|
|
|||
|
|
result = df.with_columns([expr.alias("division_result")])
|
|||
|
|
|
|||
|
|
print("\n除法结果:")
|
|||
|
|
print(result.select(["ts_code", "numerator", "denominator", "division_result"]))
|
|||
|
|
|
|||
|
|
# 验证:除零应该产生 Null,而不是 NaN/inf
|
|||
|
|
zero_row = result.filter(pl.col("denominator") == 0.0)
|
|||
|
|
if zero_row.height > 0:
|
|||
|
|
div_value = zero_row["division_result"][0]
|
|||
|
|
if div_value is None:
|
|||
|
|
print("\n[PASS] 除零正确产生了 Null")
|
|||
|
|
return True
|
|||
|
|
elif np.isnan(div_value) or np.isinf(div_value):
|
|||
|
|
print(f"\n[FAIL] 除零产生了 NaN/inf: {div_value}")
|
|||
|
|
return False
|
|||
|
|
else:
|
|||
|
|
print(f"\n[?] 除零产生了: {div_value} (可能是其他有效值)")
|
|||
|
|
return True # 也可能是预期的行为
|
|||
|
|
else:
|
|||
|
|
print("\n[SKIP] 没有找到除零行")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
test1_pass = test_ts_rank_nan_handling()
|
|||
|
|
test2_pass = test_division_by_zero()
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试结果汇总")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print(f"ts_rank NaN 处理: {'PASS' if test1_pass else 'FAIL'}")
|
|||
|
|
print(f"除法除零处理: {'PASS' if test2_pass else 'FAIL'}")
|
|||
|
|
|
|||
|
|
if test1_pass and test2_pass:
|
|||
|
|
print("\n所有测试通过!")
|
|||
|
|
else:
|
|||
|
|
print("\n部分测试失败,请检查实现")
|