feat(training): TabM 排序模型架构优化与 Rank-Gauss 标签工程

- TabMSetRank: 将 TabM 输出改为隐藏层特征,经 SetRankHead 交互后通过 final_mlp 输出 Ensemble 排序分
- SetRankHead 引入可学习残差缩放因子(Zero-init)与 Pre-Norm 结构,提升训练稳定性
- TabMRankTask 新增 Rank-Gauss 连续标签变换,支持标准分位数/指数增益/Rank-Gauss 三种标签模式
- 修复 NDCG 评估在负值标签下的计算问题
- 调整实验脚本超参数(dropout、hidden dim、weight decay)及排除因子列表
- 迁移废弃的 torch.cuda.amp 到 torch.amp,并将数据预加载至 GPU 减少循环拷贝
This commit is contained in:
2026-04-05 19:01:08 +08:00
parent 598f6eefd8
commit 1fa4ff9544
7 changed files with 205 additions and 105 deletions

View File

@@ -54,9 +54,22 @@ N_QUANTILES = 20
# 排除的因子列表 # 排除的因子列表
EXCLUDED_FACTORS = [ EXCLUDED_FACTORS = [
# 'debt_to_equity', "amivest_liq_20",
# 'GTJA_alpha016', "atr_price_impact",
# 'GTJA_alpha141', "hui_heubel_ratio",
"corwin_schultz_spread_20",
"roll_spread_20",
"gibbs_effective_spread",
"overnight_illiq_20",
"illiq_volatility_20",
"amount_cv_20",
"amount_skewness_20",
"low_vol_days_20",
"liquidity_shock_momentum",
"downside_illiq_20",
"upside_illiq_20",
"illiq_asymmetry_20",
"pastor_stambaugh_proxy"
] ]

View File

@@ -52,36 +52,22 @@ TRAINING_TYPE = "regression"
# 排除的因子列表 # 排除的因子列表
EXCLUDED_FACTORS = [ EXCLUDED_FACTORS = [
# 'GTJA_alpha016', # "amivest_liq_20",
# 'volatility_20', # "atr_price_impact",
# 'current_ratio', # "hui_heubel_ratio",
# 'GTJA_alpha001', # "corwin_schultz_spread_20",
# 'GTJA_alpha141', # "roll_spread_20",
# 'GTJA_alpha129', # "gibbs_effective_spread",
# 'GTJA_alpha164', # "overnight_illiq_20",
# 'amivest_liq_20', # "illiq_volatility_20",
# 'GTJA_alpha012', # "amount_cv_20",
# 'debt_to_equity', # "amount_skewness_20",
# 'turnover_deviation', # "low_vol_days_20",
# 'GTJA_alpha073', # "liquidity_shock_momentum",
# 'GTJA_alpha043', # "downside_illiq_20",
# 'GTJA_alpha032', # "upside_illiq_20",
# 'GTJA_alpha028', # "illiq_asymmetry_20",
# 'GTJA_alpha090', # "pastor_stambaugh_proxy"
# 'GTJA_alpha108',
# 'GTJA_alpha105',
# 'GTJA_alpha091',
# 'GTJA_alpha119',
# 'GTJA_alpha104',
# 'GTJA_alpha163',
# 'GTJA_alpha157',
# 'cost_skewness',
# 'GTJA_alpha176',
# 'chip_transition',
# 'amount_skewness_20',
# 'GTJA_alpha148',
# 'mean_median_dev',
# 'downside_illiq_20',
] ]
# 模型参数配置 # 模型参数配置

View File

@@ -46,12 +46,16 @@ TRAINING_TYPE = "tabm_rank"
# Label 配置(从 common.py 统一导入) # Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入 # LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 分位数配置(提高分辨率以更好地区分头部 # 分位数配置(分桶模式下使用Rank-Gauss 模式下不使用,但保留兼容性
N_QUANTILES = 50 N_QUANTILES = 50
# 【Top-K 优化】标签工程配置 - 默认启用平方增益 # 标签工程配置
LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2) # 可选值:
LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放 # - "rank_gauss": Rank-Gauss 连续化标签(推荐,神经网络更友好
# - "exponential": 指数化增益标签 (rank^2)
# - None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
LABEL_TRANSFORM = "rank_gauss"
LABEL_SCALE = 20.0 # 保留参数rank_gauss / exponential 下均未使用)
# 排除的因子列表 # 排除的因子列表
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"] EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]

View File

@@ -61,7 +61,7 @@ MODEL_PARAMS = {
# ==================== MLP 结构 ==================== # ==================== MLP 结构 ====================
"n_blocks": 3, "n_blocks": 3,
"d_block": 256, "d_block": 256,
"dropout": 0.5, "dropout": 0.3,
# ==================== 集成机制 ==================== # ==================== 集成机制 ====================
"ensemble_size": 32, "ensemble_size": 32,
@@ -71,9 +71,9 @@ MODEL_PARAMS = {
"setrank_heads": 4, "setrank_heads": 4,
# 【优化1】将隐藏维度从 128 降到 64。 # 【优化1】将隐藏维度从 128 降到 64。
# 截面特征对比不需要那么宽的维度,太宽会导致模型记忆当天特有的无效噪音。 # 截面特征对比不需要那么宽的维度,太宽会导致模型记忆当天特有的无效噪音。
"setrank_hidden": 128, "setrank_hidden": 256,
# 【优化2】增大 SetRank 层的 Dropout # 【优化2】增大 SetRank 层的 Dropout
"setrank_dropout": 0.5, "setrank_dropout": 0.3,
# ==================== AMP 与显存优化 ==================== # ==================== AMP 与显存优化 ====================
"use_amp": True, "use_amp": True,
@@ -85,7 +85,7 @@ MODEL_PARAMS = {
"learning_rate": 5e-4, "learning_rate": 5e-4,
# 【优化4】核心操作将 L2 惩罚(权重衰减)放大 10 倍甚至 100 倍! # 【优化4】核心操作将 L2 惩罚(权重衰减)放大 10 倍甚至 100 倍!
# 带有 Attention 的网络极容易对某些特定股票产生依赖,强烈的 Weight Decay 能逼迫模型关注全局特征。 # 带有 Attention 的网络极容易对某些特定股票产生依赖,强烈的 Weight Decay 能逼迫模型关注全局特征。
"weight_decay": 1e-5, # 原为 1e-5现改为 1e-3 "weight_decay": 1e-4, # 原为 1e-5现改为 1e-3
"epochs": 150, # 不需要 500 次,从图中看 150 绝对够了 "epochs": 150, # 不需要 500 次,从图中看 150 绝对够了

View File

@@ -739,7 +739,12 @@ class TabMRankModel(BaseModel):
for y_true, y_score in zip(y_true_list, y_score_list): for y_true, y_score in zip(y_true_list, y_score_list):
if len(y_true) > 1: if len(y_true) > 1:
try: try:
score = ndcg_score([y_true], [y_score], k=k) # 若标签存在负值(如 Rank-Gauss平移到正值区间再计算 NDCG
if y_true.min() < 0:
y_true_shifted = y_true - y_true.min() + 1e-5
score = ndcg_score([y_true_shifted], [y_score], k=k)
else:
score = ndcg_score([y_true], [y_score], k=k)
ndcg_scores.append(score) ndcg_scores.append(score)
except ValueError: except ValueError:
pass pass

View File

@@ -16,7 +16,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Sampler from torch.utils.data import DataLoader, TensorDataset, Sampler
from torch.cuda.amp import GradScaler
# 注torch.cuda.amp.autocast 已被废弃,请统一使用 torch.amp.autocast('cuda', ...)
# 同时从子模块显式导入以兼容各类 IDE 的类型提示
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast
from tabm import TabM from tabm import TabM
from src.training.components.base import BaseModel from src.training.components.base import BaseModel
@@ -65,15 +69,23 @@ class SetRankHead(nn.Module):
# 4. 恢复集成维度: [N, d_ff] -> [N, E] # 4. 恢复集成维度: [N, d_ff] -> [N, E]
self.output_proj = nn.Linear(d_ff, d_in) self.output_proj = nn.Linear(d_ff, d_in)
# 5. 可学习的残差缩放因子,初始化为 0训练初期退化为纯 TabM
self.res_scale = nn.Parameter(torch.zeros(1))
self._init_weights() self._init_weights()
def _init_weights(self): def _init_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
# Zero-init 输出层,确保训练启动时为恒等映射
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# x 初始维度: [N, E] # x 初始维度: [N, E]
@@ -85,15 +97,13 @@ class SetRankHead(nn.Module):
# B=1 (代表一天的截面), L=N (股票数量), D=d_ff # B=1 (代表一天的截面), L=N (股票数量), D=d_ff
h = h.unsqueeze(0) h = h.unsqueeze(0)
# Attention 交互 # Pre-Norm Attention 交互
attn_out, _ = self.attn(h, h, h) attn_out, _ = self.attn(self.norm1(h), self.norm1(h), self.norm1(h))
h = h + attn_out
# 残差块 1 # Pre-Norm FFN
h = self.norm1(h + attn_out) ffn_out = self.ffn(self.norm2(h))
h = h + ffn_out
# 残差块 2 (FFN)
ffn_out = self.ffn(h)
h = self.norm2(h + ffn_out)
# 降维并去除 Batch 维度: [1, N, d_ff] -> [N, d_ff] # 降维并去除 Batch 维度: [1, N, d_ff] -> [N, d_ff]
h = h.squeeze(0) h = h.squeeze(0)
@@ -101,8 +111,8 @@ class SetRankHead(nn.Module):
# 输出回 Ensemble 维度: [N, E] # 输出回 Ensemble 维度: [N, E]
out_logits = self.output_proj(h) out_logits = self.output_proj(h)
# 引入残差连接,保留原始 TabM 的预测分,仅用 SetRank 提供修正项 (极大地增加训练稳定性) # 带有可学习缩放因子的残差连接,初始阶段退化为纯 TabM
return x + out_logits return x + self.res_scale * out_logits
class TabMSetRankNet(nn.Module): class TabMSetRankNet(nn.Module):
@@ -116,32 +126,50 @@ class TabMSetRankNet(nn.Module):
setrank_heads: int, setrank_heads: int,
setrank_hidden: int, setrank_hidden: int,
setrank_dropout: float, setrank_dropout: float,
mlp_hidden: Optional[int] = None,
): ):
super().__init__() super().__init__()
# TabM 特征提取器 mlp_hidden = mlp_hidden or d_block
# TabM 特征提取器 (作为隐藏层,输出维度为 d_block)
self.tabm = TabM.make( self.tabm = TabM.make(
n_num_features=n_features, n_num_features=n_features,
cat_cardinalities=[], cat_cardinalities=[],
d_out=1, d_out=d_block,
n_blocks=n_blocks, n_blocks=n_blocks,
d_block=d_block, d_block=d_block,
dropout=dropout, dropout=dropout,
k=ensemble_size, k=ensemble_size,
) )
# 截面注意力修正器 # 截面注意力修正器 (处理 TabM 隐藏特征)
self.setrank_head = SetRankHead( self.setrank_head = SetRankHead(
d_in=ensemble_size, d_in=d_block,
n_heads=setrank_heads, n_heads=setrank_heads,
d_ff=setrank_hidden, d_ff=setrank_hidden,
dropout=setrank_dropout, dropout=setrank_dropout,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: # 最终两次 MLP 输出 Ensemble 排序分
# 1. 独立特征提取 -> [N, E] self.final_mlp = nn.Sequential(
tabm_out = self.tabm(x).squeeze(-1) nn.Linear(d_block, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, ensemble_size),
)
# 2. 截面上下文交互修正 -> [N, E] def forward(self, x: torch.Tensor) -> torch.Tensor:
scores = self.setrank_head(tabm_out) # 1. TabM 独立特征提取 -> [N, E, d_block]
tabm_out = self.tabm(x)
# 2. 对 ensemble 维度取平均,得到原始隐藏特征 -> [N, d_block]
tabm_hidden = tabm_out.mean(dim=1)
# 3. 截面上下文交互修正 -> [N, d_block]
# SetRankHead 内部实现残差: x + res_scale * transform(x)
setrank_hidden = self.setrank_head(tabm_hidden)
# 4. 两次 MLP 输出最终 Ensemble 排序分 -> [N, E]
scores = self.final_mlp(setrank_hidden)
return scores return scores
@@ -193,6 +221,7 @@ class TabMSetRankModel(BaseModel):
setrank_heads=self.params.get("setrank_heads", 4), setrank_heads=self.params.get("setrank_heads", 4),
setrank_hidden=self.params.get("setrank_hidden", 128), setrank_hidden=self.params.get("setrank_hidden", 128),
setrank_dropout=self.params.get("setrank_dropout", 0.1), setrank_dropout=self.params.get("setrank_dropout", 0.1),
mlp_hidden=self.params.get("mlp_hidden", None),
).to(self.device) ).to(self.device)
def _make_loader( def _make_loader(
@@ -202,10 +231,13 @@ class TabMSetRankModel(BaseModel):
group: Optional[np.ndarray] = None, group: Optional[np.ndarray] = None,
shuffle_groups: bool = False, shuffle_groups: bool = False,
) -> DataLoader: ) -> DataLoader:
"""创建 DataLoader (支持 Query/Group 截面打包)""" """创建 DataLoader (支持 Query/Group 截面打包)
X_tensor = torch.from_numpy(X)
数据在一开始就直接放到 GPU避免训练/预测循环中反复拷贝。
"""
X_tensor = torch.from_numpy(X).to(self.device)
if y is not None: if y is not None:
y_tensor = torch.from_numpy(y) y_tensor = torch.from_numpy(y).to(self.device)
dataset = TensorDataset(X_tensor, y_tensor) dataset = TensorDataset(X_tensor, y_tensor)
else: else:
dataset = TensorDataset(X_tensor) dataset = TensorDataset(X_tensor)
@@ -241,12 +273,11 @@ class TabMSetRankModel(BaseModel):
if len(batch) != 2: if len(batch) != 2:
continue continue
bx, by = batch bx, by = batch
bx = bx.to(self.device)
by = by.cpu().numpy() by = by.cpu().numpy()
if len(by) <= 1: if len(by) <= 1:
continue continue
with torch.amp.autocast("cuda", enabled=self.use_amp): with autocast("cuda", enabled=self.use_amp):
scores = self.model(bx) # [N, E] scores = self.model(bx) # [N, E]
# 对 Ensemble 维度求平均,得到最终排序分 [N] # 对 Ensemble 维度求平均,得到最终排序分 [N]
@@ -331,10 +362,10 @@ class TabMSetRankModel(BaseModel):
for batch in train_loader: for batch in train_loader:
if len(batch) != 2: if len(batch) != 2:
continue continue
bx, by = batch[0].to(self.device), batch[1].to(self.device) bx, by = batch
optimizer.zero_grad() optimizer.zero_grad()
with torch.amp.autocast("cuda", enabled=self.use_amp): with autocast("cuda", enabled=self.use_amp):
scores = self.model(bx) # [N, E] scores = self.model(bx) # [N, E]
loss = self.criterion(scores, by) loss = self.criterion(scores, by)
@@ -390,9 +421,19 @@ class TabMSetRankModel(BaseModel):
def predict( def predict(
self, X: pl.DataFrame, group: Optional[np.ndarray] = None self, X: pl.DataFrame, group: Optional[np.ndarray] = None
) -> np.ndarray: ) -> np.ndarray:
"""预测排序分数""" """预测排序分数
对于 TabM + SetRank 模型group 参数是必需的。
SetRankHead 中的截面注意力机制假设每个 batch 就是一天的所有股票,
若缺失 group 退化为普通 batch_size=2048 预测,会导致跨天样本错误交互。
"""
if self.model is None: if self.model is None:
raise RuntimeError("模型未训练请先调用fit()") raise RuntimeError("模型未训练请先调用fit()")
if group is None:
raise ValueError(
"TabMSetRankModel.predict() 需要提供 group 数组,"
"确保每个 batch 为完整的一天截面。请通过 test_data['groups'] 传入。"
)
if self.feature_names_: if self.feature_names_:
missing = [c for c in self.feature_names_ if c not in X.columns] missing = [c for c in self.feature_names_ if c not in X.columns]
if missing: if missing:
@@ -406,8 +447,8 @@ class TabMSetRankModel(BaseModel):
with torch.no_grad(): with torch.no_grad():
for batch in loader: for batch in loader:
bx = batch[0].to(self.device) bx = batch[0]
with torch.amp.autocast("cuda", enabled=self.use_amp): with autocast("cuda", enabled=self.use_amp):
scores = self.model(bx) # [N, E] scores = self.model(bx) # [N, E]
# 对 Ensemble 维度求平均,得到最终排序分 [N] # 对 Ensemble 维度求平均,得到最终排序分 [N]
@@ -468,7 +509,7 @@ class TabMSetRankModel(BaseModel):
"""评估 NDCG 指标""" """评估 NDCG 指标"""
from sklearn.metrics import ndcg_score from sklearn.metrics import ndcg_score
y_pred = self.predict(X) y_pred = self.predict(X, group=group)
y_true_array = y.to_numpy() if isinstance(y, pl.Series) else y y_true_array = y.to_numpy() if isinstance(y, pl.Series) else y
y_true_list = [] y_true_list = []
y_score_list = [] y_score_list = []

View File

@@ -51,11 +51,44 @@ class TabMRankTask(BaseTask):
self.label_scale = label_scale self.label_scale = label_scale
self.model_class = model_class or TabMRankModel self.model_class = model_class or TabMRankModel
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]: @staticmethod
"""准备标签(转换为分位数标签,可选指数化增益变换) def _rank_gauss(arr: np.ndarray) -> np.ndarray:
"""对一维数组做 Rank-Gauss 变换
连续收益率转换为分位数标签,并生成 group 数组 数组值转换为基于排名的标准正态分布值,结果近似服从 N(0, 1)
支持指数化增益变换以增强头部样本的区分度。
Args:
arr: 输入一维数组
Returns:
Rank-Gauss 变换后的数组
"""
from scipy.special import erfinv
from scipy.stats import rankdata
n = len(arr)
if n <= 1:
return np.zeros(n, dtype=np.float64)
# 1. 获取平均排名 (1-based)
ranks = rankdata(arr, method="average")
# 2. 映射到 (0, 1),使用 (ranks - 0.5) / n 避免边界
p = (ranks - 0.5) / n
# 3. 映射到 (-1, 1)
p = 2.0 * p - 1.0
# 4. 裁剪边界避免 erfinv 输入越界
eps = 1e-6
p = np.clip(p, -1.0 + eps, 1.0 - eps)
# 5. Rank-Gauss: erfinv(p) * sqrt(2)
return erfinv(p) * np.sqrt(2.0)
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签(转换为分位数标签或 Rank-Gauss 连续标签)
支持三种模式:
- 标准分位数标签 (0, 1, ..., n_quantiles-1)
- 指数化增益变换 (rank^2)
- Rank-Gauss 连续标签(截面正态化,不依赖分桶数)
Args: Args:
data: 数据字典 data: 数据字典
@@ -68,40 +101,58 @@ class TabMRankTask(BaseTask):
continue continue
df = data[split]["raw_data"] df = data[split]["raw_data"]
# 分位数转换
rank_col = f"{self.label_name}_rank" rank_col = f"{self.label_name}_rank"
# 1. 基础分位数标签 (0 到 n_quantiles-1) if self.label_transform == "rank_gauss":
df_ranked = df.with_columns( # Rank-Gauss 连续化标签:先截面高斯化,再缩放到 (0, 1) 全局正区间
pl.col(self.label_name) df_ranked = df.with_columns(
.rank(method="min") pl.col(self.label_name)
.over("trade_date") .map_batches(
.alias("_rank") lambda s: pl.Series(self._rank_gauss(s.to_numpy())),
).with_columns( is_elementwise=False,
((pl.col("_rank") - 1) / pl.len().over("trade_date") * self.n_quantiles) )
.floor() .map_batches(
.cast(pl.Int64) lambda s: pl.Series((s.to_numpy() + 4.0) / 8.0).clip(0.0, 1.0),
.clip(0, self.n_quantiles - 1) is_elementwise=False,
.alias("_base_rank") )
) .over("trade_date")
.alias(rank_col)
# 2. 【Top-K 优化】可选指数化增益变换
if self.label_transform == "exponential":
# 平方变换: rank^2
# 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361
# 效果:高分样本与低分样本的差距被平方级拉大
df_ranked = df_ranked.with_columns(
(pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col)
) )
else: else:
# 标准分位数标签 # 1. 基础分位数标签 (0 到 n_quantiles-1)
df_ranked = df_ranked.with_columns( df_ranked = df.with_columns(
pl.col("_base_rank").cast(pl.Float64).alias(rank_col) pl.col(self.label_name)
.rank(method="min")
.over("trade_date")
.alias("_rank")
).with_columns(
(
(pl.col("_rank") - 1)
/ pl.len().over("trade_date")
* self.n_quantiles
)
.floor()
.cast(pl.Int64)
.clip(0, self.n_quantiles - 1)
.alias("_base_rank")
) )
# 清理临时列 # 2. 【Top-K 优化】可选指数化增益变换
df_ranked = df_ranked.drop(["_rank", "_base_rank"]) if self.label_transform == "exponential":
# 平方变换: rank^2
# 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361
# 效果:高分样本与低分样本的差距被平方级拉大
df_ranked = df_ranked.with_columns(
(pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col)
)
else:
# 标准分位数标签
df_ranked = df_ranked.with_columns(
pl.col("_base_rank").cast(pl.Float64).alias(rank_col)
)
# 清理临时列
df_ranked = df_ranked.drop(["_rank", "_base_rank"])
# 更新数据 # 更新数据
data[split]["raw_data"] = df_ranked data[split]["raw_data"] = df_ranked