feat(training): 新增 TabM 排序学习模型支持并优化训练流程
- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置 - 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块 - 调整数据处理器移除过度的 NaN/null 硬填充逻辑 - 优化 RankTask 评估指标使用分位数标签替代原始收益率 - 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
@@ -43,6 +43,12 @@ from src.training.core import StockPoolManager, Trainer
|
||||
# 工具函数
|
||||
from src.training.utils import check_data_quality
|
||||
|
||||
# 数据质量分析器
|
||||
from src.training.data_quality_analyzer import (
|
||||
DataQualityAnalyzer,
|
||||
analyze_data_quality,
|
||||
)
|
||||
|
||||
# 配置
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
@@ -85,6 +91,9 @@ __all__ = [
|
||||
"Trainer",
|
||||
# 工具函数
|
||||
"check_data_quality",
|
||||
# 数据质量分析器
|
||||
"DataQualityAnalyzer",
|
||||
"analyze_data_quality",
|
||||
# 配置
|
||||
"TrainingConfig",
|
||||
# 新增:模块化 Trainer 组件(推荐使用)
|
||||
|
||||
@@ -7,6 +7,11 @@ from src.training.components.models.lightgbm import LightGBMModel
|
||||
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
|
||||
from src.training.components.models.tabpfn_model import TabPFNModel
|
||||
from src.training.components.models.tabm_model import TabMModel
|
||||
from src.training.components.models.tabm_rank_model import (
|
||||
TabMRankModel,
|
||||
EnsembleListNetLoss,
|
||||
EnsembleLambdaLoss,
|
||||
)
|
||||
from src.training.components.models.cross_section_sampler import CrossSectionSampler
|
||||
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
|
||||
|
||||
@@ -15,6 +20,9 @@ __all__ = [
|
||||
"LightGBMLambdaRankModel",
|
||||
"TabPFNModel",
|
||||
"TabMModel",
|
||||
"TabMRankModel",
|
||||
"EnsembleListNetLoss",
|
||||
"EnsembleLambdaLoss",
|
||||
"CrossSectionSampler",
|
||||
"EnsembleQuantLoss",
|
||||
]
|
||||
|
||||
747
src/training/components/models/tabm_rank_model.py
Normal file
747
src/training/components/models/tabm_rank_model.py
Normal file
@@ -0,0 +1,747 @@
|
||||
"""TabM 排序模型实现 (TabM Rank)
|
||||
|
||||
基于 TabM (Tabular Multilayer Perceptron with Ensembles) 架构
|
||||
引入 ListNet 列表级排序损失,实现类似 LambdaRank 的截面排序学习。
|
||||
适用于股票未来收益率的截面排序预测。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import scipy.stats as stats
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset, Sampler
|
||||
from tabm import TabM
|
||||
|
||||
from src.training.components.base import BaseModel
|
||||
from src.training.registry import register_model
|
||||
|
||||
|
||||
class GroupSampler(Sampler):
|
||||
"""排序学习专用的分组采样器
|
||||
|
||||
确保每个 Batch 包含的是同一个 Query (如同一天) 的所有样本。
|
||||
这与 LightGBM 的 `group` 参数逻辑完全一致。
|
||||
"""
|
||||
|
||||
def __init__(self, group_counts: np.ndarray, shuffle_groups: bool = True):
|
||||
"""初始化分组采样器
|
||||
|
||||
Args:
|
||||
group_counts: 各组的样本数量数组,如 [100, 150, 120]
|
||||
shuffle_groups: 是否打乱组的训练顺序
|
||||
"""
|
||||
self.group_counts = group_counts
|
||||
self.shuffle_groups = shuffle_groups
|
||||
# 计算每组在原数组中的起始和结束边界
|
||||
self.boundaries = np.insert(np.cumsum(group_counts), 0, 0)
|
||||
self.num_groups = len(group_counts)
|
||||
|
||||
def __iter__(self):
|
||||
"""迭代生成批次索引
|
||||
|
||||
Yields:
|
||||
list: 同一组的所有样本索引
|
||||
"""
|
||||
group_indices = list(range(self.num_groups))
|
||||
if self.shuffle_groups:
|
||||
np.random.shuffle(group_indices)
|
||||
|
||||
for g_idx in group_indices:
|
||||
start = self.boundaries[g_idx]
|
||||
end = self.boundaries[g_idx + 1]
|
||||
# 返回该组(天)内所有样本的索引,作为一个完整的 batch
|
||||
yield list(range(start, end))
|
||||
|
||||
def __len__(self):
|
||||
"""返回组数量"""
|
||||
return self.num_groups
|
||||
|
||||
|
||||
class EnsembleListNetLoss(nn.Module):
|
||||
"""集成 ListNet 排序损失 (Listwise Ranking Loss)
|
||||
|
||||
基于交叉熵将截面上的相关度转化为概率分布。
|
||||
支持 TabM 的 Ensemble 维度。
|
||||
"""
|
||||
|
||||
def __init__(self, topk_weight: float = 1.0):
|
||||
"""初始化 ListNet 损失
|
||||
|
||||
Args:
|
||||
topk_weight: 头部样本权重系数,>1.0 时强化对高分标签的关注
|
||||
"""
|
||||
super().__init__()
|
||||
self.topk_weight = topk_weight
|
||||
|
||||
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
|
||||
"""计算 ListNet 排序损失
|
||||
|
||||
Args:
|
||||
preds: [BatchSize, EnsembleSize] - 当前组所有样本的预测 logits
|
||||
targets: [BatchSize] - 当前组所有样本的真实排序标签 (0, 1, 2...)
|
||||
|
||||
Returns:
|
||||
标量损失值
|
||||
"""
|
||||
# 如果该组内样本太少,无法排序,直接返回 0
|
||||
if preds.size(0) <= 1:
|
||||
return torch.tensor(0.0, requires_grad=True, device=preds.device)
|
||||
|
||||
# [BatchSize] -> [BatchSize, EnsembleSize] 广播以对齐每个集成成员
|
||||
targets_expanded = targets.unsqueeze(1).expand_as(preds)
|
||||
|
||||
# 1. 计算真实标签的分数概率分布 (Softmax 使得高分权重成倍提升)
|
||||
targets_prob = F.softmax(targets_expanded, dim=0)
|
||||
|
||||
# 2. 【Top-K 优化】如果启用加权,给予高分标签更高权重
|
||||
if self.topk_weight > 1.0:
|
||||
# 基于标签值计算权重(标签越高,权重越大)
|
||||
# 归一化到 [1, topk_weight] 范围
|
||||
max_target = targets.max()
|
||||
min_target = targets.min()
|
||||
if max_target > min_target:
|
||||
# 线性插值:权重 = 1 + (target - min) / (max - min) * (topk_weight - 1)
|
||||
sample_weights = 1.0 + (targets - min_target) / (
|
||||
max_target - min_target
|
||||
) * (self.topk_weight - 1.0)
|
||||
sample_weights = sample_weights.unsqueeze(1).expand_as(preds)
|
||||
# 归一化权重使其和为样本数
|
||||
sample_weights = sample_weights * len(targets) / sample_weights.sum()
|
||||
# 应用权重到目标概率
|
||||
targets_prob = targets_prob * sample_weights
|
||||
|
||||
# 3. 计算预测值的对数概率分布 (log_softmax 数值上比 log(softmax) 更稳定)
|
||||
preds_log_prob = F.log_softmax(preds, dim=0)
|
||||
|
||||
# 4. 计算交叉熵损失 (在 Batch/组 维度求和)
|
||||
loss = -torch.sum(targets_prob * preds_log_prob, dim=0) # [EnsembleSize]
|
||||
|
||||
# 5. 对所有集成成员的 Loss 取平均
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class EnsembleLambdaLoss(nn.Module):
|
||||
"""集成 LambdaLoss (支持 TabM 集成维度)
|
||||
|
||||
基于 Pairwise 排序损失,引入 DeltaNDCG 权重。
|
||||
参考: "The LambdaLoss Framework for Ranking Metric Optimization" (Google Research)
|
||||
|
||||
特点:
|
||||
- 计算每对样本交换位置后对 NDCG 的影响
|
||||
- 对头部样本(高 Gain 且排名靠前)给予更高权重
|
||||
- 更适合 Top-K 选股场景
|
||||
"""
|
||||
|
||||
def __init__(self, sigma: float = 1.0, ndcg_weight_power: float = 1.0):
|
||||
"""初始化 LambdaLoss
|
||||
|
||||
Args:
|
||||
sigma: Sigmoid 函数的陡峭程度,控制梯度大小
|
||||
ndcg_weight_power: DeltaNDCG 权重幂次,>1 时进一步放大头部效应
|
||||
"""
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.ndcg_weight_power = ndcg_weight_power
|
||||
|
||||
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
|
||||
"""计算 LambdaLoss
|
||||
|
||||
Args:
|
||||
preds: [BatchSize, EnsembleSize] - 预测分
|
||||
targets: [BatchSize] - 相关性标签 (Gain)
|
||||
|
||||
Returns:
|
||||
标量损失值
|
||||
"""
|
||||
if preds.size(0) <= 1:
|
||||
return torch.tensor(0.0, requires_grad=True, device=preds.device)
|
||||
|
||||
# 1. 计算两两对之间的差值
|
||||
preds_diff = preds.unsqueeze(1) - preds.unsqueeze(0) # [B, B, E]
|
||||
|
||||
# 【性能优化】: target_diff 不需要 E 维度!直接保持 [B, B]
|
||||
target_diff = targets.unsqueeze(1) - targets.unsqueeze(0) # [B, B]
|
||||
|
||||
# 2. 掩码矩阵: [B, B, 1] 方便后续广播
|
||||
mask = (target_diff > 0).float().unsqueeze(2) # [B, B, 1]
|
||||
|
||||
# 3. 计算 Delta NDCG
|
||||
# DeltaNDCG = |Gain_i - Gain_j| * |1/log(rank_i+1) - 1/log(rank_j+1)|
|
||||
with torch.no_grad():
|
||||
# 【性能核爆优化】: 使用两次 argsort 完全消灭 for 循环
|
||||
# 第1次 argsort: 获得从大到小的索引
|
||||
# 第2次 argsort: 直接将索引反转为排名 (加上 1 就是真实名次)
|
||||
ranks = preds.argsort(dim=0, descending=True).argsort(dim=0) + 1
|
||||
ranks = ranks.float() # [B, E]
|
||||
|
||||
# 计算位置惩罚项 (log2 排名)
|
||||
log_rank = torch.log2(ranks + 1)
|
||||
inv_log_rank_diff = torch.abs(
|
||||
1.0 / log_rank.unsqueeze(1) - 1.0 / log_rank.unsqueeze(0)
|
||||
) # [B, B, E]
|
||||
|
||||
# DeltaNDCG = |Gain 差| * |位置惩罚差|
|
||||
# target_diff 是 [B, B],通过 unsqueeze 广播到 [B, B, 1]
|
||||
delta_ndcg = torch.abs(target_diff).unsqueeze(2) * inv_log_rank_diff
|
||||
|
||||
# 应用幂次调整权重分布
|
||||
if self.ndcg_weight_power != 1.0:
|
||||
delta_ndcg = torch.pow(delta_ndcg, self.ndcg_weight_power)
|
||||
|
||||
# 4. Pairwise Logistic Loss 并加权 Delta NDCG
|
||||
# loss = delta_ndcg * log(1 + exp(-sigma * (preds_i - preds_j)))
|
||||
pairwise_loss = F.binary_cross_entropy_with_logits(
|
||||
self.sigma * preds_diff,
|
||||
torch.ones_like(preds_diff),
|
||||
reduction="none",
|
||||
)
|
||||
|
||||
# 5. 应用掩码和权重
|
||||
weighted_loss = pairwise_loss * mask * delta_ndcg
|
||||
|
||||
# 避免全 0 导致的除零错误
|
||||
valid_pairs_count = mask.sum().clamp(min=1.0)
|
||||
mean_loss = weighted_loss.sum() / valid_pairs_count
|
||||
|
||||
return mean_loss
|
||||
|
||||
|
||||
@register_model("tabm_rank")
|
||||
class TabMRankModel(BaseModel):
|
||||
"""TabM 学习排序模型
|
||||
|
||||
基于 TabM 架构的排序学习模型,支持 ListNet 损失。
|
||||
适用于股票截面排序任务,将未来收益率转换为分位数标签进行训练。
|
||||
|
||||
特点:
|
||||
- 使用 ListNet 列表级排序损失
|
||||
- 支持 group 参数进行分组训练
|
||||
- 以 NDCG 作为验证指标
|
||||
- 与 LightGBMLambdaRank 接口兼容
|
||||
"""
|
||||
|
||||
name = "tabm_rank"
|
||||
|
||||
def __init__(self, params: Optional[Dict[str, Any]] = None):
|
||||
"""初始化 TabM Rank 模型
|
||||
|
||||
Args:
|
||||
params: 模型参数字典,包含:
|
||||
- ensemble_size: 集成大小 (默认: 32)
|
||||
- n_blocks: MLP层数 (默认: 3)
|
||||
- d_block: 每层神经元数 (默认: 256)
|
||||
- dropout: Dropout率 (默认: 0.1)
|
||||
- batch_size: 批次大小 (默认: 2048,仅预测时使用)
|
||||
- learning_rate: 学习率 (默认: 1e-3)
|
||||
- weight_decay: 权重衰减 (默认: 1e-5)
|
||||
- epochs: 训练轮数 (默认: 50)
|
||||
- early_stopping_round: 早停轮数 (默认: 10)
|
||||
- max_grad_norm: 梯度裁剪阈值 (默认: 1.0)
|
||||
- ndcg_k: NDCG@k 的 k 值,None 表示全局 (默认: None)
|
||||
- loss_type: 损失函数类型 (默认: "listnet")
|
||||
- "listnet": 标准 ListNet 损失
|
||||
- "weighted_listnet": 加权 ListNet,通过 topk_weight 强化头部
|
||||
- "lambda": LambdaLoss,基于 DeltaNDCG 加权
|
||||
- topk_weight: 头部样本权重系数,用于 weighted_listnet (默认: 5.0)
|
||||
- lambda_sigma: LambdaLoss 的 sigma 参数 (默认: 1.0)
|
||||
- ndcg_weight_power: DeltaNDCG 权重幂次 (默认: 1.0)
|
||||
"""
|
||||
self.params = params or {}
|
||||
self.model = None
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.training_history_: Dict[str, List[float]] = {
|
||||
"train_loss": [],
|
||||
"val_ndcg": [],
|
||||
}
|
||||
self.feature_names_: Optional[List[str]] = None
|
||||
|
||||
# 根据配置选择损失函数
|
||||
loss_type = self.params.get("loss_type", "listnet")
|
||||
if loss_type == "lambda":
|
||||
self.criterion = EnsembleLambdaLoss(
|
||||
sigma=self.params.get("lambda_sigma", 1.0),
|
||||
ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0),
|
||||
)
|
||||
elif loss_type == "weighted_listnet":
|
||||
self.criterion = EnsembleListNetLoss(
|
||||
topk_weight=self.params.get("topk_weight", 5.0)
|
||||
)
|
||||
else: # "listnet"
|
||||
self.criterion = EnsembleListNetLoss(topk_weight=1.0)
|
||||
|
||||
def _make_loader(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: Optional[np.ndarray] = None,
|
||||
group: Optional[np.ndarray] = None,
|
||||
shuffle_groups: bool = False,
|
||||
) -> DataLoader:
|
||||
"""创建 DataLoader (支持 Query/Group 截面打包)
|
||||
|
||||
Args:
|
||||
X: 特征数组 [N, n_features]
|
||||
y: 标签数组 [N] 或 None
|
||||
group: 分组数组,表示每组样本数
|
||||
shuffle_groups: 是否打乱组的顺序
|
||||
|
||||
Returns:
|
||||
DataLoader 实例
|
||||
"""
|
||||
# 【性能核爆优化】: 显存管够时,直接把全量数据一把推进 GPU,避免训练时 PCIe 搬运
|
||||
X_tensor = torch.from_numpy(X).to(self.device)
|
||||
|
||||
if y is not None:
|
||||
y_tensor = torch.from_numpy(y).to(self.device)
|
||||
dataset = TensorDataset(X_tensor, y_tensor)
|
||||
else:
|
||||
dataset = TensorDataset(X_tensor)
|
||||
|
||||
if group is not None:
|
||||
# 训练和验证时使用 GroupSampler,每个 batch 就是一个 Query
|
||||
sampler = GroupSampler(group, shuffle_groups=shuffle_groups)
|
||||
return DataLoader(dataset, batch_sampler=sampler)
|
||||
else:
|
||||
# 预测时如果没有 group,则退化为普通批次预测
|
||||
batch_size = self.params.get("batch_size", 2048)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float:
|
||||
"""验证模型 (使用 NDCG 排序指标)
|
||||
|
||||
Args:
|
||||
val_loader: 验证数据加载器
|
||||
k: NDCG@k 的 k 值,None 表示计算全局 NDCG
|
||||
|
||||
Returns:
|
||||
平均 NDCG 分数
|
||||
"""
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
assert self.model is not None, "模型未训练,无法验证"
|
||||
|
||||
self.model.eval()
|
||||
ndcg_list = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if len(batch) != 2:
|
||||
continue
|
||||
|
||||
bx, by = batch
|
||||
bx = bx.to(self.device)
|
||||
by = by.cpu().numpy()
|
||||
|
||||
if len(by) <= 1:
|
||||
continue
|
||||
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
preds = outputs.mean(dim=1).squeeze(-1).cpu().numpy() # [B]
|
||||
|
||||
try:
|
||||
# ndcg_score 需要形状为 (1, n_samples) 的二维数组
|
||||
score = ndcg_score([by], [preds], k=k)
|
||||
ndcg_list.append(score)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return float(np.mean(ndcg_list)) if len(ndcg_list) > 0 else 0.0
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: Optional[np.ndarray] = None,
|
||||
eval_set: Optional[Tuple] = None,
|
||||
) -> "TabMRankModel":
|
||||
"""训练排序模型
|
||||
|
||||
Args:
|
||||
X: 训练特征DataFrame
|
||||
y: 训练标签 (Polars Series),应为分位数标签 (0, 1, 2, ...)
|
||||
group: 分组数组,表示每个 query 的样本数
|
||||
eval_set: 验证集元组 (X_val, y_val, group_val),用于早停
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
|
||||
Raises:
|
||||
ValueError: group 参数无效
|
||||
"""
|
||||
self.feature_names_ = list(X.columns)
|
||||
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
y_np = y.to_numpy().astype(np.float32)
|
||||
|
||||
# 检查和处理 group 参数
|
||||
if group is None:
|
||||
group = np.array([len(y_np)])
|
||||
if group.sum() != len(y_np):
|
||||
raise ValueError(
|
||||
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
|
||||
)
|
||||
|
||||
train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True)
|
||||
|
||||
val_loader = None
|
||||
if eval_set is not None:
|
||||
X_val, y_val, group_val = eval_set
|
||||
X_val_np = (
|
||||
X_val.to_numpy().astype(np.float32)
|
||||
if isinstance(X_val, pl.DataFrame)
|
||||
else X_val
|
||||
)
|
||||
y_val_np = (
|
||||
y_val.to_numpy().astype(np.float32)
|
||||
if isinstance(y_val, pl.Series)
|
||||
else y_val
|
||||
)
|
||||
|
||||
if group_val is None:
|
||||
group_val = np.array([len(y_val_np)])
|
||||
val_loader = self._make_loader(
|
||||
X_val_np, y_val_np, group=group_val, shuffle_groups=False
|
||||
)
|
||||
|
||||
ensemble_size = self.params.get("ensemble_size", 32)
|
||||
n_features = X_np.shape[1]
|
||||
|
||||
# 初始化 TabM 模型
|
||||
self.model = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1,
|
||||
n_blocks=self.params.get("n_blocks", 3),
|
||||
d_block=self.params.get("d_block", 256),
|
||||
dropout=self.params.get("dropout", 0.1),
|
||||
k=ensemble_size,
|
||||
).to(self.device)
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.params.get("learning_rate", 1e-3),
|
||||
weight_decay=self.params.get("weight_decay", 1e-5),
|
||||
)
|
||||
|
||||
epochs = self.params.get("epochs", 50)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=epochs, eta_min=1e-6
|
||||
)
|
||||
|
||||
early_stopping_patience = self.params.get(
|
||||
"early_stopping_patience",
|
||||
self.params.get("early_stopping_round", 10),
|
||||
)
|
||||
best_val_ndcg = -float("inf")
|
||||
patience_counter = 0
|
||||
best_model_state = None
|
||||
ndcg_k = self.params.get("ndcg_k", None) # None 表示计算全局 NDCG
|
||||
|
||||
print(f"[TabMRank] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 训练阶段
|
||||
self.model.train()
|
||||
train_loss = 0.0
|
||||
n_train_batches = 0
|
||||
|
||||
for batch in train_loader:
|
||||
if len(batch) != 2:
|
||||
continue
|
||||
bx, by = batch[0], batch[1]
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
outputs_squeezed = outputs.squeeze(-1) # [B, E]
|
||||
|
||||
# 计算 ListNet 排序损失
|
||||
loss = self.criterion(outputs_squeezed, by)
|
||||
loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
max_norm=self.params.get("max_grad_norm", 1.0),
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
n_train_batches += 1
|
||||
|
||||
avg_train_loss = train_loss / max(n_train_batches, 1)
|
||||
self.training_history_["train_loss"].append(avg_train_loss)
|
||||
|
||||
# 验证阶段 (基于 NDCG)
|
||||
if val_loader is not None:
|
||||
val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k)
|
||||
self.training_history_["val_ndcg"].append(val_ndcg)
|
||||
|
||||
if val_ndcg > best_val_ndcg:
|
||||
best_val_ndcg = val_ndcg
|
||||
patience_counter = 0
|
||||
best_model_state = {
|
||||
k: v.cpu().clone() for k, v in self.model.state_dict().items()
|
||||
}
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Train Loss (ListNet): {avg_train_loss:.4f} | "
|
||||
f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})"
|
||||
)
|
||||
|
||||
if patience_counter >= early_stopping_patience:
|
||||
print(f"[TabMRank] 触发早停,停止于 epoch {epoch + 1}")
|
||||
break
|
||||
else:
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Train Loss: {avg_train_loss:.4f}"
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
# 恢复最佳权重
|
||||
if best_model_state is not None:
|
||||
self.model.load_state_dict(best_model_state)
|
||||
print(f"[TabMRank] 已恢复最佳模型权重 (Val NDCG: {best_val_ndcg:.4f})")
|
||||
|
||||
return self
|
||||
|
||||
def predict(
|
||||
self, X: pl.DataFrame, group: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""预测排序分数
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
group: 分组数组,表示每个 query 的样本数。
|
||||
如果提供,将使用 GroupSampler 确保预测顺序与分组一致。
|
||||
|
||||
Returns:
|
||||
预测分数 (numpy ndarray)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
ValueError: 预测数据缺失特征
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,请先调用fit()")
|
||||
|
||||
# 特征对齐检查
|
||||
if self.feature_names_:
|
||||
missing_cols = [c for c in self.feature_names_ if c not in X.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"预测数据缺失特征: {missing_cols}")
|
||||
X = X.select(self.feature_names_)
|
||||
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
loader = self._make_loader(X_np, group=group, shuffle_groups=False)
|
||||
|
||||
self.model.eval()
|
||||
all_preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
bx = batch[0].to(self.device)
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
# 排序模型预测时直接输出集成成员的均值作为最终分数
|
||||
preds = outputs.mean(dim=1).squeeze(-1) # [B]
|
||||
all_preds.append(preds.cpu().numpy())
|
||||
|
||||
return np.concatenate(all_preds)
|
||||
|
||||
def get_evals_result(self) -> Optional[Dict[str, List[float]]]:
|
||||
"""获取训练评估结果
|
||||
|
||||
Returns:
|
||||
评估结果字典,包含 train_loss 和 val_ndcg
|
||||
"""
|
||||
return self.training_history_
|
||||
|
||||
def feature_importance(self) -> None:
|
||||
"""获取特征重要性
|
||||
|
||||
TabM没有内置特征重要性计算,返回None。
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""保存模型
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,无法保存")
|
||||
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存模型权重
|
||||
model_path = path.with_suffix(".pt")
|
||||
torch.save(self.model.state_dict(), model_path)
|
||||
|
||||
# 保存元数据
|
||||
meta_path = path.with_suffix(".meta")
|
||||
meta = {
|
||||
"params": self.params,
|
||||
"feature_names": self.feature_names_,
|
||||
"training_history": self.training_history_,
|
||||
"device": str(self.device),
|
||||
}
|
||||
with open(meta_path, "wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
|
||||
print(f"[TabMRank] 模型保存到: {path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str | Path) -> "TabMRankModel":
|
||||
"""加载模型
|
||||
|
||||
Args:
|
||||
path: 模型路径(不含扩展名)
|
||||
|
||||
Returns:
|
||||
加载的 TabMRankModel 实例
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
# 加载元数据
|
||||
meta_path = path.with_suffix(".meta")
|
||||
with open(meta_path, "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
|
||||
# 创建实例
|
||||
instance = cls(meta["params"])
|
||||
instance.feature_names_ = meta["feature_names"]
|
||||
instance.training_history_ = meta["training_history"]
|
||||
|
||||
# 重建模型结构
|
||||
if instance.feature_names_ is not None:
|
||||
n_features = len(instance.feature_names_)
|
||||
ensemble_size = instance.params.get("ensemble_size", 32)
|
||||
|
||||
instance.model = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1,
|
||||
n_blocks=instance.params.get("n_blocks", 3),
|
||||
d_block=instance.params.get("d_block", 256),
|
||||
dropout=instance.params.get("dropout", 0.1),
|
||||
k=ensemble_size,
|
||||
).to(instance.device)
|
||||
|
||||
# 加载权重
|
||||
model_path = path.with_suffix(".pt")
|
||||
instance.model.load_state_dict(
|
||||
torch.load(model_path, map_location=instance.device)
|
||||
)
|
||||
|
||||
print(f"[TabMRank] 模型从 {path} 加载完成")
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def prepare_group_from_dates(
|
||||
df: pl.DataFrame,
|
||||
date_col: str = "trade_date",
|
||||
) -> np.ndarray:
|
||||
"""从日期列生成 group 数组
|
||||
|
||||
Args:
|
||||
df: 包含日期列的 DataFrame
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
|
||||
Returns:
|
||||
group 数组
|
||||
"""
|
||||
group_counts = df.group_by(date_col, maintain_order=True).agg(
|
||||
pl.count().alias("count")
|
||||
)
|
||||
return group_counts["count"].to_numpy()
|
||||
|
||||
@staticmethod
|
||||
def convert_to_quantile_labels(
|
||||
df: pl.DataFrame,
|
||||
label_col: str,
|
||||
date_col: str = "trade_date",
|
||||
n_quantiles: int = 20,
|
||||
new_col_name: Optional[str] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""将连续标签转换为分位数标签
|
||||
|
||||
Args:
|
||||
df: 输入 DataFrame
|
||||
label_col: 原始标签列名
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
n_quantiles: 分位数数量,默认 20
|
||||
new_col_name: 新列名,默认为 {label_col}_rank
|
||||
|
||||
Returns:
|
||||
添加了分位数标签列的 DataFrame
|
||||
"""
|
||||
if new_col_name is None:
|
||||
new_col_name = f"{label_col}_rank"
|
||||
|
||||
return df.with_columns(
|
||||
pl.col(label_col)
|
||||
.qcut(n_quantiles)
|
||||
.over(date_col)
|
||||
.to_physical()
|
||||
.cast(pl.Int64)
|
||||
.alias(new_col_name)
|
||||
)
|
||||
|
||||
def evaluate_ndcg(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: np.ndarray,
|
||||
k: Optional[int] = None,
|
||||
) -> float:
|
||||
"""评估 NDCG 指标
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
y: 真实标签
|
||||
group: 分组数组
|
||||
k: NDCG@k 的 k 值
|
||||
|
||||
Returns:
|
||||
NDCG 分数
|
||||
"""
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
y_pred = self.predict(X)
|
||||
|
||||
y_true_list = []
|
||||
y_score_list = []
|
||||
|
||||
start_idx = 0
|
||||
for group_size in group:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_list.append(y.to_numpy()[start_idx:end_idx])
|
||||
y_score_list.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
ndcg_scores = []
|
||||
for y_true, y_score in zip(y_true_list, y_score_list):
|
||||
if len(y_true) > 1:
|
||||
try:
|
||||
score = ndcg_score([y_true], [y_score], k=k)
|
||||
ndcg_scores.append(score)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
|
||||
@@ -289,7 +289,7 @@ class StandardScaler(BaseProcessor):
|
||||
return self
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""标准化(使用训练集学到的参数,增加 NaN 保护)
|
||||
"""标准化(使用训练集学到的参数)
|
||||
|
||||
Args:
|
||||
X: 待转换数据
|
||||
@@ -302,18 +302,11 @@ class StandardScaler(BaseProcessor):
|
||||
if col in self.mean_ and col in self.std_:
|
||||
# 避免除以0
|
||||
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
|
||||
# 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN
|
||||
expr = (
|
||||
((pl.col(col) - self.mean_[col]) / std_val)
|
||||
.fill_nan(0)
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但未学习到统计量的列
|
||||
# 统一转换为float并同时处理 NaN 和 null
|
||||
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
|
||||
# 对于应该被处理但未学习到统计量的列,统一转换为float
|
||||
expr = pl.col(col).cast(pl.Float64).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -372,20 +365,14 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
# 截面标准化:每天独立计算均值和标准差
|
||||
# 避免除以0,当std为0时设为1
|
||||
# 关键修复:先 fill_nan 再 fill_null,防止计算产生的 NaN
|
||||
expr = (
|
||||
(
|
||||
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||
)
|
||||
.fill_nan(0)
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||
).alias(col)
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但类型不匹配的列,转换为float并同时处理 NaN 和 null
|
||||
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
|
||||
# 对于应该被处理但类型不匹配的列,转换为float
|
||||
expr = pl.col(col).cast(pl.Float64).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -488,8 +475,8 @@ class Winsorizer(BaseProcessor):
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但未学习到边界的列(如全为NaN、布尔列等)
|
||||
# 统一转换为float并填充0
|
||||
expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col)
|
||||
# 统一转换为float
|
||||
expr = pl.col(col).cast(pl.Float64).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -522,11 +509,9 @@ class Winsorizer(BaseProcessor):
|
||||
clip_exprs = []
|
||||
for col in X.columns:
|
||||
if col in target_cols:
|
||||
# 先用当天分位数缩尾,如果分位数是null(该日全为NaN)则填充0
|
||||
clipped = (
|
||||
pl.col(col)
|
||||
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
clip_exprs.append(clipped)
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.factors import FactorEngine
|
||||
from src.training.pipeline import DataPipeline
|
||||
from src.training.tasks.base import BaseTask
|
||||
from src.training.result_analyzer import ResultAnalyzer
|
||||
from src.training.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
|
||||
class Trainer:
|
||||
@@ -100,8 +101,6 @@ class Trainer:
|
||||
print("\n[Step 1.5/7] 数据质量分析...")
|
||||
|
||||
try:
|
||||
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
# 获取特征列名(从训练集)
|
||||
feature_cols = data["train"].get("feature_cols", [])
|
||||
label_name = self.task.label_name
|
||||
|
||||
646
src/training/data_quality_analyzer.py
Normal file
646
src/training/data_quality_analyzer.py
Normal file
@@ -0,0 +1,646 @@
|
||||
"""数据质量分析模块
|
||||
|
||||
提供数据质量检查功能,包括:
|
||||
- 数据集日期范围信息
|
||||
- 缺失值统计
|
||||
- 零值统计
|
||||
- 按日期检查全空列
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataQualityAnalyzer:
|
||||
"""数据质量分析器
|
||||
|
||||
用于分析训练数据的质量问题,帮助识别数据异常。
|
||||
|
||||
Attributes:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
date_col: 日期列名
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
date_col: str = "trade_date",
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""初始化数据质量分析器
|
||||
|
||||
Args:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
date_col: 日期列名,默认为 "trade_date"
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
self.feature_cols = feature_cols or []
|
||||
self.label_col = label_col
|
||||
self.date_col = date_col
|
||||
self.verbose = verbose
|
||||
self.analysis_results: Dict[str, Any] = {}
|
||||
|
||||
def set_columns(self, feature_cols: List[str], label_col: str) -> None:
|
||||
"""设置要分析的列
|
||||
|
||||
Args:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
"""
|
||||
self.feature_cols = feature_cols
|
||||
self.label_col = label_col
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
split_names: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""执行完整的数据质量分析
|
||||
|
||||
Args:
|
||||
data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}}
|
||||
split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"]
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
if not split_names:
|
||||
split_names = ["train", "val", "test"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n" + "=" * 80)
|
||||
print("数据质量分析报告")
|
||||
print("=" * 80)
|
||||
|
||||
self.analysis_results = {}
|
||||
|
||||
# 首先打印数据集概览(日期范围等基本信息)
|
||||
if self.verbose:
|
||||
self._print_dataset_overview(data, split_names)
|
||||
|
||||
for split_name in split_names:
|
||||
if split_name not in data:
|
||||
continue
|
||||
|
||||
split_data = data[split_name]
|
||||
raw_df = split_data.get("raw_data")
|
||||
|
||||
if raw_df is None:
|
||||
continue
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[{split_name.upper()} 数据集]")
|
||||
print("-" * 40)
|
||||
|
||||
split_results = self._analyze_split(raw_df, split_name)
|
||||
self.analysis_results[split_name] = split_results
|
||||
|
||||
if self.verbose:
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return self.analysis_results
|
||||
|
||||
def _print_dataset_overview(
|
||||
self,
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
split_names: List[str],
|
||||
) -> None:
|
||||
"""打印数据集概览信息
|
||||
|
||||
包括每个数据集的起始日期、终止日期、样本数量等基本信息。
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
split_names: 数据划分名称列表
|
||||
"""
|
||||
print("\n[数据集概览]")
|
||||
print("-" * 40)
|
||||
|
||||
overview_data = []
|
||||
|
||||
for split_name in split_names:
|
||||
if split_name not in data:
|
||||
continue
|
||||
|
||||
split_data = data[split_name]
|
||||
raw_df = split_data.get("raw_data")
|
||||
|
||||
if raw_df is None or len(raw_df) == 0:
|
||||
overview_data.append(
|
||||
{
|
||||
"划分": split_name.upper(),
|
||||
"起始日期": "-",
|
||||
"终止日期": "-",
|
||||
"样本数": 0,
|
||||
"股票数": 0,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取日期范围
|
||||
if self.date_col in raw_df.columns:
|
||||
dates = raw_df[self.date_col]
|
||||
start_date = dates.min()
|
||||
end_date = dates.max()
|
||||
unique_dates = dates.n_unique()
|
||||
else:
|
||||
start_date = "-"
|
||||
end_date = "-"
|
||||
unique_dates = 0
|
||||
|
||||
# 获取股票数量
|
||||
if "ts_code" in raw_df.columns:
|
||||
unique_stocks = raw_df["ts_code"].n_unique()
|
||||
else:
|
||||
unique_stocks = 0
|
||||
|
||||
overview_data.append(
|
||||
{
|
||||
"划分": split_name.upper(),
|
||||
"起始日期": str(start_date),
|
||||
"终止日期": str(end_date),
|
||||
"交易日数": unique_dates,
|
||||
"样本数": len(raw_df),
|
||||
"股票数": unique_stocks,
|
||||
}
|
||||
)
|
||||
|
||||
# 打印表格
|
||||
if overview_data:
|
||||
# 计算列宽
|
||||
headers = ["划分", "起始日期", "终止日期", "交易日数", "样本数", "股票数"]
|
||||
col_widths = {}
|
||||
for header in headers:
|
||||
max_data_len = max(
|
||||
len(
|
||||
str(
|
||||
row.get(
|
||||
header.lower().replace("数", ""), row.get(header, "")
|
||||
)
|
||||
)
|
||||
)
|
||||
for row in overview_data
|
||||
)
|
||||
col_widths[header] = max(len(header), max_data_len) + 2
|
||||
|
||||
# 打印表头
|
||||
header_line = " ".join(h.ljust(col_widths[h]) for h in headers)
|
||||
print(f" {header_line}")
|
||||
print(f" {'-' * (sum(col_widths.values()) + 2 * (len(headers) - 1))}")
|
||||
|
||||
# 打印数据行
|
||||
for row in overview_data:
|
||||
line = " ".join(
|
||||
str(row.get(h, row.get(h.lower().replace("数", ""), ""))).ljust(
|
||||
col_widths[h]
|
||||
)
|
||||
for h in headers
|
||||
)
|
||||
print(f" {line}")
|
||||
|
||||
def _analyze_split(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
split_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""分析单个数据集划分
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
split_name: 划分名称
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
results = {
|
||||
"total_rows": len(df),
|
||||
"total_cols": len(df.columns),
|
||||
"feature_cols": self.feature_cols,
|
||||
"label_col": self.label_col,
|
||||
"null_analysis": {},
|
||||
"zero_analysis": {},
|
||||
"all_null_by_date": {},
|
||||
}
|
||||
|
||||
# 获取日期范围
|
||||
if self.date_col in df.columns:
|
||||
results["start_date"] = str(df[self.date_col].min())
|
||||
results["end_date"] = str(df[self.date_col].max())
|
||||
results["unique_dates"] = df[self.date_col].n_unique()
|
||||
|
||||
# 获取股票数量
|
||||
if "ts_code" in df.columns:
|
||||
results["unique_stocks"] = df["ts_code"].n_unique()
|
||||
|
||||
# 1. 分析特征列的缺失值
|
||||
null_stats = self._analyze_null_values(df, self.feature_cols)
|
||||
results["null_analysis"] = null_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_null_analysis(null_stats)
|
||||
|
||||
# 2. 分析特征列的零值
|
||||
zero_stats = self._analyze_zero_values(df, self.feature_cols)
|
||||
results["zero_analysis"] = zero_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_zero_analysis(zero_stats)
|
||||
|
||||
# 3. 检查是否存在某天某列全为空的情况
|
||||
all_null_by_date = self._check_all_null_by_date(df, self.feature_cols)
|
||||
results["all_null_by_date"] = all_null_by_date
|
||||
|
||||
if self.verbose:
|
||||
self._print_all_null_by_date(all_null_by_date)
|
||||
|
||||
# 4. 分析标签列
|
||||
if self.label_col and self.label_col in df.columns:
|
||||
label_stats = self._analyze_label(df, self.label_col)
|
||||
results["label_analysis"] = label_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_label_analysis(label_stats)
|
||||
|
||||
return results
|
||||
|
||||
def _analyze_null_values(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""分析缺失值
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
缺失值统计字典
|
||||
"""
|
||||
stats = {
|
||||
"total_cells": len(df) * len(cols),
|
||||
"null_counts": {},
|
||||
"null_percentages": {},
|
||||
"columns_with_null": [],
|
||||
"total_null_cells": 0,
|
||||
}
|
||||
|
||||
for col in cols:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
|
||||
null_count = df[col].null_count()
|
||||
if null_count > 0:
|
||||
null_pct = null_count / len(df) * 100
|
||||
stats["null_counts"][col] = null_count
|
||||
stats["null_percentages"][col] = null_pct
|
||||
stats["columns_with_null"].append(col)
|
||||
stats["total_null_cells"] += null_count
|
||||
|
||||
return stats
|
||||
|
||||
def _analyze_zero_values(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""分析零值
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
零值统计字典
|
||||
"""
|
||||
stats = {
|
||||
"total_cells": len(df) * len(cols),
|
||||
"zero_counts": {},
|
||||
"zero_percentages": {},
|
||||
"columns_with_zero": [],
|
||||
"total_zero_cells": 0,
|
||||
}
|
||||
|
||||
for col in cols:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
|
||||
# 计算零值数量(排除空值)
|
||||
non_null_series = df[col].drop_nulls()
|
||||
if len(non_null_series) == 0:
|
||||
continue
|
||||
|
||||
zero_count = (non_null_series == 0).sum()
|
||||
if zero_count > 0:
|
||||
zero_pct = zero_count / len(df) * 100
|
||||
stats["zero_counts"][col] = int(zero_count)
|
||||
stats["zero_percentages"][col] = zero_pct
|
||||
stats["columns_with_zero"].append(col)
|
||||
stats["total_zero_cells"] += int(zero_count)
|
||||
|
||||
return stats
|
||||
|
||||
def _check_all_null_by_date(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""检查是否存在某天某列全为空的情况
|
||||
|
||||
使用 polars lazy frame 进行内存安全的高效计算。
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
全空检查结果字典
|
||||
"""
|
||||
results = {
|
||||
"issues_found": False,
|
||||
"issues": [],
|
||||
}
|
||||
|
||||
if self.date_col not in df.columns:
|
||||
return results
|
||||
|
||||
# 过滤掉不在表中的列
|
||||
valid_cols = [c for c in cols if c in df.columns]
|
||||
if not valid_cols:
|
||||
return results
|
||||
|
||||
# 使用 lazy frame 进行查询优化
|
||||
lf = df.lazy()
|
||||
|
||||
# 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小)
|
||||
# 为每个列创建单独的 null_count 聚合表达式
|
||||
agg_exprs = [
|
||||
pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols
|
||||
]
|
||||
agg_exprs.append(pl.len().alias("total_rows"))
|
||||
|
||||
agg_lf = lf.group_by(self.date_col).agg(agg_exprs)
|
||||
|
||||
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
|
||||
agg_df = agg_lf.collect()
|
||||
|
||||
# 在这个已经"脱水"的小表上进行逻辑检查
|
||||
issues = []
|
||||
for col in valid_cols:
|
||||
null_col = f"{col}_nulls"
|
||||
# 找出 null 数量等于总行数的日期
|
||||
bad_dates = agg_df.filter(
|
||||
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
|
||||
).select([self.date_col, "total_rows"])
|
||||
|
||||
if not bad_dates.is_empty():
|
||||
for row in bad_dates.to_dicts():
|
||||
issues.append(
|
||||
{
|
||||
"date": row[self.date_col],
|
||||
"column": col,
|
||||
"total_rows": row["total_rows"],
|
||||
}
|
||||
)
|
||||
|
||||
if issues:
|
||||
results["issues_found"] = True
|
||||
results["issues"] = issues
|
||||
|
||||
return results
|
||||
|
||||
def _analyze_label(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
label_col: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""分析标签列
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
label_col: 标签列名
|
||||
|
||||
Returns:
|
||||
标签分析字典
|
||||
"""
|
||||
stats = {
|
||||
"total_count": len(df),
|
||||
"null_count": 0,
|
||||
"null_percentage": 0.0,
|
||||
"zero_count": 0,
|
||||
"zero_percentage": 0.0,
|
||||
"min": None,
|
||||
"max": None,
|
||||
"mean": None,
|
||||
"std": None,
|
||||
}
|
||||
|
||||
if label_col not in df.columns:
|
||||
return stats
|
||||
|
||||
series = df[label_col]
|
||||
|
||||
# 缺失值统计
|
||||
null_count = series.null_count()
|
||||
stats["null_count"] = null_count
|
||||
stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0
|
||||
|
||||
# 零值统计
|
||||
non_null_series = series.drop_nulls()
|
||||
if len(non_null_series) > 0:
|
||||
zero_count = (non_null_series == 0).sum()
|
||||
stats["zero_count"] = int(zero_count)
|
||||
stats["zero_percentage"] = zero_count / len(df) * 100
|
||||
|
||||
# 基本统计量
|
||||
stats["min"] = float(non_null_series.min())
|
||||
stats["max"] = float(non_null_series.max())
|
||||
stats["mean"] = float(non_null_series.mean())
|
||||
stats["std"] = float(non_null_series.std())
|
||||
|
||||
return stats
|
||||
|
||||
def _print_null_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印缺失值分析结果
|
||||
|
||||
Args:
|
||||
stats: 缺失值统计字典
|
||||
"""
|
||||
total_cells = stats["total_cells"]
|
||||
total_null = stats["total_null_cells"]
|
||||
null_cols = stats["columns_with_null"]
|
||||
|
||||
print(f" 缺失值统计:")
|
||||
print(f" 总单元格数: {total_cells:,}")
|
||||
print(
|
||||
f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)"
|
||||
)
|
||||
print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}")
|
||||
|
||||
if null_cols:
|
||||
print(f" 缺失值最多的5个特征:")
|
||||
sorted_cols = sorted(
|
||||
stats["null_counts"].items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
for col, count in sorted_cols:
|
||||
pct = stats["null_percentages"][col]
|
||||
print(f" {col}: {count:,} ({pct:.2f}%)")
|
||||
|
||||
def _print_zero_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印零值分析结果
|
||||
|
||||
Args:
|
||||
stats: 零值统计字典
|
||||
"""
|
||||
total_cells = stats["total_cells"]
|
||||
total_zero = stats["total_zero_cells"]
|
||||
zero_cols = stats["columns_with_zero"]
|
||||
|
||||
print(f" 零值统计:")
|
||||
print(f" 总单元格数: {total_cells:,}")
|
||||
print(
|
||||
f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)"
|
||||
)
|
||||
print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}")
|
||||
|
||||
if zero_cols:
|
||||
print(f" 零值最多的5个特征:")
|
||||
sorted_cols = sorted(
|
||||
stats["zero_counts"].items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
for col, count in sorted_cols:
|
||||
pct = stats["zero_percentages"][col]
|
||||
print(f" {col}: {count:,} ({pct:.2f}%)")
|
||||
|
||||
def _print_all_null_by_date(self, results: Dict[str, Any]) -> None:
|
||||
"""打印按日期全空检查结果
|
||||
|
||||
Args:
|
||||
results: 全空检查结果字典
|
||||
"""
|
||||
issues = results["issues"]
|
||||
|
||||
print(f" 按日期全空检查:")
|
||||
if results["issues_found"]:
|
||||
print(f" [警告] 发现 {len(issues)} 个问题:")
|
||||
# 按日期分组显示
|
||||
by_date = {}
|
||||
for issue in issues:
|
||||
date = issue["date"]
|
||||
if date not in by_date:
|
||||
by_date[date] = []
|
||||
by_date[date].append(issue["column"])
|
||||
|
||||
for date in sorted(by_date.keys())[:5]: # 只显示前5个日期
|
||||
cols = by_date[date]
|
||||
print(f" 日期 {date}: {len(cols)} 列全为空")
|
||||
if len(cols) <= 3:
|
||||
print(f" 列名: {', '.join(cols)}")
|
||||
|
||||
if len(by_date) > 5:
|
||||
print(f" ... 还有 {len(by_date) - 5} 个日期存在问题")
|
||||
else:
|
||||
print(f" [正常] 未发现某天某列全为空的情况")
|
||||
|
||||
def _print_label_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印标签分析结果
|
||||
|
||||
Args:
|
||||
stats: 标签分析字典
|
||||
"""
|
||||
print(f" 标签列统计 ({self.label_col}):")
|
||||
print(f" 总数: {stats['total_count']:,}")
|
||||
print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)")
|
||||
print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)")
|
||||
|
||||
if stats["mean"] is not None:
|
||||
print(f" 最小值: {stats['min']:.6f}")
|
||||
print(f" 最大值: {stats['max']:.6f}")
|
||||
print(f" 均值: {stats['mean']:.6f}")
|
||||
print(f" 标准差: {stats['std']:.6f}")
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""获取分析结果摘要
|
||||
|
||||
Returns:
|
||||
摘要字符串
|
||||
"""
|
||||
if not self.analysis_results:
|
||||
return "尚未执行分析"
|
||||
|
||||
lines = ["数据质量分析摘要", "=" * 40]
|
||||
|
||||
for split_name, results in self.analysis_results.items():
|
||||
lines.append(f"\n[{split_name.upper()}]")
|
||||
|
||||
# 添加日期范围信息
|
||||
if "start_date" in results and "end_date" in results:
|
||||
lines.append(
|
||||
f" 日期范围: {results['start_date']} ~ {results['end_date']}"
|
||||
)
|
||||
if "unique_dates" in results:
|
||||
lines.append(f" 交易日数: {results['unique_dates']}")
|
||||
|
||||
lines.append(f" 总行数: {results['total_rows']:,}")
|
||||
|
||||
if "unique_stocks" in results:
|
||||
lines.append(f" 股票数: {results['unique_stocks']}")
|
||||
|
||||
null_stats = results.get("null_analysis", {})
|
||||
if null_stats.get("columns_with_null"):
|
||||
lines.append(
|
||||
f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, "
|
||||
f"{len(null_stats['columns_with_null'])} 列受影响"
|
||||
)
|
||||
|
||||
zero_stats = results.get("zero_analysis", {})
|
||||
if zero_stats.get("columns_with_zero"):
|
||||
lines.append(
|
||||
f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, "
|
||||
f"{len(zero_stats['columns_with_zero'])} 列受影响"
|
||||
)
|
||||
|
||||
all_null = results.get("all_null_by_date", {})
|
||||
if all_null.get("issues_found"):
|
||||
lines.append(
|
||||
f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def analyze_data_quality(
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
date_col: str = "trade_date",
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""便捷函数:执行数据质量分析
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
date_col: 日期列名,默认为 "trade_date"
|
||||
verbose: 是否打印详细信息
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
analyzer = DataQualityAnalyzer(
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
date_col=date_col,
|
||||
verbose=verbose,
|
||||
)
|
||||
return analyzer.analyze(data)
|
||||
@@ -7,10 +7,12 @@ from src.training.tasks.base import BaseTask
|
||||
from src.training.tasks.regression_task import RegressionTask
|
||||
from src.training.tasks.rank_task import RankTask
|
||||
from src.training.tasks.tabm_regression_task import TabMRegressionTask
|
||||
from src.training.tasks.tabm_rank_task import TabMRankTask
|
||||
|
||||
__all__ = [
|
||||
"BaseTask",
|
||||
"RegressionTask",
|
||||
"RankTask",
|
||||
"TabMRegressionTask",
|
||||
"TabMRankTask",
|
||||
]
|
||||
|
||||
@@ -153,7 +153,6 @@ class RankTask(BaseTask):
|
||||
if k_list is None:
|
||||
k_list = [1, 5, 10, 20]
|
||||
|
||||
y_true = test_data["y_raw"]
|
||||
y_pred = self.predict(test_data)
|
||||
groups = test_data["groups"]
|
||||
|
||||
@@ -166,9 +165,13 @@ class RankTask(BaseTask):
|
||||
y_true_groups = []
|
||||
y_pred_groups = []
|
||||
|
||||
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
|
||||
# 这样与模型学习目标一致,避免原始收益率中负值的影响
|
||||
y_true_array = test_data["y"].to_numpy()
|
||||
|
||||
for group_size in groups:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_groups.append(y_true.to_numpy()[start_idx:end_idx])
|
||||
y_true_groups.append(y_true_array[start_idx:end_idx])
|
||||
y_pred_groups.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
|
||||
249
src/training/tasks/tabm_rank_task.py
Normal file
249
src/training/tasks/tabm_rank_task.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""TabM 排序学习任务实现
|
||||
|
||||
实现基于 TabM 的排序学习训练流程:
|
||||
- Label 转换为分位数标签
|
||||
- 生成 group 数组
|
||||
- 使用 TabMRankModel(基于 ListNet Loss)
|
||||
- 支持 NDCG@k 评估
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.training.tasks.base import BaseTask
|
||||
from src.training.components.models.tabm_rank_model import TabMRankModel
|
||||
|
||||
|
||||
class TabMRankTask(BaseTask):
|
||||
"""TabM 排序学习任务
|
||||
|
||||
使用 TabMRankModel 进行排序学习训练。
|
||||
将连续收益率转换为分位数标签进行训练。
|
||||
支持指数化增益标签以增强 Top-K 关注。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_params: Dict[str, Any],
|
||||
label_name: str = "future_return_5",
|
||||
n_quantiles: int = 20,
|
||||
label_transform: Optional[str] = None,
|
||||
label_scale: float = 20.0,
|
||||
):
|
||||
"""初始化排序学习任务
|
||||
|
||||
Args:
|
||||
model_params: TabM 参数字典
|
||||
label_name: Label 列名
|
||||
n_quantiles: 分位数数量
|
||||
label_transform: 标签变换类型,可选:
|
||||
- None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
|
||||
- "exponential": 指数化增益: 2^(rank/scale) - 1
|
||||
label_scale: 指数变换的缩放因子,用于控制增益幅度
|
||||
"""
|
||||
super().__init__(model_params, label_name)
|
||||
self.n_quantiles = n_quantiles
|
||||
self.label_transform = label_transform
|
||||
self.label_scale = label_scale
|
||||
|
||||
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
|
||||
"""准备标签(转换为分位数标签,可选指数化增益变换)
|
||||
|
||||
将连续收益率转换为分位数标签,并生成 group 数组。
|
||||
支持指数化增益变换以增强头部样本的区分度。
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
处理后的数据字典(添加了 y_rank 和 groups)
|
||||
"""
|
||||
for split in ["train", "val", "test"]:
|
||||
if split not in data:
|
||||
continue
|
||||
|
||||
df = data[split]["raw_data"]
|
||||
|
||||
# 分位数转换
|
||||
rank_col = f"{self.label_name}_rank"
|
||||
|
||||
# 1. 基础分位数标签 (0 到 n_quantiles-1)
|
||||
df_ranked = df.with_columns(
|
||||
pl.col(self.label_name)
|
||||
.rank(method="min")
|
||||
.over("trade_date")
|
||||
.alias("_rank")
|
||||
).with_columns(
|
||||
((pl.col("_rank") - 1) / pl.len().over("trade_date") * self.n_quantiles)
|
||||
.floor()
|
||||
.cast(pl.Int64)
|
||||
.clip(0, self.n_quantiles - 1)
|
||||
.alias("_base_rank")
|
||||
)
|
||||
|
||||
# 2. 【Top-K 优化】可选指数化增益变换
|
||||
if self.label_transform == "exponential":
|
||||
# 平方变换: rank^2
|
||||
# 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361
|
||||
# 效果:高分样本与低分样本的差距被平方级拉大
|
||||
df_ranked = df_ranked.with_columns(
|
||||
(pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col)
|
||||
)
|
||||
else:
|
||||
# 标准分位数标签
|
||||
df_ranked = df_ranked.with_columns(
|
||||
pl.col("_base_rank").cast(pl.Float64).alias(rank_col)
|
||||
)
|
||||
|
||||
# 清理临时列
|
||||
df_ranked = df_ranked.drop(["_rank", "_base_rank"])
|
||||
|
||||
# 更新数据
|
||||
data[split]["raw_data"] = df_ranked
|
||||
data[split]["y"] = df_ranked[rank_col]
|
||||
data[split]["y_raw"] = df_ranked[self.label_name] # 保留原始值
|
||||
|
||||
# 生成 group 数组
|
||||
data[split]["groups"] = self._compute_group_array(df_ranked, "trade_date")
|
||||
|
||||
return data
|
||||
|
||||
def _compute_group_array(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
date_col: str = "trade_date",
|
||||
) -> np.ndarray:
|
||||
"""计算 group 数组
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
date_col: 日期列名
|
||||
|
||||
Returns:
|
||||
group 数组(每个日期的样本数)
|
||||
"""
|
||||
group_counts = df.group_by(date_col, maintain_order=True).agg(
|
||||
pl.count().alias("count")
|
||||
)
|
||||
return group_counts["count"].to_numpy()
|
||||
|
||||
def fit(self, train_data: Dict, val_data: Dict) -> None:
|
||||
"""训练排序模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
val_data: 验证数据
|
||||
"""
|
||||
self.model = TabMRankModel(params=self.model_params)
|
||||
|
||||
self.model.fit(
|
||||
train_data["X"],
|
||||
train_data["y"],
|
||||
group=train_data["groups"],
|
||||
eval_set=(val_data["X"], val_data["y"], val_data["groups"])
|
||||
if val_data
|
||||
else None,
|
||||
)
|
||||
|
||||
def predict(self, test_data: Dict) -> np.ndarray:
|
||||
"""生成预测
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
|
||||
Returns:
|
||||
预测结果
|
||||
"""
|
||||
# 传入 groups 参数,确保预测顺序与分组一致,与验证逻辑保持一致
|
||||
return self.model.predict(test_data["X"], group=test_data.get("groups"))
|
||||
|
||||
def evaluate_ndcg(
|
||||
self,
|
||||
test_data: Dict,
|
||||
k_list: List[int] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""评估 NDCG@k
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
k_list: k 值列表,默认 [1, 5, 10, 20]
|
||||
|
||||
Returns:
|
||||
NDCG 分数字典 {"ndcg@1": score, ...}
|
||||
"""
|
||||
if k_list is None:
|
||||
k_list = [1, 5, 10, 20]
|
||||
|
||||
y_pred = self.predict(test_data)
|
||||
groups = test_data["groups"]
|
||||
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
results = {}
|
||||
|
||||
# 按 group 拆分
|
||||
start_idx = 0
|
||||
y_true_groups = []
|
||||
y_pred_groups = []
|
||||
|
||||
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
|
||||
# 这样与模型学习目标一致,避免原始收益率中负值的影响
|
||||
y_true_array = test_data["y"].to_numpy()
|
||||
|
||||
for group_size in groups:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_groups.append(y_true_array[start_idx:end_idx])
|
||||
y_pred_groups.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
# 计算每个 k 的 NDCG
|
||||
for k in k_list:
|
||||
ndcg_scores = []
|
||||
for yt, yp in zip(y_true_groups, y_pred_groups):
|
||||
if len(yt) > 1:
|
||||
try:
|
||||
score = ndcg_score([yt], [yp], k=k)
|
||||
ndcg_scores.append(score)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
results[f"ndcg@{k}"] = float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
|
||||
|
||||
return results
|
||||
|
||||
def plot_training_metrics(self) -> None:
|
||||
"""绘制训练指标曲线(NDCG)"""
|
||||
if self.model and hasattr(self.model, "get_evals_result"):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
evals_result = self.model.get_evals_result()
|
||||
if not evals_result:
|
||||
print("[警告] 没有训练指标数据可供绘制")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
|
||||
|
||||
# 绘制训练损失
|
||||
if "train_loss" in evals_result:
|
||||
ax[0].plot(evals_result["train_loss"], label="Train Loss")
|
||||
ax[0].set_xlabel("Epoch")
|
||||
ax[0].set_ylabel("ListNet Loss")
|
||||
ax[0].set_title("Training Loss")
|
||||
ax[0].legend()
|
||||
ax[0].grid(True)
|
||||
|
||||
# 绘制验证 NDCG
|
||||
if "val_ndcg" in evals_result:
|
||||
ax[1].plot(evals_result["val_ndcg"], label="Val NDCG")
|
||||
ax[1].set_xlabel("Epoch")
|
||||
ax[1].set_ylabel("NDCG")
|
||||
ax[1].set_title("Validation NDCG")
|
||||
ax[1].legend()
|
||||
ax[1].grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
except Exception as e:
|
||||
print(f"[警告] 无法绘制训练曲线: {e}")
|
||||
Reference in New Issue
Block a user