feat(training): 新增 TabM 模型支持及数据质量优化

- 添加 TabMModel、TabPFNModel 深度学习模型实现
- 新增 DataQualityAnalyzer 进行训练前数据质量诊断
- 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性
- 支持 train_skip_days 参数跳过训练初期数据不足期
- Pipeline 自动清理标签为 NaN 的样本
This commit is contained in:
2026-03-31 23:11:21 +08:00
parent 9e0114c745
commit 36a3ccbcc8
22 changed files with 4421 additions and 204 deletions

View File

@@ -14,7 +14,7 @@ from src.factors import FactorEngine
# =============================================================================
# 日期范围配置(正确的 train/val/test 三分法)
# =============================================================================
TRAIN_START = "20200101"
TRAIN_START = "20190101"
TRAIN_END = "20231231"
VAL_START = "20240101"
VAL_END = "20241231"
@@ -79,8 +79,8 @@ SELECTED_FACTORS = [
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
# "GTJA_alpha005",
# "GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
@@ -123,133 +123,133 @@ SELECTED_FACTORS = [
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
# "GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
# "GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
# "GTJA_alpha138",
"GTJA_alpha139",
# "GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
# "GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
"GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
# "GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
# "GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
# "GTJA_alpha051",
# "GTJA_alpha052",
# "GTJA_alpha053",
# "GTJA_alpha054",
# "GTJA_alpha056",
# "GTJA_alpha057",
# "GTJA_alpha058",
# "GTJA_alpha059",
# "GTJA_alpha060",
# "GTJA_alpha061",
# "GTJA_alpha062",
# "GTJA_alpha063",
# "GTJA_alpha064",
# "GTJA_alpha065",
# "GTJA_alpha066",
# "GTJA_alpha067",
# "GTJA_alpha068",
# "GTJA_alpha070",
# "GTJA_alpha071",
# "GTJA_alpha072",
# "GTJA_alpha073",
# "GTJA_alpha074",
# "GTJA_alpha076",
# "GTJA_alpha077",
# "GTJA_alpha078",
# "GTJA_alpha079",
# "GTJA_alpha080",
# "GTJA_alpha081",
# "GTJA_alpha082",
# "GTJA_alpha083",
# "GTJA_alpha084",
# "GTJA_alpha085",
# "GTJA_alpha086",
# "GTJA_alpha087",
# "GTJA_alpha088",
# "GTJA_alpha089",
# "GTJA_alpha090",
# "GTJA_alpha091",
# "GTJA_alpha092",
# "GTJA_alpha093",
# "GTJA_alpha094",
# "GTJA_alpha095",
# "GTJA_alpha096",
# "GTJA_alpha097",
# "GTJA_alpha098",
# "GTJA_alpha099",
# "GTJA_alpha100",
# "GTJA_alpha101",
# "GTJA_alpha102",
# "GTJA_alpha103",
# "GTJA_alpha104",
# "GTJA_alpha105",
# "GTJA_alpha106",
# "GTJA_alpha107",
# "GTJA_alpha108",
# "GTJA_alpha109",
# "GTJA_alpha110",
# "GTJA_alpha111",
# "GTJA_alpha112",
# # "GTJA_alpha113",
# "GTJA_alpha114",
# "GTJA_alpha115",
# "GTJA_alpha117",
# "GTJA_alpha118",
# "GTJA_alpha119",
# "GTJA_alpha120",
# # "GTJA_alpha121",
# "GTJA_alpha122",
# "GTJA_alpha123",
# "GTJA_alpha124",
# "GTJA_alpha125",
# "GTJA_alpha126",
# "GTJA_alpha127",
# "GTJA_alpha128",
# "GTJA_alpha129",
# "GTJA_alpha130",
# "GTJA_alpha131",
# "GTJA_alpha132",
# "GTJA_alpha133",
# "GTJA_alpha134",
# "GTJA_alpha135",
# "GTJA_alpha136",
# # "GTJA_alpha138",
# "GTJA_alpha139",
# # "GTJA_alpha140",
# "GTJA_alpha141",
# "GTJA_alpha142",
# "GTJA_alpha145",
# # "GTJA_alpha146",
# "GTJA_alpha148",
# "GTJA_alpha150",
# "GTJA_alpha151",
# "GTJA_alpha152",
# "GTJA_alpha153",
# "GTJA_alpha154",
# "GTJA_alpha155",
# "GTJA_alpha156",
# "GTJA_alpha157",
# "GTJA_alpha158",
# "GTJA_alpha159",
# "GTJA_alpha160",
# "GTJA_alpha161",
# # "GTJA_alpha162",
# "GTJA_alpha163",
# "GTJA_alpha164",
# # "GTJA_alpha165",
# "GTJA_alpha166",
# "GTJA_alpha167",
# "GTJA_alpha168",
# "GTJA_alpha169",
# "GTJA_alpha170",
# "GTJA_alpha171",
# "GTJA_alpha173",
# "GTJA_alpha174",
# "GTJA_alpha175",
# "GTJA_alpha176",
# "GTJA_alpha177",
# "GTJA_alpha178",
# "GTJA_alpha179",
# "GTJA_alpha180",
# # "GTJA_alpha183",
# "GTJA_alpha184",
# "GTJA_alpha185",
# "GTJA_alpha187",
# "GTJA_alpha188",
# "GTJA_alpha189",
# "GTJA_alpha191",
"chip_dispersion_90",
"chip_dispersion_70",
"cost_skewness",
@@ -488,6 +488,9 @@ MODEL_SAVE_DIR = "models" # 模型保存目录
# Top N 配置:每日推荐股票数量
TOP_N = 5 # 可调整为 10, 20 等
# 训练数据跳过天数配置
TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据避免训练初期数据不足的问题
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
"""生成输出文件路径。

View File

@@ -0,0 +1,514 @@
"""数据质量分析模块
提供数据质量检查功能,包括:
- 缺失值统计
- 零值统计
- 按日期检查全空列
"""
from typing import Any, Dict, List, Optional
import polars as pl
import numpy as np
class DataQualityAnalyzer:
"""数据质量分析器
用于分析训练数据的质量问题,帮助识别数据异常。
Attributes:
feature_cols: 特征列名列表
label_col: 标签列名
verbose: 是否打印详细信息
"""
def __init__(
self,
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
verbose: bool = True,
):
"""初始化数据质量分析器
Args:
feature_cols: 特征列名列表
label_col: 标签列名
verbose: 是否打印详细信息
"""
self.feature_cols = feature_cols or []
self.label_col = label_col
self.verbose = verbose
self.analysis_results: Dict[str, Any] = {}
def set_columns(self, feature_cols: List[str], label_col: str) -> None:
"""设置要分析的列
Args:
feature_cols: 特征列名列表
label_col: 标签列名
"""
self.feature_cols = feature_cols
self.label_col = label_col
def analyze(
self,
data: Dict[str, Dict[str, Any]],
split_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""执行完整的数据质量分析
Args:
data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}}
split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"]
Returns:
分析结果字典
"""
if not split_names:
split_names = ["train", "val", "test"]
if self.verbose:
print("\n" + "=" * 80)
print("数据质量分析报告")
print("=" * 80)
self.analysis_results = {}
for split_name in split_names:
if split_name not in data:
continue
split_data = data[split_name]
raw_df = split_data.get("raw_data")
if raw_df is None:
continue
if self.verbose:
print(f"\n[{split_name.upper()} 数据集]")
print("-" * 40)
split_results = self._analyze_split(raw_df, split_name)
self.analysis_results[split_name] = split_results
if self.verbose:
print("\n" + "=" * 80)
return self.analysis_results
def _analyze_split(
self,
df: pl.DataFrame,
split_name: str,
) -> Dict[str, Any]:
"""分析单个数据集划分
Args:
df: 数据框
split_name: 划分名称
Returns:
分析结果字典
"""
results = {
"total_rows": len(df),
"total_cols": len(df.columns),
"feature_cols": self.feature_cols,
"label_col": self.label_col,
"null_analysis": {},
"zero_analysis": {},
"all_null_by_date": {},
}
# 1. 分析特征列的缺失值
null_stats = self._analyze_null_values(df, self.feature_cols)
results["null_analysis"] = null_stats
if self.verbose:
self._print_null_analysis(null_stats)
# 2. 分析特征列的零值
zero_stats = self._analyze_zero_values(df, self.feature_cols)
results["zero_analysis"] = zero_stats
if self.verbose:
self._print_zero_analysis(zero_stats)
# 3. 检查是否存在某天某列全为空的情况
all_null_by_date = self._check_all_null_by_date(df, self.feature_cols)
results["all_null_by_date"] = all_null_by_date
if self.verbose:
self._print_all_null_by_date(all_null_by_date)
# 4. 分析标签列
if self.label_col and self.label_col in df.columns:
label_stats = self._analyze_label(df, self.label_col)
results["label_analysis"] = label_stats
if self.verbose:
self._print_label_analysis(label_stats)
return results
def _analyze_null_values(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""分析缺失值
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
缺失值统计字典
"""
stats = {
"total_cells": len(df) * len(cols),
"null_counts": {},
"null_percentages": {},
"columns_with_null": [],
"total_null_cells": 0,
}
for col in cols:
if col not in df.columns:
continue
null_count = df[col].null_count()
if null_count > 0:
null_pct = null_count / len(df) * 100
stats["null_counts"][col] = null_count
stats["null_percentages"][col] = null_pct
stats["columns_with_null"].append(col)
stats["total_null_cells"] += null_count
return stats
def _analyze_zero_values(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""分析零值
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
零值统计字典
"""
stats = {
"total_cells": len(df) * len(cols),
"zero_counts": {},
"zero_percentages": {},
"columns_with_zero": [],
"total_zero_cells": 0,
}
for col in cols:
if col not in df.columns:
continue
# 计算零值数量(排除空值)
non_null_series = df[col].drop_nulls()
if len(non_null_series) == 0:
continue
zero_count = (non_null_series == 0).sum()
if zero_count > 0:
zero_pct = zero_count / len(df) * 100
stats["zero_counts"][col] = int(zero_count)
stats["zero_percentages"][col] = zero_pct
stats["columns_with_zero"].append(col)
stats["total_zero_cells"] += int(zero_count)
return stats
def _check_all_null_by_date(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""检查是否存在某天某列全为空的情况
使用 polars lazy frame 进行内存安全的高效计算。
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
全空检查结果字典
"""
results = {
"issues_found": False,
"issues": [],
}
if "trade_date" not in df.columns:
return results
# 过滤掉不在表中的列
valid_cols = [c for c in cols if c in df.columns]
if not valid_cols:
return results
# 使用 lazy frame 进行查询优化
lf = df.lazy()
# 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小)
# 为每个列创建单独的 null_count 聚合表达式
agg_exprs = [
pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols
]
agg_exprs.append(pl.len().alias("total_rows"))
agg_lf = lf.group_by("trade_date").agg(agg_exprs)
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
agg_df = agg_lf.collect()
# 在这个已经"脱水"的小表上进行逻辑检查
issues = []
for col in valid_cols:
null_col = f"{col}_nulls"
# 找出 null 数量等于总行数的日期
bad_dates = agg_df.filter(
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
).select(["trade_date", "total_rows"])
if not bad_dates.is_empty():
for row in bad_dates.to_dicts():
issues.append(
{
"date": row["trade_date"],
"column": col,
"total_rows": row["total_rows"],
}
)
if issues:
results["issues_found"] = True
results["issues"] = issues
return results
def _analyze_label(
self,
df: pl.DataFrame,
label_col: str,
) -> Dict[str, Any]:
"""分析标签列
Args:
df: 数据框
label_col: 标签列名
Returns:
标签分析字典
"""
stats = {
"total_count": len(df),
"null_count": 0,
"null_percentage": 0.0,
"zero_count": 0,
"zero_percentage": 0.0,
"min": None,
"max": None,
"mean": None,
"std": None,
}
if label_col not in df.columns:
return stats
series = df[label_col]
# 缺失值统计
null_count = series.null_count()
stats["null_count"] = null_count
stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0
# 零值统计
non_null_series = series.drop_nulls()
if len(non_null_series) > 0:
zero_count = (non_null_series == 0).sum()
stats["zero_count"] = int(zero_count)
stats["zero_percentage"] = zero_count / len(df) * 100
# 基本统计量
stats["min"] = float(non_null_series.min())
stats["max"] = float(non_null_series.max())
stats["mean"] = float(non_null_series.mean())
stats["std"] = float(non_null_series.std())
return stats
def _print_null_analysis(self, stats: Dict[str, Any]) -> None:
"""打印缺失值分析结果
Args:
stats: 缺失值统计字典
"""
total_cells = stats["total_cells"]
total_null = stats["total_null_cells"]
null_cols = stats["columns_with_null"]
print(f" 缺失值统计:")
print(f" 总单元格数: {total_cells:,}")
print(
f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)"
)
print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}")
if null_cols:
print(f" 缺失值最多的5个特征:")
sorted_cols = sorted(
stats["null_counts"].items(),
key=lambda x: x[1],
reverse=True,
)[:5]
for col, count in sorted_cols:
pct = stats["null_percentages"][col]
print(f" {col}: {count:,} ({pct:.2f}%)")
def _print_zero_analysis(self, stats: Dict[str, Any]) -> None:
"""打印零值分析结果
Args:
stats: 零值统计字典
"""
total_cells = stats["total_cells"]
total_zero = stats["total_zero_cells"]
zero_cols = stats["columns_with_zero"]
print(f" 零值统计:")
print(f" 总单元格数: {total_cells:,}")
print(
f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)"
)
print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}")
if zero_cols:
print(f" 零值最多的5个特征:")
sorted_cols = sorted(
stats["zero_counts"].items(),
key=lambda x: x[1],
reverse=True,
)[:5]
for col, count in sorted_cols:
pct = stats["zero_percentages"][col]
print(f" {col}: {count:,} ({pct:.2f}%)")
def _print_all_null_by_date(self, results: Dict[str, Any]) -> None:
"""打印按日期全空检查结果
Args:
results: 全空检查结果字典
"""
issues = results["issues"]
print(f" 按日期全空检查:")
if results["issues_found"]:
print(f" [警告] 发现 {len(issues)} 个问题:")
# 按日期分组显示
by_date = {}
for issue in issues:
date = issue["date"]
if date not in by_date:
by_date[date] = []
by_date[date].append(issue["column"])
for date in sorted(by_date.keys())[:5]: # 只显示前5个日期
cols = by_date[date]
print(f" 日期 {date}: {len(cols)} 列全为空")
if len(cols) <= 3:
print(f" 列名: {', '.join(cols)}")
if len(by_date) > 5:
print(f" ... 还有 {len(by_date) - 5} 个日期存在问题")
else:
print(f" [正常] 未发现某天某列全为空的情况")
def _print_label_analysis(self, stats: Dict[str, Any]) -> None:
"""打印标签分析结果
Args:
stats: 标签分析字典
"""
print(f" 标签列统计 ({self.label_col}):")
print(f" 总数: {stats['total_count']:,}")
print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)")
print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)")
if stats["mean"] is not None:
print(f" 最小值: {stats['min']:.6f}")
print(f" 最大值: {stats['max']:.6f}")
print(f" 均值: {stats['mean']:.6f}")
print(f" 标准差: {stats['std']:.6f}")
def get_summary(self) -> str:
"""获取分析结果摘要
Returns:
摘要字符串
"""
if not self.analysis_results:
return "尚未执行分析"
lines = ["数据质量分析摘要", "=" * 40]
for split_name, results in self.analysis_results.items():
lines.append(f"\n[{split_name.upper()}]")
lines.append(f" 总行数: {results['total_rows']:,}")
null_stats = results.get("null_analysis", {})
if null_stats.get("columns_with_null"):
lines.append(
f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, "
f"{len(null_stats['columns_with_null'])} 列受影响"
)
zero_stats = results.get("zero_analysis", {})
if zero_stats.get("columns_with_zero"):
lines.append(
f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, "
f"{len(zero_stats['columns_with_zero'])} 列受影响"
)
all_null = results.get("all_null_by_date", {})
if all_null.get("issues_found"):
lines.append(
f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题"
)
return "\n".join(lines)
def analyze_data_quality(
data: Dict[str, Dict[str, Any]],
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
verbose: bool = True,
) -> Dict[str, Any]:
"""便捷函数:执行数据质量分析
Args:
data: 数据字典
feature_cols: 特征列名列表
label_col: 标签列名
verbose: 是否打印详细信息
Returns:
分析结果字典
"""
analyzer = DataQualityAnalyzer(
feature_cols=feature_cols,
label_col=label_col,
verbose=verbose,
)
return analyzer.analyze(data)

View File

@@ -37,6 +37,7 @@ from src.experiment.common import (
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
@@ -155,6 +156,7 @@ def main():
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 RankTask

View File

@@ -38,6 +38,7 @@ from src.experiment.common import (
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
@@ -51,55 +52,55 @@ TRAINING_TYPE = "regression"
# 排除的因子列表
EXCLUDED_FACTORS = [
'GTJA_alpha036',
'GTJA_alpha032',
'GTJA_alpha010',
'GTJA_alpha005',
'CP',
'BP',
'debt_to_equity',
'current_ratio',
'GTJA_alpha002',
'GTJA_alpha027',
'GTJA_alpha064',
'GTJA_alpha062',
'GTJA_alpha043',
'GTJA_alpha044',
'GTJA_alpha120',
'GTJA_alpha117',
'GTJA_alpha103',
'GTJA_alpha104',
'GTJA_alpha105',
'GTJA_alpha073',
'GTJA_alpha077',
'GTJA_alpha085',
'GTJA_alpha090',
'GTJA_alpha087',
'GTJA_alpha083',
'GTJA_alpha092',
'GTJA_alpha133',
'GTJA_alpha131',
'GTJA_alpha126',
'GTJA_alpha124',
'GTJA_alpha162',
'GTJA_alpha164',
'GTJA_alpha157',
'GTJA_alpha177',
'price_to_avg_cost',
'cost_skewness',
'GTJA_alpha191',
'GTJA_alpha180',
'history_position',
'bottom_profit',
'mean_median_dev',
'smart_money_accumulation',
'GTJA_alpha013',
'GTJA_alpha099',
'GTJA_alpha107',
'GTJA_alpha119',
'GTJA_alpha141',
'GTJA_alpha130',
'GTJA_alpha173',
"GTJA_alpha036",
"GTJA_alpha032",
"GTJA_alpha010",
"GTJA_alpha005",
"CP",
"BP",
"debt_to_equity",
"current_ratio",
"GTJA_alpha002",
"GTJA_alpha027",
"GTJA_alpha064",
"GTJA_alpha062",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha120",
"GTJA_alpha117",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha073",
"GTJA_alpha077",
"GTJA_alpha085",
"GTJA_alpha090",
"GTJA_alpha087",
"GTJA_alpha083",
"GTJA_alpha092",
"GTJA_alpha133",
"GTJA_alpha131",
"GTJA_alpha126",
"GTJA_alpha124",
"GTJA_alpha162",
"GTJA_alpha164",
"GTJA_alpha157",
"GTJA_alpha177",
"price_to_avg_cost",
"cost_skewness",
"GTJA_alpha191",
"GTJA_alpha180",
"history_position",
"bottom_profit",
"mean_median_dev",
"smart_money_accumulation",
"GTJA_alpha013",
"GTJA_alpha099",
"GTJA_alpha107",
"GTJA_alpha119",
"GTJA_alpha141",
"GTJA_alpha130",
"GTJA_alpha173",
]
# 模型参数配置
@@ -184,6 +185,7 @@ def main():
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 RegressionTask

View File

@@ -0,0 +1,375 @@
# %% md
# # TabM 回归训练流程
#
# 使用 TabM (Tabular MLP with Ensembles) 模型进行回归训练。
# TabM 通过内置集成机制ensemble_size=32实现高效的多模型集成。
# %% md
# ## 1. 导入依赖
# %%
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
TabMRegressionTask,
NullFiller,
Winsorizer,
StandardScaler,
CrossSectionalStandardScaler,
)
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
# 训练类型标识
TRAINING_TYPE = "tabm_regression"
# %% md
# ## 2. 训练特定配置
# %%
# Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 排除的因子列表(与 LightGBM 回归保持一致)
EXCLUDED_FACTORS = [
# "GTJA_alpha001",
# "GTJA_alpha002",
# "GTJA_alpha003",
# "GTJA_alpha004",
# "GTJA_alpha005",
# "GTJA_alpha006",
# "GTJA_alpha007",
# "GTJA_alpha008",
# "GTJA_alpha009",
# "GTJA_alpha010",
# "GTJA_alpha011",
# "GTJA_alpha012",
# "GTJA_alpha013",
# "GTJA_alpha014",
# "GTJA_alpha015",
# "GTJA_alpha016",
# "GTJA_alpha017",
# "GTJA_alpha018",
# "GTJA_alpha019",
# "GTJA_alpha020",
# "GTJA_alpha022",
# "GTJA_alpha023",
# "GTJA_alpha024",
# "GTJA_alpha025",
# "GTJA_alpha026",
# "GTJA_alpha027",
# "GTJA_alpha028",
# "GTJA_alpha029",
# "GTJA_alpha031",
# "GTJA_alpha032",
# "GTJA_alpha033",
# "GTJA_alpha034",
# "GTJA_alpha035",
# "GTJA_alpha036",
# "GTJA_alpha037",
# # "GTJA_alpha038",
# "GTJA_alpha039",
# "GTJA_alpha040",
# "GTJA_alpha041",
# "GTJA_alpha042",
# "GTJA_alpha043",
# "GTJA_alpha044",
# "GTJA_alpha045",
# "GTJA_alpha046",
# "GTJA_alpha047",
# "GTJA_alpha048",
# "GTJA_alpha049",
# "GTJA_alpha050",
# "GTJA_alpha051",
# "GTJA_alpha052",
# "GTJA_alpha053",
# "GTJA_alpha054",
# "GTJA_alpha056",
# "GTJA_alpha057",
# "GTJA_alpha058",
# "GTJA_alpha059",
# "GTJA_alpha060",
# "GTJA_alpha061",
# "GTJA_alpha062",
# "GTJA_alpha063",
# "GTJA_alpha064",
# "GTJA_alpha065",
# "GTJA_alpha066",
# "GTJA_alpha067",
# "GTJA_alpha068",
# "GTJA_alpha070",
# "GTJA_alpha071",
# "GTJA_alpha072",
# "GTJA_alpha073",
# "GTJA_alpha074",
# "GTJA_alpha076",
# "GTJA_alpha077",
# "GTJA_alpha078",
# "GTJA_alpha079",
# "GTJA_alpha080",
# "GTJA_alpha081",
# "GTJA_alpha082",
# "GTJA_alpha083",
# "GTJA_alpha084",
# "GTJA_alpha085",
# "GTJA_alpha086",
# "GTJA_alpha087",
# "GTJA_alpha088",
# "GTJA_alpha089",
# "GTJA_alpha090",
# "GTJA_alpha091",
# "GTJA_alpha092",
# "GTJA_alpha093",
# "GTJA_alpha094",
# "GTJA_alpha095",
# "GTJA_alpha096",
# "GTJA_alpha097",
# "GTJA_alpha098",
# "GTJA_alpha099",
# "GTJA_alpha100",
# "GTJA_alpha101",
# "GTJA_alpha102",
# "GTJA_alpha103",
# "GTJA_alpha104",
# "GTJA_alpha105",
# "GTJA_alpha106",
# "GTJA_alpha107",
# "GTJA_alpha108",
# "GTJA_alpha109",
# "GTJA_alpha110",
# "GTJA_alpha111",
# "GTJA_alpha112",
# # "GTJA_alpha113",
# "GTJA_alpha114",
# "GTJA_alpha115",
# "GTJA_alpha117",
# "GTJA_alpha118",
# "GTJA_alpha119",
# "GTJA_alpha120",
# # "GTJA_alpha121",
# "GTJA_alpha122",
# "GTJA_alpha123",
# "GTJA_alpha124",
# "GTJA_alpha125",
# "GTJA_alpha126",
# "GTJA_alpha127",
# "GTJA_alpha128",
# "GTJA_alpha129",
# "GTJA_alpha130",
# "GTJA_alpha131",
# "GTJA_alpha132",
# "GTJA_alpha133",
# "GTJA_alpha134",
# "GTJA_alpha135",
# "GTJA_alpha136",
# # "GTJA_alpha138",
# "GTJA_alpha139",
# # "GTJA_alpha140",
# "GTJA_alpha141",
# "GTJA_alpha142",
# "GTJA_alpha145",
# # "GTJA_alpha146",
# "GTJA_alpha148",
# "GTJA_alpha150",
# "GTJA_alpha151",
# "GTJA_alpha152",
# "GTJA_alpha153",
# "GTJA_alpha154",
# "GTJA_alpha155",
# "GTJA_alpha156",
# "GTJA_alpha157",
# "GTJA_alpha158",
# "GTJA_alpha159",
# "GTJA_alpha160",
# "GTJA_alpha161",
# "GTJA_alpha162",
# "GTJA_alpha163",
# "GTJA_alpha164",
# # "GTJA_alpha165",
# "GTJA_alpha166",
# "GTJA_alpha167",
# "GTJA_alpha168",
# "GTJA_alpha169",
# "GTJA_alpha170",
# "GTJA_alpha171",
# "GTJA_alpha173",
# "GTJA_alpha174",
# "GTJA_alpha175",
# "GTJA_alpha176",
# "GTJA_alpha177",
# "GTJA_alpha178",
# "GTJA_alpha179",
# "GTJA_alpha180",
# # "GTJA_alpha183",
# "GTJA_alpha184",
# "GTJA_alpha185",
# "GTJA_alpha187",
# "GTJA_alpha188",
# "GTJA_alpha189",
# "GTJA_alpha191",
# "chip_dispersion_90",
# "chip_dispersion_70",
# "cost_skewness",
# "dispersion_change_20",
# "price_to_avg_cost",
# "price_to_median_cost",
# "mean_median_dev",
# "trap_pressure",
# "bottom_profit",
# "history_position",
# "winner_rate_surge_5",
# "winner_rate_cs_rank",
# "winner_rate_dev_20",
# "winner_rate_volatility",
# "smart_money_accumulation",
# "winner_vol_corr_20",
# "cost_base_momentum",
# "bottom_cost_stability",
# "pivot_reversion",
# "chip_transition",
]
# TabM 模型参数配置(来自用户提供的示例代码)
MODEL_PARAMS = {
# ==================== MLP 结构 ====================
"n_blocks": 3, # MLP 层数
"d_block": 256, # 每层神经元数
"dropout": 0.3, # Dropout 率
# ==================== 集成机制 ====================
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
# ==================== 训练参数 ====================
"batch_size": 1024, # 批次大小
"learning_rate": 1e-3, # 学习率
"weight_decay": 1e-5, # 权重衰减
"epochs": 100, # 训练轮数
# ==================== 早停 ====================
"early_stopping_patience": 30, # 早停耐心值
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabm_regression_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabM 回归模型训练")
print("=" * 80)
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
# 【关键】TabM 需要标准化输入,使用 StandardScaler
# 处理顺序NullFiller -> Winsorizer -> StandardScaler
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布
(StandardScaler, {}), # TabM 需要标准化输入
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 TabMRegressionTask
print("\n[4] 创建 TabMRegressionTask")
task = TabMRegressionTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 绘制训练曲线
print("\n[7] 绘制训练曲线")
task.plot_training_metrics(
output_path=os.path.join(OUTPUT_DIR, "tabm_training_curve.png")
)
# 8. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[8] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
print("\n" + "=" * 80)
print("TabM 训练流程完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_regression_output.csv')}")
print("=" * 80)
return results
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,425 @@
# %% md
# # TabPFN 回归训练流程
#
# 使用 TabPFN (Prior-Data Fitted Network) 进行回归预测。
# TabPFN 通过上下文学习进行预测,无需传统梯度下降训练过程。
# %% md
# ## 1. 导入依赖
# %%
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
NullFiller,
Winsorizer,
StandardScaler,
CrossSectionalStandardScaler,
)
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.training.components.models import TabPFNModel
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
)
# 训练类型标识
TRAINING_TYPE = "tabpfn"
# %% md
# ## 2. 训练特定配置
# %%
# Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 排除的因子列表(与 regression.py 保持一致)
EXCLUDED_FACTORS = [
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
"GTJA_alpha011",
"GTJA_alpha012",
"GTJA_alpha013",
"GTJA_alpha014",
"GTJA_alpha015",
"GTJA_alpha016",
"GTJA_alpha017",
"GTJA_alpha018",
"GTJA_alpha019",
"GTJA_alpha020",
"GTJA_alpha022",
"GTJA_alpha023",
"GTJA_alpha024",
"GTJA_alpha025",
"GTJA_alpha026",
"GTJA_alpha027",
"GTJA_alpha028",
"GTJA_alpha029",
"GTJA_alpha031",
"GTJA_alpha032",
"GTJA_alpha033",
"GTJA_alpha034",
"GTJA_alpha035",
"GTJA_alpha036",
"GTJA_alpha037",
# "GTJA_alpha038",
"GTJA_alpha039",
"GTJA_alpha040",
"GTJA_alpha041",
"GTJA_alpha042",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha045",
"GTJA_alpha046",
"GTJA_alpha047",
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
# "GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
# "GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
# "GTJA_alpha138",
"GTJA_alpha139",
# "GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
# "GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
"GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
# "GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
# "GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
"chip_dispersion_90",
"chip_dispersion_70",
"cost_skewness",
"dispersion_change_20",
"price_to_avg_cost",
"price_to_median_cost",
"mean_median_dev",
"trap_pressure",
"bottom_profit",
"history_position",
"winner_rate_surge_5",
"winner_rate_cs_rank",
"winner_rate_dev_20",
"winner_rate_volatility",
"smart_money_accumulation",
"winner_vol_corr_20",
"cost_base_momentum",
"bottom_cost_stability",
"pivot_reversion",
"chip_transition",
]
# 模型参数配置
MODEL_PARAMS = {
# ==================== 设备配置 ====================
"device": "cuda", # 计算设备: "cuda" 或 "cpu"(默认 cuda
# ==================== 上下文限制 ====================
"max_context_size": 100, # 16GB GPU 建议 1000-300032GB 可尝试 5000-8000
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabpfn_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
# %% md
# ## 3. 自定义 TabPFN 任务
# %%
from src.training.tasks import RegressionTask
class TabPFNTask(RegressionTask):
"""TabPFN 回归任务
继承自 RegressionTask但使用 TabPFNModel 作为模型。
TabPFN 不需要传统的训练过程,而是通过上下文学习进行预测。
"""
def __init__(self, model_params: dict, label_name: str):
"""初始化 TabPFN 任务
Args:
model_params: TabPFN 参数字典
label_name: Label 列名称
"""
# 不调用父类 __init__直接初始化以避免创建 LightGBMModel
from src.training.tasks.base import BaseTask
BaseTask.__init__(self, model_params, label_name)
self.evals_result: dict | None = None
self.model = TabPFNModel(params=model_params)
def fit(self, train_data: dict, val_data: dict) -> None:
"""训练 TabPFN 模型
TabPFN 通过将训练数据加载到模型上下文中进行"训练"
不需要传统的梯度下降优化过程。
Args:
train_data: 训练数据 {"X": DataFrame, "y": Series}
val_data: 验证数据,用于评估但不参与训练
"""
X_train = train_data["X"]
y_train = train_data["y"]
X_val = val_data.get("X")
y_val = val_data.get("y")
# TabPFN 使用 eval_set 进行验证
self.model.fit(
X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None
)
def get_model(self) -> TabPFNModel:
"""获取训练好的模型实例"""
return self.model
# %% md
# ## 4. 主函数
# %%
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabPFN 回归模型训练")
print("=" * 80)
print("\n[说明] TabPFN 使用上下文学习(In-Context Learning)")
print(" 训练过程实际是加载数据到模型上下文。")
print(" 如果训练数据超过上下文限制,会自动截取最近的数据。")
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
# (CrossSectionalStandardScaler, {}),
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
# (StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
)
# 4. 创建 TabPFNTask
print("\n[4] 创建 TabPFNTask")
task = TabPFNTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[7] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
# 8. 输出 TabPFN 特有指标
print("\n" + "=" * 80)
print("TabPFN 训练完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabpfn_output.csv')}")
# 显示验证集评估结果(如果可用)
model = task.get_model()
best_score = model.get_best_score()
if best_score:
print("\n[验证集评估指标]")
for metric, value in best_score.get("valid_0", {}).items():
print(f" - {metric}: {value:.6f}")
print("=" * 80)
return results
if __name__ == "__main__":
main()