feat(training): 新增 TabM 模型支持及数据质量优化

- 添加 TabMModel、TabPFNModel 深度学习模型实现
- 新增 DataQualityAnalyzer 进行训练前数据质量诊断
- 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性
- 支持 train_skip_days 参数跳过训练初期数据不足期
- Pipeline 自动清理标签为 NaN 的样本
This commit is contained in:
2026-03-31 23:11:21 +08:00
parent 9e0114c745
commit 36a3ccbcc8
22 changed files with 4421 additions and 204 deletions

View File

@@ -38,6 +38,7 @@ from src.experiment.common import (
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
@@ -51,55 +52,55 @@ TRAINING_TYPE = "regression"
# 排除的因子列表
EXCLUDED_FACTORS = [
'GTJA_alpha036',
'GTJA_alpha032',
'GTJA_alpha010',
'GTJA_alpha005',
'CP',
'BP',
'debt_to_equity',
'current_ratio',
'GTJA_alpha002',
'GTJA_alpha027',
'GTJA_alpha064',
'GTJA_alpha062',
'GTJA_alpha043',
'GTJA_alpha044',
'GTJA_alpha120',
'GTJA_alpha117',
'GTJA_alpha103',
'GTJA_alpha104',
'GTJA_alpha105',
'GTJA_alpha073',
'GTJA_alpha077',
'GTJA_alpha085',
'GTJA_alpha090',
'GTJA_alpha087',
'GTJA_alpha083',
'GTJA_alpha092',
'GTJA_alpha133',
'GTJA_alpha131',
'GTJA_alpha126',
'GTJA_alpha124',
'GTJA_alpha162',
'GTJA_alpha164',
'GTJA_alpha157',
'GTJA_alpha177',
'price_to_avg_cost',
'cost_skewness',
'GTJA_alpha191',
'GTJA_alpha180',
'history_position',
'bottom_profit',
'mean_median_dev',
'smart_money_accumulation',
'GTJA_alpha013',
'GTJA_alpha099',
'GTJA_alpha107',
'GTJA_alpha119',
'GTJA_alpha141',
'GTJA_alpha130',
'GTJA_alpha173',
"GTJA_alpha036",
"GTJA_alpha032",
"GTJA_alpha010",
"GTJA_alpha005",
"CP",
"BP",
"debt_to_equity",
"current_ratio",
"GTJA_alpha002",
"GTJA_alpha027",
"GTJA_alpha064",
"GTJA_alpha062",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha120",
"GTJA_alpha117",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha073",
"GTJA_alpha077",
"GTJA_alpha085",
"GTJA_alpha090",
"GTJA_alpha087",
"GTJA_alpha083",
"GTJA_alpha092",
"GTJA_alpha133",
"GTJA_alpha131",
"GTJA_alpha126",
"GTJA_alpha124",
"GTJA_alpha162",
"GTJA_alpha164",
"GTJA_alpha157",
"GTJA_alpha177",
"price_to_avg_cost",
"cost_skewness",
"GTJA_alpha191",
"GTJA_alpha180",
"history_position",
"bottom_profit",
"mean_median_dev",
"smart_money_accumulation",
"GTJA_alpha013",
"GTJA_alpha099",
"GTJA_alpha107",
"GTJA_alpha119",
"GTJA_alpha141",
"GTJA_alpha130",
"GTJA_alpha173",
]
# 模型参数配置
@@ -184,6 +185,7 @@ def main():
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 RegressionTask