feat(training): 新增 TabM 排序学习模型支持并优化训练流程
- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置 - 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块 - 调整数据处理器移除过度的 NaN/null 硬填充逻辑 - 优化 RankTask 评估指标使用分位数标签替代原始收益率 - 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
@@ -52,36 +52,36 @@ TRAINING_TYPE = "regression"
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
'GTJA_alpha016',
|
||||
'volatility_20',
|
||||
'current_ratio',
|
||||
'GTJA_alpha001',
|
||||
'GTJA_alpha141',
|
||||
'GTJA_alpha129',
|
||||
'GTJA_alpha164',
|
||||
'amivest_liq_20',
|
||||
'GTJA_alpha012',
|
||||
'debt_to_equity',
|
||||
'turnover_deviation',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha043',
|
||||
'GTJA_alpha032',
|
||||
'GTJA_alpha028',
|
||||
'GTJA_alpha090',
|
||||
'GTJA_alpha108',
|
||||
'GTJA_alpha105',
|
||||
'GTJA_alpha091',
|
||||
'GTJA_alpha119',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha163',
|
||||
'GTJA_alpha157',
|
||||
'cost_skewness',
|
||||
'GTJA_alpha176',
|
||||
'chip_transition',
|
||||
'amount_skewness_20',
|
||||
'GTJA_alpha148',
|
||||
'mean_median_dev',
|
||||
'downside_illiq_20',
|
||||
# 'GTJA_alpha016',
|
||||
# 'volatility_20',
|
||||
# 'current_ratio',
|
||||
# 'GTJA_alpha001',
|
||||
# 'GTJA_alpha141',
|
||||
# 'GTJA_alpha129',
|
||||
# 'GTJA_alpha164',
|
||||
# 'amivest_liq_20',
|
||||
# 'GTJA_alpha012',
|
||||
# 'debt_to_equity',
|
||||
# 'turnover_deviation',
|
||||
# 'GTJA_alpha073',
|
||||
# 'GTJA_alpha043',
|
||||
# 'GTJA_alpha032',
|
||||
# 'GTJA_alpha028',
|
||||
# 'GTJA_alpha090',
|
||||
# 'GTJA_alpha108',
|
||||
# 'GTJA_alpha105',
|
||||
# 'GTJA_alpha091',
|
||||
# 'GTJA_alpha119',
|
||||
# 'GTJA_alpha104',
|
||||
# 'GTJA_alpha163',
|
||||
# 'GTJA_alpha157',
|
||||
# 'cost_skewness',
|
||||
# 'GTJA_alpha176',
|
||||
# 'chip_transition',
|
||||
# 'amount_skewness_20',
|
||||
# 'GTJA_alpha148',
|
||||
# 'mean_median_dev',
|
||||
# 'downside_illiq_20',
|
||||
]
|
||||
|
||||
# 模型参数配置
|
||||
@@ -153,15 +153,15 @@ def main():
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(StandardScaler, {}),
|
||||
# (CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
# (StandardScaler, {}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
|
||||
Reference in New Issue
Block a user