feat(training): 实现 LightGBM 模型
- 新增 LightGBMModel:LightGBM 回归模型实现
- 支持自定义参数(objective, num_leaves, learning_rate, n_estimators 等)
- 使用 LightGBM 原生格式保存/加载模型(不依赖 pickle)
- 支持特征重要性提取
- 已注册到 ModelRegistry(@register_model("lightgbm"))
This commit is contained in:
@@ -22,6 +22,9 @@ from src.training.components.processors import (
|
||||
Winsorizer,
|
||||
)
|
||||
|
||||
# 模型
|
||||
from src.training.components.models import LightGBMModel
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"BaseProcessor",
|
||||
@@ -31,4 +34,5 @@ __all__ = [
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
"LightGBMModel",
|
||||
]
|
||||
|
||||
8
src/training/components/models/__init__.py
Normal file
8
src/training/components/models/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""模型子模块
|
||||
|
||||
包含各种机器学习模型的实现。
|
||||
"""
|
||||
|
||||
from src.training.components.models.lightgbm import LightGBMModel
|
||||
|
||||
__all__ = ["LightGBMModel"]
|
||||
194
src/training/components/models/lightgbm.py
Normal file
194
src/training/components/models/lightgbm.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""LightGBM 模型实现
|
||||
|
||||
提供 LightGBM 回归模型的实现,支持特征重要性和原生模型保存。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
|
||||
from src.training.components.base import BaseModel
|
||||
from src.training.registry import register_model
|
||||
|
||||
|
||||
@register_model("lightgbm")
|
||||
class LightGBMModel(BaseModel):
|
||||
"""LightGBM 回归模型
|
||||
|
||||
使用 LightGBM 库实现梯度提升回归树。
|
||||
支持自定义参数、特征重要性提取和原生模型格式保存。
|
||||
|
||||
Attributes:
|
||||
name: 模型名称 "lightgbm"
|
||||
params: LightGBM 参数字典
|
||||
model: 训练后的 LightGBM Booster 对象
|
||||
feature_names_: 特征名称列表
|
||||
"""
|
||||
|
||||
name = "lightgbm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
objective: str = "regression",
|
||||
metric: str = "rmse",
|
||||
num_leaves: int = 31,
|
||||
learning_rate: float = 0.05,
|
||||
n_estimators: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化 LightGBM 模型
|
||||
|
||||
Args:
|
||||
objective: 目标函数,默认 "regression"
|
||||
metric: 评估指标,默认 "rmse"
|
||||
num_leaves: 叶子节点数,默认 31
|
||||
learning_rate: 学习率,默认 0.05
|
||||
n_estimators: 迭代次数,默认 100
|
||||
**kwargs: 其他 LightGBM 参数
|
||||
"""
|
||||
self.params = {
|
||||
"objective": objective,
|
||||
"metric": metric,
|
||||
"num_leaves": num_leaves,
|
||||
"learning_rate": learning_rate,
|
||||
"verbose": -1, # 抑制训练输出
|
||||
**kwargs,
|
||||
}
|
||||
self.n_estimators = n_estimators
|
||||
self.model = None
|
||||
self.feature_names_: Optional[list] = None
|
||||
|
||||
def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel":
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
y: 目标变量 (Polars Series)
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 lightgbm
|
||||
RuntimeError: 训练失败
|
||||
"""
|
||||
try:
|
||||
import lightgbm as lgb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"使用 LightGBMModel 需要安装 lightgbm: pip install lightgbm"
|
||||
)
|
||||
|
||||
# 保存特征名称
|
||||
self.feature_names_ = X.columns
|
||||
|
||||
# 转换为 numpy
|
||||
X_np = X.to_numpy()
|
||||
y_np = y.to_numpy()
|
||||
|
||||
# 创建数据集
|
||||
train_data = lgb.Dataset(X_np, label=y_np)
|
||||
|
||||
# 训练
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
train_data,
|
||||
num_boost_round=self.n_estimators,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
预测结果 (numpy ndarray)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||
|
||||
X_np = X.to_numpy()
|
||||
return self.model.predict(X_np)
|
||||
|
||||
def feature_importance(self) -> Optional[pd.Series]:
|
||||
"""返回特征重要性
|
||||
|
||||
Returns:
|
||||
特征重要性序列,如果模型未训练则返回 None
|
||||
"""
|
||||
if self.model is None or self.feature_names_ is None:
|
||||
return None
|
||||
|
||||
importance = self.model.feature_importance(importance_type="gain")
|
||||
return pd.Series(importance, index=self.feature_names_)
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""保存模型(使用 LightGBM 原生格式)
|
||||
|
||||
使用 LightGBM 的原生格式保存,不依赖 pickle,
|
||||
可以在不同环境中加载。
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,无法保存")
|
||||
|
||||
self.model.save_model(path)
|
||||
|
||||
# 同时保存特征名称(LightGBM 原生格式不保存这个)
|
||||
import json
|
||||
|
||||
meta_path = path + ".meta.json"
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"feature_names": self.feature_names_,
|
||||
"params": self.params,
|
||||
"n_estimators": self.n_estimators,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "LightGBMModel":
|
||||
"""加载模型
|
||||
|
||||
从 LightGBM 原生格式加载模型。
|
||||
|
||||
Args:
|
||||
path: 模型文件路径
|
||||
|
||||
Returns:
|
||||
加载的 LightGBMModel 实例
|
||||
"""
|
||||
import lightgbm as lgb
|
||||
import json
|
||||
|
||||
instance = cls()
|
||||
instance.model = lgb.Booster(model_file=path)
|
||||
|
||||
# 加载元数据
|
||||
meta_path = path + ".meta.json"
|
||||
try:
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
instance.feature_names_ = meta.get("feature_names")
|
||||
instance.params = meta.get("params", instance.params)
|
||||
instance.n_estimators = meta.get("n_estimators", instance.n_estimators)
|
||||
except FileNotFoundError:
|
||||
# 如果没有元数据文件,继续运行(feature_names_ 为 None)
|
||||
pass
|
||||
|
||||
return instance
|
||||
Reference in New Issue
Block a user