feat(training): 新增 TabM 排序学习模型支持并优化训练流程

- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置
- 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块
- 调整数据处理器移除过度的 NaN/null 硬填充逻辑
- 优化 RankTask 评估指标使用分位数标签替代原始收益率
- 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
2026-04-04 22:39:58 +08:00
parent 9e7d4241c6
commit a66d5e9db3
16 changed files with 1663 additions and 344 deletions

View File

@@ -43,6 +43,12 @@ from src.training.core import StockPoolManager, Trainer
# 工具函数
from src.training.utils import check_data_quality
# 数据质量分析器
from src.training.data_quality_analyzer import (
DataQualityAnalyzer,
analyze_data_quality,
)
# 配置
from src.training.config import TrainingConfig
@@ -85,6 +91,9 @@ __all__ = [
"Trainer",
# 工具函数
"check_data_quality",
# 数据质量分析器
"DataQualityAnalyzer",
"analyze_data_quality",
# 配置
"TrainingConfig",
# 新增:模块化 Trainer 组件(推荐使用)

View File

@@ -7,6 +7,11 @@ from src.training.components.models.lightgbm import LightGBMModel
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
from src.training.components.models.tabpfn_model import TabPFNModel
from src.training.components.models.tabm_model import TabMModel
from src.training.components.models.tabm_rank_model import (
TabMRankModel,
EnsembleListNetLoss,
EnsembleLambdaLoss,
)
from src.training.components.models.cross_section_sampler import CrossSectionSampler
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
@@ -15,6 +20,9 @@ __all__ = [
"LightGBMLambdaRankModel",
"TabPFNModel",
"TabMModel",
"TabMRankModel",
"EnsembleListNetLoss",
"EnsembleLambdaLoss",
"CrossSectionSampler",
"EnsembleQuantLoss",
]

View File

@@ -0,0 +1,747 @@
"""TabM 排序模型实现 (TabM Rank)
基于 TabM (Tabular Multilayer Perceptron with Ensembles) 架构
引入 ListNet 列表级排序损失,实现类似 LambdaRank 的截面排序学习。
适用于股票未来收益率的截面排序预测。
"""
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
import pickle
import numpy as np
import polars as pl
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Sampler
from tabm import TabM
from src.training.components.base import BaseModel
from src.training.registry import register_model
class GroupSampler(Sampler):
"""排序学习专用的分组采样器
确保每个 Batch 包含的是同一个 Query (如同一天) 的所有样本。
这与 LightGBM 的 `group` 参数逻辑完全一致。
"""
def __init__(self, group_counts: np.ndarray, shuffle_groups: bool = True):
"""初始化分组采样器
Args:
group_counts: 各组的样本数量数组,如 [100, 150, 120]
shuffle_groups: 是否打乱组的训练顺序
"""
self.group_counts = group_counts
self.shuffle_groups = shuffle_groups
# 计算每组在原数组中的起始和结束边界
self.boundaries = np.insert(np.cumsum(group_counts), 0, 0)
self.num_groups = len(group_counts)
def __iter__(self):
"""迭代生成批次索引
Yields:
list: 同一组的所有样本索引
"""
group_indices = list(range(self.num_groups))
if self.shuffle_groups:
np.random.shuffle(group_indices)
for g_idx in group_indices:
start = self.boundaries[g_idx]
end = self.boundaries[g_idx + 1]
# 返回该组(天)内所有样本的索引,作为一个完整的 batch
yield list(range(start, end))
def __len__(self):
"""返回组数量"""
return self.num_groups
class EnsembleListNetLoss(nn.Module):
"""集成 ListNet 排序损失 (Listwise Ranking Loss)
基于交叉熵将截面上的相关度转化为概率分布。
支持 TabM 的 Ensemble 维度。
"""
def __init__(self, topk_weight: float = 1.0):
"""初始化 ListNet 损失
Args:
topk_weight: 头部样本权重系数,>1.0 时强化对高分标签的关注
"""
super().__init__()
self.topk_weight = topk_weight
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
"""计算 ListNet 排序损失
Args:
preds: [BatchSize, EnsembleSize] - 当前组所有样本的预测 logits
targets: [BatchSize] - 当前组所有样本的真实排序标签 (0, 1, 2...)
Returns:
标量损失值
"""
# 如果该组内样本太少,无法排序,直接返回 0
if preds.size(0) <= 1:
return torch.tensor(0.0, requires_grad=True, device=preds.device)
# [BatchSize] -> [BatchSize, EnsembleSize] 广播以对齐每个集成成员
targets_expanded = targets.unsqueeze(1).expand_as(preds)
# 1. 计算真实标签的分数概率分布 (Softmax 使得高分权重成倍提升)
targets_prob = F.softmax(targets_expanded, dim=0)
# 2. 【Top-K 优化】如果启用加权,给予高分标签更高权重
if self.topk_weight > 1.0:
# 基于标签值计算权重(标签越高,权重越大)
# 归一化到 [1, topk_weight] 范围
max_target = targets.max()
min_target = targets.min()
if max_target > min_target:
# 线性插值:权重 = 1 + (target - min) / (max - min) * (topk_weight - 1)
sample_weights = 1.0 + (targets - min_target) / (
max_target - min_target
) * (self.topk_weight - 1.0)
sample_weights = sample_weights.unsqueeze(1).expand_as(preds)
# 归一化权重使其和为样本数
sample_weights = sample_weights * len(targets) / sample_weights.sum()
# 应用权重到目标概率
targets_prob = targets_prob * sample_weights
# 3. 计算预测值的对数概率分布 (log_softmax 数值上比 log(softmax) 更稳定)
preds_log_prob = F.log_softmax(preds, dim=0)
# 4. 计算交叉熵损失 (在 Batch/组 维度求和)
loss = -torch.sum(targets_prob * preds_log_prob, dim=0) # [EnsembleSize]
# 5. 对所有集成成员的 Loss 取平均
return loss.mean()
class EnsembleLambdaLoss(nn.Module):
"""集成 LambdaLoss (支持 TabM 集成维度)
基于 Pairwise 排序损失,引入 DeltaNDCG 权重。
参考: "The LambdaLoss Framework for Ranking Metric Optimization" (Google Research)
特点:
- 计算每对样本交换位置后对 NDCG 的影响
- 对头部样本(高 Gain 且排名靠前)给予更高权重
- 更适合 Top-K 选股场景
"""
def __init__(self, sigma: float = 1.0, ndcg_weight_power: float = 1.0):
"""初始化 LambdaLoss
Args:
sigma: Sigmoid 函数的陡峭程度,控制梯度大小
ndcg_weight_power: DeltaNDCG 权重幂次,>1 时进一步放大头部效应
"""
super().__init__()
self.sigma = sigma
self.ndcg_weight_power = ndcg_weight_power
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
"""计算 LambdaLoss
Args:
preds: [BatchSize, EnsembleSize] - 预测分
targets: [BatchSize] - 相关性标签 (Gain)
Returns:
标量损失值
"""
if preds.size(0) <= 1:
return torch.tensor(0.0, requires_grad=True, device=preds.device)
# 1. 计算两两对之间的差值
preds_diff = preds.unsqueeze(1) - preds.unsqueeze(0) # [B, B, E]
# 【性能优化】: target_diff 不需要 E 维度!直接保持 [B, B]
target_diff = targets.unsqueeze(1) - targets.unsqueeze(0) # [B, B]
# 2. 掩码矩阵: [B, B, 1] 方便后续广播
mask = (target_diff > 0).float().unsqueeze(2) # [B, B, 1]
# 3. 计算 Delta NDCG
# DeltaNDCG = |Gain_i - Gain_j| * |1/log(rank_i+1) - 1/log(rank_j+1)|
with torch.no_grad():
# 【性能核爆优化】: 使用两次 argsort 完全消灭 for 循环
# 第1次 argsort: 获得从大到小的索引
# 第2次 argsort: 直接将索引反转为排名 (加上 1 就是真实名次)
ranks = preds.argsort(dim=0, descending=True).argsort(dim=0) + 1
ranks = ranks.float() # [B, E]
# 计算位置惩罚项 (log2 排名)
log_rank = torch.log2(ranks + 1)
inv_log_rank_diff = torch.abs(
1.0 / log_rank.unsqueeze(1) - 1.0 / log_rank.unsqueeze(0)
) # [B, B, E]
# DeltaNDCG = |Gain 差| * |位置惩罚差|
# target_diff 是 [B, B],通过 unsqueeze 广播到 [B, B, 1]
delta_ndcg = torch.abs(target_diff).unsqueeze(2) * inv_log_rank_diff
# 应用幂次调整权重分布
if self.ndcg_weight_power != 1.0:
delta_ndcg = torch.pow(delta_ndcg, self.ndcg_weight_power)
# 4. Pairwise Logistic Loss 并加权 Delta NDCG
# loss = delta_ndcg * log(1 + exp(-sigma * (preds_i - preds_j)))
pairwise_loss = F.binary_cross_entropy_with_logits(
self.sigma * preds_diff,
torch.ones_like(preds_diff),
reduction="none",
)
# 5. 应用掩码和权重
weighted_loss = pairwise_loss * mask * delta_ndcg
# 避免全 0 导致的除零错误
valid_pairs_count = mask.sum().clamp(min=1.0)
mean_loss = weighted_loss.sum() / valid_pairs_count
return mean_loss
@register_model("tabm_rank")
class TabMRankModel(BaseModel):
"""TabM 学习排序模型
基于 TabM 架构的排序学习模型,支持 ListNet 损失。
适用于股票截面排序任务,将未来收益率转换为分位数标签进行训练。
特点:
- 使用 ListNet 列表级排序损失
- 支持 group 参数进行分组训练
- 以 NDCG 作为验证指标
- 与 LightGBMLambdaRank 接口兼容
"""
name = "tabm_rank"
def __init__(self, params: Optional[Dict[str, Any]] = None):
"""初始化 TabM Rank 模型
Args:
params: 模型参数字典,包含:
- ensemble_size: 集成大小 (默认: 32)
- n_blocks: MLP层数 (默认: 3)
- d_block: 每层神经元数 (默认: 256)
- dropout: Dropout率 (默认: 0.1)
- batch_size: 批次大小 (默认: 2048仅预测时使用)
- learning_rate: 学习率 (默认: 1e-3)
- weight_decay: 权重衰减 (默认: 1e-5)
- epochs: 训练轮数 (默认: 50)
- early_stopping_round: 早停轮数 (默认: 10)
- max_grad_norm: 梯度裁剪阈值 (默认: 1.0)
- ndcg_k: NDCG@k 的 k 值None 表示全局 (默认: None)
- loss_type: 损失函数类型 (默认: "listnet")
- "listnet": 标准 ListNet 损失
- "weighted_listnet": 加权 ListNet通过 topk_weight 强化头部
- "lambda": LambdaLoss基于 DeltaNDCG 加权
- topk_weight: 头部样本权重系数,用于 weighted_listnet (默认: 5.0)
- lambda_sigma: LambdaLoss 的 sigma 参数 (默认: 1.0)
- ndcg_weight_power: DeltaNDCG 权重幂次 (默认: 1.0)
"""
self.params = params or {}
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.training_history_: Dict[str, List[float]] = {
"train_loss": [],
"val_ndcg": [],
}
self.feature_names_: Optional[List[str]] = None
# 根据配置选择损失函数
loss_type = self.params.get("loss_type", "listnet")
if loss_type == "lambda":
self.criterion = EnsembleLambdaLoss(
sigma=self.params.get("lambda_sigma", 1.0),
ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0),
)
elif loss_type == "weighted_listnet":
self.criterion = EnsembleListNetLoss(
topk_weight=self.params.get("topk_weight", 5.0)
)
else: # "listnet"
self.criterion = EnsembleListNetLoss(topk_weight=1.0)
def _make_loader(
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
group: Optional[np.ndarray] = None,
shuffle_groups: bool = False,
) -> DataLoader:
"""创建 DataLoader (支持 Query/Group 截面打包)
Args:
X: 特征数组 [N, n_features]
y: 标签数组 [N] 或 None
group: 分组数组,表示每组样本数
shuffle_groups: 是否打乱组的顺序
Returns:
DataLoader 实例
"""
# 【性能核爆优化】: 显存管够时,直接把全量数据一把推进 GPU避免训练时 PCIe 搬运
X_tensor = torch.from_numpy(X).to(self.device)
if y is not None:
y_tensor = torch.from_numpy(y).to(self.device)
dataset = TensorDataset(X_tensor, y_tensor)
else:
dataset = TensorDataset(X_tensor)
if group is not None:
# 训练和验证时使用 GroupSampler每个 batch 就是一个 Query
sampler = GroupSampler(group, shuffle_groups=shuffle_groups)
return DataLoader(dataset, batch_sampler=sampler)
else:
# 预测时如果没有 group则退化为普通批次预测
batch_size = self.params.get("batch_size", 2048)
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float:
"""验证模型 (使用 NDCG 排序指标)
Args:
val_loader: 验证数据加载器
k: NDCG@k 的 k 值None 表示计算全局 NDCG
Returns:
平均 NDCG 分数
"""
from sklearn.metrics import ndcg_score
assert self.model is not None, "模型未训练,无法验证"
self.model.eval()
ndcg_list = []
with torch.no_grad():
for batch in val_loader:
if len(batch) != 2:
continue
bx, by = batch
bx = bx.to(self.device)
by = by.cpu().numpy()
if len(by) <= 1:
continue
outputs = self.model(bx) # [B, E, 1]
preds = outputs.mean(dim=1).squeeze(-1).cpu().numpy() # [B]
try:
# ndcg_score 需要形状为 (1, n_samples) 的二维数组
score = ndcg_score([by], [preds], k=k)
ndcg_list.append(score)
except ValueError:
pass
return float(np.mean(ndcg_list)) if len(ndcg_list) > 0 else 0.0
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
group: Optional[np.ndarray] = None,
eval_set: Optional[Tuple] = None,
) -> "TabMRankModel":
"""训练排序模型
Args:
X: 训练特征DataFrame
y: 训练标签 (Polars Series),应为分位数标签 (0, 1, 2, ...)
group: 分组数组,表示每个 query 的样本数
eval_set: 验证集元组 (X_val, y_val, group_val),用于早停
Returns:
self (支持链式调用)
Raises:
ValueError: group 参数无效
"""
self.feature_names_ = list(X.columns)
X_np = X.to_numpy().astype(np.float32)
y_np = y.to_numpy().astype(np.float32)
# 检查和处理 group 参数
if group is None:
group = np.array([len(y_np)])
if group.sum() != len(y_np):
raise ValueError(
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
)
train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True)
val_loader = None
if eval_set is not None:
X_val, y_val, group_val = eval_set
X_val_np = (
X_val.to_numpy().astype(np.float32)
if isinstance(X_val, pl.DataFrame)
else X_val
)
y_val_np = (
y_val.to_numpy().astype(np.float32)
if isinstance(y_val, pl.Series)
else y_val
)
if group_val is None:
group_val = np.array([len(y_val_np)])
val_loader = self._make_loader(
X_val_np, y_val_np, group=group_val, shuffle_groups=False
)
ensemble_size = self.params.get("ensemble_size", 32)
n_features = X_np.shape[1]
# 初始化 TabM 模型
self.model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1,
n_blocks=self.params.get("n_blocks", 3),
d_block=self.params.get("d_block", 256),
dropout=self.params.get("dropout", 0.1),
k=ensemble_size,
).to(self.device)
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.params.get("learning_rate", 1e-3),
weight_decay=self.params.get("weight_decay", 1e-5),
)
epochs = self.params.get("epochs", 50)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=1e-6
)
early_stopping_patience = self.params.get(
"early_stopping_patience",
self.params.get("early_stopping_round", 10),
)
best_val_ndcg = -float("inf")
patience_counter = 0
best_model_state = None
ndcg_k = self.params.get("ndcg_k", None) # None 表示计算全局 NDCG
print(f"[TabMRank] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}")
for epoch in range(epochs):
# 训练阶段
self.model.train()
train_loss = 0.0
n_train_batches = 0
for batch in train_loader:
if len(batch) != 2:
continue
bx, by = batch[0], batch[1]
optimizer.zero_grad()
outputs = self.model(bx) # [B, E, 1]
outputs_squeezed = outputs.squeeze(-1) # [B, E]
# 计算 ListNet 排序损失
loss = self.criterion(outputs_squeezed, by)
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.params.get("max_grad_norm", 1.0),
)
optimizer.step()
train_loss += loss.item()
n_train_batches += 1
avg_train_loss = train_loss / max(n_train_batches, 1)
self.training_history_["train_loss"].append(avg_train_loss)
# 验证阶段 (基于 NDCG)
if val_loader is not None:
val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k)
self.training_history_["val_ndcg"].append(val_ndcg)
if val_ndcg > best_val_ndcg:
best_val_ndcg = val_ndcg
patience_counter = 0
best_model_state = {
k: v.cpu().clone() for k, v in self.model.state_dict().items()
}
else:
patience_counter += 1
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
f"Train Loss (ListNet): {avg_train_loss:.4f} | "
f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})"
)
if patience_counter >= early_stopping_patience:
print(f"[TabMRank] 触发早停,停止于 epoch {epoch + 1}")
break
else:
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f}"
)
scheduler.step()
# 恢复最佳权重
if best_model_state is not None:
self.model.load_state_dict(best_model_state)
print(f"[TabMRank] 已恢复最佳模型权重 (Val NDCG: {best_val_ndcg:.4f})")
return self
def predict(
self, X: pl.DataFrame, group: Optional[np.ndarray] = None
) -> np.ndarray:
"""预测排序分数
Args:
X: 特征矩阵 (Polars DataFrame)
group: 分组数组,表示每个 query 的样本数。
如果提供,将使用 GroupSampler 确保预测顺序与分组一致。
Returns:
预测分数 (numpy ndarray)
Raises:
RuntimeError: 模型未训练时调用
ValueError: 预测数据缺失特征
"""
if self.model is None:
raise RuntimeError("模型未训练请先调用fit()")
# 特征对齐检查
if self.feature_names_:
missing_cols = [c for c in self.feature_names_ if c not in X.columns]
if missing_cols:
raise ValueError(f"预测数据缺失特征: {missing_cols}")
X = X.select(self.feature_names_)
X_np = X.to_numpy().astype(np.float32)
loader = self._make_loader(X_np, group=group, shuffle_groups=False)
self.model.eval()
all_preds = []
with torch.no_grad():
for batch in loader:
bx = batch[0].to(self.device)
outputs = self.model(bx) # [B, E, 1]
# 排序模型预测时直接输出集成成员的均值作为最终分数
preds = outputs.mean(dim=1).squeeze(-1) # [B]
all_preds.append(preds.cpu().numpy())
return np.concatenate(all_preds)
def get_evals_result(self) -> Optional[Dict[str, List[float]]]:
"""获取训练评估结果
Returns:
评估结果字典,包含 train_loss 和 val_ndcg
"""
return self.training_history_
def feature_importance(self) -> None:
"""获取特征重要性
TabM没有内置特征重要性计算返回None。
"""
return None
def save(self, path: str | Path) -> None:
"""保存模型
Args:
path: 保存路径
Raises:
RuntimeError: 模型未训练时调用
"""
if self.model is None:
raise RuntimeError("模型未训练,无法保存")
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# 保存模型权重
model_path = path.with_suffix(".pt")
torch.save(self.model.state_dict(), model_path)
# 保存元数据
meta_path = path.with_suffix(".meta")
meta = {
"params": self.params,
"feature_names": self.feature_names_,
"training_history": self.training_history_,
"device": str(self.device),
}
with open(meta_path, "wb") as f:
pickle.dump(meta, f)
print(f"[TabMRank] 模型保存到: {path}")
@classmethod
def load(cls, path: str | Path) -> "TabMRankModel":
"""加载模型
Args:
path: 模型路径(不含扩展名)
Returns:
加载的 TabMRankModel 实例
"""
path = Path(path)
# 加载元数据
meta_path = path.with_suffix(".meta")
with open(meta_path, "rb") as f:
meta = pickle.load(f)
# 创建实例
instance = cls(meta["params"])
instance.feature_names_ = meta["feature_names"]
instance.training_history_ = meta["training_history"]
# 重建模型结构
if instance.feature_names_ is not None:
n_features = len(instance.feature_names_)
ensemble_size = instance.params.get("ensemble_size", 32)
instance.model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1,
n_blocks=instance.params.get("n_blocks", 3),
d_block=instance.params.get("d_block", 256),
dropout=instance.params.get("dropout", 0.1),
k=ensemble_size,
).to(instance.device)
# 加载权重
model_path = path.with_suffix(".pt")
instance.model.load_state_dict(
torch.load(model_path, map_location=instance.device)
)
print(f"[TabMRank] 模型从 {path} 加载完成")
return instance
@staticmethod
def prepare_group_from_dates(
df: pl.DataFrame,
date_col: str = "trade_date",
) -> np.ndarray:
"""从日期列生成 group 数组
Args:
df: 包含日期列的 DataFrame
date_col: 日期列名,默认 "trade_date"
Returns:
group 数组
"""
group_counts = df.group_by(date_col, maintain_order=True).agg(
pl.count().alias("count")
)
return group_counts["count"].to_numpy()
@staticmethod
def convert_to_quantile_labels(
df: pl.DataFrame,
label_col: str,
date_col: str = "trade_date",
n_quantiles: int = 20,
new_col_name: Optional[str] = None,
) -> pl.DataFrame:
"""将连续标签转换为分位数标签
Args:
df: 输入 DataFrame
label_col: 原始标签列名
date_col: 日期列名,默认 "trade_date"
n_quantiles: 分位数数量,默认 20
new_col_name: 新列名,默认为 {label_col}_rank
Returns:
添加了分位数标签列的 DataFrame
"""
if new_col_name is None:
new_col_name = f"{label_col}_rank"
return df.with_columns(
pl.col(label_col)
.qcut(n_quantiles)
.over(date_col)
.to_physical()
.cast(pl.Int64)
.alias(new_col_name)
)
def evaluate_ndcg(
self,
X: pl.DataFrame,
y: pl.Series,
group: np.ndarray,
k: Optional[int] = None,
) -> float:
"""评估 NDCG 指标
Args:
X: 特征矩阵
y: 真实标签
group: 分组数组
k: NDCG@k 的 k 值
Returns:
NDCG 分数
"""
from sklearn.metrics import ndcg_score
y_pred = self.predict(X)
y_true_list = []
y_score_list = []
start_idx = 0
for group_size in group:
end_idx = start_idx + group_size
y_true_list.append(y.to_numpy()[start_idx:end_idx])
y_score_list.append(y_pred[start_idx:end_idx])
start_idx = end_idx
ndcg_scores = []
for y_true, y_score in zip(y_true_list, y_score_list):
if len(y_true) > 1:
try:
score = ndcg_score([y_true], [y_score], k=k)
ndcg_scores.append(score)
except ValueError:
pass
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0

View File

@@ -289,7 +289,7 @@ class StandardScaler(BaseProcessor):
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""标准化(使用训练集学到的参数,增加 NaN 保护
"""标准化(使用训练集学到的参数)
Args:
X: 待转换数据
@@ -302,18 +302,11 @@ class StandardScaler(BaseProcessor):
if col in self.mean_ and col in self.std_:
# 避免除以0
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
# 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN
expr = (
((pl.col(col) - self.mean_[col]) / std_val)
.fill_nan(0)
.fill_null(0)
.alias(col)
)
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但未学习到统计量的列
# 统一转换为float并同时处理 NaN 和 null
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
# 对于应该被处理但未学习到统计量的列统一转换为float
expr = pl.col(col).cast(pl.Float64).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -372,20 +365,14 @@ class CrossSectionalStandardScaler(BaseProcessor):
if col in self.feature_cols and X[col].dtype.is_numeric():
# 截面标准化:每天独立计算均值和标准差
# 避免除以0当std为0时设为1
# 关键修复:先 fill_nan 再 fill_null防止计算产生的 NaN
expr = (
(
(pl.col(col) - pl.col(col).mean().over(self.date_col))
/ (pl.col(col).std().over(self.date_col) + 1e-10)
)
.fill_nan(0)
.fill_null(0)
.alias(col)
)
(pl.col(col) - pl.col(col).mean().over(self.date_col))
/ (pl.col(col).std().over(self.date_col) + 1e-10)
).alias(col)
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但类型不匹配的列转换为float并同时处理 NaN 和 null
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
# 对于应该被处理但类型不匹配的列转换为float
expr = pl.col(col).cast(pl.Float64).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -488,8 +475,8 @@ class Winsorizer(BaseProcessor):
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但未学习到边界的列如全为NaN、布尔列等
# 统一转换为float并填充0
expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col)
# 统一转换为float
expr = pl.col(col).cast(pl.Float64).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -522,11 +509,9 @@ class Winsorizer(BaseProcessor):
clip_exprs = []
for col in X.columns:
if col in target_cols:
# 先用当天分位数缩尾如果分位数是null该日全为NaN则填充0
clipped = (
pl.col(col)
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
.fill_null(0)
.alias(col)
)
clip_exprs.append(clipped)

View File

@@ -13,6 +13,7 @@ from src.factors import FactorEngine
from src.training.pipeline import DataPipeline
from src.training.tasks.base import BaseTask
from src.training.result_analyzer import ResultAnalyzer
from src.training.data_quality_analyzer import DataQualityAnalyzer
class Trainer:
@@ -100,8 +101,6 @@ class Trainer:
print("\n[Step 1.5/7] 数据质量分析...")
try:
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
# 获取特征列名(从训练集)
feature_cols = data["train"].get("feature_cols", [])
label_name = self.task.label_name

View File

@@ -0,0 +1,646 @@
"""数据质量分析模块
提供数据质量检查功能,包括:
- 数据集日期范围信息
- 缺失值统计
- 零值统计
- 按日期检查全空列
"""
from typing import Any, Dict, List, Optional
import polars as pl
import numpy as np
class DataQualityAnalyzer:
"""数据质量分析器
用于分析训练数据的质量问题,帮助识别数据异常。
Attributes:
feature_cols: 特征列名列表
label_col: 标签列名
date_col: 日期列名
verbose: 是否打印详细信息
"""
def __init__(
self,
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
date_col: str = "trade_date",
verbose: bool = True,
):
"""初始化数据质量分析器
Args:
feature_cols: 特征列名列表
label_col: 标签列名
date_col: 日期列名,默认为 "trade_date"
verbose: 是否打印详细信息
"""
self.feature_cols = feature_cols or []
self.label_col = label_col
self.date_col = date_col
self.verbose = verbose
self.analysis_results: Dict[str, Any] = {}
def set_columns(self, feature_cols: List[str], label_col: str) -> None:
"""设置要分析的列
Args:
feature_cols: 特征列名列表
label_col: 标签列名
"""
self.feature_cols = feature_cols
self.label_col = label_col
def analyze(
self,
data: Dict[str, Dict[str, Any]],
split_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""执行完整的数据质量分析
Args:
data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}}
split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"]
Returns:
分析结果字典
"""
if not split_names:
split_names = ["train", "val", "test"]
if self.verbose:
print("\n" + "=" * 80)
print("数据质量分析报告")
print("=" * 80)
self.analysis_results = {}
# 首先打印数据集概览(日期范围等基本信息)
if self.verbose:
self._print_dataset_overview(data, split_names)
for split_name in split_names:
if split_name not in data:
continue
split_data = data[split_name]
raw_df = split_data.get("raw_data")
if raw_df is None:
continue
if self.verbose:
print(f"\n[{split_name.upper()} 数据集]")
print("-" * 40)
split_results = self._analyze_split(raw_df, split_name)
self.analysis_results[split_name] = split_results
if self.verbose:
print("\n" + "=" * 80)
return self.analysis_results
def _print_dataset_overview(
self,
data: Dict[str, Dict[str, Any]],
split_names: List[str],
) -> None:
"""打印数据集概览信息
包括每个数据集的起始日期、终止日期、样本数量等基本信息。
Args:
data: 数据字典
split_names: 数据划分名称列表
"""
print("\n[数据集概览]")
print("-" * 40)
overview_data = []
for split_name in split_names:
if split_name not in data:
continue
split_data = data[split_name]
raw_df = split_data.get("raw_data")
if raw_df is None or len(raw_df) == 0:
overview_data.append(
{
"划分": split_name.upper(),
"起始日期": "-",
"终止日期": "-",
"样本数": 0,
"股票数": 0,
}
)
continue
# 获取日期范围
if self.date_col in raw_df.columns:
dates = raw_df[self.date_col]
start_date = dates.min()
end_date = dates.max()
unique_dates = dates.n_unique()
else:
start_date = "-"
end_date = "-"
unique_dates = 0
# 获取股票数量
if "ts_code" in raw_df.columns:
unique_stocks = raw_df["ts_code"].n_unique()
else:
unique_stocks = 0
overview_data.append(
{
"划分": split_name.upper(),
"起始日期": str(start_date),
"终止日期": str(end_date),
"交易日数": unique_dates,
"样本数": len(raw_df),
"股票数": unique_stocks,
}
)
# 打印表格
if overview_data:
# 计算列宽
headers = ["划分", "起始日期", "终止日期", "交易日数", "样本数", "股票数"]
col_widths = {}
for header in headers:
max_data_len = max(
len(
str(
row.get(
header.lower().replace("", ""), row.get(header, "")
)
)
)
for row in overview_data
)
col_widths[header] = max(len(header), max_data_len) + 2
# 打印表头
header_line = " ".join(h.ljust(col_widths[h]) for h in headers)
print(f" {header_line}")
print(f" {'-' * (sum(col_widths.values()) + 2 * (len(headers) - 1))}")
# 打印数据行
for row in overview_data:
line = " ".join(
str(row.get(h, row.get(h.lower().replace("", ""), ""))).ljust(
col_widths[h]
)
for h in headers
)
print(f" {line}")
def _analyze_split(
self,
df: pl.DataFrame,
split_name: str,
) -> Dict[str, Any]:
"""分析单个数据集划分
Args:
df: 数据框
split_name: 划分名称
Returns:
分析结果字典
"""
results = {
"total_rows": len(df),
"total_cols": len(df.columns),
"feature_cols": self.feature_cols,
"label_col": self.label_col,
"null_analysis": {},
"zero_analysis": {},
"all_null_by_date": {},
}
# 获取日期范围
if self.date_col in df.columns:
results["start_date"] = str(df[self.date_col].min())
results["end_date"] = str(df[self.date_col].max())
results["unique_dates"] = df[self.date_col].n_unique()
# 获取股票数量
if "ts_code" in df.columns:
results["unique_stocks"] = df["ts_code"].n_unique()
# 1. 分析特征列的缺失值
null_stats = self._analyze_null_values(df, self.feature_cols)
results["null_analysis"] = null_stats
if self.verbose:
self._print_null_analysis(null_stats)
# 2. 分析特征列的零值
zero_stats = self._analyze_zero_values(df, self.feature_cols)
results["zero_analysis"] = zero_stats
if self.verbose:
self._print_zero_analysis(zero_stats)
# 3. 检查是否存在某天某列全为空的情况
all_null_by_date = self._check_all_null_by_date(df, self.feature_cols)
results["all_null_by_date"] = all_null_by_date
if self.verbose:
self._print_all_null_by_date(all_null_by_date)
# 4. 分析标签列
if self.label_col and self.label_col in df.columns:
label_stats = self._analyze_label(df, self.label_col)
results["label_analysis"] = label_stats
if self.verbose:
self._print_label_analysis(label_stats)
return results
def _analyze_null_values(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""分析缺失值
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
缺失值统计字典
"""
stats = {
"total_cells": len(df) * len(cols),
"null_counts": {},
"null_percentages": {},
"columns_with_null": [],
"total_null_cells": 0,
}
for col in cols:
if col not in df.columns:
continue
null_count = df[col].null_count()
if null_count > 0:
null_pct = null_count / len(df) * 100
stats["null_counts"][col] = null_count
stats["null_percentages"][col] = null_pct
stats["columns_with_null"].append(col)
stats["total_null_cells"] += null_count
return stats
def _analyze_zero_values(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""分析零值
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
零值统计字典
"""
stats = {
"total_cells": len(df) * len(cols),
"zero_counts": {},
"zero_percentages": {},
"columns_with_zero": [],
"total_zero_cells": 0,
}
for col in cols:
if col not in df.columns:
continue
# 计算零值数量(排除空值)
non_null_series = df[col].drop_nulls()
if len(non_null_series) == 0:
continue
zero_count = (non_null_series == 0).sum()
if zero_count > 0:
zero_pct = zero_count / len(df) * 100
stats["zero_counts"][col] = int(zero_count)
stats["zero_percentages"][col] = zero_pct
stats["columns_with_zero"].append(col)
stats["total_zero_cells"] += int(zero_count)
return stats
def _check_all_null_by_date(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""检查是否存在某天某列全为空的情况
使用 polars lazy frame 进行内存安全的高效计算。
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
全空检查结果字典
"""
results = {
"issues_found": False,
"issues": [],
}
if self.date_col not in df.columns:
return results
# 过滤掉不在表中的列
valid_cols = [c for c in cols if c in df.columns]
if not valid_cols:
return results
# 使用 lazy frame 进行查询优化
lf = df.lazy()
# 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小)
# 为每个列创建单独的 null_count 聚合表达式
agg_exprs = [
pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols
]
agg_exprs.append(pl.len().alias("total_rows"))
agg_lf = lf.group_by(self.date_col).agg(agg_exprs)
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
agg_df = agg_lf.collect()
# 在这个已经"脱水"的小表上进行逻辑检查
issues = []
for col in valid_cols:
null_col = f"{col}_nulls"
# 找出 null 数量等于总行数的日期
bad_dates = agg_df.filter(
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
).select([self.date_col, "total_rows"])
if not bad_dates.is_empty():
for row in bad_dates.to_dicts():
issues.append(
{
"date": row[self.date_col],
"column": col,
"total_rows": row["total_rows"],
}
)
if issues:
results["issues_found"] = True
results["issues"] = issues
return results
def _analyze_label(
self,
df: pl.DataFrame,
label_col: str,
) -> Dict[str, Any]:
"""分析标签列
Args:
df: 数据框
label_col: 标签列名
Returns:
标签分析字典
"""
stats = {
"total_count": len(df),
"null_count": 0,
"null_percentage": 0.0,
"zero_count": 0,
"zero_percentage": 0.0,
"min": None,
"max": None,
"mean": None,
"std": None,
}
if label_col not in df.columns:
return stats
series = df[label_col]
# 缺失值统计
null_count = series.null_count()
stats["null_count"] = null_count
stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0
# 零值统计
non_null_series = series.drop_nulls()
if len(non_null_series) > 0:
zero_count = (non_null_series == 0).sum()
stats["zero_count"] = int(zero_count)
stats["zero_percentage"] = zero_count / len(df) * 100
# 基本统计量
stats["min"] = float(non_null_series.min())
stats["max"] = float(non_null_series.max())
stats["mean"] = float(non_null_series.mean())
stats["std"] = float(non_null_series.std())
return stats
def _print_null_analysis(self, stats: Dict[str, Any]) -> None:
"""打印缺失值分析结果
Args:
stats: 缺失值统计字典
"""
total_cells = stats["total_cells"]
total_null = stats["total_null_cells"]
null_cols = stats["columns_with_null"]
print(f" 缺失值统计:")
print(f" 总单元格数: {total_cells:,}")
print(
f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)"
)
print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}")
if null_cols:
print(f" 缺失值最多的5个特征:")
sorted_cols = sorted(
stats["null_counts"].items(),
key=lambda x: x[1],
reverse=True,
)[:5]
for col, count in sorted_cols:
pct = stats["null_percentages"][col]
print(f" {col}: {count:,} ({pct:.2f}%)")
def _print_zero_analysis(self, stats: Dict[str, Any]) -> None:
"""打印零值分析结果
Args:
stats: 零值统计字典
"""
total_cells = stats["total_cells"]
total_zero = stats["total_zero_cells"]
zero_cols = stats["columns_with_zero"]
print(f" 零值统计:")
print(f" 总单元格数: {total_cells:,}")
print(
f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)"
)
print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}")
if zero_cols:
print(f" 零值最多的5个特征:")
sorted_cols = sorted(
stats["zero_counts"].items(),
key=lambda x: x[1],
reverse=True,
)[:5]
for col, count in sorted_cols:
pct = stats["zero_percentages"][col]
print(f" {col}: {count:,} ({pct:.2f}%)")
def _print_all_null_by_date(self, results: Dict[str, Any]) -> None:
"""打印按日期全空检查结果
Args:
results: 全空检查结果字典
"""
issues = results["issues"]
print(f" 按日期全空检查:")
if results["issues_found"]:
print(f" [警告] 发现 {len(issues)} 个问题:")
# 按日期分组显示
by_date = {}
for issue in issues:
date = issue["date"]
if date not in by_date:
by_date[date] = []
by_date[date].append(issue["column"])
for date in sorted(by_date.keys())[:5]: # 只显示前5个日期
cols = by_date[date]
print(f" 日期 {date}: {len(cols)} 列全为空")
if len(cols) <= 3:
print(f" 列名: {', '.join(cols)}")
if len(by_date) > 5:
print(f" ... 还有 {len(by_date) - 5} 个日期存在问题")
else:
print(f" [正常] 未发现某天某列全为空的情况")
def _print_label_analysis(self, stats: Dict[str, Any]) -> None:
"""打印标签分析结果
Args:
stats: 标签分析字典
"""
print(f" 标签列统计 ({self.label_col}):")
print(f" 总数: {stats['total_count']:,}")
print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)")
print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)")
if stats["mean"] is not None:
print(f" 最小值: {stats['min']:.6f}")
print(f" 最大值: {stats['max']:.6f}")
print(f" 均值: {stats['mean']:.6f}")
print(f" 标准差: {stats['std']:.6f}")
def get_summary(self) -> str:
"""获取分析结果摘要
Returns:
摘要字符串
"""
if not self.analysis_results:
return "尚未执行分析"
lines = ["数据质量分析摘要", "=" * 40]
for split_name, results in self.analysis_results.items():
lines.append(f"\n[{split_name.upper()}]")
# 添加日期范围信息
if "start_date" in results and "end_date" in results:
lines.append(
f" 日期范围: {results['start_date']} ~ {results['end_date']}"
)
if "unique_dates" in results:
lines.append(f" 交易日数: {results['unique_dates']}")
lines.append(f" 总行数: {results['total_rows']:,}")
if "unique_stocks" in results:
lines.append(f" 股票数: {results['unique_stocks']}")
null_stats = results.get("null_analysis", {})
if null_stats.get("columns_with_null"):
lines.append(
f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, "
f"{len(null_stats['columns_with_null'])} 列受影响"
)
zero_stats = results.get("zero_analysis", {})
if zero_stats.get("columns_with_zero"):
lines.append(
f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, "
f"{len(zero_stats['columns_with_zero'])} 列受影响"
)
all_null = results.get("all_null_by_date", {})
if all_null.get("issues_found"):
lines.append(
f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题"
)
return "\n".join(lines)
def analyze_data_quality(
data: Dict[str, Dict[str, Any]],
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
date_col: str = "trade_date",
verbose: bool = True,
) -> Dict[str, Any]:
"""便捷函数:执行数据质量分析
Args:
data: 数据字典
feature_cols: 特征列名列表
label_col: 标签列名
date_col: 日期列名,默认为 "trade_date"
verbose: 是否打印详细信息
Returns:
分析结果字典
"""
analyzer = DataQualityAnalyzer(
feature_cols=feature_cols,
label_col=label_col,
date_col=date_col,
verbose=verbose,
)
return analyzer.analyze(data)

View File

@@ -7,10 +7,12 @@ from src.training.tasks.base import BaseTask
from src.training.tasks.regression_task import RegressionTask
from src.training.tasks.rank_task import RankTask
from src.training.tasks.tabm_regression_task import TabMRegressionTask
from src.training.tasks.tabm_rank_task import TabMRankTask
__all__ = [
"BaseTask",
"RegressionTask",
"RankTask",
"TabMRegressionTask",
"TabMRankTask",
]

View File

@@ -153,7 +153,6 @@ class RankTask(BaseTask):
if k_list is None:
k_list = [1, 5, 10, 20]
y_true = test_data["y_raw"]
y_pred = self.predict(test_data)
groups = test_data["groups"]
@@ -166,9 +165,13 @@ class RankTask(BaseTask):
y_true_groups = []
y_pred_groups = []
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
# 这样与模型学习目标一致,避免原始收益率中负值的影响
y_true_array = test_data["y"].to_numpy()
for group_size in groups:
end_idx = start_idx + group_size
y_true_groups.append(y_true.to_numpy()[start_idx:end_idx])
y_true_groups.append(y_true_array[start_idx:end_idx])
y_pred_groups.append(y_pred[start_idx:end_idx])
start_idx = end_idx

View File

@@ -0,0 +1,249 @@
"""TabM 排序学习任务实现
实现基于 TabM 的排序学习训练流程:
- Label 转换为分位数标签
- 生成 group 数组
- 使用 TabMRankModel基于 ListNet Loss
- 支持 NDCG@k 评估
"""
from typing import Any, Dict, List, Optional
import numpy as np
import polars as pl
from src.training.tasks.base import BaseTask
from src.training.components.models.tabm_rank_model import TabMRankModel
class TabMRankTask(BaseTask):
"""TabM 排序学习任务
使用 TabMRankModel 进行排序学习训练。
将连续收益率转换为分位数标签进行训练。
支持指数化增益标签以增强 Top-K 关注。
"""
def __init__(
self,
model_params: Dict[str, Any],
label_name: str = "future_return_5",
n_quantiles: int = 20,
label_transform: Optional[str] = None,
label_scale: float = 20.0,
):
"""初始化排序学习任务
Args:
model_params: TabM 参数字典
label_name: Label 列名
n_quantiles: 分位数数量
label_transform: 标签变换类型,可选:
- None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
- "exponential": 指数化增益: 2^(rank/scale) - 1
label_scale: 指数变换的缩放因子,用于控制增益幅度
"""
super().__init__(model_params, label_name)
self.n_quantiles = n_quantiles
self.label_transform = label_transform
self.label_scale = label_scale
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签(转换为分位数标签,可选指数化增益变换)
将连续收益率转换为分位数标签,并生成 group 数组。
支持指数化增益变换以增强头部样本的区分度。
Args:
data: 数据字典
Returns:
处理后的数据字典(添加了 y_rank 和 groups
"""
for split in ["train", "val", "test"]:
if split not in data:
continue
df = data[split]["raw_data"]
# 分位数转换
rank_col = f"{self.label_name}_rank"
# 1. 基础分位数标签 (0 到 n_quantiles-1)
df_ranked = df.with_columns(
pl.col(self.label_name)
.rank(method="min")
.over("trade_date")
.alias("_rank")
).with_columns(
((pl.col("_rank") - 1) / pl.len().over("trade_date") * self.n_quantiles)
.floor()
.cast(pl.Int64)
.clip(0, self.n_quantiles - 1)
.alias("_base_rank")
)
# 2. 【Top-K 优化】可选指数化增益变换
if self.label_transform == "exponential":
# 平方变换: rank^2
# 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361
# 效果:高分样本与低分样本的差距被平方级拉大
df_ranked = df_ranked.with_columns(
(pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col)
)
else:
# 标准分位数标签
df_ranked = df_ranked.with_columns(
pl.col("_base_rank").cast(pl.Float64).alias(rank_col)
)
# 清理临时列
df_ranked = df_ranked.drop(["_rank", "_base_rank"])
# 更新数据
data[split]["raw_data"] = df_ranked
data[split]["y"] = df_ranked[rank_col]
data[split]["y_raw"] = df_ranked[self.label_name] # 保留原始值
# 生成 group 数组
data[split]["groups"] = self._compute_group_array(df_ranked, "trade_date")
return data
def _compute_group_array(
self,
df: pl.DataFrame,
date_col: str = "trade_date",
) -> np.ndarray:
"""计算 group 数组
Args:
df: 数据框
date_col: 日期列名
Returns:
group 数组(每个日期的样本数)
"""
group_counts = df.group_by(date_col, maintain_order=True).agg(
pl.count().alias("count")
)
return group_counts["count"].to_numpy()
def fit(self, train_data: Dict, val_data: Dict) -> None:
"""训练排序模型
Args:
train_data: 训练数据
val_data: 验证数据
"""
self.model = TabMRankModel(params=self.model_params)
self.model.fit(
train_data["X"],
train_data["y"],
group=train_data["groups"],
eval_set=(val_data["X"], val_data["y"], val_data["groups"])
if val_data
else None,
)
def predict(self, test_data: Dict) -> np.ndarray:
"""生成预测
Args:
test_data: 测试数据
Returns:
预测结果
"""
# 传入 groups 参数,确保预测顺序与分组一致,与验证逻辑保持一致
return self.model.predict(test_data["X"], group=test_data.get("groups"))
def evaluate_ndcg(
self,
test_data: Dict,
k_list: List[int] = None,
) -> Dict[str, float]:
"""评估 NDCG@k
Args:
test_data: 测试数据
k_list: k 值列表,默认 [1, 5, 10, 20]
Returns:
NDCG 分数字典 {"ndcg@1": score, ...}
"""
if k_list is None:
k_list = [1, 5, 10, 20]
y_pred = self.predict(test_data)
groups = test_data["groups"]
from sklearn.metrics import ndcg_score
results = {}
# 按 group 拆分
start_idx = 0
y_true_groups = []
y_pred_groups = []
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
# 这样与模型学习目标一致,避免原始收益率中负值的影响
y_true_array = test_data["y"].to_numpy()
for group_size in groups:
end_idx = start_idx + group_size
y_true_groups.append(y_true_array[start_idx:end_idx])
y_pred_groups.append(y_pred[start_idx:end_idx])
start_idx = end_idx
# 计算每个 k 的 NDCG
for k in k_list:
ndcg_scores = []
for yt, yp in zip(y_true_groups, y_pred_groups):
if len(yt) > 1:
try:
score = ndcg_score([yt], [yp], k=k)
ndcg_scores.append(score)
except ValueError:
pass
results[f"ndcg@{k}"] = float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
return results
def plot_training_metrics(self) -> None:
"""绘制训练指标曲线NDCG"""
if self.model and hasattr(self.model, "get_evals_result"):
try:
import matplotlib.pyplot as plt
evals_result = self.model.get_evals_result()
if not evals_result:
print("[警告] 没有训练指标数据可供绘制")
return
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
# 绘制训练损失
if "train_loss" in evals_result:
ax[0].plot(evals_result["train_loss"], label="Train Loss")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("ListNet Loss")
ax[0].set_title("Training Loss")
ax[0].legend()
ax[0].grid(True)
# 绘制验证 NDCG
if "val_ndcg" in evals_result:
ax[1].plot(evals_result["val_ndcg"], label="Val NDCG")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("NDCG")
ax[1].set_title("Validation NDCG")
ax[1].legend()
ax[1].grid(True)
plt.tight_layout()
plt.show()
except Exception as e:
print(f"[警告] 无法绘制训练曲线: {e}")