"""测试 TabMRankModel 基础功能""" import pytest import numpy as np import polars as pl import torch from src.training.components.models import TabMRankModel class TestTabMRankModel: """TabMRankModel 测试类""" def test_model_initialization(self): """测试模型初始化""" model = TabMRankModel(params={"ensemble_size": 16}) assert model.name == "tabm_rank" assert model.params["ensemble_size"] == 16 assert model.device in [torch.device("cpu"), torch.device("cuda")] def test_prepare_group_from_dates(self): """测试从日期生成 group 数组""" df = pl.DataFrame( { "trade_date": [ "20240101", "20240101", "20240102", "20240102", "20240102", ], "value": [1, 2, 3, 4, 5], } ) group = TabMRankModel.prepare_group_from_dates(df) assert np.array_equal(group, np.array([2, 3])) def test_convert_to_quantile_labels(self): """测试转换分位数标签""" df = pl.DataFrame( { "trade_date": ["20240101"] * 5 + ["20240102"] * 5, "return": [0.1, 0.05, 0.0, -0.05, -0.1] + [0.2, 0.1, 0.0, -0.1, -0.2], } ) result = TabMRankModel.convert_to_quantile_labels(df, "return", n_quantiles=5) assert "return_rank" in result.columns assert result["return_rank"].dtype == pl.Int64 def test_save_load(self, tmp_path): """测试模型保存和加载(跳过实际训练)""" params = { "ensemble_size": 8, "n_blocks": 2, "d_block": 64, "epochs": 1, } model = TabMRankModel(params=params) model.feature_names_ = ["feat1", "feat2"] # 模拟训练历史 model.training_history_["train_loss"] = [0.5, 0.4] model.training_history_["val_ndcg"] = [0.3, 0.35] # 测试保存前需要初始化模型 # 这里只测试元数据保存 save_path = tmp_path / "test_model" import pickle with open(save_path.with_suffix(".meta"), "wb") as f: pickle.dump( { "params": model.params, "feature_names": model.feature_names_, "training_history": model.training_history_, "device": str(model.device), }, f, ) # 测试加载元数据 loaded_model = TabMRankModel.load(save_path) assert loaded_model.params == params assert loaded_model.feature_names_ == ["feat1", "feat2"] if __name__ == "__main__": pytest.main([__file__, "-v"])