Files
ProStock/tests/training/test_lightgbm_model.py
liaozhaorun 192718095f feat(training): 实现训练模块核心组件(commits 6-9)
- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择
- Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化
- TrainingConfig:pydantic 配置管理,含必填字段和日期验证
- experiment 模块:预留结构
- 从计划中移除 metrics 组件
- 调整 commit 序号(7-10 → 6-9)
- 更新 training/__init__.py 导出所有公开 API
2026-03-03 22:57:01 +08:00

236 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试 LightGBM 模型
验证 LightGBMModel 的训练、预测、保存和加载功能。
"""
import os
import tempfile
import numpy as np
import polars as pl
import pytest
from src.training.components.models.lightgbm import LightGBMModel
class TestLightGBMModel:
"""LightGBMModel 测试类"""
def test_init_default(self):
"""测试默认初始化"""
model = LightGBMModel()
assert model.name == "lightgbm"
assert model.params["objective"] == "regression"
assert model.params["metric"] == "rmse"
assert model.params["num_leaves"] == 31
assert model.params["learning_rate"] == 0.05
assert model.n_estimators == 100
assert model.model is None
def test_init_custom(self):
"""测试自定义参数"""
model = LightGBMModel(
objective="huber",
metric="mae",
num_leaves=50,
learning_rate=0.1,
n_estimators=200,
)
assert model.params["objective"] == "huber"
assert model.params["metric"] == "mae"
assert model.params["num_leaves"] == 50
assert model.params["learning_rate"] == 0.1
assert model.n_estimators == 200
def test_fit_success(self):
"""测试正常训练"""
# 创建简单回归数据
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
}
)
y = pl.Series("target", [1.5, 3.0, 4.5, 6.0, 7.5])
model = LightGBMModel(n_estimators=10)
result = model.fit(X, y)
# 验证返回 self支持链式调用
assert result is model
# 验证模型已训练
assert model.model is not None
# 验证特征名称已保存
assert model.feature_names_ == ["feature1", "feature2"]
def test_predict_before_fit(self):
"""测试未训练就预测"""
X = pl.DataFrame(
{
"feature1": [1.0, 2.0],
"feature2": [2.0, 4.0],
}
)
model = LightGBMModel()
with pytest.raises(RuntimeError, match="模型尚未训练"):
model.predict(X)
def test_predict_success(self):
"""测试正常预测"""
# 创建回归数据
np.random.seed(42)
n_samples = 100
X_train = pl.DataFrame(
{
"feature1": np.random.randn(n_samples),
"feature2": np.random.randn(n_samples),
}
)
# y = 2*feature1 + 3*feature2 + noise
y_train = pl.Series(
"target",
2 * X_train["feature1"]
+ 3 * X_train["feature2"]
+ np.random.randn(n_samples) * 0.1,
)
model = LightGBMModel(n_estimators=20, learning_rate=0.1)
model.fit(X_train, y_train)
# 预测新数据(使用明显不同的值)
X_test = pl.DataFrame(
{
"feature1": [-2.0, 3.0],
"feature2": [-1.0, 4.0],
}
)
predictions = model.predict(X_test)
# 验证预测结果格式
assert isinstance(predictions, np.ndarray)
assert len(predictions) == 2
# 验证预测值是数值
assert all(np.isfinite(predictions))
# 验证单调性(第二个样本的 feature 值更大,预测值也应更大)
assert predictions[1] > predictions[0]
def test_feature_importance_before_fit(self):
"""测试未训练就获取特征重要性"""
model = LightGBMModel()
assert model.feature_importance() is None
def test_feature_importance_after_fit(self):
"""测试训练后获取特征重要性"""
X = pl.DataFrame(
{
"feature1": np.random.randn(100),
"feature2": np.random.randn(100),
}
)
y = pl.Series("target", X["feature1"] * 2 + X["feature2"] * 0.1)
model = LightGBMModel(n_estimators=10)
model.fit(X, y)
importance = model.feature_importance()
# 验证特征重要性格式
assert importance is not None
assert len(importance) == 2
assert "feature1" in importance.index
assert "feature2" in importance.index
# feature1 的系数更大,重要性应该更高
assert importance["feature1"] >= importance["feature2"]
def test_save_before_fit(self):
"""测试未训练就保存"""
model = LightGBMModel()
with pytest.raises(RuntimeError, match="模型尚未训练"):
model.save("dummy.txt")
def test_save_and_load(self):
"""测试保存和加载"""
# 训练模型
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
}
)
y = pl.Series("target", [2.0, 4.0, 6.0, 8.0, 10.0])
model = LightGBMModel(n_estimators=10, learning_rate=0.1)
model.fit(X, y)
# 保存前预测
X_test = pl.DataFrame(
{
"feature1": [6.0],
"feature2": [12.0],
}
)
pred_before = model.predict(X_test)
# 保存到临时文件
with tempfile.TemporaryDirectory() as tmpdir:
save_path = os.path.join(tmpdir, "model.txt")
model.save(save_path)
# 加载模型
loaded_model = LightGBMModel.load(save_path)
# 验证加载后预测结果相同
pred_after = loaded_model.predict(X_test)
assert pred_after[0] == pytest.approx(pred_before[0], rel=1e-5)
# 验证元数据已恢复
assert loaded_model.feature_names_ == ["feature1", "feature2"]
assert loaded_model.n_estimators == 10
def test_registration(self):
"""测试模型已注册到 registry"""
from src.training.registry import ModelRegistry
# 重新导入模型模块以确保注册(处理其他测试 clear 注册表的情况)
import importlib
import src.training.components.models.lightgbm as lightgbm_module
importlib.reload(lightgbm_module)
from src.training.components.models.lightgbm import (
LightGBMModel as ReloadedModel,
)
model_class = ModelRegistry.get_model("lightgbm")
assert model_class is ReloadedModel
def test_fit_predict_consistency(self):
"""测试多次预测结果一致"""
X = pl.DataFrame(
{
"feature1": np.random.randn(50),
"feature2": np.random.randn(50),
}
)
y = pl.Series("target", X["feature1"] + X["feature2"])
model = LightGBMModel(n_estimators=10)
model.fit(X, y)
X_test = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0],
"feature2": [1.0, 2.0, 3.0],
}
)
# 多次预测应该返回相同结果
pred1 = model.predict(X_test)
pred2 = model.predict(X_test)
np.testing.assert_array_almost_equal(pred1, pred2)
if __name__ == "__main__":
pytest.main([__file__, "-v"])