feat(training): 添加数据质量检查工具并重构实验脚本
- 新增 check_data_quality 函数用于检测全空/全零/全NaN数据质量问题 - 重构 register_factors 函数,消除 FEATURE_COLS 和 PROCESSORS 冗余定义 - 修复实验脚本中特征列表不一致的问题,确保处理器覆盖所有特征 - 优化 LambdaRank 模型参数配置
This commit is contained in:
@@ -37,6 +37,9 @@ from src.training.components.filters import BaseFilter, STFilter
|
||||
# 训练核心
|
||||
from src.training.core import StockPoolManager, Trainer
|
||||
|
||||
# 工具函数
|
||||
from src.training.utils import check_data_quality
|
||||
|
||||
# 配置
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
@@ -67,6 +70,8 @@ __all__ = [
|
||||
# 训练核心
|
||||
"StockPoolManager",
|
||||
"Trainer",
|
||||
# 工具函数
|
||||
"check_data_quality",
|
||||
# 配置
|
||||
"TrainingConfig",
|
||||
]
|
||||
|
||||
@@ -98,6 +98,7 @@ class LightGBMLambdaRankModel(BaseModel):
|
||||
|
||||
self.model = None
|
||||
self.feature_names_: Optional[list] = None
|
||||
self.evals_result_: Optional[dict] = None # 存储训练评估结果
|
||||
|
||||
def fit(
|
||||
self,
|
||||
@@ -155,8 +156,9 @@ class LightGBMLambdaRankModel(BaseModel):
|
||||
# 创建训练数据集
|
||||
train_data = lgb.Dataset(X_np, label=y_np, group=group)
|
||||
|
||||
# 准备验证集
|
||||
# 准备验证集和验证集名称
|
||||
valid_sets = [train_data]
|
||||
valid_names = ["train"]
|
||||
if eval_set is not None:
|
||||
X_val, y_val, group_val = eval_set
|
||||
X_val_np = X_val.to_numpy() if isinstance(X_val, pl.DataFrame) else X_val
|
||||
@@ -169,15 +171,23 @@ class LightGBMLambdaRankModel(BaseModel):
|
||||
|
||||
val_data = lgb.Dataset(X_val_np, label=y_val_np, group=group_val)
|
||||
valid_sets.append(val_data)
|
||||
valid_names.append("val")
|
||||
|
||||
# 初始化评估结果存储
|
||||
self.evals_result_ = {}
|
||||
|
||||
# 训练
|
||||
callbacks = [lgb.early_stopping(stopping_rounds=self.early_stopping_rounds)]
|
||||
callbacks = [
|
||||
lgb.early_stopping(stopping_rounds=self.early_stopping_rounds),
|
||||
lgb.record_evaluation(self.evals_result_),
|
||||
]
|
||||
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
train_data,
|
||||
num_boost_round=self.n_estimators,
|
||||
valid_sets=valid_sets,
|
||||
valid_names=valid_names,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
@@ -201,6 +211,185 @@ class LightGBMLambdaRankModel(BaseModel):
|
||||
X_np = X.to_numpy()
|
||||
return self.model.predict(X_np)
|
||||
|
||||
def get_evals_result(self) -> Optional[dict]:
|
||||
"""获取训练评估结果
|
||||
|
||||
Returns:
|
||||
评估结果字典,包含训练集和验证集的指标历史
|
||||
格式: {'train': {'metric_name': [...]}, 'val': {'metric_name': [...]}}
|
||||
如果模型尚未训练,返回 None
|
||||
"""
|
||||
return self.evals_result_
|
||||
|
||||
def plot_metric(
|
||||
self,
|
||||
metric: Optional[str] = None,
|
||||
figsize: tuple = (10, 6),
|
||||
title: Optional[str] = None,
|
||||
ax=None,
|
||||
):
|
||||
"""绘制训练指标曲线
|
||||
|
||||
使用 LightGBM 原生的 plot_metric 接口绘制训练曲线。
|
||||
|
||||
Args:
|
||||
metric: 要绘制的指标名称,如 'ndcg@5'、'ndcg@10' 等。
|
||||
如果为 None,则自动选择第一个可用的 NDCG 指标。
|
||||
figsize: 图大小,默认 (10, 6)
|
||||
title: 图表标题,如果为 None 则自动生成
|
||||
ax: matplotlib Axes 对象,如果为 None 则创建新图
|
||||
|
||||
Returns:
|
||||
matplotlib Axes 对象
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型尚未训练
|
||||
ValueError: 指定的指标不存在
|
||||
|
||||
Examples:
|
||||
>>> model.plot_metric('ndcg@20') # 绘制 ndcg@20 曲线
|
||||
>>> model.plot_metric() # 自动选择指标
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||
|
||||
if self.evals_result_ is None or not self.evals_result_:
|
||||
raise RuntimeError("没有可用的评估结果")
|
||||
|
||||
import lightgbm as lgb
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 如果没有指定指标,自动选择第一个 NDCG 指标
|
||||
if metric is None:
|
||||
available_metrics = list(self.evals_result_.get("train", {}).keys())
|
||||
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
|
||||
if ndcg_metrics:
|
||||
metric = ndcg_metrics[0]
|
||||
elif available_metrics:
|
||||
metric = available_metrics[0]
|
||||
else:
|
||||
raise ValueError("没有可用的评估指标")
|
||||
|
||||
# 检查指标是否存在
|
||||
if metric not in self.evals_result_.get("train", {}):
|
||||
available = list(self.evals_result_.get("train", {}).keys())
|
||||
raise ValueError(f"指标 '{metric}' 不存在。可用的指标: {available}")
|
||||
|
||||
# 创建图表
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# 使用 LightGBM 原生接口绘制
|
||||
lgb.plot_metric(self.evals_result_, metric=metric, ax=ax)
|
||||
|
||||
# 设置标题
|
||||
if title is None:
|
||||
title = f"Training Metric ({metric.upper()}) over Iterations"
|
||||
ax.set_title(title, fontsize=12, fontweight="bold")
|
||||
|
||||
return ax
|
||||
|
||||
def plot_all_metrics(
|
||||
self,
|
||||
metrics: Optional[list] = None,
|
||||
figsize: tuple = (14, 10),
|
||||
max_cols: int = 2,
|
||||
):
|
||||
"""绘制所有训练指标曲线
|
||||
|
||||
在一个图表中绘制多个指标的训练曲线。
|
||||
|
||||
Args:
|
||||
metrics: 要绘制的指标列表,如果为 None 则绘制所有 NDCG 指标
|
||||
figsize: 图大小,默认 (14, 10)
|
||||
max_cols: 每行最多显示的子图数,默认 2
|
||||
|
||||
Returns:
|
||||
matplotlib Figure 对象
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型尚未训练
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||
|
||||
if self.evals_result_ is None or not self.evals_result_:
|
||||
raise RuntimeError("没有可用的评估结果")
|
||||
|
||||
import lightgbm as lgb
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
available_metrics = list(self.evals_result_.get("train", {}).keys())
|
||||
|
||||
# 如果没有指定指标,使用所有 NDCG 指标(最多 4 个)
|
||||
if metrics is None:
|
||||
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
|
||||
metrics = ndcg_metrics[:4] if ndcg_metrics else available_metrics[:4]
|
||||
|
||||
if not metrics:
|
||||
raise ValueError("没有可用的评估指标")
|
||||
|
||||
# 计算子图布局
|
||||
n_metrics = len(metrics)
|
||||
n_cols = min(max_cols, n_metrics)
|
||||
n_rows = (n_metrics + n_cols - 1) // n_cols
|
||||
|
||||
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
||||
if n_metrics == 1:
|
||||
axes = [axes]
|
||||
else:
|
||||
axes = (
|
||||
axes.flatten()
|
||||
if n_rows > 1
|
||||
else [axes]
|
||||
if n_cols == 1
|
||||
else axes.flatten()
|
||||
)
|
||||
|
||||
for idx, metric in enumerate(metrics):
|
||||
if idx < len(axes):
|
||||
ax = axes[idx]
|
||||
if metric in available_metrics:
|
||||
self.plot_metric(metric=metric, ax=ax)
|
||||
ax.set_title(f"{metric.upper()}", fontsize=11, fontweight="bold")
|
||||
else:
|
||||
ax.text(
|
||||
0.5,
|
||||
0.5,
|
||||
f"Metric '{metric}' not found",
|
||||
ha="center",
|
||||
va="center",
|
||||
transform=ax.transAxes,
|
||||
)
|
||||
|
||||
# 隐藏多余的子图
|
||||
for idx in range(n_metrics, len(axes)):
|
||||
axes[idx].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
def get_best_iteration(self) -> Optional[int]:
|
||||
"""获取最佳迭代轮数(考虑早停)
|
||||
|
||||
Returns:
|
||||
最佳迭代轮数,如果模型未训练返回 None
|
||||
"""
|
||||
if self.model is None:
|
||||
return None
|
||||
return self.model.best_iteration
|
||||
|
||||
def get_best_score(self) -> Optional[dict]:
|
||||
"""获取最佳评分
|
||||
|
||||
Returns:
|
||||
最佳评分字典,格式: {'valid_0': {'metric': value}, 'valid_1': {...}}
|
||||
如果模型未训练返回 None
|
||||
"""
|
||||
if self.model is None:
|
||||
return None
|
||||
return self.model.best_score
|
||||
|
||||
def feature_importance(self) -> Optional[pd.Series]:
|
||||
"""返回特征重要性
|
||||
|
||||
|
||||
171
src/training/utils.py
Normal file
171
src/training/utils.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""训练模块工具函数
|
||||
|
||||
提供数据质量检查、验证等通用工具函数。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import polars as pl
|
||||
|
||||
|
||||
def check_data_quality(
|
||||
df: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
date_col: str = "trade_date",
|
||||
check_all_null: bool = True,
|
||||
check_all_zero: bool = True,
|
||||
check_all_nan: bool = True,
|
||||
raise_on_error: bool = True,
|
||||
) -> Dict[str, List[Dict[str, Union[str, int]]]]:
|
||||
"""检查数据质量,检测某天某个因子是否全部为空、0或NaN
|
||||
|
||||
此检查必须在 fillna、标准化等处理之前执行,否则错误会被掩盖。
|
||||
|
||||
Args:
|
||||
df: 待检查的数据
|
||||
feature_cols: 特征列名列表
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
check_all_null: 是否检查全空,默认 True
|
||||
check_all_zero: 是否检查全零,默认 True
|
||||
check_all_nan: 是否检查全NaN,默认 True
|
||||
raise_on_error: 发现问题时是否抛出异常,默认 True
|
||||
|
||||
Returns:
|
||||
检查结果字典,格式:
|
||||
{
|
||||
"all_null": [{"date": "20240101", "factor": "factor_name", "count": 500}],
|
||||
"all_zero": [{"date": "20240101", "factor": "factor_name", "count": 500}],
|
||||
"all_nan": [{"date": "20240101", "factor": "factor_name", "count": 500}],
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 当发现质量问题且 raise_on_error=True 时
|
||||
|
||||
Examples:
|
||||
>>> import polars as pl
|
||||
>>> df = pl.DataFrame({
|
||||
... "trade_date": ["20240101", "20240101", "20240102"],
|
||||
... "ts_code": ["000001.SZ", "000002.SZ", "000001.SZ"],
|
||||
... "factor1": [1.0, 2.0, None],
|
||||
... "factor2": [0.0, 0.0, 1.0],
|
||||
... })
|
||||
>>> result = check_data_quality(df, ["factor1", "factor2"])
|
||||
"""
|
||||
issues = {
|
||||
"all_null": [],
|
||||
"all_zero": [],
|
||||
"all_nan": [],
|
||||
}
|
||||
|
||||
# 获取实际存在的特征列
|
||||
existing_cols = [col for col in feature_cols if col in df.columns]
|
||||
if not existing_cols:
|
||||
return issues
|
||||
|
||||
# 按日期分组检查
|
||||
for date in df[date_col].unique():
|
||||
day_data = df.filter(pl.col(date_col) == date)
|
||||
day_str = str(date)
|
||||
|
||||
for col in existing_cols:
|
||||
if not day_data[col].dtype.is_numeric():
|
||||
continue
|
||||
|
||||
col_data = day_data[col]
|
||||
non_null_count = col_data.count()
|
||||
|
||||
if non_null_count == 0:
|
||||
# 该日期该因子完全没有有效数据
|
||||
if check_all_null:
|
||||
issues["all_null"].append(
|
||||
{
|
||||
"date": day_str,
|
||||
"factor": col,
|
||||
"count": len(day_data),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否全为零
|
||||
if check_all_zero:
|
||||
abs_sum = col_data.abs().sum()
|
||||
if abs_sum == 0:
|
||||
issues["all_zero"].append(
|
||||
{
|
||||
"date": day_str,
|
||||
"factor": col,
|
||||
"count": non_null_count,
|
||||
}
|
||||
)
|
||||
|
||||
# 检查是否全为NaN(在Polars中表现为null)
|
||||
if check_all_nan:
|
||||
null_count = col_data.null_count()
|
||||
if null_count == non_null_count:
|
||||
issues["all_nan"].append(
|
||||
{
|
||||
"date": day_str,
|
||||
"factor": col,
|
||||
"count": non_null_count,
|
||||
}
|
||||
)
|
||||
|
||||
# 生成报告
|
||||
total_issues = sum(len(v) for v in issues.values())
|
||||
|
||||
if total_issues > 0:
|
||||
report_lines = ["\n" + "=" * 80, "数据质量检查报告", "=" * 80]
|
||||
|
||||
if issues["all_null"]:
|
||||
report_lines.append(f"\n[严重] 发现 {len(issues['all_null'])} 个全空因子:")
|
||||
report_lines.append(
|
||||
" (某天的某个因子所有值都是 null,可能是数据缺失或计算错误)"
|
||||
)
|
||||
for issue in issues["all_null"][:10]: # 最多显示10条
|
||||
msg = f" - 日期 {issue['date']}: {issue['factor']} (样本数: {issue['count']})"
|
||||
report_lines.append(msg)
|
||||
if len(issues["all_null"]) > 10:
|
||||
report_lines.append(f" ... 还有 {len(issues['all_null']) - 10} 个")
|
||||
|
||||
if issues["all_zero"]:
|
||||
report_lines.append(f"\n[警告] 发现 {len(issues['all_zero'])} 个全零因子:")
|
||||
report_lines.append(
|
||||
" (某天的某个因子所有值都是 0,可能是计算错误或数据源问题)"
|
||||
)
|
||||
for issue in issues["all_zero"][:10]:
|
||||
msg = f" - 日期 {issue['date']}: {issue['factor']} (样本数: {issue['count']})"
|
||||
report_lines.append(msg)
|
||||
if len(issues["all_zero"]) > 10:
|
||||
report_lines.append(f" ... 还有 {len(issues['all_zero']) - 10} 个")
|
||||
|
||||
if issues["all_nan"]:
|
||||
report_lines.append(f"\n[警告] 发现 {len(issues['all_nan'])} 个全NaN因子:")
|
||||
report_lines.append(" (某天的某个因子所有值都是 NaN,可能是数值计算错误)")
|
||||
for issue in issues["all_nan"][:10]:
|
||||
msg = f" - 日期 {issue['date']}: {issue['factor']} (样本数: {issue['count']})"
|
||||
report_lines.append(msg)
|
||||
if len(issues["all_nan"]) > 10:
|
||||
report_lines.append(f" ... 还有 {len(issues['all_nan']) - 10} 个")
|
||||
|
||||
report_lines.extend(
|
||||
[
|
||||
"\n" + "-" * 80,
|
||||
"建议处理方式:",
|
||||
" 1. 检查因子定义和数据源,确认计算逻辑是否正确",
|
||||
" 2. 如果是预期内的缺失(如新股无历史数据),考虑调整因子计算窗口",
|
||||
" 3. 如果是数据同步问题,重新同步相关数据",
|
||||
" 4. 可以使用 filter 排除问题日期或因子",
|
||||
"=" * 80,
|
||||
]
|
||||
)
|
||||
|
||||
report = "\n".join(report_lines)
|
||||
print(report)
|
||||
|
||||
if raise_on_error:
|
||||
raise ValueError(
|
||||
f"数据质量检查失败: 发现 {total_issues} 个问题,"
|
||||
f"详见上方报告。如需忽略,请设置 raise_on_error=False"
|
||||
)
|
||||
|
||||
return issues
|
||||
Reference in New Issue
Block a user