From f35a6a76a6b0d9fc04856d5504502758938ce337 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Tue, 3 Mar 2026 22:30:37 +0800 Subject: [PATCH] =?UTF-8?q?feat(training):=20=E5=AE=9E=E7=8E=B0=20LightGBM?= =?UTF-8?q?=20=E6=A8=A1=E5=9E=8B=20-=20=E6=96=B0=E5=A2=9E=20LightGBMModel?= =?UTF-8?q?=EF=BC=9ALightGBM=20=E5=9B=9E=E5=BD=92=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20-=20=E6=94=AF=E6=8C=81=E8=87=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=E5=8F=82=E6=95=B0=EF=BC=88objective,=20num=5Fleaves,?= =?UTF-8?q?=20learning=5Frate,=20n=5Festimators=20=E7=AD=89=EF=BC=89=20-?= =?UTF-8?q?=20=E4=BD=BF=E7=94=A8=20LightGBM=20=E5=8E=9F=E7=94=9F=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E4=BF=9D=E5=AD=98/=E5=8A=A0=E8=BD=BD=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=EF=BC=88=E4=B8=8D=E4=BE=9D=E8=B5=96=20pickle=EF=BC=89?= =?UTF-8?q?=20-=20=E6=94=AF=E6=8C=81=E7=89=B9=E5=BE=81=E9=87=8D=E8=A6=81?= =?UTF-8?q?=E6=80=A7=E6=8F=90=E5=8F=96=20-=20=E5=B7=B2=E6=B3=A8=E5=86=8C?= =?UTF-8?q?=E5=88=B0=20ModelRegistry=EF=BC=88@register=5Fmodel("lightgbm")?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/training/components/__init__.py | 4 + src/training/components/models/__init__.py | 8 + src/training/components/models/lightgbm.py | 194 ++++++++++++++++++ tests/training/test_lightgbm_model.py | 226 +++++++++++++++++++++ 4 files changed, 432 insertions(+) create mode 100644 src/training/components/models/__init__.py create mode 100644 src/training/components/models/lightgbm.py create mode 100644 tests/training/test_lightgbm_model.py diff --git a/src/training/components/__init__.py b/src/training/components/__init__.py index 3106fed..8879569 100644 --- a/src/training/components/__init__.py +++ b/src/training/components/__init__.py @@ -22,6 +22,9 @@ from src.training.components.processors import ( Winsorizer, ) +# 模型 +from src.training.components.models import LightGBMModel + __all__ = [ "BaseModel", "BaseProcessor", @@ -31,4 +34,5 @@ __all__ = [ "StandardScaler", "CrossSectionalStandardScaler", "Winsorizer", + "LightGBMModel", ] diff --git a/src/training/components/models/__init__.py b/src/training/components/models/__init__.py new file mode 100644 index 0000000..216e259 --- /dev/null +++ b/src/training/components/models/__init__.py @@ -0,0 +1,8 @@ +"""模型子模块 + +包含各种机器学习模型的实现。 +""" + +from src.training.components.models.lightgbm import LightGBMModel + +__all__ = ["LightGBMModel"] diff --git a/src/training/components/models/lightgbm.py b/src/training/components/models/lightgbm.py new file mode 100644 index 0000000..3162eac --- /dev/null +++ b/src/training/components/models/lightgbm.py @@ -0,0 +1,194 @@ +"""LightGBM 模型实现 + +提供 LightGBM 回归模型的实现,支持特征重要性和原生模型保存。 +""" + +from typing import Optional + +import numpy as np +import pandas as pd +import polars as pl + +from src.training.components.base import BaseModel +from src.training.registry import register_model + + +@register_model("lightgbm") +class LightGBMModel(BaseModel): + """LightGBM 回归模型 + + 使用 LightGBM 库实现梯度提升回归树。 + 支持自定义参数、特征重要性提取和原生模型格式保存。 + + Attributes: + name: 模型名称 "lightgbm" + params: LightGBM 参数字典 + model: 训练后的 LightGBM Booster 对象 + feature_names_: 特征名称列表 + """ + + name = "lightgbm" + + def __init__( + self, + objective: str = "regression", + metric: str = "rmse", + num_leaves: int = 31, + learning_rate: float = 0.05, + n_estimators: int = 100, + **kwargs, + ): + """初始化 LightGBM 模型 + + Args: + objective: 目标函数,默认 "regression" + metric: 评估指标,默认 "rmse" + num_leaves: 叶子节点数,默认 31 + learning_rate: 学习率,默认 0.05 + n_estimators: 迭代次数,默认 100 + **kwargs: 其他 LightGBM 参数 + """ + self.params = { + "objective": objective, + "metric": metric, + "num_leaves": num_leaves, + "learning_rate": learning_rate, + "verbose": -1, # 抑制训练输出 + **kwargs, + } + self.n_estimators = n_estimators + self.model = None + self.feature_names_: Optional[list] = None + + def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel": + """训练模型 + + Args: + X: 特征矩阵 (Polars DataFrame) + y: 目标变量 (Polars Series) + + Returns: + self (支持链式调用) + + Raises: + ImportError: 未安装 lightgbm + RuntimeError: 训练失败 + """ + try: + import lightgbm as lgb + except ImportError: + raise ImportError( + "使用 LightGBMModel 需要安装 lightgbm: pip install lightgbm" + ) + + # 保存特征名称 + self.feature_names_ = X.columns + + # 转换为 numpy + X_np = X.to_numpy() + y_np = y.to_numpy() + + # 创建数据集 + train_data = lgb.Dataset(X_np, label=y_np) + + # 训练 + self.model = lgb.train( + self.params, + train_data, + num_boost_round=self.n_estimators, + ) + + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测 + + Args: + X: 特征矩阵 (Polars DataFrame) + + Returns: + 预测结果 (numpy ndarray) + + Raises: + RuntimeError: 模型未训练时调用 + """ + if self.model is None: + raise RuntimeError("模型尚未训练,请先调用 fit()") + + X_np = X.to_numpy() + return self.model.predict(X_np) + + def feature_importance(self) -> Optional[pd.Series]: + """返回特征重要性 + + Returns: + 特征重要性序列,如果模型未训练则返回 None + """ + if self.model is None or self.feature_names_ is None: + return None + + importance = self.model.feature_importance(importance_type="gain") + return pd.Series(importance, index=self.feature_names_) + + def save(self, path: str) -> None: + """保存模型(使用 LightGBM 原生格式) + + 使用 LightGBM 的原生格式保存,不依赖 pickle, + 可以在不同环境中加载。 + + Args: + path: 保存路径 + + Raises: + RuntimeError: 模型未训练时调用 + """ + if self.model is None: + raise RuntimeError("模型尚未训练,无法保存") + + self.model.save_model(path) + + # 同时保存特征名称(LightGBM 原生格式不保存这个) + import json + + meta_path = path + ".meta.json" + with open(meta_path, "w") as f: + json.dump( + { + "feature_names": self.feature_names_, + "params": self.params, + "n_estimators": self.n_estimators, + }, + f, + ) + + @classmethod + def load(cls, path: str) -> "LightGBMModel": + """加载模型 + + 从 LightGBM 原生格式加载模型。 + + Args: + path: 模型文件路径 + + Returns: + 加载的 LightGBMModel 实例 + """ + import lightgbm as lgb + import json + + instance = cls() + instance.model = lgb.Booster(model_file=path) + + # 加载元数据 + meta_path = path + ".meta.json" + try: + with open(meta_path, "r") as f: + meta = json.load(f) + instance.feature_names_ = meta.get("feature_names") + instance.params = meta.get("params", instance.params) + instance.n_estimators = meta.get("n_estimators", instance.n_estimators) + except FileNotFoundError: + # 如果没有元数据文件,继续运行(feature_names_ 为 None) + pass + + return instance diff --git a/tests/training/test_lightgbm_model.py b/tests/training/test_lightgbm_model.py new file mode 100644 index 0000000..3a673f2 --- /dev/null +++ b/tests/training/test_lightgbm_model.py @@ -0,0 +1,226 @@ +"""测试 LightGBM 模型 + +验证 LightGBMModel 的训练、预测、保存和加载功能。 +""" + +import os +import tempfile + +import numpy as np +import polars as pl +import pytest + +from src.training.components.models.lightgbm import LightGBMModel + + +class TestLightGBMModel: + """LightGBMModel 测试类""" + + def test_init_default(self): + """测试默认初始化""" + model = LightGBMModel() + assert model.name == "lightgbm" + assert model.params["objective"] == "regression" + assert model.params["metric"] == "rmse" + assert model.params["num_leaves"] == 31 + assert model.params["learning_rate"] == 0.05 + assert model.n_estimators == 100 + assert model.model is None + + def test_init_custom(self): + """测试自定义参数""" + model = LightGBMModel( + objective="huber", + metric="mae", + num_leaves=50, + learning_rate=0.1, + n_estimators=200, + ) + assert model.params["objective"] == "huber" + assert model.params["metric"] == "mae" + assert model.params["num_leaves"] == 50 + assert model.params["learning_rate"] == 0.1 + assert model.n_estimators == 200 + + def test_fit_success(self): + """测试正常训练""" + # 创建简单回归数据 + X = pl.DataFrame( + { + "feature1": [1.0, 2.0, 3.0, 4.0, 5.0], + "feature2": [2.0, 4.0, 6.0, 8.0, 10.0], + } + ) + y = pl.Series("target", [1.5, 3.0, 4.5, 6.0, 7.5]) + + model = LightGBMModel(n_estimators=10) + result = model.fit(X, y) + + # 验证返回 self(支持链式调用) + assert result is model + # 验证模型已训练 + assert model.model is not None + # 验证特征名称已保存 + assert model.feature_names_ == ["feature1", "feature2"] + + def test_predict_before_fit(self): + """测试未训练就预测""" + X = pl.DataFrame( + { + "feature1": [1.0, 2.0], + "feature2": [2.0, 4.0], + } + ) + model = LightGBMModel() + + with pytest.raises(RuntimeError, match="模型尚未训练"): + model.predict(X) + + def test_predict_success(self): + """测试正常预测""" + # 创建回归数据 + np.random.seed(42) + n_samples = 100 + X_train = pl.DataFrame( + { + "feature1": np.random.randn(n_samples), + "feature2": np.random.randn(n_samples), + } + ) + # y = 2*feature1 + 3*feature2 + noise + y_train = pl.Series( + "target", + 2 * X_train["feature1"] + + 3 * X_train["feature2"] + + np.random.randn(n_samples) * 0.1, + ) + + model = LightGBMModel(n_estimators=20, learning_rate=0.1) + model.fit(X_train, y_train) + + # 预测新数据(使用明显不同的值) + X_test = pl.DataFrame( + { + "feature1": [-2.0, 3.0], + "feature2": [-1.0, 4.0], + } + ) + predictions = model.predict(X_test) + + # 验证预测结果格式 + assert isinstance(predictions, np.ndarray) + assert len(predictions) == 2 + # 验证预测值是数值 + assert all(np.isfinite(predictions)) + # 验证单调性(第二个样本的 feature 值更大,预测值也应更大) + assert predictions[1] > predictions[0] + + def test_feature_importance_before_fit(self): + """测试未训练就获取特征重要性""" + model = LightGBMModel() + assert model.feature_importance() is None + + def test_feature_importance_after_fit(self): + """测试训练后获取特征重要性""" + X = pl.DataFrame( + { + "feature1": np.random.randn(100), + "feature2": np.random.randn(100), + } + ) + y = pl.Series("target", X["feature1"] * 2 + X["feature2"] * 0.1) + + model = LightGBMModel(n_estimators=10) + model.fit(X, y) + + importance = model.feature_importance() + + # 验证特征重要性格式 + assert importance is not None + assert len(importance) == 2 + assert "feature1" in importance.index + assert "feature2" in importance.index + # feature1 的系数更大,重要性应该更高 + assert importance["feature1"] >= importance["feature2"] + + def test_save_before_fit(self): + """测试未训练就保存""" + model = LightGBMModel() + + with pytest.raises(RuntimeError, match="模型尚未训练"): + model.save("dummy.txt") + + def test_save_and_load(self): + """测试保存和加载""" + # 训练模型 + X = pl.DataFrame( + { + "feature1": [1.0, 2.0, 3.0, 4.0, 5.0], + "feature2": [2.0, 4.0, 6.0, 8.0, 10.0], + } + ) + y = pl.Series("target", [2.0, 4.0, 6.0, 8.0, 10.0]) + + model = LightGBMModel(n_estimators=10, learning_rate=0.1) + model.fit(X, y) + + # 保存前预测 + X_test = pl.DataFrame( + { + "feature1": [6.0], + "feature2": [12.0], + } + ) + pred_before = model.predict(X_test) + + # 保存到临时文件 + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, "model.txt") + model.save(save_path) + + # 加载模型 + loaded_model = LightGBMModel.load(save_path) + + # 验证加载后预测结果相同 + pred_after = loaded_model.predict(X_test) + assert pred_after[0] == pytest.approx(pred_before[0], rel=1e-5) + + # 验证元数据已恢复 + assert loaded_model.feature_names_ == ["feature1", "feature2"] + assert loaded_model.n_estimators == 10 + + def test_registration(self): + """测试模型已注册到 registry""" + from src.training.registry import ModelRegistry + + model_class = ModelRegistry.get_model("lightgbm") + assert model_class is LightGBMModel + + def test_fit_predict_consistency(self): + """测试多次预测结果一致""" + X = pl.DataFrame( + { + "feature1": np.random.randn(50), + "feature2": np.random.randn(50), + } + ) + y = pl.Series("target", X["feature1"] + X["feature2"]) + + model = LightGBMModel(n_estimators=10) + model.fit(X, y) + + X_test = pl.DataFrame( + { + "feature1": [1.0, 2.0, 3.0], + "feature2": [1.0, 2.0, 3.0], + } + ) + + # 多次预测应该返回相同结果 + pred1 = model.predict(X_test) + pred2 = model.predict(X_test) + np.testing.assert_array_almost_equal(pred1, pred2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])