- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择 - Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化 - TrainingConfig:pydantic 配置管理,含必填字段和日期验证 - experiment 模块:预留结构 - 从计划中移除 metrics 组件 - 调整 commit 序号(7-10 → 6-9) - 更新 training/__init__.py 导出所有公开 API
236 lines
7.2 KiB
Python
236 lines
7.2 KiB
Python
"""测试 LightGBM 模型
|
||
|
||
验证 LightGBMModel 的训练、预测、保存和加载功能。
|
||
"""
|
||
|
||
import os
|
||
import tempfile
|
||
|
||
import numpy as np
|
||
import polars as pl
|
||
import pytest
|
||
|
||
from src.training.components.models.lightgbm import LightGBMModel
|
||
|
||
|
||
class TestLightGBMModel:
|
||
"""LightGBMModel 测试类"""
|
||
|
||
def test_init_default(self):
|
||
"""测试默认初始化"""
|
||
model = LightGBMModel()
|
||
assert model.name == "lightgbm"
|
||
assert model.params["objective"] == "regression"
|
||
assert model.params["metric"] == "rmse"
|
||
assert model.params["num_leaves"] == 31
|
||
assert model.params["learning_rate"] == 0.05
|
||
assert model.n_estimators == 100
|
||
assert model.model is None
|
||
|
||
def test_init_custom(self):
|
||
"""测试自定义参数"""
|
||
model = LightGBMModel(
|
||
objective="huber",
|
||
metric="mae",
|
||
num_leaves=50,
|
||
learning_rate=0.1,
|
||
n_estimators=200,
|
||
)
|
||
assert model.params["objective"] == "huber"
|
||
assert model.params["metric"] == "mae"
|
||
assert model.params["num_leaves"] == 50
|
||
assert model.params["learning_rate"] == 0.1
|
||
assert model.n_estimators == 200
|
||
|
||
def test_fit_success(self):
|
||
"""测试正常训练"""
|
||
# 创建简单回归数据
|
||
X = pl.DataFrame(
|
||
{
|
||
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
|
||
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
|
||
}
|
||
)
|
||
y = pl.Series("target", [1.5, 3.0, 4.5, 6.0, 7.5])
|
||
|
||
model = LightGBMModel(n_estimators=10)
|
||
result = model.fit(X, y)
|
||
|
||
# 验证返回 self(支持链式调用)
|
||
assert result is model
|
||
# 验证模型已训练
|
||
assert model.model is not None
|
||
# 验证特征名称已保存
|
||
assert model.feature_names_ == ["feature1", "feature2"]
|
||
|
||
def test_predict_before_fit(self):
|
||
"""测试未训练就预测"""
|
||
X = pl.DataFrame(
|
||
{
|
||
"feature1": [1.0, 2.0],
|
||
"feature2": [2.0, 4.0],
|
||
}
|
||
)
|
||
model = LightGBMModel()
|
||
|
||
with pytest.raises(RuntimeError, match="模型尚未训练"):
|
||
model.predict(X)
|
||
|
||
def test_predict_success(self):
|
||
"""测试正常预测"""
|
||
# 创建回归数据
|
||
np.random.seed(42)
|
||
n_samples = 100
|
||
X_train = pl.DataFrame(
|
||
{
|
||
"feature1": np.random.randn(n_samples),
|
||
"feature2": np.random.randn(n_samples),
|
||
}
|
||
)
|
||
# y = 2*feature1 + 3*feature2 + noise
|
||
y_train = pl.Series(
|
||
"target",
|
||
2 * X_train["feature1"]
|
||
+ 3 * X_train["feature2"]
|
||
+ np.random.randn(n_samples) * 0.1,
|
||
)
|
||
|
||
model = LightGBMModel(n_estimators=20, learning_rate=0.1)
|
||
model.fit(X_train, y_train)
|
||
|
||
# 预测新数据(使用明显不同的值)
|
||
X_test = pl.DataFrame(
|
||
{
|
||
"feature1": [-2.0, 3.0],
|
||
"feature2": [-1.0, 4.0],
|
||
}
|
||
)
|
||
predictions = model.predict(X_test)
|
||
|
||
# 验证预测结果格式
|
||
assert isinstance(predictions, np.ndarray)
|
||
assert len(predictions) == 2
|
||
# 验证预测值是数值
|
||
assert all(np.isfinite(predictions))
|
||
# 验证单调性(第二个样本的 feature 值更大,预测值也应更大)
|
||
assert predictions[1] > predictions[0]
|
||
|
||
def test_feature_importance_before_fit(self):
|
||
"""测试未训练就获取特征重要性"""
|
||
model = LightGBMModel()
|
||
assert model.feature_importance() is None
|
||
|
||
def test_feature_importance_after_fit(self):
|
||
"""测试训练后获取特征重要性"""
|
||
X = pl.DataFrame(
|
||
{
|
||
"feature1": np.random.randn(100),
|
||
"feature2": np.random.randn(100),
|
||
}
|
||
)
|
||
y = pl.Series("target", X["feature1"] * 2 + X["feature2"] * 0.1)
|
||
|
||
model = LightGBMModel(n_estimators=10)
|
||
model.fit(X, y)
|
||
|
||
importance = model.feature_importance()
|
||
|
||
# 验证特征重要性格式
|
||
assert importance is not None
|
||
assert len(importance) == 2
|
||
assert "feature1" in importance.index
|
||
assert "feature2" in importance.index
|
||
# feature1 的系数更大,重要性应该更高
|
||
assert importance["feature1"] >= importance["feature2"]
|
||
|
||
def test_save_before_fit(self):
|
||
"""测试未训练就保存"""
|
||
model = LightGBMModel()
|
||
|
||
with pytest.raises(RuntimeError, match="模型尚未训练"):
|
||
model.save("dummy.txt")
|
||
|
||
def test_save_and_load(self):
|
||
"""测试保存和加载"""
|
||
# 训练模型
|
||
X = pl.DataFrame(
|
||
{
|
||
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
|
||
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
|
||
}
|
||
)
|
||
y = pl.Series("target", [2.0, 4.0, 6.0, 8.0, 10.0])
|
||
|
||
model = LightGBMModel(n_estimators=10, learning_rate=0.1)
|
||
model.fit(X, y)
|
||
|
||
# 保存前预测
|
||
X_test = pl.DataFrame(
|
||
{
|
||
"feature1": [6.0],
|
||
"feature2": [12.0],
|
||
}
|
||
)
|
||
pred_before = model.predict(X_test)
|
||
|
||
# 保存到临时文件
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
save_path = os.path.join(tmpdir, "model.txt")
|
||
model.save(save_path)
|
||
|
||
# 加载模型
|
||
loaded_model = LightGBMModel.load(save_path)
|
||
|
||
# 验证加载后预测结果相同
|
||
pred_after = loaded_model.predict(X_test)
|
||
assert pred_after[0] == pytest.approx(pred_before[0], rel=1e-5)
|
||
|
||
# 验证元数据已恢复
|
||
assert loaded_model.feature_names_ == ["feature1", "feature2"]
|
||
assert loaded_model.n_estimators == 10
|
||
|
||
def test_registration(self):
|
||
"""测试模型已注册到 registry"""
|
||
from src.training.registry import ModelRegistry
|
||
|
||
# 重新导入模型模块以确保注册(处理其他测试 clear 注册表的情况)
|
||
import importlib
|
||
import src.training.components.models.lightgbm as lightgbm_module
|
||
|
||
importlib.reload(lightgbm_module)
|
||
from src.training.components.models.lightgbm import (
|
||
LightGBMModel as ReloadedModel,
|
||
)
|
||
|
||
model_class = ModelRegistry.get_model("lightgbm")
|
||
assert model_class is ReloadedModel
|
||
|
||
def test_fit_predict_consistency(self):
|
||
"""测试多次预测结果一致"""
|
||
X = pl.DataFrame(
|
||
{
|
||
"feature1": np.random.randn(50),
|
||
"feature2": np.random.randn(50),
|
||
}
|
||
)
|
||
y = pl.Series("target", X["feature1"] + X["feature2"])
|
||
|
||
model = LightGBMModel(n_estimators=10)
|
||
model.fit(X, y)
|
||
|
||
X_test = pl.DataFrame(
|
||
{
|
||
"feature1": [1.0, 2.0, 3.0],
|
||
"feature2": [1.0, 2.0, 3.0],
|
||
}
|
||
)
|
||
|
||
# 多次预测应该返回相同结果
|
||
pred1 = model.predict(X_test)
|
||
pred2 = model.predict(X_test)
|
||
np.testing.assert_array_almost_equal(pred1, pred2)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|