feat(training): 新增 TabM 排序学习模型支持并优化训练流程

- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置
- 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块
- 调整数据处理器移除过度的 NaN/null 硬填充逻辑
- 优化 RankTask 评估指标使用分位数标签替代原始收益率
- 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
2026-04-04 22:39:58 +08:00
parent 9e7d4241c6
commit a66d5e9db3
16 changed files with 1663 additions and 344 deletions

View File

@@ -52,36 +52,36 @@ TRAINING_TYPE = "regression"
# 排除的因子列表
EXCLUDED_FACTORS = [
'GTJA_alpha016',
'volatility_20',
'current_ratio',
'GTJA_alpha001',
'GTJA_alpha141',
'GTJA_alpha129',
'GTJA_alpha164',
'amivest_liq_20',
'GTJA_alpha012',
'debt_to_equity',
'turnover_deviation',
'GTJA_alpha073',
'GTJA_alpha043',
'GTJA_alpha032',
'GTJA_alpha028',
'GTJA_alpha090',
'GTJA_alpha108',
'GTJA_alpha105',
'GTJA_alpha091',
'GTJA_alpha119',
'GTJA_alpha104',
'GTJA_alpha163',
'GTJA_alpha157',
'cost_skewness',
'GTJA_alpha176',
'chip_transition',
'amount_skewness_20',
'GTJA_alpha148',
'mean_median_dev',
'downside_illiq_20',
# 'GTJA_alpha016',
# 'volatility_20',
# 'current_ratio',
# 'GTJA_alpha001',
# 'GTJA_alpha141',
# 'GTJA_alpha129',
# 'GTJA_alpha164',
# 'amivest_liq_20',
# 'GTJA_alpha012',
# 'debt_to_equity',
# 'turnover_deviation',
# 'GTJA_alpha073',
# 'GTJA_alpha043',
# 'GTJA_alpha032',
# 'GTJA_alpha028',
# 'GTJA_alpha090',
# 'GTJA_alpha108',
# 'GTJA_alpha105',
# 'GTJA_alpha091',
# 'GTJA_alpha119',
# 'GTJA_alpha104',
# 'GTJA_alpha163',
# 'GTJA_alpha157',
# 'cost_skewness',
# 'GTJA_alpha176',
# 'chip_transition',
# 'amount_skewness_20',
# 'GTJA_alpha148',
# 'mean_median_dev',
# 'downside_illiq_20',
]
# 模型参数配置
@@ -153,15 +153,15 @@ def main():
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(NullFiller, {"strategy": "mean"}),
(StandardScaler, {}),
# (CrossSectionalStandardScaler, {}),
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
# (StandardScaler, {}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,