Files
ProStock/src/experiment/tabm_rank_train.py
liaozhaorun 1fa4ff9544 feat(training): TabM 排序模型架构优化与 Rank-Gauss 标签工程
- TabMSetRank: 将 TabM 输出改为隐藏层特征,经 SetRankHead 交互后通过 final_mlp 输出 Ensemble 排序分
- SetRankHead 引入可学习残差缩放因子(Zero-init)与 Pre-Norm 结构,提升训练稳定性
- TabMRankTask 新增 Rank-Gauss 连续标签变换,支持标准分位数/指数增益/Rank-Gauss 三种标签模式
- 修复 NDCG 评估在负值标签下的计算问题
- 调整实验脚本超参数(dropout、hidden dim、weight decay)及排除因子列表
- 迁移废弃的 torch.cuda.amp 到 torch.amp,并将数据预加载至 GPU 减少循环拷贝
2026-04-05 19:01:08 +08:00

181 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""TabM 排序学习训练流程(模块化版本)
使用新的模块化 Trainer 架构,基于 TabMRankModel 实现排序学习。
TabM 使用 ListNet 损失函数,支持集成学习。
"""
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
NullFiller,
Winsorizer,
CrossSectionalStandardScaler,
)
from src.training.tasks.tabm_rank_task import TabMRankTask
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
TRAINING_TYPE = "tabm_rank"
# %%
# Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 分位数配置分桶模式下使用Rank-Gauss 模式下不使用,但保留兼容性)
N_QUANTILES = 50
# 标签工程配置
# 可选值:
# - "rank_gauss": Rank-Gauss 连续化标签(推荐,神经网络更友好)
# - "exponential": 指数化增益标签 (rank^2)
# - None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
LABEL_TRANSFORM = "rank_gauss"
LABEL_SCALE = 20.0 # 保留参数rank_gauss / exponential 下均未使用)
# 排除的因子列表
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
# TabM Rank 模型参数配置Top-K 优化全部开启,使用 LambdaLoss
MODEL_PARAMS = {
# ==================== MLP 结构 ====================
"n_blocks": 4, # MLP 层数
"d_block": 256, # 每层神经元数
"dropout": 0.5, # Dropout 率
# ==================== 集成机制 ====================
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
# ==================== 训练参数 ====================
"learning_rate": 1e-4, # 学习率
"weight_decay": 1e-5, # 权重衰减
"epochs": 500, # 训练轮数
# ==================== 早停 ====================
"early_stopping_round": 50, # 早停耐心值
# NDCG 评估 - 关注 Top-20
"ndcg_k": 20, # 验证时计算 NDCG@20
# 【Top-K 优化】损失函数配置 - 使用 LambdaLoss
"loss_type": "lambda", # 使用 LambdaLoss 精准优化 Top-K
"lambda_sigma": 1.0, # Sigmoid 陡峭程度
"ndcg_weight_power": 1.0, # DeltaNDCG 权重幂次,>1 进一步放大头部效应
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabm_rank_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabM 排序学习训练(模块化版本)")
print("=" * 80)
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(NullFiller, {"strategy": "mean"}),
(CrossSectionalStandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 TabMRankTask
print("\n[4] 创建 TabMRankTask")
task = TabMRankTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
n_quantiles=N_QUANTILES,
label_transform=LABEL_TRANSFORM,
label_scale=LABEL_SCALE,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[7] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
print("\n" + "=" * 80)
print("训练流程完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_rank_output.csv')}")
print("=" * 80)
return results
if __name__ == "__main__":
main()