feat(training): 新增 TabM 排序学习模型支持并优化训练流程
- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置 - 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块 - 调整数据处理器移除过度的 NaN/null 硬填充逻辑 - 优化 RankTask 评估指标使用分位数标签替代原始收益率 - 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
88
tests/test_tabm_rank_model.py
Normal file
88
tests/test_tabm_rank_model.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""测试 TabMRankModel 基础功能"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import torch
|
||||
|
||||
from src.training.components.models import TabMRankModel
|
||||
|
||||
|
||||
class TestTabMRankModel:
|
||||
"""TabMRankModel 测试类"""
|
||||
|
||||
def test_model_initialization(self):
|
||||
"""测试模型初始化"""
|
||||
model = TabMRankModel(params={"ensemble_size": 16})
|
||||
assert model.name == "tabm_rank"
|
||||
assert model.params["ensemble_size"] == 16
|
||||
assert model.device in [torch.device("cpu"), torch.device("cuda")]
|
||||
|
||||
def test_prepare_group_from_dates(self):
|
||||
"""测试从日期生成 group 数组"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240102",
|
||||
"20240102",
|
||||
],
|
||||
"value": [1, 2, 3, 4, 5],
|
||||
}
|
||||
)
|
||||
group = TabMRankModel.prepare_group_from_dates(df)
|
||||
assert np.array_equal(group, np.array([2, 3]))
|
||||
|
||||
def test_convert_to_quantile_labels(self):
|
||||
"""测试转换分位数标签"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101"] * 5 + ["20240102"] * 5,
|
||||
"return": [0.1, 0.05, 0.0, -0.05, -0.1] + [0.2, 0.1, 0.0, -0.1, -0.2],
|
||||
}
|
||||
)
|
||||
result = TabMRankModel.convert_to_quantile_labels(df, "return", n_quantiles=5)
|
||||
assert "return_rank" in result.columns
|
||||
assert result["return_rank"].dtype == pl.Int64
|
||||
|
||||
def test_save_load(self, tmp_path):
|
||||
"""测试模型保存和加载(跳过实际训练)"""
|
||||
params = {
|
||||
"ensemble_size": 8,
|
||||
"n_blocks": 2,
|
||||
"d_block": 64,
|
||||
"epochs": 1,
|
||||
}
|
||||
model = TabMRankModel(params=params)
|
||||
model.feature_names_ = ["feat1", "feat2"]
|
||||
|
||||
# 模拟训练历史
|
||||
model.training_history_["train_loss"] = [0.5, 0.4]
|
||||
model.training_history_["val_ndcg"] = [0.3, 0.35]
|
||||
|
||||
# 测试保存前需要初始化模型
|
||||
# 这里只测试元数据保存
|
||||
save_path = tmp_path / "test_model"
|
||||
import pickle
|
||||
|
||||
with open(save_path.with_suffix(".meta"), "wb") as f:
|
||||
pickle.dump(
|
||||
{
|
||||
"params": model.params,
|
||||
"feature_names": model.feature_names_,
|
||||
"training_history": model.training_history_,
|
||||
"device": str(model.device),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
# 测试加载元数据
|
||||
loaded_model = TabMRankModel.load(save_path)
|
||||
assert loaded_model.params == params
|
||||
assert loaded_model.feature_names_ == ["feat1", "feat2"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user