feat(training): 新增 TabM SetRank 模型并支持任务注入

- 添加 TabMSetRankModel 实现集合排序训练
- TabMRankTask 支持通过 model_class 注入兼容模型
- 启用 common.py 中的流动性因子
This commit is contained in:
2026-04-05 01:03:17 +08:00
parent a66d5e9db3
commit 94d5d13bb1
6 changed files with 718 additions and 20 deletions

View File

@@ -271,22 +271,22 @@ SELECTED_FACTORS = [
"pivot_reversion",
"chip_transition",
# "amivest_liq_20",
# "atr_price_impact",
# "hui_heubel_ratio",
# "corwin_schultz_spread_20",
# "roll_spread_20",
# "gibbs_effective_spread",
# "overnight_illiq_20",
# "illiq_volatility_20",
# "amount_cv_20",
# "amount_skewness_20",
# "low_vol_days_20",
# "liquidity_shock_momentum",
# "downside_illiq_20",
# "upside_illiq_20",
# "illiq_asymmetry_20",
# "pastor_stambaugh_proxy"
"amivest_liq_20",
"atr_price_impact",
"hui_heubel_ratio",
"corwin_schultz_spread_20",
"roll_spread_20",
"gibbs_effective_spread",
"overnight_illiq_20",
"illiq_volatility_20",
"amount_cv_20",
"amount_skewness_20",
"low_vol_days_20",
"liquidity_shock_momentum",
"downside_illiq_20",
"upside_illiq_20",
"illiq_asymmetry_20",
"pastor_stambaugh_proxy"
]
# 因子定义字典完整因子库用于存放尚未注册到metadata的因子

View File

@@ -0,0 +1,201 @@
"""TabM + SetRank 排序学习训练流程
使用模块化 Trainer 架构,基于 TabMSetRankModel 实现排序学习。
引入 SetRank 组内注意力头,其余配置与 tabm_rank_train.py 对齐。
"""
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
NullFiller,
Winsorizer,
CrossSectionalStandardScaler,
)
from src.training.tasks.tabm_rank_task import TabMRankTask
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.training.components.models import TabMSetRankModel
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
TRAINING_TYPE = "tabm_setrank_rank"
# %%
# Label 配置(从 common.py 统一导入)
# 分位数配置(提高分辨率以更好地区分头部)
N_QUANTILES = 50
# 【Top-K 优化】标签工程配置 - 默认启用平方增益
LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2)
LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放)
# 排除的因子列表
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
# TabM + SetRank 模型参数配置
MODEL_PARAMS = {
# ==================== MLP 结构 ====================
"n_blocks": 3,
"d_block": 256,
"dropout": 0.5,
# ==================== 集成机制 ====================
"ensemble_size": 32,
# ==================== SetRank 头 (降维防过拟合) ====================
"use_setrank": True,
"setrank_heads": 4,
# 【优化1】将隐藏维度从 128 降到 64。
# 截面特征对比不需要那么宽的维度,太宽会导致模型记忆当天特有的无效噪音。
"setrank_hidden": 128,
# 【优化2】增大 SetRank 层的 Dropout
"setrank_dropout": 0.5,
# ==================== AMP 与显存优化 ====================
"use_amp": True,
"num_workers": 0,
"pin_memory": False,
# ==================== 训练参数 (强正则化) ====================
# 【优化3】稍微调低学习率让模型在接近最优点时不要走得太快防震荡
"learning_rate": 5e-4,
# 【优化4】核心操作将 L2 惩罚(权重衰减)放大 10 倍甚至 100 倍!
# 带有 Attention 的网络极容易对某些特定股票产生依赖,强烈的 Weight Decay 能逼迫模型关注全局特征。
"weight_decay": 1e-5, # 原为 1e-5现改为 1e-3
"epochs": 150, # 不需要 500 次,从图中看 150 绝对够了
# ==================== 早停 ====================
"early_stopping_round": 30, # 耐心值 30 足矣
# ==================== NDCG 评估 ====================
"ndcg_k": 20,
# ==================== 损失函数配置 ====================
"loss_type": "lambda",
"lambda_sigma": 1.0,
# 【优化5】稍微放大 DeltaNDCG 的权重幂次,让模型在排错 Top 5 股票时受到更严厉的惩罚
"ndcg_weight_power": 1.0,
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabm_setrank_rank_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabM + SetRank 排序学习训练")
print("=" * 80)
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(NullFiller, {"strategy": "mean"}),
(CrossSectionalStandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 TabMRankTask注入 TabMSetRankModel
print("\n[4] 创建 TabMRankTaskTabMSetRankModel")
task = TabMRankTask(
model_class=TabMSetRankModel,
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
n_quantiles=N_QUANTILES,
label_transform=LABEL_TRANSFORM,
label_scale=LABEL_SCALE,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[7] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
print("\n" + "=" * 80)
print("训练流程完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_setrank_rank_output.csv')}")
print("=" * 80)
return results
if __name__ == "__main__":
main()

View File

@@ -12,6 +12,7 @@ from src.training.components.models.tabm_rank_model import (
EnsembleListNetLoss,
EnsembleLambdaLoss,
)
from src.training.components.models.tabm_setrank_model import TabMSetRankModel
from src.training.components.models.cross_section_sampler import CrossSectionSampler
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
@@ -21,6 +22,7 @@ __all__ = [
"TabPFNModel",
"TabMModel",
"TabMRankModel",
"TabMSetRankModel",
"EnsembleListNetLoss",
"EnsembleLambdaLoss",
"CrossSectionSampler",

View File

@@ -0,0 +1,489 @@
"""TabM + SetRank 排序模型实现 (TabM_SetRank)
基于 TabM 特征提取 + SetRank 组内注意力建模。
引入 ListNet/LambdaLoss 列表级排序损失,支持 AMP 混合精度与显存优化。
适用于股票未来收益率的截面排序预测。
"""
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
import pickle
import numpy as np
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Sampler
from torch.cuda.amp import GradScaler
from tabm import TabM
from src.training.components.base import BaseModel
from src.training.registry import register_model
from src.training.components.models.tabm_rank_model import (
EnsembleListNetLoss,
EnsembleLambdaLoss,
GroupSampler,
)
class SetRankHead(nn.Module):
"""标准的 SetRank 组内上下文注意力头
输入: [N, E] (TabM集成预测)
输出: [N, E] (经过截面交互后的集成预测)
"""
def __init__(
self,
d_in: int, # TabM 的 Ensemble 数量 (例如 32)
n_heads: int = 4,
d_ff: int = 128, # SetRank 的隐藏特征维度
dropout: float = 0.1,
):
super().__init__()
# 1. 特征升维: 将 Logits 映射为高维语义特征
self.feature_proj = nn.Linear(d_in, d_ff)
# 2. 多头注意力层 (SetRank 核心)
# 注意:不要加入位置编码 (Positional Encoding),因为股票截面是一个无序集合(Set),必须保持排列等变性
self.attn = nn.MultiheadAttention(
embed_dim=d_ff, num_heads=n_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_ff)
# 3. 前馈神经网络 (标准 Transformer FFN 结构)
self.ffn = nn.Sequential(
nn.Linear(d_ff, d_ff * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff * 2, d_ff),
)
self.norm2 = nn.LayerNorm(d_ff)
# 4. 恢复集成维度: [N, d_ff] -> [N, E]
self.output_proj = nn.Linear(d_ff, d_in)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x 初始维度: [N, E]
# 升维: [N, d_ff]
h = self.feature_proj(x)
# 增加 Batch 维度以适配 MultiheadAttention: [1, N, d_ff]
# B=1 (代表一天的截面), L=N (股票数量), D=d_ff
h = h.unsqueeze(0)
# Attention 交互
attn_out, _ = self.attn(h, h, h)
# 残差块 1
h = self.norm1(h + attn_out)
# 残差块 2 (FFN)
ffn_out = self.ffn(h)
h = self.norm2(h + ffn_out)
# 降维并去除 Batch 维度: [1, N, d_ff] -> [N, d_ff]
h = h.squeeze(0)
# 输出回 Ensemble 维度: [N, E]
out_logits = self.output_proj(h)
# 引入残差连接,保留原始 TabM 的预测分,仅用 SetRank 提供修正项 (极大地增加训练稳定性)
return x + out_logits
class TabMSetRankNet(nn.Module):
def __init__(
self,
n_features: int,
ensemble_size: int,
n_blocks: int,
d_block: int,
dropout: float,
setrank_heads: int,
setrank_hidden: int,
setrank_dropout: float,
):
super().__init__()
# TabM 特征提取器
self.tabm = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1,
n_blocks=n_blocks,
d_block=d_block,
dropout=dropout,
k=ensemble_size,
)
# 截面注意力修正器
self.setrank_head = SetRankHead(
d_in=ensemble_size,
n_heads=setrank_heads,
d_ff=setrank_hidden,
dropout=setrank_dropout,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 1. 独立特征提取 -> [N, E]
tabm_out = self.tabm(x).squeeze(-1)
# 2. 截面上下文交互修正 -> [N, E]
scores = self.setrank_head(tabm_out)
return scores
@register_model("tabm_setrank_rank")
class TabMSetRankModel(BaseModel):
"""TabM + SetRank 学习排序模型
使用统一的 nn.Module (TabMSetRankNet) 封装 TabM 与 SetRankHead。
支持 AMP 混合精度与显存优化。
"""
name = "tabm_setrank_rank"
def __init__(self, params: Optional[Dict[str, Any]] = None):
self.params = params or {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.training_history_: Dict[str, List[float]] = {
"train_loss": [],
"val_ndcg": [],
}
self.feature_names_: Optional[List[str]] = None
self.use_amp = self.params.get("use_amp", True)
self.scaler = GradScaler(enabled=self.use_amp)
loss_type = self.params.get("loss_type", "listnet")
if loss_type == "lambda":
self.criterion = EnsembleLambdaLoss(
sigma=self.params.get("lambda_sigma", 1.0),
ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0),
)
elif loss_type == "weighted_listnet":
self.criterion = EnsembleListNetLoss(
topk_weight=self.params.get("topk_weight", 5.0)
)
else:
self.criterion = EnsembleListNetLoss(topk_weight=1.0)
n_features = self.params.get("n_features")
if n_features is None:
raise ValueError("TabMSetRankModel 初始化需要在 params 中传入 'n_features'")
ensemble_size = self.params.get("ensemble_size", 32)
self.model = TabMSetRankNet(
n_features=n_features,
ensemble_size=ensemble_size,
n_blocks=self.params.get("n_blocks", 3),
d_block=self.params.get("d_block", 256),
dropout=self.params.get("dropout", 0.1),
setrank_heads=self.params.get("setrank_heads", 4),
setrank_hidden=self.params.get("setrank_hidden", 128),
setrank_dropout=self.params.get("setrank_dropout", 0.1),
).to(self.device)
def _make_loader(
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
group: Optional[np.ndarray] = None,
shuffle_groups: bool = False,
) -> DataLoader:
"""创建 DataLoader (支持 Query/Group 截面打包)"""
X_tensor = torch.from_numpy(X)
if y is not None:
y_tensor = torch.from_numpy(y)
dataset = TensorDataset(X_tensor, y_tensor)
else:
dataset = TensorDataset(X_tensor)
if group is not None:
sampler = GroupSampler(group, shuffle_groups=shuffle_groups)
return DataLoader(
dataset,
batch_sampler=sampler,
num_workers=self.params.get("num_workers", 0),
pin_memory=self.params.get("pin_memory", False),
)
else:
batch_size = self.params.get("batch_size", 2048)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.params.get("num_workers", 0),
pin_memory=self.params.get("pin_memory", False),
)
def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float:
"""验证模型 (使用 NDCG 排序指标)"""
from sklearn.metrics import ndcg_score
assert self.model is not None, "模型未训练,无法验证"
self.model.eval()
ndcg_list = []
with torch.no_grad():
for batch in val_loader:
if len(batch) != 2:
continue
bx, by = batch
bx = bx.to(self.device)
by = by.cpu().numpy()
if len(by) <= 1:
continue
with torch.amp.autocast("cuda", enabled=self.use_amp):
scores = self.model(bx) # [N, E]
# 对 Ensemble 维度求平均,得到最终排序分 [N]
preds = scores.mean(dim=1).cpu().numpy()
try:
ndcg_list.append(ndcg_score([by], [preds], k=k))
except ValueError:
pass
return float(np.mean(ndcg_list)) if ndcg_list else 0.0
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
group: Optional[np.ndarray] = None,
eval_set: Optional[Tuple] = None,
) -> "TabMSetRankModel":
"""训练排序模型"""
self.feature_names_ = list(X.columns)
X_np = X.to_numpy().astype(np.float32)
y_np = y.to_numpy().astype(np.float32)
if group is None:
group = np.array([len(y_np)])
if group.sum() != len(y_np):
raise ValueError(
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
)
train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True)
val_loader = None
if eval_set is not None:
X_val, y_val, group_val = eval_set
X_val_np = (
X_val.to_numpy().astype(np.float32)
if isinstance(X_val, pl.DataFrame)
else X_val
)
y_val_np = (
y_val.to_numpy().astype(np.float32)
if isinstance(y_val, pl.Series)
else y_val
)
if group_val is None:
group_val = np.array([len(y_val_np)])
val_loader = self._make_loader(
X_val_np, y_val_np, group=group_val, shuffle_groups=False
)
n_features = X_np.shape[1]
expected_n_features = self.params.get("n_features")
if expected_n_features is not None and n_features != expected_n_features:
raise ValueError(
f"输入特征维度 ({n_features}) 与初始化时指定的 n_features ({expected_n_features}) 不一致"
)
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.params.get("learning_rate", 1e-3),
weight_decay=self.params.get("weight_decay", 1e-5),
)
epochs = self.params.get("epochs", 50)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=1e-6
)
early_stopping_patience = self.params.get(
"early_stopping_patience",
self.params.get("early_stopping_round", 10),
)
best_val_ndcg = -float("inf")
patience_counter = 0
best_model_state = None
ndcg_k = self.params.get("ndcg_k", None)
print(f"[TabM+SetRank] 开始训练... 设备: {self.device}, AMP: {self.use_amp}")
for epoch in range(epochs):
self.model.train()
train_loss = 0.0
n_train_batches = 0
for batch in train_loader:
if len(batch) != 2:
continue
bx, by = batch[0].to(self.device), batch[1].to(self.device)
optimizer.zero_grad()
with torch.amp.autocast("cuda", enabled=self.use_amp):
scores = self.model(bx) # [N, E]
loss = self.criterion(scores, by)
self.scaler.scale(loss).backward()
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.params.get("max_grad_norm", 1.0),
)
self.scaler.step(optimizer)
self.scaler.update()
train_loss += loss.item()
n_train_batches += 1
avg_train_loss = train_loss / max(n_train_batches, 1)
self.training_history_["train_loss"].append(avg_train_loss)
if val_loader is not None:
val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k)
self.training_history_["val_ndcg"].append(val_ndcg)
if val_ndcg > best_val_ndcg:
best_val_ndcg = val_ndcg
patience_counter = 0
best_model_state = {
k: v.cpu().clone() for k, v in self.model.state_dict().items()
}
else:
patience_counter += 1
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabM+SetRank] Epoch {epoch + 1}/{epochs} | "
f"Loss: {avg_train_loss:.4f} | "
f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})"
)
if patience_counter >= early_stopping_patience:
print(f"[TabM+SetRank] 早停于 epoch {epoch + 1}")
break
else:
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabM+SetRank] Epoch {epoch + 1}/{epochs} | "
f"Loss: {avg_train_loss:.4f}"
)
scheduler.step()
if best_model_state:
self.model.load_state_dict(best_model_state)
print(f"[TabM+SetRank] 恢复最佳权重 (Val NDCG: {best_val_ndcg:.4f})")
return self
def predict(
self, X: pl.DataFrame, group: Optional[np.ndarray] = None
) -> np.ndarray:
"""预测排序分数"""
if self.model is None:
raise RuntimeError("模型未训练请先调用fit()")
if self.feature_names_:
missing = [c for c in self.feature_names_ if c not in X.columns]
if missing:
raise ValueError(f"缺失特征: {missing}")
X = X.select(self.feature_names_)
X_np = X.to_numpy().astype(np.float32)
loader = self._make_loader(X_np, group=group, shuffle_groups=False)
self.model.eval()
all_preds = []
with torch.no_grad():
for batch in loader:
bx = batch[0].to(self.device)
with torch.amp.autocast("cuda", enabled=self.use_amp):
scores = self.model(bx) # [N, E]
# 对 Ensemble 维度求平均,得到最终排序分 [N]
preds = scores.mean(dim=1)
all_preds.append(preds.cpu().numpy())
return np.concatenate(all_preds)
def get_evals_result(self) -> Optional[Dict[str, List[float]]]:
"""获取训练评估结果"""
return self.training_history_
def feature_importance(self) -> None:
"""TabM 没有内置特征重要性计算,返回 None。"""
return None
def save(self, path: str | Path) -> None:
"""保存模型"""
if self.model is None:
raise RuntimeError("模型未训练,无法保存")
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.model.state_dict(), path.with_suffix(".pt"))
with open(path.with_suffix(".meta"), "wb") as f:
pickle.dump(
{
"params": self.params,
"feature_names": self.feature_names_,
"training_history": self.training_history_,
"device": str(self.device),
},
f,
)
@classmethod
def load(cls, path: str | Path) -> "TabMSetRankModel":
"""加载模型"""
path = Path(path)
with open(path.with_suffix(".meta"), "rb") as f:
meta = pickle.load(f)
meta_params = meta["params"]
feature_names = meta["feature_names"]
assert feature_names is not None, "加载的模型缺少 feature_names无法重建网络"
meta_params["n_features"] = len(feature_names)
instance = cls(meta_params)
instance.feature_names_ = feature_names
instance.training_history_ = meta["training_history"]
ckpt = torch.load(path.with_suffix(".pt"), map_location=instance.device)
instance.model.load_state_dict(ckpt)
return instance
def evaluate_ndcg(
self,
X: pl.DataFrame,
y: pl.Series,
group: np.ndarray,
k: Optional[int] = None,
) -> float:
"""评估 NDCG 指标"""
from sklearn.metrics import ndcg_score
y_pred = self.predict(X)
y_true_array = y.to_numpy() if isinstance(y, pl.Series) else y
y_true_list = []
y_score_list = []
start_idx = 0
for group_size in group:
end_idx = start_idx + group_size
y_true_list.append(y_true_array[start_idx:end_idx])
y_score_list.append(y_pred[start_idx:end_idx])
start_idx = end_idx
ndcg_scores = []
for yt, yp in zip(y_true_list, y_score_list):
if len(yt) > 1:
try:
ndcg_scores.append(ndcg_score([yt], [yp], k=k))
except ValueError:
pass
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0

View File

@@ -7,7 +7,7 @@
- 支持 NDCG@k 评估
"""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Type
import numpy as np
import polars as pl
@@ -18,7 +18,7 @@ from src.training.components.models.tabm_rank_model import TabMRankModel
class TabMRankTask(BaseTask):
"""TabM 排序学习任务
使用 TabMRankModel 进行排序学习训练。
使用 TabMRankModel(或兼容模型)进行排序学习训练。
将连续收益率转换为分位数标签进行训练。
支持指数化增益标签以增强 Top-K 关注。
"""
@@ -30,6 +30,7 @@ class TabMRankTask(BaseTask):
n_quantiles: int = 20,
label_transform: Optional[str] = None,
label_scale: float = 20.0,
model_class: Optional[Type] = None,
):
"""初始化排序学习任务
@@ -41,11 +42,14 @@ class TabMRankTask(BaseTask):
- None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
- "exponential": 指数化增益: 2^(rank/scale) - 1
label_scale: 指数变换的缩放因子,用于控制增益幅度
model_class: 模型类,默认 TabMRankModel。可传入兼容接口的其他模型
如 TabMSetRankModel。
"""
super().__init__(model_params, label_name)
self.n_quantiles = n_quantiles
self.label_transform = label_transform
self.label_scale = label_scale
self.model_class = model_class or TabMRankModel
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签(转换为分位数标签,可选指数化增益变换)
@@ -135,7 +139,8 @@ class TabMRankTask(BaseTask):
train_data: 训练数据
val_data: 验证数据
"""
self.model = TabMRankModel(params=self.model_params)
self.model_params["n_features"] = train_data["X"].shape[1]
self.model = self.model_class(params=self.model_params)
self.model.fit(
train_data["X"],