feat(training): 实现 Trainer 模块化重构 (Trainer V2)
- 新增 FactorManager 组件:统一管理多种来源因子 - 新增 DataPipeline 组件:完整数据处理流程(注册、过滤、划分、预处理) - 新增 Task 策略组件:BaseTask 抽象基类、RegressionTask、RankTask - 新增 ResultAnalyzer 组件:特征重要性分析和结果组装 - 新增 TrainerV2:作为纯调度引擎协调各组件 - 支持回归和排序学习两种训练模式 - 采用组合模式解耦训练流程,消除代码重复
This commit is contained in:
@@ -185,131 +185,6 @@ class LightGBMLambdaRankModel(BaseModel):
|
||||
return None
|
||||
return self.model.best_score
|
||||
|
||||
def plot_metric(
|
||||
self,
|
||||
metric: Optional[str] = None,
|
||||
figsize: tuple = (10, 6),
|
||||
title: Optional[str] = None,
|
||||
ax=None,
|
||||
):
|
||||
"""绘制训练指标曲线
|
||||
|
||||
Args:
|
||||
metric: 要绘制的指标名称,如 'ndcg@5'
|
||||
figsize: 图大小,默认 (10, 6)
|
||||
title: 图表标题
|
||||
ax: matplotlib Axes 对象
|
||||
|
||||
Returns:
|
||||
matplotlib Axes 对象
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||
|
||||
if self.evals_result_ is None or not self.evals_result_:
|
||||
raise RuntimeError("没有可用的评估结果")
|
||||
|
||||
import lightgbm as lgb
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if metric is None:
|
||||
available_metrics = list(self.evals_result_.get("train", {}).keys())
|
||||
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
|
||||
if ndcg_metrics:
|
||||
metric = ndcg_metrics[0]
|
||||
elif available_metrics:
|
||||
metric = available_metrics[0]
|
||||
else:
|
||||
raise ValueError("没有可用的评估指标")
|
||||
|
||||
if metric not in self.evals_result_.get("train", {}):
|
||||
available = list(self.evals_result_.get("train", {}).keys())
|
||||
raise ValueError(f"指标 '{metric}' 不存在。可用的指标: {available}")
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
lgb.plot_metric(self.evals_result_, metric=metric, ax=ax)
|
||||
|
||||
if title is None:
|
||||
assert metric is not None
|
||||
title = f"Training Metric ({metric.upper()}) over Iterations"
|
||||
ax.set_title(title, fontsize=12, fontweight="bold")
|
||||
|
||||
return ax
|
||||
|
||||
def plot_all_metrics(
|
||||
self,
|
||||
metrics: Optional[list] = None,
|
||||
figsize: tuple = (14, 10),
|
||||
max_cols: int = 2,
|
||||
):
|
||||
"""绘制所有训练指标曲线
|
||||
|
||||
Args:
|
||||
metrics: 要绘制的指标列表
|
||||
figsize: 图大小,默认 (14, 10)
|
||||
max_cols: 每行最多显示的子图数,默认 2
|
||||
|
||||
Returns:
|
||||
matplotlib Figure 对象
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||
|
||||
if self.evals_result_ is None or not self.evals_result_:
|
||||
raise RuntimeError("没有可用的评估结果")
|
||||
|
||||
import lightgbm as lgb
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
available_metrics = list(self.evals_result_.get("train", {}).keys())
|
||||
|
||||
if metrics is None:
|
||||
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
|
||||
metrics = ndcg_metrics[:4] if ndcg_metrics else available_metrics[:4]
|
||||
|
||||
if not metrics:
|
||||
raise ValueError("没有可用的评估指标")
|
||||
|
||||
n_metrics = len(metrics)
|
||||
n_cols = min(max_cols, n_metrics)
|
||||
n_rows = (n_metrics + n_cols - 1) // n_cols
|
||||
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
||||
if n_metrics == 1:
|
||||
axes = [axes]
|
||||
else:
|
||||
axes = (
|
||||
axes.flatten()
|
||||
if n_rows > 1
|
||||
else [axes]
|
||||
if n_cols == 1
|
||||
else axes.flatten()
|
||||
)
|
||||
|
||||
for idx, metric in enumerate(metrics):
|
||||
if idx < len(axes):
|
||||
ax = axes[idx]
|
||||
if metric in available_metrics:
|
||||
self.plot_metric(metric=metric, ax=ax)
|
||||
ax.set_title(f"{metric.upper()}", fontsize=11, fontweight="bold")
|
||||
else:
|
||||
ax.text(
|
||||
0.5,
|
||||
0.5,
|
||||
f"Metric '{metric}' not found",
|
||||
ha="center",
|
||||
va="center",
|
||||
transform=ax.transAxes,
|
||||
)
|
||||
|
||||
for idx in range(n_metrics, len(axes)):
|
||||
axes[idx].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
def feature_importance(self) -> Optional[pd.Series]:
|
||||
"""返回特征重要性
|
||||
|
||||
|
||||
Reference in New Issue
Block a user