feat(training): 新增 TabM SetRank 模型并支持任务注入
- 添加 TabMSetRankModel 实现集合排序训练 - TabMRankTask 支持通过 model_class 注入兼容模型 - 启用 common.py 中的流动性因子
This commit is contained in:
@@ -271,22 +271,22 @@ SELECTED_FACTORS = [
|
||||
"pivot_reversion",
|
||||
"chip_transition",
|
||||
|
||||
# "amivest_liq_20",
|
||||
# "atr_price_impact",
|
||||
# "hui_heubel_ratio",
|
||||
# "corwin_schultz_spread_20",
|
||||
# "roll_spread_20",
|
||||
# "gibbs_effective_spread",
|
||||
# "overnight_illiq_20",
|
||||
# "illiq_volatility_20",
|
||||
# "amount_cv_20",
|
||||
# "amount_skewness_20",
|
||||
# "low_vol_days_20",
|
||||
# "liquidity_shock_momentum",
|
||||
# "downside_illiq_20",
|
||||
# "upside_illiq_20",
|
||||
# "illiq_asymmetry_20",
|
||||
# "pastor_stambaugh_proxy"
|
||||
"amivest_liq_20",
|
||||
"atr_price_impact",
|
||||
"hui_heubel_ratio",
|
||||
"corwin_schultz_spread_20",
|
||||
"roll_spread_20",
|
||||
"gibbs_effective_spread",
|
||||
"overnight_illiq_20",
|
||||
"illiq_volatility_20",
|
||||
"amount_cv_20",
|
||||
"amount_skewness_20",
|
||||
"low_vol_days_20",
|
||||
"liquidity_shock_momentum",
|
||||
"downside_illiq_20",
|
||||
"upside_illiq_20",
|
||||
"illiq_asymmetry_20",
|
||||
"pastor_stambaugh_proxy"
|
||||
]
|
||||
|
||||
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
||||
|
||||
201
src/experiment/tabm_setrank_train.py
Normal file
201
src/experiment/tabm_setrank_train.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""TabM + SetRank 排序学习训练流程
|
||||
|
||||
使用模块化 Trainer 架构,基于 TabMSetRankModel 实现排序学习。
|
||||
引入 SetRank 组内注意力头,其余配置与 tabm_rank_train.py 对齐。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.tasks.tabm_rank_task import TabMRankTask
|
||||
from src.training.core.trainer_v2 import Trainer
|
||||
from src.training.components.filters import STFilter
|
||||
from src.training.components.models import TabMSetRankModel
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
TRAIN_START,
|
||||
TRAIN_END,
|
||||
VAL_START,
|
||||
VAL_END,
|
||||
TEST_START,
|
||||
TEST_END,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabm_setrank_rank"
|
||||
|
||||
# %%
|
||||
# Label 配置(从 common.py 统一导入)
|
||||
|
||||
# 分位数配置(提高分辨率以更好地区分头部)
|
||||
N_QUANTILES = 50
|
||||
|
||||
# 【Top-K 优化】标签工程配置 - 默认启用平方增益
|
||||
LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2)
|
||||
LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放)
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
|
||||
|
||||
# TabM + SetRank 模型参数配置
|
||||
MODEL_PARAMS = {
|
||||
# ==================== MLP 结构 ====================
|
||||
"n_blocks": 3,
|
||||
"d_block": 256,
|
||||
"dropout": 0.5,
|
||||
|
||||
# ==================== 集成机制 ====================
|
||||
"ensemble_size": 32,
|
||||
|
||||
# ==================== SetRank 头 (降维防过拟合) ====================
|
||||
"use_setrank": True,
|
||||
"setrank_heads": 4,
|
||||
# 【优化1】将隐藏维度从 128 降到 64。
|
||||
# 截面特征对比不需要那么宽的维度,太宽会导致模型记忆当天特有的无效噪音。
|
||||
"setrank_hidden": 128,
|
||||
# 【优化2】增大 SetRank 层的 Dropout
|
||||
"setrank_dropout": 0.5,
|
||||
|
||||
# ==================== AMP 与显存优化 ====================
|
||||
"use_amp": True,
|
||||
"num_workers": 0,
|
||||
"pin_memory": False,
|
||||
|
||||
# ==================== 训练参数 (强正则化) ====================
|
||||
# 【优化3】稍微调低学习率,让模型在接近最优点时不要走得太快(防震荡)
|
||||
"learning_rate": 5e-4,
|
||||
# 【优化4】核心操作!将 L2 惩罚(权重衰减)放大 10 倍甚至 100 倍!
|
||||
# 带有 Attention 的网络极容易对某些特定股票产生依赖,强烈的 Weight Decay 能逼迫模型关注全局特征。
|
||||
"weight_decay": 1e-5, # 原为 1e-5,现改为 1e-3
|
||||
|
||||
"epochs": 150, # 不需要 500 次,从图中看 150 绝对够了
|
||||
|
||||
# ==================== 早停 ====================
|
||||
"early_stopping_round": 30, # 耐心值 30 足矣
|
||||
|
||||
# ==================== NDCG 评估 ====================
|
||||
"ndcg_k": 20,
|
||||
|
||||
# ==================== 损失函数配置 ====================
|
||||
"loss_type": "lambda",
|
||||
"lambda_sigma": 1.0,
|
||||
# 【优化5】稍微放大 DeltaNDCG 的权重幂次,让模型在排错 Top 5 股票时受到更严厉的惩罚
|
||||
"ndcg_weight_power": 1.0,
|
||||
}
|
||||
|
||||
# 日期范围配置
|
||||
date_range = {
|
||||
"train": (TRAIN_START, TRAIN_END),
|
||||
"val": (VAL_START, VAL_END),
|
||||
"test": (TEST_START, TEST_END),
|
||||
}
|
||||
|
||||
# 输出配置
|
||||
output_config = {
|
||||
"output_dir": OUTPUT_DIR,
|
||||
"output_filename": "tabm_setrank_rank_output.csv",
|
||||
"save_predictions": SAVE_PREDICTIONS,
|
||||
"save_model": SAVE_MODEL,
|
||||
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
||||
"top_n": TOP_N,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TabM + SetRank 排序学习训练")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 创建 FactorManager
|
||||
print("\n[2] 创建 FactorManager")
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
# 3. 创建 DataPipeline
|
||||
print("\n[3] 创建 DataPipeline")
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 4. 创建 TabMRankTask(注入 TabMSetRankModel)
|
||||
print("\n[4] 创建 TabMRankTask(TabMSetRankModel)")
|
||||
task = TabMRankTask(
|
||||
model_class=TabMSetRankModel,
|
||||
model_params=MODEL_PARAMS,
|
||||
label_name=LABEL_NAME,
|
||||
n_quantiles=N_QUANTILES,
|
||||
label_transform=LABEL_TRANSFORM,
|
||||
label_scale=LABEL_SCALE,
|
||||
)
|
||||
|
||||
# 5. 创建 Trainer
|
||||
print("\n[5] 创建 Trainer")
|
||||
trainer = Trainer(
|
||||
data_pipeline=pipeline,
|
||||
task=task,
|
||||
output_config=output_config,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 6. 执行训练
|
||||
print("\n[6] 执行训练")
|
||||
results = trainer.run(engine=engine, date_range=date_range)
|
||||
|
||||
# 7. 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n[7] 保存模型和因子信息")
|
||||
save_model_with_factors(
|
||||
model=task.get_model(),
|
||||
model_path=output_config["model_save_path"],
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=pipeline.get_fitted_processors(),
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("训练流程完成!")
|
||||
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_setrank_rank_output.csv')}")
|
||||
print("=" * 80)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user