feat(training): LightGBM支持验证集早停
- 为fit方法添加eval_set参数,支持验证集评估和早停 - 因子引擎简化初始化,移除metadata_path参数 - 回归实验精简因子定义,移除冗余因子库
This commit is contained in:
@@ -49,12 +49,18 @@ class LightGBMModel(BaseModel):
|
||||
self.model = None
|
||||
self.feature_names_: Optional[list] = None
|
||||
|
||||
def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel":
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
eval_set: Optional[tuple] = None,
|
||||
) -> "LightGBMModel":
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
y: 目标变量 (Polars Series)
|
||||
eval_set: 验证集元组 (X_val, y_val),用于早停
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
@@ -76,6 +82,14 @@ class LightGBMModel(BaseModel):
|
||||
|
||||
train_data = lgb.Dataset(X_np, label=y_np)
|
||||
|
||||
# 准备验证集
|
||||
valid_sets = None
|
||||
if eval_set is not None:
|
||||
X_val, y_val = eval_set
|
||||
X_val_np = X_val.to_numpy()
|
||||
y_val_np = y_val.to_numpy()
|
||||
valid_sets = lgb.Dataset(X_val_np, label=y_val_np, reference=train_data)
|
||||
|
||||
# 从 params 中提取 num_boost_round,默认 100
|
||||
num_boost_round = self.params.pop("n_estimators", 100)
|
||||
|
||||
@@ -83,6 +97,7 @@ class LightGBMModel(BaseModel):
|
||||
self.params,
|
||||
train_data,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[valid_sets] if valid_sets else None,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user