feat(training): 新增 TabM 模型支持及数据质量优化
- 添加 TabMModel、TabPFNModel 深度学习模型实现 - 新增 DataQualityAnalyzer 进行训练前数据质量诊断 - 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性 - 支持 train_skip_days 参数跳过训练初期数据不足期 - Pipeline 自动清理标签为 NaN 的样本
This commit is contained in:
425
src/experiment/tabpfn_regression.py
Normal file
425
src/experiment/tabpfn_regression.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# %% md
|
||||
# # TabPFN 回归训练流程
|
||||
#
|
||||
# 使用 TabPFN (Prior-Data Fitted Network) 进行回归预测。
|
||||
# TabPFN 通过上下文学习进行预测,无需传统梯度下降训练过程。
|
||||
# %% md
|
||||
# ## 1. 导入依赖
|
||||
# %%
|
||||
import os
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.core.trainer_v2 import Trainer
|
||||
from src.training.components.filters import STFilter
|
||||
from src.training.components.models import TabPFNModel
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
TRAIN_START,
|
||||
TRAIN_END,
|
||||
VAL_START,
|
||||
VAL_END,
|
||||
TEST_START,
|
||||
TEST_END,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabpfn"
|
||||
|
||||
# %% md
|
||||
# ## 2. 训练特定配置
|
||||
# %%
|
||||
# Label 配置(从 common.py 统一导入)
|
||||
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
|
||||
|
||||
# 排除的因子列表(与 regression.py 保持一致)
|
||||
EXCLUDED_FACTORS = [
|
||||
"GTJA_alpha001",
|
||||
"GTJA_alpha002",
|
||||
"GTJA_alpha003",
|
||||
"GTJA_alpha004",
|
||||
"GTJA_alpha005",
|
||||
"GTJA_alpha006",
|
||||
"GTJA_alpha007",
|
||||
"GTJA_alpha008",
|
||||
"GTJA_alpha009",
|
||||
"GTJA_alpha010",
|
||||
"GTJA_alpha011",
|
||||
"GTJA_alpha012",
|
||||
"GTJA_alpha013",
|
||||
"GTJA_alpha014",
|
||||
"GTJA_alpha015",
|
||||
"GTJA_alpha016",
|
||||
"GTJA_alpha017",
|
||||
"GTJA_alpha018",
|
||||
"GTJA_alpha019",
|
||||
"GTJA_alpha020",
|
||||
"GTJA_alpha022",
|
||||
"GTJA_alpha023",
|
||||
"GTJA_alpha024",
|
||||
"GTJA_alpha025",
|
||||
"GTJA_alpha026",
|
||||
"GTJA_alpha027",
|
||||
"GTJA_alpha028",
|
||||
"GTJA_alpha029",
|
||||
"GTJA_alpha031",
|
||||
"GTJA_alpha032",
|
||||
"GTJA_alpha033",
|
||||
"GTJA_alpha034",
|
||||
"GTJA_alpha035",
|
||||
"GTJA_alpha036",
|
||||
"GTJA_alpha037",
|
||||
# "GTJA_alpha038",
|
||||
"GTJA_alpha039",
|
||||
"GTJA_alpha040",
|
||||
"GTJA_alpha041",
|
||||
"GTJA_alpha042",
|
||||
"GTJA_alpha043",
|
||||
"GTJA_alpha044",
|
||||
"GTJA_alpha045",
|
||||
"GTJA_alpha046",
|
||||
"GTJA_alpha047",
|
||||
"GTJA_alpha048",
|
||||
"GTJA_alpha049",
|
||||
"GTJA_alpha050",
|
||||
"GTJA_alpha051",
|
||||
"GTJA_alpha052",
|
||||
"GTJA_alpha053",
|
||||
"GTJA_alpha054",
|
||||
"GTJA_alpha056",
|
||||
"GTJA_alpha057",
|
||||
"GTJA_alpha058",
|
||||
"GTJA_alpha059",
|
||||
"GTJA_alpha060",
|
||||
"GTJA_alpha061",
|
||||
"GTJA_alpha062",
|
||||
"GTJA_alpha063",
|
||||
"GTJA_alpha064",
|
||||
"GTJA_alpha065",
|
||||
"GTJA_alpha066",
|
||||
"GTJA_alpha067",
|
||||
"GTJA_alpha068",
|
||||
"GTJA_alpha070",
|
||||
"GTJA_alpha071",
|
||||
"GTJA_alpha072",
|
||||
"GTJA_alpha073",
|
||||
"GTJA_alpha074",
|
||||
"GTJA_alpha076",
|
||||
"GTJA_alpha077",
|
||||
"GTJA_alpha078",
|
||||
"GTJA_alpha079",
|
||||
"GTJA_alpha080",
|
||||
"GTJA_alpha081",
|
||||
"GTJA_alpha082",
|
||||
"GTJA_alpha083",
|
||||
"GTJA_alpha084",
|
||||
"GTJA_alpha085",
|
||||
"GTJA_alpha086",
|
||||
"GTJA_alpha087",
|
||||
"GTJA_alpha088",
|
||||
"GTJA_alpha089",
|
||||
"GTJA_alpha090",
|
||||
"GTJA_alpha091",
|
||||
"GTJA_alpha092",
|
||||
"GTJA_alpha093",
|
||||
"GTJA_alpha094",
|
||||
"GTJA_alpha095",
|
||||
"GTJA_alpha096",
|
||||
"GTJA_alpha097",
|
||||
"GTJA_alpha098",
|
||||
"GTJA_alpha099",
|
||||
"GTJA_alpha100",
|
||||
"GTJA_alpha101",
|
||||
"GTJA_alpha102",
|
||||
"GTJA_alpha103",
|
||||
"GTJA_alpha104",
|
||||
"GTJA_alpha105",
|
||||
"GTJA_alpha106",
|
||||
"GTJA_alpha107",
|
||||
"GTJA_alpha108",
|
||||
"GTJA_alpha109",
|
||||
"GTJA_alpha110",
|
||||
"GTJA_alpha111",
|
||||
"GTJA_alpha112",
|
||||
# "GTJA_alpha113",
|
||||
"GTJA_alpha114",
|
||||
"GTJA_alpha115",
|
||||
"GTJA_alpha117",
|
||||
"GTJA_alpha118",
|
||||
"GTJA_alpha119",
|
||||
"GTJA_alpha120",
|
||||
# "GTJA_alpha121",
|
||||
"GTJA_alpha122",
|
||||
"GTJA_alpha123",
|
||||
"GTJA_alpha124",
|
||||
"GTJA_alpha125",
|
||||
"GTJA_alpha126",
|
||||
"GTJA_alpha127",
|
||||
"GTJA_alpha128",
|
||||
"GTJA_alpha129",
|
||||
"GTJA_alpha130",
|
||||
"GTJA_alpha131",
|
||||
"GTJA_alpha132",
|
||||
"GTJA_alpha133",
|
||||
"GTJA_alpha134",
|
||||
"GTJA_alpha135",
|
||||
"GTJA_alpha136",
|
||||
# "GTJA_alpha138",
|
||||
"GTJA_alpha139",
|
||||
# "GTJA_alpha140",
|
||||
"GTJA_alpha141",
|
||||
"GTJA_alpha142",
|
||||
"GTJA_alpha145",
|
||||
# "GTJA_alpha146",
|
||||
"GTJA_alpha148",
|
||||
"GTJA_alpha150",
|
||||
"GTJA_alpha151",
|
||||
"GTJA_alpha152",
|
||||
"GTJA_alpha153",
|
||||
"GTJA_alpha154",
|
||||
"GTJA_alpha155",
|
||||
"GTJA_alpha156",
|
||||
"GTJA_alpha157",
|
||||
"GTJA_alpha158",
|
||||
"GTJA_alpha159",
|
||||
"GTJA_alpha160",
|
||||
"GTJA_alpha161",
|
||||
"GTJA_alpha162",
|
||||
"GTJA_alpha163",
|
||||
"GTJA_alpha164",
|
||||
# "GTJA_alpha165",
|
||||
"GTJA_alpha166",
|
||||
"GTJA_alpha167",
|
||||
"GTJA_alpha168",
|
||||
"GTJA_alpha169",
|
||||
"GTJA_alpha170",
|
||||
"GTJA_alpha171",
|
||||
"GTJA_alpha173",
|
||||
"GTJA_alpha174",
|
||||
"GTJA_alpha175",
|
||||
"GTJA_alpha176",
|
||||
"GTJA_alpha177",
|
||||
"GTJA_alpha178",
|
||||
"GTJA_alpha179",
|
||||
"GTJA_alpha180",
|
||||
# "GTJA_alpha183",
|
||||
"GTJA_alpha184",
|
||||
"GTJA_alpha185",
|
||||
"GTJA_alpha187",
|
||||
"GTJA_alpha188",
|
||||
"GTJA_alpha189",
|
||||
"GTJA_alpha191",
|
||||
"chip_dispersion_90",
|
||||
"chip_dispersion_70",
|
||||
"cost_skewness",
|
||||
"dispersion_change_20",
|
||||
"price_to_avg_cost",
|
||||
"price_to_median_cost",
|
||||
"mean_median_dev",
|
||||
"trap_pressure",
|
||||
"bottom_profit",
|
||||
"history_position",
|
||||
"winner_rate_surge_5",
|
||||
"winner_rate_cs_rank",
|
||||
"winner_rate_dev_20",
|
||||
"winner_rate_volatility",
|
||||
"smart_money_accumulation",
|
||||
"winner_vol_corr_20",
|
||||
"cost_base_momentum",
|
||||
"bottom_cost_stability",
|
||||
"pivot_reversion",
|
||||
"chip_transition",
|
||||
]
|
||||
|
||||
# 模型参数配置
|
||||
MODEL_PARAMS = {
|
||||
# ==================== 设备配置 ====================
|
||||
"device": "cuda", # 计算设备: "cuda" 或 "cpu"(默认 cuda)
|
||||
# ==================== 上下文限制 ====================
|
||||
"max_context_size": 100, # 16GB GPU 建议 1000-3000,32GB 可尝试 5000-8000
|
||||
}
|
||||
|
||||
# 日期范围配置
|
||||
date_range = {
|
||||
"train": (TRAIN_START, TRAIN_END),
|
||||
"val": (VAL_START, VAL_END),
|
||||
"test": (TEST_START, TEST_END),
|
||||
}
|
||||
|
||||
# 输出配置
|
||||
output_config = {
|
||||
"output_dir": OUTPUT_DIR,
|
||||
"output_filename": "tabpfn_output.csv",
|
||||
"save_predictions": SAVE_PREDICTIONS,
|
||||
"save_model": SAVE_MODEL,
|
||||
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
||||
"top_n": TOP_N,
|
||||
}
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 3. 自定义 TabPFN 任务
|
||||
# %%
|
||||
from src.training.tasks import RegressionTask
|
||||
|
||||
|
||||
class TabPFNTask(RegressionTask):
|
||||
"""TabPFN 回归任务
|
||||
|
||||
继承自 RegressionTask,但使用 TabPFNModel 作为模型。
|
||||
TabPFN 不需要传统的训练过程,而是通过上下文学习进行预测。
|
||||
"""
|
||||
|
||||
def __init__(self, model_params: dict, label_name: str):
|
||||
"""初始化 TabPFN 任务
|
||||
|
||||
Args:
|
||||
model_params: TabPFN 参数字典
|
||||
label_name: Label 列名称
|
||||
"""
|
||||
# 不调用父类 __init__,直接初始化以避免创建 LightGBMModel
|
||||
from src.training.tasks.base import BaseTask
|
||||
|
||||
BaseTask.__init__(self, model_params, label_name)
|
||||
self.evals_result: dict | None = None
|
||||
self.model = TabPFNModel(params=model_params)
|
||||
|
||||
def fit(self, train_data: dict, val_data: dict) -> None:
|
||||
"""训练 TabPFN 模型
|
||||
|
||||
TabPFN 通过将训练数据加载到模型上下文中进行"训练",
|
||||
不需要传统的梯度下降优化过程。
|
||||
|
||||
Args:
|
||||
train_data: 训练数据 {"X": DataFrame, "y": Series}
|
||||
val_data: 验证数据,用于评估但不参与训练
|
||||
"""
|
||||
X_train = train_data["X"]
|
||||
y_train = train_data["y"]
|
||||
X_val = val_data.get("X")
|
||||
y_val = val_data.get("y")
|
||||
|
||||
# TabPFN 使用 eval_set 进行验证
|
||||
self.model.fit(
|
||||
X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None
|
||||
)
|
||||
|
||||
def get_model(self) -> TabPFNModel:
|
||||
"""获取训练好的模型实例"""
|
||||
return self.model
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 4. 主函数
|
||||
# %%
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TabPFN 回归模型训练")
|
||||
print("=" * 80)
|
||||
print("\n[说明] TabPFN 使用上下文学习(In-Context Learning),")
|
||||
print(" 训练过程实际是加载数据到模型上下文。")
|
||||
print(" 如果训练数据超过上下文限制,会自动截取最近的数据。")
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 创建 FactorManager
|
||||
print("\n[2] 创建 FactorManager")
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
# 3. 创建 DataPipeline
|
||||
print("\n[3] 创建 DataPipeline")
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(StandardScaler, {}),
|
||||
# (CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
# (StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
)
|
||||
|
||||
# 4. 创建 TabPFNTask
|
||||
print("\n[4] 创建 TabPFNTask")
|
||||
task = TabPFNTask(
|
||||
model_params=MODEL_PARAMS,
|
||||
label_name=LABEL_NAME,
|
||||
)
|
||||
|
||||
# 5. 创建 Trainer
|
||||
print("\n[5] 创建 Trainer")
|
||||
trainer = Trainer(
|
||||
data_pipeline=pipeline,
|
||||
task=task,
|
||||
output_config=output_config,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 6. 执行训练
|
||||
print("\n[6] 执行训练")
|
||||
results = trainer.run(engine=engine, date_range=date_range)
|
||||
|
||||
# 7. 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n[7] 保存模型和因子信息")
|
||||
save_model_with_factors(
|
||||
model=task.get_model(),
|
||||
model_path=output_config["model_save_path"],
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=pipeline.get_fitted_processors(),
|
||||
)
|
||||
|
||||
# 8. 输出 TabPFN 特有指标
|
||||
print("\n" + "=" * 80)
|
||||
print("TabPFN 训练完成!")
|
||||
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabpfn_output.csv')}")
|
||||
|
||||
# 显示验证集评估结果(如果可用)
|
||||
model = task.get_model()
|
||||
best_score = model.get_best_score()
|
||||
if best_score:
|
||||
print("\n[验证集评估指标]")
|
||||
for metric, value in best_score.get("valid_0", {}).items():
|
||||
print(f" - {metric}: {value:.6f}")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user