Files
ProStock/src/experiment/tabpfn_regression.py
liaozhaorun 36a3ccbcc8 feat(training): 新增 TabM 模型支持及数据质量优化
- 添加 TabMModel、TabPFNModel 深度学习模型实现
- 新增 DataQualityAnalyzer 进行训练前数据质量诊断
- 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性
- 支持 train_skip_days 参数跳过训练初期数据不足期
- Pipeline 自动清理标签为 NaN 的样本
2026-03-31 23:11:21 +08:00

426 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# %% md
# # TabPFN 回归训练流程
#
# 使用 TabPFN (Prior-Data Fitted Network) 进行回归预测。
# TabPFN 通过上下文学习进行预测,无需传统梯度下降训练过程。
# %% md
# ## 1. 导入依赖
# %%
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
NullFiller,
Winsorizer,
StandardScaler,
CrossSectionalStandardScaler,
)
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.training.components.models import TabPFNModel
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
)
# 训练类型标识
TRAINING_TYPE = "tabpfn"
# %% md
# ## 2. 训练特定配置
# %%
# Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 排除的因子列表(与 regression.py 保持一致)
EXCLUDED_FACTORS = [
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
"GTJA_alpha011",
"GTJA_alpha012",
"GTJA_alpha013",
"GTJA_alpha014",
"GTJA_alpha015",
"GTJA_alpha016",
"GTJA_alpha017",
"GTJA_alpha018",
"GTJA_alpha019",
"GTJA_alpha020",
"GTJA_alpha022",
"GTJA_alpha023",
"GTJA_alpha024",
"GTJA_alpha025",
"GTJA_alpha026",
"GTJA_alpha027",
"GTJA_alpha028",
"GTJA_alpha029",
"GTJA_alpha031",
"GTJA_alpha032",
"GTJA_alpha033",
"GTJA_alpha034",
"GTJA_alpha035",
"GTJA_alpha036",
"GTJA_alpha037",
# "GTJA_alpha038",
"GTJA_alpha039",
"GTJA_alpha040",
"GTJA_alpha041",
"GTJA_alpha042",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha045",
"GTJA_alpha046",
"GTJA_alpha047",
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
# "GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
# "GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
# "GTJA_alpha138",
"GTJA_alpha139",
# "GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
# "GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
"GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
# "GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
# "GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
"chip_dispersion_90",
"chip_dispersion_70",
"cost_skewness",
"dispersion_change_20",
"price_to_avg_cost",
"price_to_median_cost",
"mean_median_dev",
"trap_pressure",
"bottom_profit",
"history_position",
"winner_rate_surge_5",
"winner_rate_cs_rank",
"winner_rate_dev_20",
"winner_rate_volatility",
"smart_money_accumulation",
"winner_vol_corr_20",
"cost_base_momentum",
"bottom_cost_stability",
"pivot_reversion",
"chip_transition",
]
# 模型参数配置
MODEL_PARAMS = {
# ==================== 设备配置 ====================
"device": "cuda", # 计算设备: "cuda" 或 "cpu"(默认 cuda
# ==================== 上下文限制 ====================
"max_context_size": 100, # 16GB GPU 建议 1000-300032GB 可尝试 5000-8000
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabpfn_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
# %% md
# ## 3. 自定义 TabPFN 任务
# %%
from src.training.tasks import RegressionTask
class TabPFNTask(RegressionTask):
"""TabPFN 回归任务
继承自 RegressionTask但使用 TabPFNModel 作为模型。
TabPFN 不需要传统的训练过程,而是通过上下文学习进行预测。
"""
def __init__(self, model_params: dict, label_name: str):
"""初始化 TabPFN 任务
Args:
model_params: TabPFN 参数字典
label_name: Label 列名称
"""
# 不调用父类 __init__直接初始化以避免创建 LightGBMModel
from src.training.tasks.base import BaseTask
BaseTask.__init__(self, model_params, label_name)
self.evals_result: dict | None = None
self.model = TabPFNModel(params=model_params)
def fit(self, train_data: dict, val_data: dict) -> None:
"""训练 TabPFN 模型
TabPFN 通过将训练数据加载到模型上下文中进行"训练"
不需要传统的梯度下降优化过程。
Args:
train_data: 训练数据 {"X": DataFrame, "y": Series}
val_data: 验证数据,用于评估但不参与训练
"""
X_train = train_data["X"]
y_train = train_data["y"]
X_val = val_data.get("X")
y_val = val_data.get("y")
# TabPFN 使用 eval_set 进行验证
self.model.fit(
X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None
)
def get_model(self) -> TabPFNModel:
"""获取训练好的模型实例"""
return self.model
# %% md
# ## 4. 主函数
# %%
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabPFN 回归模型训练")
print("=" * 80)
print("\n[说明] TabPFN 使用上下文学习(In-Context Learning)")
print(" 训练过程实际是加载数据到模型上下文。")
print(" 如果训练数据超过上下文限制,会自动截取最近的数据。")
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
# (CrossSectionalStandardScaler, {}),
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
# (StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
)
# 4. 创建 TabPFNTask
print("\n[4] 创建 TabPFNTask")
task = TabPFNTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[7] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
# 8. 输出 TabPFN 特有指标
print("\n" + "=" * 80)
print("TabPFN 训练完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabpfn_output.csv')}")
# 显示验证集评估结果(如果可用)
model = task.get_model()
best_score = model.get_best_score()
if best_score:
print("\n[验证集评估指标]")
for metric, value in best_score.get("valid_0", {}).items():
print(f" - {metric}: {value:.6f}")
print("=" * 80)
return results
if __name__ == "__main__":
main()