# %% md # # TabPFN 回归训练流程 # # 使用 TabPFN (Prior-Data Fitted Network) 进行回归预测。 # TabPFN 通过上下文学习进行预测,无需传统梯度下降训练过程。 # %% md # ## 1. 导入依赖 # %% import os from src.factors import FactorEngine from src.training import ( FactorManager, DataPipeline, NullFiller, Winsorizer, StandardScaler, CrossSectionalStandardScaler, ) from src.training.core.trainer_v2 import Trainer from src.training.components.filters import STFilter from src.training.components.models import TabPFNModel from src.experiment.common import ( SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_NAME, LABEL_FACTOR, TRAIN_START, TRAIN_END, VAL_START, VAL_END, TEST_START, TEST_END, stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, OUTPUT_DIR, SAVE_PREDICTIONS, SAVE_MODEL, get_model_save_path, save_model_with_factors, TOP_N, ) # 训练类型标识 TRAINING_TYPE = "tabpfn" # %% md # ## 2. 训练特定配置 # %% # Label 配置(从 common.py 统一导入) # LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入 # 排除的因子列表(与 regression.py 保持一致) EXCLUDED_FACTORS = [ "GTJA_alpha001", "GTJA_alpha002", "GTJA_alpha003", "GTJA_alpha004", "GTJA_alpha005", "GTJA_alpha006", "GTJA_alpha007", "GTJA_alpha008", "GTJA_alpha009", "GTJA_alpha010", "GTJA_alpha011", "GTJA_alpha012", "GTJA_alpha013", "GTJA_alpha014", "GTJA_alpha015", "GTJA_alpha016", "GTJA_alpha017", "GTJA_alpha018", "GTJA_alpha019", "GTJA_alpha020", "GTJA_alpha022", "GTJA_alpha023", "GTJA_alpha024", "GTJA_alpha025", "GTJA_alpha026", "GTJA_alpha027", "GTJA_alpha028", "GTJA_alpha029", "GTJA_alpha031", "GTJA_alpha032", "GTJA_alpha033", "GTJA_alpha034", "GTJA_alpha035", "GTJA_alpha036", "GTJA_alpha037", # "GTJA_alpha038", "GTJA_alpha039", "GTJA_alpha040", "GTJA_alpha041", "GTJA_alpha042", "GTJA_alpha043", "GTJA_alpha044", "GTJA_alpha045", "GTJA_alpha046", "GTJA_alpha047", "GTJA_alpha048", "GTJA_alpha049", "GTJA_alpha050", "GTJA_alpha051", "GTJA_alpha052", "GTJA_alpha053", "GTJA_alpha054", "GTJA_alpha056", "GTJA_alpha057", "GTJA_alpha058", "GTJA_alpha059", "GTJA_alpha060", "GTJA_alpha061", "GTJA_alpha062", "GTJA_alpha063", "GTJA_alpha064", "GTJA_alpha065", "GTJA_alpha066", "GTJA_alpha067", "GTJA_alpha068", "GTJA_alpha070", "GTJA_alpha071", "GTJA_alpha072", "GTJA_alpha073", "GTJA_alpha074", "GTJA_alpha076", "GTJA_alpha077", "GTJA_alpha078", "GTJA_alpha079", "GTJA_alpha080", "GTJA_alpha081", "GTJA_alpha082", "GTJA_alpha083", "GTJA_alpha084", "GTJA_alpha085", "GTJA_alpha086", "GTJA_alpha087", "GTJA_alpha088", "GTJA_alpha089", "GTJA_alpha090", "GTJA_alpha091", "GTJA_alpha092", "GTJA_alpha093", "GTJA_alpha094", "GTJA_alpha095", "GTJA_alpha096", "GTJA_alpha097", "GTJA_alpha098", "GTJA_alpha099", "GTJA_alpha100", "GTJA_alpha101", "GTJA_alpha102", "GTJA_alpha103", "GTJA_alpha104", "GTJA_alpha105", "GTJA_alpha106", "GTJA_alpha107", "GTJA_alpha108", "GTJA_alpha109", "GTJA_alpha110", "GTJA_alpha111", "GTJA_alpha112", # "GTJA_alpha113", "GTJA_alpha114", "GTJA_alpha115", "GTJA_alpha117", "GTJA_alpha118", "GTJA_alpha119", "GTJA_alpha120", # "GTJA_alpha121", "GTJA_alpha122", "GTJA_alpha123", "GTJA_alpha124", "GTJA_alpha125", "GTJA_alpha126", "GTJA_alpha127", "GTJA_alpha128", "GTJA_alpha129", "GTJA_alpha130", "GTJA_alpha131", "GTJA_alpha132", "GTJA_alpha133", "GTJA_alpha134", "GTJA_alpha135", "GTJA_alpha136", # "GTJA_alpha138", "GTJA_alpha139", # "GTJA_alpha140", "GTJA_alpha141", "GTJA_alpha142", "GTJA_alpha145", # "GTJA_alpha146", "GTJA_alpha148", "GTJA_alpha150", "GTJA_alpha151", "GTJA_alpha152", "GTJA_alpha153", "GTJA_alpha154", "GTJA_alpha155", "GTJA_alpha156", "GTJA_alpha157", "GTJA_alpha158", "GTJA_alpha159", "GTJA_alpha160", "GTJA_alpha161", "GTJA_alpha162", "GTJA_alpha163", "GTJA_alpha164", # "GTJA_alpha165", "GTJA_alpha166", "GTJA_alpha167", "GTJA_alpha168", "GTJA_alpha169", "GTJA_alpha170", "GTJA_alpha171", "GTJA_alpha173", "GTJA_alpha174", "GTJA_alpha175", "GTJA_alpha176", "GTJA_alpha177", "GTJA_alpha178", "GTJA_alpha179", "GTJA_alpha180", # "GTJA_alpha183", "GTJA_alpha184", "GTJA_alpha185", "GTJA_alpha187", "GTJA_alpha188", "GTJA_alpha189", "GTJA_alpha191", "chip_dispersion_90", "chip_dispersion_70", "cost_skewness", "dispersion_change_20", "price_to_avg_cost", "price_to_median_cost", "mean_median_dev", "trap_pressure", "bottom_profit", "history_position", "winner_rate_surge_5", "winner_rate_cs_rank", "winner_rate_dev_20", "winner_rate_volatility", "smart_money_accumulation", "winner_vol_corr_20", "cost_base_momentum", "bottom_cost_stability", "pivot_reversion", "chip_transition", ] # 模型参数配置 MODEL_PARAMS = { # ==================== 设备配置 ==================== "device": "cuda", # 计算设备: "cuda" 或 "cpu"(默认 cuda) # ==================== 上下文限制 ==================== "max_context_size": 100, # 16GB GPU 建议 1000-3000,32GB 可尝试 5000-8000 } # 日期范围配置 date_range = { "train": (TRAIN_START, TRAIN_END), "val": (VAL_START, VAL_END), "test": (TEST_START, TEST_END), } # 输出配置 output_config = { "output_dir": OUTPUT_DIR, "output_filename": "tabpfn_output.csv", "save_predictions": SAVE_PREDICTIONS, "save_model": SAVE_MODEL, "model_save_path": get_model_save_path(TRAINING_TYPE), "top_n": TOP_N, } # %% md # ## 3. 自定义 TabPFN 任务 # %% from src.training.tasks import RegressionTask class TabPFNTask(RegressionTask): """TabPFN 回归任务 继承自 RegressionTask,但使用 TabPFNModel 作为模型。 TabPFN 不需要传统的训练过程,而是通过上下文学习进行预测。 """ def __init__(self, model_params: dict, label_name: str): """初始化 TabPFN 任务 Args: model_params: TabPFN 参数字典 label_name: Label 列名称 """ # 不调用父类 __init__,直接初始化以避免创建 LightGBMModel from src.training.tasks.base import BaseTask BaseTask.__init__(self, model_params, label_name) self.evals_result: dict | None = None self.model = TabPFNModel(params=model_params) def fit(self, train_data: dict, val_data: dict) -> None: """训练 TabPFN 模型 TabPFN 通过将训练数据加载到模型上下文中进行"训练", 不需要传统的梯度下降优化过程。 Args: train_data: 训练数据 {"X": DataFrame, "y": Series} val_data: 验证数据,用于评估但不参与训练 """ X_train = train_data["X"] y_train = train_data["y"] X_val = val_data.get("X") y_val = val_data.get("y") # TabPFN 使用 eval_set 进行验证 self.model.fit( X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None ) def get_model(self) -> TabPFNModel: """获取训练好的模型实例""" return self.model # %% md # ## 4. 主函数 # %% def main(): """主函数""" print("\n" + "=" * 80) print("TabPFN 回归模型训练") print("=" * 80) print("\n[说明] TabPFN 使用上下文学习(In-Context Learning),") print(" 训练过程实际是加载数据到模型上下文。") print(" 如果训练数据超过上下文限制,会自动截取最近的数据。") # 1. 创建 FactorEngine print("\n[1] 创建 FactorEngine") engine = FactorEngine() # 2. 创建 FactorManager print("\n[2] 创建 FactorManager") factor_manager = FactorManager( selected_factors=SELECTED_FACTORS, factor_definitions=FACTOR_DEFINITIONS, label_factor=LABEL_FACTOR, excluded_factors=EXCLUDED_FACTORS, ) # 3. 创建 DataPipeline print("\n[3] 创建 DataPipeline") pipeline = DataPipeline( factor_manager=factor_manager, processor_configs=[ (NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99}), (StandardScaler, {}), # (CrossSectionalStandardScaler, {}), ], label_processor_configs=[ # 对 label 进行缩尾处理(去除极端收益率) (Winsorizer, {"lower": 0.05, "upper": 0.95}), # (StandardScaler, {}), ], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, ) # 4. 创建 TabPFNTask print("\n[4] 创建 TabPFNTask") task = TabPFNTask( model_params=MODEL_PARAMS, label_name=LABEL_NAME, ) # 5. 创建 Trainer print("\n[5] 创建 Trainer") trainer = Trainer( data_pipeline=pipeline, task=task, output_config=output_config, verbose=True, ) # 6. 执行训练 print("\n[6] 执行训练") results = trainer.run(engine=engine, date_range=date_range) # 7. 保存模型和因子信息(如果启用) if SAVE_MODEL: print("\n[7] 保存模型和因子信息") save_model_with_factors( model=task.get_model(), model_path=output_config["model_save_path"], selected_factors=SELECTED_FACTORS, factor_definitions=FACTOR_DEFINITIONS, fitted_processors=pipeline.get_fitted_processors(), ) # 8. 输出 TabPFN 特有指标 print("\n" + "=" * 80) print("TabPFN 训练完成!") print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabpfn_output.csv')}") # 显示验证集评估结果(如果可用) model = task.get_model() best_score = model.get_best_score() if best_score: print("\n[验证集评估指标]") for metric, value in best_score.get("valid_0", {}).items(): print(f" - {metric}: {value:.6f}") print("=" * 80) return results if __name__ == "__main__": main()