- TabMSetRank: 将 TabM 输出改为隐藏层特征,经 SetRankHead 交互后通过 final_mlp 输出 Ensemble 排序分 - SetRankHead 引入可学习残差缩放因子(Zero-init)与 Pre-Norm 结构,提升训练稳定性 - TabMRankTask 新增 Rank-Gauss 连续标签变换,支持标准分位数/指数增益/Rank-Gauss 三种标签模式 - 修复 NDCG 评估在负值标签下的计算问题 - 调整实验脚本超参数(dropout、hidden dim、weight decay)及排除因子列表 - 迁移废弃的 torch.cuda.amp 到 torch.amp,并将数据预加载至 GPU 减少循环拷贝
181 lines
5.3 KiB
Python
181 lines
5.3 KiB
Python
"""TabM 排序学习训练流程(模块化版本)
|
||
|
||
使用新的模块化 Trainer 架构,基于 TabMRankModel 实现排序学习。
|
||
TabM 使用 ListNet 损失函数,支持集成学习。
|
||
"""
|
||
|
||
import os
|
||
|
||
from src.factors import FactorEngine
|
||
from src.training import (
|
||
FactorManager,
|
||
DataPipeline,
|
||
NullFiller,
|
||
Winsorizer,
|
||
CrossSectionalStandardScaler,
|
||
)
|
||
from src.training.tasks.tabm_rank_task import TabMRankTask
|
||
from src.training.core.trainer_v2 import Trainer
|
||
from src.training.components.filters import STFilter
|
||
from src.experiment.common import (
|
||
SELECTED_FACTORS,
|
||
FACTOR_DEFINITIONS,
|
||
LABEL_NAME,
|
||
LABEL_FACTOR,
|
||
TRAIN_START,
|
||
TRAIN_END,
|
||
VAL_START,
|
||
VAL_END,
|
||
TEST_START,
|
||
TEST_END,
|
||
stock_pool_filter,
|
||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||
OUTPUT_DIR,
|
||
SAVE_PREDICTIONS,
|
||
SAVE_MODEL,
|
||
get_model_save_path,
|
||
save_model_with_factors,
|
||
TOP_N,
|
||
TRAIN_SKIP_DAYS,
|
||
)
|
||
|
||
# 训练类型标识
|
||
TRAINING_TYPE = "tabm_rank"
|
||
|
||
# %%
|
||
# Label 配置(从 common.py 统一导入)
|
||
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
|
||
|
||
# 分位数配置(分桶模式下使用;Rank-Gauss 模式下不使用,但保留兼容性)
|
||
N_QUANTILES = 50
|
||
|
||
# 标签工程配置
|
||
# 可选值:
|
||
# - "rank_gauss": Rank-Gauss 连续化标签(推荐,神经网络更友好)
|
||
# - "exponential": 指数化增益标签 (rank^2)
|
||
# - None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
|
||
LABEL_TRANSFORM = "rank_gauss"
|
||
LABEL_SCALE = 20.0 # 保留参数(rank_gauss / exponential 下均未使用)
|
||
|
||
# 排除的因子列表
|
||
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
|
||
|
||
# TabM Rank 模型参数配置(Top-K 优化全部开启,使用 LambdaLoss)
|
||
MODEL_PARAMS = {
|
||
# ==================== MLP 结构 ====================
|
||
"n_blocks": 4, # MLP 层数
|
||
"d_block": 256, # 每层神经元数
|
||
"dropout": 0.5, # Dropout 率
|
||
# ==================== 集成机制 ====================
|
||
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
|
||
# ==================== 训练参数 ====================
|
||
"learning_rate": 1e-4, # 学习率
|
||
"weight_decay": 1e-5, # 权重衰减
|
||
"epochs": 500, # 训练轮数
|
||
# ==================== 早停 ====================
|
||
"early_stopping_round": 50, # 早停耐心值
|
||
# NDCG 评估 - 关注 Top-20
|
||
"ndcg_k": 20, # 验证时计算 NDCG@20
|
||
# 【Top-K 优化】损失函数配置 - 使用 LambdaLoss
|
||
"loss_type": "lambda", # 使用 LambdaLoss 精准优化 Top-K
|
||
"lambda_sigma": 1.0, # Sigmoid 陡峭程度
|
||
"ndcg_weight_power": 1.0, # DeltaNDCG 权重幂次,>1 进一步放大头部效应
|
||
}
|
||
|
||
# 日期范围配置
|
||
date_range = {
|
||
"train": (TRAIN_START, TRAIN_END),
|
||
"val": (VAL_START, VAL_END),
|
||
"test": (TEST_START, TEST_END),
|
||
}
|
||
|
||
# 输出配置
|
||
output_config = {
|
||
"output_dir": OUTPUT_DIR,
|
||
"output_filename": "tabm_rank_output.csv",
|
||
"save_predictions": SAVE_PREDICTIONS,
|
||
"save_model": SAVE_MODEL,
|
||
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
||
"top_n": TOP_N,
|
||
}
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("\n" + "=" * 80)
|
||
print("TabM 排序学习训练(模块化版本)")
|
||
print("=" * 80)
|
||
|
||
# 1. 创建 FactorEngine
|
||
print("\n[1] 创建 FactorEngine")
|
||
engine = FactorEngine()
|
||
|
||
# 2. 创建 FactorManager
|
||
print("\n[2] 创建 FactorManager")
|
||
factor_manager = FactorManager(
|
||
selected_factors=SELECTED_FACTORS,
|
||
factor_definitions=FACTOR_DEFINITIONS,
|
||
label_factor=LABEL_FACTOR,
|
||
excluded_factors=EXCLUDED_FACTORS,
|
||
)
|
||
|
||
# 3. 创建 DataPipeline
|
||
print("\n[3] 创建 DataPipeline")
|
||
pipeline = DataPipeline(
|
||
factor_manager=factor_manager,
|
||
processor_configs=[
|
||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||
(NullFiller, {"strategy": "mean"}),
|
||
(CrossSectionalStandardScaler, {}),
|
||
],
|
||
filters=[STFilter(data_router=engine.router)],
|
||
stock_pool_filter_func=stock_pool_filter,
|
||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||
train_skip_days=TRAIN_SKIP_DAYS,
|
||
)
|
||
|
||
# 4. 创建 TabMRankTask
|
||
print("\n[4] 创建 TabMRankTask")
|
||
task = TabMRankTask(
|
||
model_params=MODEL_PARAMS,
|
||
label_name=LABEL_NAME,
|
||
n_quantiles=N_QUANTILES,
|
||
label_transform=LABEL_TRANSFORM,
|
||
label_scale=LABEL_SCALE,
|
||
)
|
||
|
||
# 5. 创建 Trainer
|
||
print("\n[5] 创建 Trainer")
|
||
trainer = Trainer(
|
||
data_pipeline=pipeline,
|
||
task=task,
|
||
output_config=output_config,
|
||
verbose=True,
|
||
)
|
||
|
||
# 6. 执行训练
|
||
print("\n[6] 执行训练")
|
||
results = trainer.run(engine=engine, date_range=date_range)
|
||
|
||
# 7. 保存模型和因子信息(如果启用)
|
||
if SAVE_MODEL:
|
||
print("\n[7] 保存模型和因子信息")
|
||
save_model_with_factors(
|
||
model=task.get_model(),
|
||
model_path=output_config["model_save_path"],
|
||
selected_factors=SELECTED_FACTORS,
|
||
factor_definitions=FACTOR_DEFINITIONS,
|
||
fitted_processors=pipeline.get_fitted_processors(),
|
||
)
|
||
|
||
print("\n" + "=" * 80)
|
||
print("训练流程完成!")
|
||
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_rank_output.csv')}")
|
||
print("=" * 80)
|
||
|
||
return results
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|