feat(training): TabM 排序模型架构优化与 Rank-Gauss 标签工程

- TabMSetRank: 将 TabM 输出改为隐藏层特征,经 SetRankHead 交互后通过 final_mlp 输出 Ensemble 排序分
- SetRankHead 引入可学习残差缩放因子(Zero-init)与 Pre-Norm 结构,提升训练稳定性
- TabMRankTask 新增 Rank-Gauss 连续标签变换,支持标准分位数/指数增益/Rank-Gauss 三种标签模式
- 修复 NDCG 评估在负值标签下的计算问题
- 调整实验脚本超参数(dropout、hidden dim、weight decay)及排除因子列表
- 迁移废弃的 torch.cuda.amp 到 torch.amp,并将数据预加载至 GPU 减少循环拷贝
This commit is contained in:
2026-04-05 19:01:08 +08:00
parent 598f6eefd8
commit 1fa4ff9544
7 changed files with 205 additions and 105 deletions

View File

@@ -52,36 +52,22 @@ TRAINING_TYPE = "regression"
# 排除的因子列表
EXCLUDED_FACTORS = [
# 'GTJA_alpha016',
# 'volatility_20',
# 'current_ratio',
# 'GTJA_alpha001',
# 'GTJA_alpha141',
# 'GTJA_alpha129',
# 'GTJA_alpha164',
# 'amivest_liq_20',
# 'GTJA_alpha012',
# 'debt_to_equity',
# 'turnover_deviation',
# 'GTJA_alpha073',
# 'GTJA_alpha043',
# 'GTJA_alpha032',
# 'GTJA_alpha028',
# 'GTJA_alpha090',
# 'GTJA_alpha108',
# 'GTJA_alpha105',
# 'GTJA_alpha091',
# 'GTJA_alpha119',
# 'GTJA_alpha104',
# 'GTJA_alpha163',
# 'GTJA_alpha157',
# 'cost_skewness',
# 'GTJA_alpha176',
# 'chip_transition',
# 'amount_skewness_20',
# 'GTJA_alpha148',
# 'mean_median_dev',
# 'downside_illiq_20',
# "amivest_liq_20",
# "atr_price_impact",
# "hui_heubel_ratio",
# "corwin_schultz_spread_20",
# "roll_spread_20",
# "gibbs_effective_spread",
# "overnight_illiq_20",
# "illiq_volatility_20",
# "amount_cv_20",
# "amount_skewness_20",
# "low_vol_days_20",
# "liquidity_shock_momentum",
# "downside_illiq_20",
# "upside_illiq_20",
# "illiq_asymmetry_20",
# "pastor_stambaugh_proxy"
]
# 模型参数配置