feat(training): TabM 排序模型架构优化与 Rank-Gauss 标签工程
- TabMSetRank: 将 TabM 输出改为隐藏层特征,经 SetRankHead 交互后通过 final_mlp 输出 Ensemble 排序分 - SetRankHead 引入可学习残差缩放因子(Zero-init)与 Pre-Norm 结构,提升训练稳定性 - TabMRankTask 新增 Rank-Gauss 连续标签变换,支持标准分位数/指数增益/Rank-Gauss 三种标签模式 - 修复 NDCG 评估在负值标签下的计算问题 - 调整实验脚本超参数(dropout、hidden dim、weight decay)及排除因子列表 - 迁移废弃的 torch.cuda.amp 到 torch.amp,并将数据预加载至 GPU 减少循环拷贝
This commit is contained in:
@@ -54,9 +54,22 @@ N_QUANTILES = 20
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
# 'debt_to_equity',
|
||||
# 'GTJA_alpha016',
|
||||
# 'GTJA_alpha141',
|
||||
"amivest_liq_20",
|
||||
"atr_price_impact",
|
||||
"hui_heubel_ratio",
|
||||
"corwin_schultz_spread_20",
|
||||
"roll_spread_20",
|
||||
"gibbs_effective_spread",
|
||||
"overnight_illiq_20",
|
||||
"illiq_volatility_20",
|
||||
"amount_cv_20",
|
||||
"amount_skewness_20",
|
||||
"low_vol_days_20",
|
||||
"liquidity_shock_momentum",
|
||||
"downside_illiq_20",
|
||||
"upside_illiq_20",
|
||||
"illiq_asymmetry_20",
|
||||
"pastor_stambaugh_proxy"
|
||||
|
||||
]
|
||||
|
||||
|
||||
@@ -52,36 +52,22 @@ TRAINING_TYPE = "regression"
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
# 'GTJA_alpha016',
|
||||
# 'volatility_20',
|
||||
# 'current_ratio',
|
||||
# 'GTJA_alpha001',
|
||||
# 'GTJA_alpha141',
|
||||
# 'GTJA_alpha129',
|
||||
# 'GTJA_alpha164',
|
||||
# 'amivest_liq_20',
|
||||
# 'GTJA_alpha012',
|
||||
# 'debt_to_equity',
|
||||
# 'turnover_deviation',
|
||||
# 'GTJA_alpha073',
|
||||
# 'GTJA_alpha043',
|
||||
# 'GTJA_alpha032',
|
||||
# 'GTJA_alpha028',
|
||||
# 'GTJA_alpha090',
|
||||
# 'GTJA_alpha108',
|
||||
# 'GTJA_alpha105',
|
||||
# 'GTJA_alpha091',
|
||||
# 'GTJA_alpha119',
|
||||
# 'GTJA_alpha104',
|
||||
# 'GTJA_alpha163',
|
||||
# 'GTJA_alpha157',
|
||||
# 'cost_skewness',
|
||||
# 'GTJA_alpha176',
|
||||
# 'chip_transition',
|
||||
# 'amount_skewness_20',
|
||||
# 'GTJA_alpha148',
|
||||
# 'mean_median_dev',
|
||||
# 'downside_illiq_20',
|
||||
# "amivest_liq_20",
|
||||
# "atr_price_impact",
|
||||
# "hui_heubel_ratio",
|
||||
# "corwin_schultz_spread_20",
|
||||
# "roll_spread_20",
|
||||
# "gibbs_effective_spread",
|
||||
# "overnight_illiq_20",
|
||||
# "illiq_volatility_20",
|
||||
# "amount_cv_20",
|
||||
# "amount_skewness_20",
|
||||
# "low_vol_days_20",
|
||||
# "liquidity_shock_momentum",
|
||||
# "downside_illiq_20",
|
||||
# "upside_illiq_20",
|
||||
# "illiq_asymmetry_20",
|
||||
# "pastor_stambaugh_proxy"
|
||||
]
|
||||
|
||||
# 模型参数配置
|
||||
|
||||
@@ -46,12 +46,16 @@ TRAINING_TYPE = "tabm_rank"
|
||||
# Label 配置(从 common.py 统一导入)
|
||||
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
|
||||
|
||||
# 分位数配置(提高分辨率以更好地区分头部)
|
||||
# 分位数配置(分桶模式下使用;Rank-Gauss 模式下不使用,但保留兼容性)
|
||||
N_QUANTILES = 50
|
||||
|
||||
# 【Top-K 优化】标签工程配置 - 默认启用平方增益
|
||||
LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2)
|
||||
LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放)
|
||||
# 标签工程配置
|
||||
# 可选值:
|
||||
# - "rank_gauss": Rank-Gauss 连续化标签(推荐,神经网络更友好)
|
||||
# - "exponential": 指数化增益标签 (rank^2)
|
||||
# - None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
|
||||
LABEL_TRANSFORM = "rank_gauss"
|
||||
LABEL_SCALE = 20.0 # 保留参数(rank_gauss / exponential 下均未使用)
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
|
||||
|
||||
@@ -61,7 +61,7 @@ MODEL_PARAMS = {
|
||||
# ==================== MLP 结构 ====================
|
||||
"n_blocks": 3,
|
||||
"d_block": 256,
|
||||
"dropout": 0.5,
|
||||
"dropout": 0.3,
|
||||
|
||||
# ==================== 集成机制 ====================
|
||||
"ensemble_size": 32,
|
||||
@@ -71,9 +71,9 @@ MODEL_PARAMS = {
|
||||
"setrank_heads": 4,
|
||||
# 【优化1】将隐藏维度从 128 降到 64。
|
||||
# 截面特征对比不需要那么宽的维度,太宽会导致模型记忆当天特有的无效噪音。
|
||||
"setrank_hidden": 128,
|
||||
"setrank_hidden": 256,
|
||||
# 【优化2】增大 SetRank 层的 Dropout
|
||||
"setrank_dropout": 0.5,
|
||||
"setrank_dropout": 0.3,
|
||||
|
||||
# ==================== AMP 与显存优化 ====================
|
||||
"use_amp": True,
|
||||
@@ -85,7 +85,7 @@ MODEL_PARAMS = {
|
||||
"learning_rate": 5e-4,
|
||||
# 【优化4】核心操作!将 L2 惩罚(权重衰减)放大 10 倍甚至 100 倍!
|
||||
# 带有 Attention 的网络极容易对某些特定股票产生依赖,强烈的 Weight Decay 能逼迫模型关注全局特征。
|
||||
"weight_decay": 1e-5, # 原为 1e-5,现改为 1e-3
|
||||
"weight_decay": 1e-4, # 原为 1e-5,现改为 1e-3
|
||||
|
||||
"epochs": 150, # 不需要 500 次,从图中看 150 绝对够了
|
||||
|
||||
|
||||
Reference in New Issue
Block a user