- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置 - 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块 - 调整数据处理器移除过度的 NaN/null 硬填充逻辑 - 优化 RankTask 评估指标使用分位数标签替代原始收益率 - 更新实验脚本处理器顺序与模型超参数配置
89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
"""测试 TabMRankModel 基础功能"""
|
|
|
|
import pytest
|
|
import numpy as np
|
|
import polars as pl
|
|
import torch
|
|
|
|
from src.training.components.models import TabMRankModel
|
|
|
|
|
|
class TestTabMRankModel:
|
|
"""TabMRankModel 测试类"""
|
|
|
|
def test_model_initialization(self):
|
|
"""测试模型初始化"""
|
|
model = TabMRankModel(params={"ensemble_size": 16})
|
|
assert model.name == "tabm_rank"
|
|
assert model.params["ensemble_size"] == 16
|
|
assert model.device in [torch.device("cpu"), torch.device("cuda")]
|
|
|
|
def test_prepare_group_from_dates(self):
|
|
"""测试从日期生成 group 数组"""
|
|
df = pl.DataFrame(
|
|
{
|
|
"trade_date": [
|
|
"20240101",
|
|
"20240101",
|
|
"20240102",
|
|
"20240102",
|
|
"20240102",
|
|
],
|
|
"value": [1, 2, 3, 4, 5],
|
|
}
|
|
)
|
|
group = TabMRankModel.prepare_group_from_dates(df)
|
|
assert np.array_equal(group, np.array([2, 3]))
|
|
|
|
def test_convert_to_quantile_labels(self):
|
|
"""测试转换分位数标签"""
|
|
df = pl.DataFrame(
|
|
{
|
|
"trade_date": ["20240101"] * 5 + ["20240102"] * 5,
|
|
"return": [0.1, 0.05, 0.0, -0.05, -0.1] + [0.2, 0.1, 0.0, -0.1, -0.2],
|
|
}
|
|
)
|
|
result = TabMRankModel.convert_to_quantile_labels(df, "return", n_quantiles=5)
|
|
assert "return_rank" in result.columns
|
|
assert result["return_rank"].dtype == pl.Int64
|
|
|
|
def test_save_load(self, tmp_path):
|
|
"""测试模型保存和加载(跳过实际训练)"""
|
|
params = {
|
|
"ensemble_size": 8,
|
|
"n_blocks": 2,
|
|
"d_block": 64,
|
|
"epochs": 1,
|
|
}
|
|
model = TabMRankModel(params=params)
|
|
model.feature_names_ = ["feat1", "feat2"]
|
|
|
|
# 模拟训练历史
|
|
model.training_history_["train_loss"] = [0.5, 0.4]
|
|
model.training_history_["val_ndcg"] = [0.3, 0.35]
|
|
|
|
# 测试保存前需要初始化模型
|
|
# 这里只测试元数据保存
|
|
save_path = tmp_path / "test_model"
|
|
import pickle
|
|
|
|
with open(save_path.with_suffix(".meta"), "wb") as f:
|
|
pickle.dump(
|
|
{
|
|
"params": model.params,
|
|
"feature_names": model.feature_names_,
|
|
"training_history": model.training_history_,
|
|
"device": str(model.device),
|
|
},
|
|
f,
|
|
)
|
|
|
|
# 测试加载元数据
|
|
loaded_model = TabMRankModel.load(save_path)
|
|
assert loaded_model.params == params
|
|
assert loaded_model.feature_names_ == ["feat1", "feat2"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|