Files
ProStock/tests/test_tabm_integration.py

311 lines
8.5 KiB
Python
Raw Normal View History

"""TabM 集成测试
测试 TabMModel TabMRegressionTask 的完整训练流程
"""
import os
import sys
from pathlib import Path
import numpy as np
import polars as pl
import pytest
import torch
# 确保 src 在路径中
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.training.components.models import TabMModel
from src.training.tasks import TabMRegressionTask
# ==========================================
# 测试数据准备
# ==========================================
def create_sample_data(n_samples: int = 1000, n_features: int = 20, seed: int = 42):
"""创建样本数据用于测试
Args:
n_samples: 样本数量
n_features: 特征数量
seed: 随机种子
Returns:
(train_X, train_y, val_X, val_y, test_X, test_y)
"""
np.random.seed(seed)
torch.manual_seed(seed)
# 创建特征矩阵
X_train = pl.DataFrame(
np.random.randn(n_samples, n_features).astype(np.float32),
schema=[f"feature_{i}" for i in range(n_features)],
)
y_train = pl.Series("target", np.random.randn(n_samples).astype(np.float32))
X_val = pl.DataFrame(
np.random.randn(n_samples // 2, n_features).astype(np.float32),
schema=[f"feature_{i}" for i in range(n_features)],
)
y_val = pl.Series("target", np.random.randn(n_samples // 2).astype(np.float32))
X_test = pl.DataFrame(
np.random.randn(n_samples // 2, n_features).astype(np.float32),
schema=[f"feature_{i}" for i in range(n_features)],
)
y_test = pl.Series("target", np.random.randn(n_samples // 2).astype(np.float32))
return X_train, y_train, X_val, y_val, X_test, y_test
# ==========================================
# TabMModel 测试
# ==========================================
class TestTabMModel:
"""TabMModel 单元测试"""
def test_initialization(self):
"""测试模型初始化"""
params = {
"n_blocks": 2,
"d_block": 128,
"ensemble_size": 8, # 小规模集成用于测试
"batch_size": 64,
"epochs": 2,
}
model = TabMModel(params)
assert model.name == "tabm"
assert model.params == params
assert model.device.type in ["cuda", "cpu"]
assert model.model is None # 未训练时为 None
def test_fit_and_predict(self):
"""测试训练和预测"""
# 创建小规模数据
X_train, y_train, X_val, y_val, X_test, _ = create_sample_data(
n_samples=200, n_features=10, seed=42
)
params = {
"n_blocks": 1,
"d_block": 64,
"ensemble_size": 4,
"batch_size": 32,
"epochs": 2,
"early_stopping_patience": 10,
}
model = TabMModel(params)
# 训练
model.fit(X_train, y_train, eval_set=(X_val, y_val))
# 验证模型已训练
assert model.model is not None
assert len(model.training_history_["train_loss"]) > 0
# 预测
predictions = model.predict(X_test)
# 验证预测结果
assert isinstance(predictions, np.ndarray)
assert len(predictions) == len(X_test)
assert predictions.shape == (len(X_test),)
def test_save_and_load(self, tmp_path):
"""测试模型保存和加载"""
# 创建数据
X_train, y_train, X_val, y_val, _, _ = create_sample_data(
n_samples=200, n_features=10, seed=42
)
params = {
"n_blocks": 1,
"d_block": 64,
"ensemble_size": 4,
"batch_size": 32,
"epochs": 2,
}
# 训练模型
model = TabMModel(params)
model.fit(X_train, y_train, eval_set=(X_val, y_val))
# 保存
save_path = tmp_path / "test_tabm_model"
model.save(str(save_path))
# 加载
loaded_model = TabMModel.load(str(save_path))
# 验证加载的模型
assert loaded_model.params == params
assert loaded_model.feature_names_ == model.feature_names_
assert loaded_model.model is not None
# 预测结果应该一致
pred1 = model.predict(X_val)
pred2 = loaded_model.predict(X_val)
np.testing.assert_allclose(pred1, pred2, rtol=1e-5)
# ==========================================
# TabMRegressionTask 测试
# ==========================================
class TestTabMRegressionTask:
"""TabMRegressionTask 单元测试"""
def test_initialization(self):
"""测试任务初始化"""
params = {
"n_blocks": 2,
"d_block": 128,
"ensemble_size": 8,
"batch_size": 64,
"epochs": 2,
}
task = TabMRegressionTask(model_params=params, label_name="target")
assert task.model_params == params
assert task.label_name == "target"
assert task.model is None
def test_prepare_labels(self):
"""测试标签准备(回归任务不做转换)"""
params = {
"ensemble_size": 4,
"epochs": 2,
}
task = TabMRegressionTask(model_params=params, label_name="target")
# 创建测试数据
data = {
"train": {
"X": pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}),
"y": pl.Series("target", [0.1, 0.2, 0.3]),
}
}
result = task.prepare_labels(data)
# 回归任务不做转换,数据应该保持不变
assert result == data
def test_fit_train_and_predict(self):
"""测试完整训练和预测流程"""
# 创建数据
X_train, y_train, X_val, y_val, X_test, y_test = create_sample_data(
n_samples=300, n_features=10, seed=42
)
params = {
"n_blocks": 1,
"d_block": 64,
"ensemble_size": 4,
"batch_size": 32,
"epochs": 3,
}
task = TabMRegressionTask(model_params=params, label_name="target")
# 准备数据格式
train_data = {"X": X_train, "y": y_train}
val_data = {"X": X_val, "y": y_val}
# 训练
task.fit(train_data, val_data)
# 验证模型已训练
assert task.get_model() is not None
# 预测
predictions = task.predict({"X": X_test})
# 验证预测结果
assert len(predictions) == len(X_test)
# ==========================================
# 集成测试
# ==========================================
class TestTabMIntegration:
"""TabM 集成测试"""
def test_full_workflow(self):
"""测试完整工作流程"""
# 创建数据
X_train, y_train, X_val, y_val, X_test, y_test = create_sample_data(
n_samples=500, n_features=15, seed=42
)
params = {
"n_blocks": 2,
"d_block": 128,
"ensemble_size": 8,
"batch_size": 64,
"epochs": 5,
}
# 1. 创建 Task
task = TabMRegressionTask(model_params=params, label_name="target")
# 2. 准备数据
train_data = {"X": X_train, "y": y_train}
val_data = {"X": X_val, "y": y_val}
# 3. 训练
task.fit(train_data, val_data)
# 4. 验证训练历史
model = task.get_model()
assert len(model.training_history_["train_loss"]) > 0
assert len(model.training_history_["val_ic"]) > 0
# 5. 预测
predictions = task.predict({"X": X_test})
# 6. 验证预测质量
# 简单验证:预测值不应全为常数
assert np.std(predictions) > 1e-6, "预测值全为常数,可能是模型未正常训练"
# 验证预测值与真实值存在一定相关性
correlation = np.corrcoef(predictions, y_test.to_numpy())[0, 1]
# 注意:随机数据的相关性可能很低,这是正常的
print(f"预测与真实值相关系数: {correlation:.4f}")
def test_gpu_availability(self):
"""测试 GPU 可用性"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = {
"ensemble_size": 2,
"epochs": 1,
}
model = TabMModel(params)
assert model.device == device
expected_type = "cuda" if torch.cuda.is_available() else "cpu"
assert model.device.type == expected_type
# ==========================================
# 运行测试
# ==========================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])