Files
ProStock/tests/training/test_lightgbm_model.py
liaozhaorun f35a6a76a6 feat(training): 实现 LightGBM 模型
- 新增 LightGBMModel:LightGBM 回归模型实现
- 支持自定义参数(objective, num_leaves, learning_rate, n_estimators 等)
- 使用 LightGBM 原生格式保存/加载模型(不依赖 pickle)
- 支持特征重要性提取
- 已注册到 ModelRegistry(@register_model("lightgbm"))
2026-03-03 22:30:37 +08:00

227 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试 LightGBM 模型
验证 LightGBMModel 的训练、预测、保存和加载功能。
"""
import os
import tempfile
import numpy as np
import polars as pl
import pytest
from src.training.components.models.lightgbm import LightGBMModel
class TestLightGBMModel:
"""LightGBMModel 测试类"""
def test_init_default(self):
"""测试默认初始化"""
model = LightGBMModel()
assert model.name == "lightgbm"
assert model.params["objective"] == "regression"
assert model.params["metric"] == "rmse"
assert model.params["num_leaves"] == 31
assert model.params["learning_rate"] == 0.05
assert model.n_estimators == 100
assert model.model is None
def test_init_custom(self):
"""测试自定义参数"""
model = LightGBMModel(
objective="huber",
metric="mae",
num_leaves=50,
learning_rate=0.1,
n_estimators=200,
)
assert model.params["objective"] == "huber"
assert model.params["metric"] == "mae"
assert model.params["num_leaves"] == 50
assert model.params["learning_rate"] == 0.1
assert model.n_estimators == 200
def test_fit_success(self):
"""测试正常训练"""
# 创建简单回归数据
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
}
)
y = pl.Series("target", [1.5, 3.0, 4.5, 6.0, 7.5])
model = LightGBMModel(n_estimators=10)
result = model.fit(X, y)
# 验证返回 self支持链式调用
assert result is model
# 验证模型已训练
assert model.model is not None
# 验证特征名称已保存
assert model.feature_names_ == ["feature1", "feature2"]
def test_predict_before_fit(self):
"""测试未训练就预测"""
X = pl.DataFrame(
{
"feature1": [1.0, 2.0],
"feature2": [2.0, 4.0],
}
)
model = LightGBMModel()
with pytest.raises(RuntimeError, match="模型尚未训练"):
model.predict(X)
def test_predict_success(self):
"""测试正常预测"""
# 创建回归数据
np.random.seed(42)
n_samples = 100
X_train = pl.DataFrame(
{
"feature1": np.random.randn(n_samples),
"feature2": np.random.randn(n_samples),
}
)
# y = 2*feature1 + 3*feature2 + noise
y_train = pl.Series(
"target",
2 * X_train["feature1"]
+ 3 * X_train["feature2"]
+ np.random.randn(n_samples) * 0.1,
)
model = LightGBMModel(n_estimators=20, learning_rate=0.1)
model.fit(X_train, y_train)
# 预测新数据(使用明显不同的值)
X_test = pl.DataFrame(
{
"feature1": [-2.0, 3.0],
"feature2": [-1.0, 4.0],
}
)
predictions = model.predict(X_test)
# 验证预测结果格式
assert isinstance(predictions, np.ndarray)
assert len(predictions) == 2
# 验证预测值是数值
assert all(np.isfinite(predictions))
# 验证单调性(第二个样本的 feature 值更大,预测值也应更大)
assert predictions[1] > predictions[0]
def test_feature_importance_before_fit(self):
"""测试未训练就获取特征重要性"""
model = LightGBMModel()
assert model.feature_importance() is None
def test_feature_importance_after_fit(self):
"""测试训练后获取特征重要性"""
X = pl.DataFrame(
{
"feature1": np.random.randn(100),
"feature2": np.random.randn(100),
}
)
y = pl.Series("target", X["feature1"] * 2 + X["feature2"] * 0.1)
model = LightGBMModel(n_estimators=10)
model.fit(X, y)
importance = model.feature_importance()
# 验证特征重要性格式
assert importance is not None
assert len(importance) == 2
assert "feature1" in importance.index
assert "feature2" in importance.index
# feature1 的系数更大,重要性应该更高
assert importance["feature1"] >= importance["feature2"]
def test_save_before_fit(self):
"""测试未训练就保存"""
model = LightGBMModel()
with pytest.raises(RuntimeError, match="模型尚未训练"):
model.save("dummy.txt")
def test_save_and_load(self):
"""测试保存和加载"""
# 训练模型
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
}
)
y = pl.Series("target", [2.0, 4.0, 6.0, 8.0, 10.0])
model = LightGBMModel(n_estimators=10, learning_rate=0.1)
model.fit(X, y)
# 保存前预测
X_test = pl.DataFrame(
{
"feature1": [6.0],
"feature2": [12.0],
}
)
pred_before = model.predict(X_test)
# 保存到临时文件
with tempfile.TemporaryDirectory() as tmpdir:
save_path = os.path.join(tmpdir, "model.txt")
model.save(save_path)
# 加载模型
loaded_model = LightGBMModel.load(save_path)
# 验证加载后预测结果相同
pred_after = loaded_model.predict(X_test)
assert pred_after[0] == pytest.approx(pred_before[0], rel=1e-5)
# 验证元数据已恢复
assert loaded_model.feature_names_ == ["feature1", "feature2"]
assert loaded_model.n_estimators == 10
def test_registration(self):
"""测试模型已注册到 registry"""
from src.training.registry import ModelRegistry
model_class = ModelRegistry.get_model("lightgbm")
assert model_class is LightGBMModel
def test_fit_predict_consistency(self):
"""测试多次预测结果一致"""
X = pl.DataFrame(
{
"feature1": np.random.randn(50),
"feature2": np.random.randn(50),
}
)
y = pl.Series("target", X["feature1"] + X["feature2"])
model = LightGBMModel(n_estimators=10)
model.fit(X, y)
X_test = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0],
"feature2": [1.0, 2.0, 3.0],
}
)
# 多次预测应该返回相同结果
pred1 = model.predict(X_test)
pred2 = model.predict(X_test)
np.testing.assert_array_almost_equal(pred1, pred2)
if __name__ == "__main__":
pytest.main([__file__, "-v"])