feat(training): 新增 TabM 排序学习模型支持并优化训练流程

- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置
- 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块
- 调整数据处理器移除过度的 NaN/null 硬填充逻辑
- 优化 RankTask 评估指标使用分位数标签替代原始收益率
- 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
2026-04-04 22:39:58 +08:00
parent 9e7d4241c6
commit a66d5e9db3
16 changed files with 1663 additions and 344 deletions

View File

@@ -0,0 +1,88 @@
"""测试 TabMRankModel 基础功能"""
import pytest
import numpy as np
import polars as pl
import torch
from src.training.components.models import TabMRankModel
class TestTabMRankModel:
"""TabMRankModel 测试类"""
def test_model_initialization(self):
"""测试模型初始化"""
model = TabMRankModel(params={"ensemble_size": 16})
assert model.name == "tabm_rank"
assert model.params["ensemble_size"] == 16
assert model.device in [torch.device("cpu"), torch.device("cuda")]
def test_prepare_group_from_dates(self):
"""测试从日期生成 group 数组"""
df = pl.DataFrame(
{
"trade_date": [
"20240101",
"20240101",
"20240102",
"20240102",
"20240102",
],
"value": [1, 2, 3, 4, 5],
}
)
group = TabMRankModel.prepare_group_from_dates(df)
assert np.array_equal(group, np.array([2, 3]))
def test_convert_to_quantile_labels(self):
"""测试转换分位数标签"""
df = pl.DataFrame(
{
"trade_date": ["20240101"] * 5 + ["20240102"] * 5,
"return": [0.1, 0.05, 0.0, -0.05, -0.1] + [0.2, 0.1, 0.0, -0.1, -0.2],
}
)
result = TabMRankModel.convert_to_quantile_labels(df, "return", n_quantiles=5)
assert "return_rank" in result.columns
assert result["return_rank"].dtype == pl.Int64
def test_save_load(self, tmp_path):
"""测试模型保存和加载(跳过实际训练)"""
params = {
"ensemble_size": 8,
"n_blocks": 2,
"d_block": 64,
"epochs": 1,
}
model = TabMRankModel(params=params)
model.feature_names_ = ["feat1", "feat2"]
# 模拟训练历史
model.training_history_["train_loss"] = [0.5, 0.4]
model.training_history_["val_ndcg"] = [0.3, 0.35]
# 测试保存前需要初始化模型
# 这里只测试元数据保存
save_path = tmp_path / "test_model"
import pickle
with open(save_path.with_suffix(".meta"), "wb") as f:
pickle.dump(
{
"params": model.params,
"feature_names": model.feature_names_,
"training_history": model.training_history_,
"device": str(model.device),
},
f,
)
# 测试加载元数据
loaded_model = TabMRankModel.load(save_path)
assert loaded_model.params == params
assert loaded_model.feature_names_ == ["feat1", "feat2"]
if __name__ == "__main__":
pytest.main([__file__, "-v"])