"""TabM 集成测试 测试 TabMModel 和 TabMRegressionTask 的完整训练流程。 """ import os import sys from pathlib import Path import numpy as np import polars as pl import pytest import torch # 确保 src 在路径中 sys.path.insert(0, str(Path(__file__).parent.parent)) from src.training.components.models import TabMModel from src.training.tasks import TabMRegressionTask # ========================================== # 测试数据准备 # ========================================== def create_sample_data(n_samples: int = 1000, n_features: int = 20, seed: int = 42): """创建样本数据用于测试 Args: n_samples: 样本数量 n_features: 特征数量 seed: 随机种子 Returns: (train_X, train_y, val_X, val_y, test_X, test_y) """ np.random.seed(seed) torch.manual_seed(seed) # 创建特征矩阵 X_train = pl.DataFrame( np.random.randn(n_samples, n_features).astype(np.float32), schema=[f"feature_{i}" for i in range(n_features)], ) y_train = pl.Series("target", np.random.randn(n_samples).astype(np.float32)) X_val = pl.DataFrame( np.random.randn(n_samples // 2, n_features).astype(np.float32), schema=[f"feature_{i}" for i in range(n_features)], ) y_val = pl.Series("target", np.random.randn(n_samples // 2).astype(np.float32)) X_test = pl.DataFrame( np.random.randn(n_samples // 2, n_features).astype(np.float32), schema=[f"feature_{i}" for i in range(n_features)], ) y_test = pl.Series("target", np.random.randn(n_samples // 2).astype(np.float32)) return X_train, y_train, X_val, y_val, X_test, y_test # ========================================== # TabMModel 测试 # ========================================== class TestTabMModel: """TabMModel 单元测试""" def test_initialization(self): """测试模型初始化""" params = { "n_blocks": 2, "d_block": 128, "ensemble_size": 8, # 小规模集成用于测试 "batch_size": 64, "epochs": 2, } model = TabMModel(params) assert model.name == "tabm" assert model.params == params assert model.device.type in ["cuda", "cpu"] assert model.model is None # 未训练时为 None def test_fit_and_predict(self): """测试训练和预测""" # 创建小规模数据 X_train, y_train, X_val, y_val, X_test, _ = create_sample_data( n_samples=200, n_features=10, seed=42 ) params = { "n_blocks": 1, "d_block": 64, "ensemble_size": 4, "batch_size": 32, "epochs": 2, "early_stopping_patience": 10, } model = TabMModel(params) # 训练 model.fit(X_train, y_train, eval_set=(X_val, y_val)) # 验证模型已训练 assert model.model is not None assert len(model.training_history_["train_loss"]) > 0 # 预测 predictions = model.predict(X_test) # 验证预测结果 assert isinstance(predictions, np.ndarray) assert len(predictions) == len(X_test) assert predictions.shape == (len(X_test),) def test_save_and_load(self, tmp_path): """测试模型保存和加载""" # 创建数据 X_train, y_train, X_val, y_val, _, _ = create_sample_data( n_samples=200, n_features=10, seed=42 ) params = { "n_blocks": 1, "d_block": 64, "ensemble_size": 4, "batch_size": 32, "epochs": 2, } # 训练模型 model = TabMModel(params) model.fit(X_train, y_train, eval_set=(X_val, y_val)) # 保存 save_path = tmp_path / "test_tabm_model" model.save(str(save_path)) # 加载 loaded_model = TabMModel.load(str(save_path)) # 验证加载的模型 assert loaded_model.params == params assert loaded_model.feature_names_ == model.feature_names_ assert loaded_model.model is not None # 预测结果应该一致 pred1 = model.predict(X_val) pred2 = loaded_model.predict(X_val) np.testing.assert_allclose(pred1, pred2, rtol=1e-5) # ========================================== # TabMRegressionTask 测试 # ========================================== class TestTabMRegressionTask: """TabMRegressionTask 单元测试""" def test_initialization(self): """测试任务初始化""" params = { "n_blocks": 2, "d_block": 128, "ensemble_size": 8, "batch_size": 64, "epochs": 2, } task = TabMRegressionTask(model_params=params, label_name="target") assert task.model_params == params assert task.label_name == "target" assert task.model is None def test_prepare_labels(self): """测试标签准备(回归任务不做转换)""" params = { "ensemble_size": 4, "epochs": 2, } task = TabMRegressionTask(model_params=params, label_name="target") # 创建测试数据 data = { "train": { "X": pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), "y": pl.Series("target", [0.1, 0.2, 0.3]), } } result = task.prepare_labels(data) # 回归任务不做转换,数据应该保持不变 assert result == data def test_fit_train_and_predict(self): """测试完整训练和预测流程""" # 创建数据 X_train, y_train, X_val, y_val, X_test, y_test = create_sample_data( n_samples=300, n_features=10, seed=42 ) params = { "n_blocks": 1, "d_block": 64, "ensemble_size": 4, "batch_size": 32, "epochs": 3, } task = TabMRegressionTask(model_params=params, label_name="target") # 准备数据格式 train_data = {"X": X_train, "y": y_train} val_data = {"X": X_val, "y": y_val} # 训练 task.fit(train_data, val_data) # 验证模型已训练 assert task.get_model() is not None # 预测 predictions = task.predict({"X": X_test}) # 验证预测结果 assert len(predictions) == len(X_test) # ========================================== # 集成测试 # ========================================== class TestTabMIntegration: """TabM 集成测试""" def test_full_workflow(self): """测试完整工作流程""" # 创建数据 X_train, y_train, X_val, y_val, X_test, y_test = create_sample_data( n_samples=500, n_features=15, seed=42 ) params = { "n_blocks": 2, "d_block": 128, "ensemble_size": 8, "batch_size": 64, "epochs": 5, } # 1. 创建 Task task = TabMRegressionTask(model_params=params, label_name="target") # 2. 准备数据 train_data = {"X": X_train, "y": y_train} val_data = {"X": X_val, "y": y_val} # 3. 训练 task.fit(train_data, val_data) # 4. 验证训练历史 model = task.get_model() assert len(model.training_history_["train_loss"]) > 0 assert len(model.training_history_["val_ic"]) > 0 # 5. 预测 predictions = task.predict({"X": X_test}) # 6. 验证预测质量 # 简单验证:预测值不应全为常数 assert np.std(predictions) > 1e-6, "预测值全为常数,可能是模型未正常训练" # 验证预测值与真实值存在一定相关性 correlation = np.corrcoef(predictions, y_test.to_numpy())[0, 1] # 注意:随机数据的相关性可能很低,这是正常的 print(f"预测与真实值相关系数: {correlation:.4f}") def test_gpu_availability(self): """测试 GPU 可用性""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") params = { "ensemble_size": 2, "epochs": 1, } model = TabMModel(params) assert model.device == device expected_type = "cuda" if torch.cuda.is_available() else "cpu" assert model.device.type == expected_type # ========================================== # 运行测试 # ========================================== if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])