feat(training): 新增 TabM 模型支持及数据质量优化
- 添加 TabMModel、TabPFNModel 深度学习模型实现 - 新增 DataQualityAnalyzer 进行训练前数据质量诊断 - 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性 - 支持 train_skip_days 参数跳过训练初期数据不足期 - Pipeline 自动清理标签为 NaN 的样本
This commit is contained in:
@@ -14,7 +14,7 @@ from src.factors import FactorEngine
|
||||
# =============================================================================
|
||||
# 日期范围配置(正确的 train/val/test 三分法)
|
||||
# =============================================================================
|
||||
TRAIN_START = "20200101"
|
||||
TRAIN_START = "20190101"
|
||||
TRAIN_END = "20231231"
|
||||
VAL_START = "20240101"
|
||||
VAL_END = "20241231"
|
||||
@@ -79,8 +79,8 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha002",
|
||||
"GTJA_alpha003",
|
||||
"GTJA_alpha004",
|
||||
"GTJA_alpha005",
|
||||
"GTJA_alpha006",
|
||||
# "GTJA_alpha005",
|
||||
# "GTJA_alpha006",
|
||||
"GTJA_alpha007",
|
||||
"GTJA_alpha008",
|
||||
"GTJA_alpha009",
|
||||
@@ -123,133 +123,133 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha048",
|
||||
"GTJA_alpha049",
|
||||
"GTJA_alpha050",
|
||||
"GTJA_alpha051",
|
||||
"GTJA_alpha052",
|
||||
"GTJA_alpha053",
|
||||
"GTJA_alpha054",
|
||||
"GTJA_alpha056",
|
||||
"GTJA_alpha057",
|
||||
"GTJA_alpha058",
|
||||
"GTJA_alpha059",
|
||||
"GTJA_alpha060",
|
||||
"GTJA_alpha061",
|
||||
"GTJA_alpha062",
|
||||
"GTJA_alpha063",
|
||||
"GTJA_alpha064",
|
||||
"GTJA_alpha065",
|
||||
"GTJA_alpha066",
|
||||
"GTJA_alpha067",
|
||||
"GTJA_alpha068",
|
||||
"GTJA_alpha070",
|
||||
"GTJA_alpha071",
|
||||
"GTJA_alpha072",
|
||||
"GTJA_alpha073",
|
||||
"GTJA_alpha074",
|
||||
"GTJA_alpha076",
|
||||
"GTJA_alpha077",
|
||||
"GTJA_alpha078",
|
||||
"GTJA_alpha079",
|
||||
"GTJA_alpha080",
|
||||
"GTJA_alpha081",
|
||||
"GTJA_alpha082",
|
||||
"GTJA_alpha083",
|
||||
"GTJA_alpha084",
|
||||
"GTJA_alpha085",
|
||||
"GTJA_alpha086",
|
||||
"GTJA_alpha087",
|
||||
"GTJA_alpha088",
|
||||
"GTJA_alpha089",
|
||||
"GTJA_alpha090",
|
||||
"GTJA_alpha091",
|
||||
"GTJA_alpha092",
|
||||
"GTJA_alpha093",
|
||||
"GTJA_alpha094",
|
||||
"GTJA_alpha095",
|
||||
"GTJA_alpha096",
|
||||
"GTJA_alpha097",
|
||||
"GTJA_alpha098",
|
||||
"GTJA_alpha099",
|
||||
"GTJA_alpha100",
|
||||
"GTJA_alpha101",
|
||||
"GTJA_alpha102",
|
||||
"GTJA_alpha103",
|
||||
"GTJA_alpha104",
|
||||
"GTJA_alpha105",
|
||||
"GTJA_alpha106",
|
||||
"GTJA_alpha107",
|
||||
"GTJA_alpha108",
|
||||
"GTJA_alpha109",
|
||||
"GTJA_alpha110",
|
||||
"GTJA_alpha111",
|
||||
"GTJA_alpha112",
|
||||
# "GTJA_alpha113",
|
||||
"GTJA_alpha114",
|
||||
"GTJA_alpha115",
|
||||
"GTJA_alpha117",
|
||||
"GTJA_alpha118",
|
||||
"GTJA_alpha119",
|
||||
"GTJA_alpha120",
|
||||
# "GTJA_alpha121",
|
||||
"GTJA_alpha122",
|
||||
"GTJA_alpha123",
|
||||
"GTJA_alpha124",
|
||||
"GTJA_alpha125",
|
||||
"GTJA_alpha126",
|
||||
"GTJA_alpha127",
|
||||
"GTJA_alpha128",
|
||||
"GTJA_alpha129",
|
||||
"GTJA_alpha130",
|
||||
"GTJA_alpha131",
|
||||
"GTJA_alpha132",
|
||||
"GTJA_alpha133",
|
||||
"GTJA_alpha134",
|
||||
"GTJA_alpha135",
|
||||
"GTJA_alpha136",
|
||||
# "GTJA_alpha138",
|
||||
"GTJA_alpha139",
|
||||
# "GTJA_alpha140",
|
||||
"GTJA_alpha141",
|
||||
"GTJA_alpha142",
|
||||
"GTJA_alpha145",
|
||||
# "GTJA_alpha146",
|
||||
"GTJA_alpha148",
|
||||
"GTJA_alpha150",
|
||||
"GTJA_alpha151",
|
||||
"GTJA_alpha152",
|
||||
"GTJA_alpha153",
|
||||
"GTJA_alpha154",
|
||||
"GTJA_alpha155",
|
||||
"GTJA_alpha156",
|
||||
"GTJA_alpha157",
|
||||
"GTJA_alpha158",
|
||||
"GTJA_alpha159",
|
||||
"GTJA_alpha160",
|
||||
"GTJA_alpha161",
|
||||
"GTJA_alpha162",
|
||||
"GTJA_alpha163",
|
||||
"GTJA_alpha164",
|
||||
# "GTJA_alpha165",
|
||||
"GTJA_alpha166",
|
||||
"GTJA_alpha167",
|
||||
"GTJA_alpha168",
|
||||
"GTJA_alpha169",
|
||||
"GTJA_alpha170",
|
||||
"GTJA_alpha171",
|
||||
"GTJA_alpha173",
|
||||
"GTJA_alpha174",
|
||||
"GTJA_alpha175",
|
||||
"GTJA_alpha176",
|
||||
"GTJA_alpha177",
|
||||
"GTJA_alpha178",
|
||||
"GTJA_alpha179",
|
||||
"GTJA_alpha180",
|
||||
# "GTJA_alpha183",
|
||||
"GTJA_alpha184",
|
||||
"GTJA_alpha185",
|
||||
"GTJA_alpha187",
|
||||
"GTJA_alpha188",
|
||||
"GTJA_alpha189",
|
||||
"GTJA_alpha191",
|
||||
# "GTJA_alpha051",
|
||||
# "GTJA_alpha052",
|
||||
# "GTJA_alpha053",
|
||||
# "GTJA_alpha054",
|
||||
# "GTJA_alpha056",
|
||||
# "GTJA_alpha057",
|
||||
# "GTJA_alpha058",
|
||||
# "GTJA_alpha059",
|
||||
# "GTJA_alpha060",
|
||||
# "GTJA_alpha061",
|
||||
# "GTJA_alpha062",
|
||||
# "GTJA_alpha063",
|
||||
# "GTJA_alpha064",
|
||||
# "GTJA_alpha065",
|
||||
# "GTJA_alpha066",
|
||||
# "GTJA_alpha067",
|
||||
# "GTJA_alpha068",
|
||||
# "GTJA_alpha070",
|
||||
# "GTJA_alpha071",
|
||||
# "GTJA_alpha072",
|
||||
# "GTJA_alpha073",
|
||||
# "GTJA_alpha074",
|
||||
# "GTJA_alpha076",
|
||||
# "GTJA_alpha077",
|
||||
# "GTJA_alpha078",
|
||||
# "GTJA_alpha079",
|
||||
# "GTJA_alpha080",
|
||||
# "GTJA_alpha081",
|
||||
# "GTJA_alpha082",
|
||||
# "GTJA_alpha083",
|
||||
# "GTJA_alpha084",
|
||||
# "GTJA_alpha085",
|
||||
# "GTJA_alpha086",
|
||||
# "GTJA_alpha087",
|
||||
# "GTJA_alpha088",
|
||||
# "GTJA_alpha089",
|
||||
# "GTJA_alpha090",
|
||||
# "GTJA_alpha091",
|
||||
# "GTJA_alpha092",
|
||||
# "GTJA_alpha093",
|
||||
# "GTJA_alpha094",
|
||||
# "GTJA_alpha095",
|
||||
# "GTJA_alpha096",
|
||||
# "GTJA_alpha097",
|
||||
# "GTJA_alpha098",
|
||||
# "GTJA_alpha099",
|
||||
# "GTJA_alpha100",
|
||||
# "GTJA_alpha101",
|
||||
# "GTJA_alpha102",
|
||||
# "GTJA_alpha103",
|
||||
# "GTJA_alpha104",
|
||||
# "GTJA_alpha105",
|
||||
# "GTJA_alpha106",
|
||||
# "GTJA_alpha107",
|
||||
# "GTJA_alpha108",
|
||||
# "GTJA_alpha109",
|
||||
# "GTJA_alpha110",
|
||||
# "GTJA_alpha111",
|
||||
# "GTJA_alpha112",
|
||||
# # "GTJA_alpha113",
|
||||
# "GTJA_alpha114",
|
||||
# "GTJA_alpha115",
|
||||
# "GTJA_alpha117",
|
||||
# "GTJA_alpha118",
|
||||
# "GTJA_alpha119",
|
||||
# "GTJA_alpha120",
|
||||
# # "GTJA_alpha121",
|
||||
# "GTJA_alpha122",
|
||||
# "GTJA_alpha123",
|
||||
# "GTJA_alpha124",
|
||||
# "GTJA_alpha125",
|
||||
# "GTJA_alpha126",
|
||||
# "GTJA_alpha127",
|
||||
# "GTJA_alpha128",
|
||||
# "GTJA_alpha129",
|
||||
# "GTJA_alpha130",
|
||||
# "GTJA_alpha131",
|
||||
# "GTJA_alpha132",
|
||||
# "GTJA_alpha133",
|
||||
# "GTJA_alpha134",
|
||||
# "GTJA_alpha135",
|
||||
# "GTJA_alpha136",
|
||||
# # "GTJA_alpha138",
|
||||
# "GTJA_alpha139",
|
||||
# # "GTJA_alpha140",
|
||||
# "GTJA_alpha141",
|
||||
# "GTJA_alpha142",
|
||||
# "GTJA_alpha145",
|
||||
# # "GTJA_alpha146",
|
||||
# "GTJA_alpha148",
|
||||
# "GTJA_alpha150",
|
||||
# "GTJA_alpha151",
|
||||
# "GTJA_alpha152",
|
||||
# "GTJA_alpha153",
|
||||
# "GTJA_alpha154",
|
||||
# "GTJA_alpha155",
|
||||
# "GTJA_alpha156",
|
||||
# "GTJA_alpha157",
|
||||
# "GTJA_alpha158",
|
||||
# "GTJA_alpha159",
|
||||
# "GTJA_alpha160",
|
||||
# "GTJA_alpha161",
|
||||
# # "GTJA_alpha162",
|
||||
# "GTJA_alpha163",
|
||||
# "GTJA_alpha164",
|
||||
# # "GTJA_alpha165",
|
||||
# "GTJA_alpha166",
|
||||
# "GTJA_alpha167",
|
||||
# "GTJA_alpha168",
|
||||
# "GTJA_alpha169",
|
||||
# "GTJA_alpha170",
|
||||
# "GTJA_alpha171",
|
||||
# "GTJA_alpha173",
|
||||
# "GTJA_alpha174",
|
||||
# "GTJA_alpha175",
|
||||
# "GTJA_alpha176",
|
||||
# "GTJA_alpha177",
|
||||
# "GTJA_alpha178",
|
||||
# "GTJA_alpha179",
|
||||
# "GTJA_alpha180",
|
||||
# # "GTJA_alpha183",
|
||||
# "GTJA_alpha184",
|
||||
# "GTJA_alpha185",
|
||||
# "GTJA_alpha187",
|
||||
# "GTJA_alpha188",
|
||||
# "GTJA_alpha189",
|
||||
# "GTJA_alpha191",
|
||||
"chip_dispersion_90",
|
||||
"chip_dispersion_70",
|
||||
"cost_skewness",
|
||||
@@ -488,6 +488,9 @@ MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
|
||||
# 训练数据跳过天数配置
|
||||
TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据,避免训练初期数据不足的问题
|
||||
|
||||
|
||||
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
"""生成输出文件路径。
|
||||
|
||||
514
src/experiment/data_quality_analyzer.py
Normal file
514
src/experiment/data_quality_analyzer.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""数据质量分析模块
|
||||
|
||||
提供数据质量检查功能,包括:
|
||||
- 缺失值统计
|
||||
- 零值统计
|
||||
- 按日期检查全空列
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataQualityAnalyzer:
|
||||
"""数据质量分析器
|
||||
|
||||
用于分析训练数据的质量问题,帮助识别数据异常。
|
||||
|
||||
Attributes:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""初始化数据质量分析器
|
||||
|
||||
Args:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
self.feature_cols = feature_cols or []
|
||||
self.label_col = label_col
|
||||
self.verbose = verbose
|
||||
self.analysis_results: Dict[str, Any] = {}
|
||||
|
||||
def set_columns(self, feature_cols: List[str], label_col: str) -> None:
|
||||
"""设置要分析的列
|
||||
|
||||
Args:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
"""
|
||||
self.feature_cols = feature_cols
|
||||
self.label_col = label_col
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
split_names: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""执行完整的数据质量分析
|
||||
|
||||
Args:
|
||||
data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}}
|
||||
split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"]
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
if not split_names:
|
||||
split_names = ["train", "val", "test"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n" + "=" * 80)
|
||||
print("数据质量分析报告")
|
||||
print("=" * 80)
|
||||
|
||||
self.analysis_results = {}
|
||||
|
||||
for split_name in split_names:
|
||||
if split_name not in data:
|
||||
continue
|
||||
|
||||
split_data = data[split_name]
|
||||
raw_df = split_data.get("raw_data")
|
||||
|
||||
if raw_df is None:
|
||||
continue
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[{split_name.upper()} 数据集]")
|
||||
print("-" * 40)
|
||||
|
||||
split_results = self._analyze_split(raw_df, split_name)
|
||||
self.analysis_results[split_name] = split_results
|
||||
|
||||
if self.verbose:
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return self.analysis_results
|
||||
|
||||
def _analyze_split(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
split_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""分析单个数据集划分
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
split_name: 划分名称
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
results = {
|
||||
"total_rows": len(df),
|
||||
"total_cols": len(df.columns),
|
||||
"feature_cols": self.feature_cols,
|
||||
"label_col": self.label_col,
|
||||
"null_analysis": {},
|
||||
"zero_analysis": {},
|
||||
"all_null_by_date": {},
|
||||
}
|
||||
|
||||
# 1. 分析特征列的缺失值
|
||||
null_stats = self._analyze_null_values(df, self.feature_cols)
|
||||
results["null_analysis"] = null_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_null_analysis(null_stats)
|
||||
|
||||
# 2. 分析特征列的零值
|
||||
zero_stats = self._analyze_zero_values(df, self.feature_cols)
|
||||
results["zero_analysis"] = zero_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_zero_analysis(zero_stats)
|
||||
|
||||
# 3. 检查是否存在某天某列全为空的情况
|
||||
all_null_by_date = self._check_all_null_by_date(df, self.feature_cols)
|
||||
results["all_null_by_date"] = all_null_by_date
|
||||
|
||||
if self.verbose:
|
||||
self._print_all_null_by_date(all_null_by_date)
|
||||
|
||||
# 4. 分析标签列
|
||||
if self.label_col and self.label_col in df.columns:
|
||||
label_stats = self._analyze_label(df, self.label_col)
|
||||
results["label_analysis"] = label_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_label_analysis(label_stats)
|
||||
|
||||
return results
|
||||
|
||||
def _analyze_null_values(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""分析缺失值
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
缺失值统计字典
|
||||
"""
|
||||
stats = {
|
||||
"total_cells": len(df) * len(cols),
|
||||
"null_counts": {},
|
||||
"null_percentages": {},
|
||||
"columns_with_null": [],
|
||||
"total_null_cells": 0,
|
||||
}
|
||||
|
||||
for col in cols:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
|
||||
null_count = df[col].null_count()
|
||||
if null_count > 0:
|
||||
null_pct = null_count / len(df) * 100
|
||||
stats["null_counts"][col] = null_count
|
||||
stats["null_percentages"][col] = null_pct
|
||||
stats["columns_with_null"].append(col)
|
||||
stats["total_null_cells"] += null_count
|
||||
|
||||
return stats
|
||||
|
||||
def _analyze_zero_values(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""分析零值
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
零值统计字典
|
||||
"""
|
||||
stats = {
|
||||
"total_cells": len(df) * len(cols),
|
||||
"zero_counts": {},
|
||||
"zero_percentages": {},
|
||||
"columns_with_zero": [],
|
||||
"total_zero_cells": 0,
|
||||
}
|
||||
|
||||
for col in cols:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
|
||||
# 计算零值数量(排除空值)
|
||||
non_null_series = df[col].drop_nulls()
|
||||
if len(non_null_series) == 0:
|
||||
continue
|
||||
|
||||
zero_count = (non_null_series == 0).sum()
|
||||
if zero_count > 0:
|
||||
zero_pct = zero_count / len(df) * 100
|
||||
stats["zero_counts"][col] = int(zero_count)
|
||||
stats["zero_percentages"][col] = zero_pct
|
||||
stats["columns_with_zero"].append(col)
|
||||
stats["total_zero_cells"] += int(zero_count)
|
||||
|
||||
return stats
|
||||
|
||||
def _check_all_null_by_date(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""检查是否存在某天某列全为空的情况
|
||||
|
||||
使用 polars lazy frame 进行内存安全的高效计算。
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
全空检查结果字典
|
||||
"""
|
||||
results = {
|
||||
"issues_found": False,
|
||||
"issues": [],
|
||||
}
|
||||
|
||||
if "trade_date" not in df.columns:
|
||||
return results
|
||||
|
||||
# 过滤掉不在表中的列
|
||||
valid_cols = [c for c in cols if c in df.columns]
|
||||
if not valid_cols:
|
||||
return results
|
||||
|
||||
# 使用 lazy frame 进行查询优化
|
||||
lf = df.lazy()
|
||||
|
||||
# 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小)
|
||||
# 为每个列创建单独的 null_count 聚合表达式
|
||||
agg_exprs = [
|
||||
pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols
|
||||
]
|
||||
agg_exprs.append(pl.len().alias("total_rows"))
|
||||
|
||||
agg_lf = lf.group_by("trade_date").agg(agg_exprs)
|
||||
|
||||
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
|
||||
agg_df = agg_lf.collect()
|
||||
|
||||
# 在这个已经"脱水"的小表上进行逻辑检查
|
||||
issues = []
|
||||
for col in valid_cols:
|
||||
null_col = f"{col}_nulls"
|
||||
# 找出 null 数量等于总行数的日期
|
||||
bad_dates = agg_df.filter(
|
||||
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
|
||||
).select(["trade_date", "total_rows"])
|
||||
|
||||
if not bad_dates.is_empty():
|
||||
for row in bad_dates.to_dicts():
|
||||
issues.append(
|
||||
{
|
||||
"date": row["trade_date"],
|
||||
"column": col,
|
||||
"total_rows": row["total_rows"],
|
||||
}
|
||||
)
|
||||
|
||||
if issues:
|
||||
results["issues_found"] = True
|
||||
results["issues"] = issues
|
||||
|
||||
return results
|
||||
|
||||
def _analyze_label(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
label_col: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""分析标签列
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
label_col: 标签列名
|
||||
|
||||
Returns:
|
||||
标签分析字典
|
||||
"""
|
||||
stats = {
|
||||
"total_count": len(df),
|
||||
"null_count": 0,
|
||||
"null_percentage": 0.0,
|
||||
"zero_count": 0,
|
||||
"zero_percentage": 0.0,
|
||||
"min": None,
|
||||
"max": None,
|
||||
"mean": None,
|
||||
"std": None,
|
||||
}
|
||||
|
||||
if label_col not in df.columns:
|
||||
return stats
|
||||
|
||||
series = df[label_col]
|
||||
|
||||
# 缺失值统计
|
||||
null_count = series.null_count()
|
||||
stats["null_count"] = null_count
|
||||
stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0
|
||||
|
||||
# 零值统计
|
||||
non_null_series = series.drop_nulls()
|
||||
if len(non_null_series) > 0:
|
||||
zero_count = (non_null_series == 0).sum()
|
||||
stats["zero_count"] = int(zero_count)
|
||||
stats["zero_percentage"] = zero_count / len(df) * 100
|
||||
|
||||
# 基本统计量
|
||||
stats["min"] = float(non_null_series.min())
|
||||
stats["max"] = float(non_null_series.max())
|
||||
stats["mean"] = float(non_null_series.mean())
|
||||
stats["std"] = float(non_null_series.std())
|
||||
|
||||
return stats
|
||||
|
||||
def _print_null_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印缺失值分析结果
|
||||
|
||||
Args:
|
||||
stats: 缺失值统计字典
|
||||
"""
|
||||
total_cells = stats["total_cells"]
|
||||
total_null = stats["total_null_cells"]
|
||||
null_cols = stats["columns_with_null"]
|
||||
|
||||
print(f" 缺失值统计:")
|
||||
print(f" 总单元格数: {total_cells:,}")
|
||||
print(
|
||||
f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)"
|
||||
)
|
||||
print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}")
|
||||
|
||||
if null_cols:
|
||||
print(f" 缺失值最多的5个特征:")
|
||||
sorted_cols = sorted(
|
||||
stats["null_counts"].items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
for col, count in sorted_cols:
|
||||
pct = stats["null_percentages"][col]
|
||||
print(f" {col}: {count:,} ({pct:.2f}%)")
|
||||
|
||||
def _print_zero_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印零值分析结果
|
||||
|
||||
Args:
|
||||
stats: 零值统计字典
|
||||
"""
|
||||
total_cells = stats["total_cells"]
|
||||
total_zero = stats["total_zero_cells"]
|
||||
zero_cols = stats["columns_with_zero"]
|
||||
|
||||
print(f" 零值统计:")
|
||||
print(f" 总单元格数: {total_cells:,}")
|
||||
print(
|
||||
f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)"
|
||||
)
|
||||
print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}")
|
||||
|
||||
if zero_cols:
|
||||
print(f" 零值最多的5个特征:")
|
||||
sorted_cols = sorted(
|
||||
stats["zero_counts"].items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
for col, count in sorted_cols:
|
||||
pct = stats["zero_percentages"][col]
|
||||
print(f" {col}: {count:,} ({pct:.2f}%)")
|
||||
|
||||
def _print_all_null_by_date(self, results: Dict[str, Any]) -> None:
|
||||
"""打印按日期全空检查结果
|
||||
|
||||
Args:
|
||||
results: 全空检查结果字典
|
||||
"""
|
||||
issues = results["issues"]
|
||||
|
||||
print(f" 按日期全空检查:")
|
||||
if results["issues_found"]:
|
||||
print(f" [警告] 发现 {len(issues)} 个问题:")
|
||||
# 按日期分组显示
|
||||
by_date = {}
|
||||
for issue in issues:
|
||||
date = issue["date"]
|
||||
if date not in by_date:
|
||||
by_date[date] = []
|
||||
by_date[date].append(issue["column"])
|
||||
|
||||
for date in sorted(by_date.keys())[:5]: # 只显示前5个日期
|
||||
cols = by_date[date]
|
||||
print(f" 日期 {date}: {len(cols)} 列全为空")
|
||||
if len(cols) <= 3:
|
||||
print(f" 列名: {', '.join(cols)}")
|
||||
|
||||
if len(by_date) > 5:
|
||||
print(f" ... 还有 {len(by_date) - 5} 个日期存在问题")
|
||||
else:
|
||||
print(f" [正常] 未发现某天某列全为空的情况")
|
||||
|
||||
def _print_label_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印标签分析结果
|
||||
|
||||
Args:
|
||||
stats: 标签分析字典
|
||||
"""
|
||||
print(f" 标签列统计 ({self.label_col}):")
|
||||
print(f" 总数: {stats['total_count']:,}")
|
||||
print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)")
|
||||
print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)")
|
||||
|
||||
if stats["mean"] is not None:
|
||||
print(f" 最小值: {stats['min']:.6f}")
|
||||
print(f" 最大值: {stats['max']:.6f}")
|
||||
print(f" 均值: {stats['mean']:.6f}")
|
||||
print(f" 标准差: {stats['std']:.6f}")
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""获取分析结果摘要
|
||||
|
||||
Returns:
|
||||
摘要字符串
|
||||
"""
|
||||
if not self.analysis_results:
|
||||
return "尚未执行分析"
|
||||
|
||||
lines = ["数据质量分析摘要", "=" * 40]
|
||||
|
||||
for split_name, results in self.analysis_results.items():
|
||||
lines.append(f"\n[{split_name.upper()}]")
|
||||
lines.append(f" 总行数: {results['total_rows']:,}")
|
||||
|
||||
null_stats = results.get("null_analysis", {})
|
||||
if null_stats.get("columns_with_null"):
|
||||
lines.append(
|
||||
f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, "
|
||||
f"{len(null_stats['columns_with_null'])} 列受影响"
|
||||
)
|
||||
|
||||
zero_stats = results.get("zero_analysis", {})
|
||||
if zero_stats.get("columns_with_zero"):
|
||||
lines.append(
|
||||
f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, "
|
||||
f"{len(zero_stats['columns_with_zero'])} 列受影响"
|
||||
)
|
||||
|
||||
all_null = results.get("all_null_by_date", {})
|
||||
if all_null.get("issues_found"):
|
||||
lines.append(
|
||||
f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def analyze_data_quality(
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""便捷函数:执行数据质量分析
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
verbose: 是否打印详细信息
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
analyzer = DataQualityAnalyzer(
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
verbose=verbose,
|
||||
)
|
||||
return analyzer.analyze(data)
|
||||
@@ -37,6 +37,7 @@ from src.experiment.common import (
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
@@ -155,6 +156,7 @@ def main():
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 4. 创建 RankTask
|
||||
|
||||
@@ -38,6 +38,7 @@ from src.experiment.common import (
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
@@ -51,55 +52,55 @@ TRAINING_TYPE = "regression"
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
'GTJA_alpha036',
|
||||
'GTJA_alpha032',
|
||||
'GTJA_alpha010',
|
||||
'GTJA_alpha005',
|
||||
'CP',
|
||||
'BP',
|
||||
'debt_to_equity',
|
||||
'current_ratio',
|
||||
'GTJA_alpha002',
|
||||
'GTJA_alpha027',
|
||||
'GTJA_alpha064',
|
||||
'GTJA_alpha062',
|
||||
'GTJA_alpha043',
|
||||
'GTJA_alpha044',
|
||||
'GTJA_alpha120',
|
||||
'GTJA_alpha117',
|
||||
'GTJA_alpha103',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha105',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha077',
|
||||
'GTJA_alpha085',
|
||||
'GTJA_alpha090',
|
||||
'GTJA_alpha087',
|
||||
'GTJA_alpha083',
|
||||
'GTJA_alpha092',
|
||||
'GTJA_alpha133',
|
||||
'GTJA_alpha131',
|
||||
'GTJA_alpha126',
|
||||
'GTJA_alpha124',
|
||||
'GTJA_alpha162',
|
||||
'GTJA_alpha164',
|
||||
'GTJA_alpha157',
|
||||
'GTJA_alpha177',
|
||||
'price_to_avg_cost',
|
||||
'cost_skewness',
|
||||
'GTJA_alpha191',
|
||||
'GTJA_alpha180',
|
||||
'history_position',
|
||||
'bottom_profit',
|
||||
'mean_median_dev',
|
||||
'smart_money_accumulation',
|
||||
'GTJA_alpha013',
|
||||
'GTJA_alpha099',
|
||||
'GTJA_alpha107',
|
||||
'GTJA_alpha119',
|
||||
'GTJA_alpha141',
|
||||
'GTJA_alpha130',
|
||||
'GTJA_alpha173',
|
||||
"GTJA_alpha036",
|
||||
"GTJA_alpha032",
|
||||
"GTJA_alpha010",
|
||||
"GTJA_alpha005",
|
||||
"CP",
|
||||
"BP",
|
||||
"debt_to_equity",
|
||||
"current_ratio",
|
||||
"GTJA_alpha002",
|
||||
"GTJA_alpha027",
|
||||
"GTJA_alpha064",
|
||||
"GTJA_alpha062",
|
||||
"GTJA_alpha043",
|
||||
"GTJA_alpha044",
|
||||
"GTJA_alpha120",
|
||||
"GTJA_alpha117",
|
||||
"GTJA_alpha103",
|
||||
"GTJA_alpha104",
|
||||
"GTJA_alpha105",
|
||||
"GTJA_alpha073",
|
||||
"GTJA_alpha077",
|
||||
"GTJA_alpha085",
|
||||
"GTJA_alpha090",
|
||||
"GTJA_alpha087",
|
||||
"GTJA_alpha083",
|
||||
"GTJA_alpha092",
|
||||
"GTJA_alpha133",
|
||||
"GTJA_alpha131",
|
||||
"GTJA_alpha126",
|
||||
"GTJA_alpha124",
|
||||
"GTJA_alpha162",
|
||||
"GTJA_alpha164",
|
||||
"GTJA_alpha157",
|
||||
"GTJA_alpha177",
|
||||
"price_to_avg_cost",
|
||||
"cost_skewness",
|
||||
"GTJA_alpha191",
|
||||
"GTJA_alpha180",
|
||||
"history_position",
|
||||
"bottom_profit",
|
||||
"mean_median_dev",
|
||||
"smart_money_accumulation",
|
||||
"GTJA_alpha013",
|
||||
"GTJA_alpha099",
|
||||
"GTJA_alpha107",
|
||||
"GTJA_alpha119",
|
||||
"GTJA_alpha141",
|
||||
"GTJA_alpha130",
|
||||
"GTJA_alpha173",
|
||||
]
|
||||
|
||||
# 模型参数配置
|
||||
@@ -184,6 +185,7 @@ def main():
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 4. 创建 RegressionTask
|
||||
|
||||
375
src/experiment/tabm_regression.py
Normal file
375
src/experiment/tabm_regression.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# %% md
|
||||
# # TabM 回归训练流程
|
||||
#
|
||||
# 使用 TabM (Tabular MLP with Ensembles) 模型进行回归训练。
|
||||
# TabM 通过内置集成机制(ensemble_size=32)实现高效的多模型集成。
|
||||
# %% md
|
||||
# ## 1. 导入依赖
|
||||
# %%
|
||||
import os
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
TabMRegressionTask,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.core.trainer_v2 import Trainer
|
||||
from src.training.components.filters import STFilter
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
TRAIN_START,
|
||||
TRAIN_END,
|
||||
VAL_START,
|
||||
VAL_END,
|
||||
TEST_START,
|
||||
TEST_END,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
)
|
||||
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabm_regression"
|
||||
|
||||
# %% md
|
||||
# ## 2. 训练特定配置
|
||||
# %%
|
||||
# Label 配置(从 common.py 统一导入)
|
||||
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
|
||||
|
||||
# 排除的因子列表(与 LightGBM 回归保持一致)
|
||||
EXCLUDED_FACTORS = [
|
||||
# "GTJA_alpha001",
|
||||
# "GTJA_alpha002",
|
||||
# "GTJA_alpha003",
|
||||
# "GTJA_alpha004",
|
||||
# "GTJA_alpha005",
|
||||
# "GTJA_alpha006",
|
||||
# "GTJA_alpha007",
|
||||
# "GTJA_alpha008",
|
||||
# "GTJA_alpha009",
|
||||
# "GTJA_alpha010",
|
||||
# "GTJA_alpha011",
|
||||
# "GTJA_alpha012",
|
||||
# "GTJA_alpha013",
|
||||
# "GTJA_alpha014",
|
||||
# "GTJA_alpha015",
|
||||
# "GTJA_alpha016",
|
||||
# "GTJA_alpha017",
|
||||
# "GTJA_alpha018",
|
||||
# "GTJA_alpha019",
|
||||
# "GTJA_alpha020",
|
||||
# "GTJA_alpha022",
|
||||
# "GTJA_alpha023",
|
||||
# "GTJA_alpha024",
|
||||
# "GTJA_alpha025",
|
||||
# "GTJA_alpha026",
|
||||
# "GTJA_alpha027",
|
||||
# "GTJA_alpha028",
|
||||
# "GTJA_alpha029",
|
||||
# "GTJA_alpha031",
|
||||
# "GTJA_alpha032",
|
||||
# "GTJA_alpha033",
|
||||
# "GTJA_alpha034",
|
||||
# "GTJA_alpha035",
|
||||
# "GTJA_alpha036",
|
||||
# "GTJA_alpha037",
|
||||
# # "GTJA_alpha038",
|
||||
# "GTJA_alpha039",
|
||||
# "GTJA_alpha040",
|
||||
# "GTJA_alpha041",
|
||||
# "GTJA_alpha042",
|
||||
# "GTJA_alpha043",
|
||||
# "GTJA_alpha044",
|
||||
# "GTJA_alpha045",
|
||||
# "GTJA_alpha046",
|
||||
# "GTJA_alpha047",
|
||||
# "GTJA_alpha048",
|
||||
# "GTJA_alpha049",
|
||||
# "GTJA_alpha050",
|
||||
# "GTJA_alpha051",
|
||||
# "GTJA_alpha052",
|
||||
# "GTJA_alpha053",
|
||||
# "GTJA_alpha054",
|
||||
# "GTJA_alpha056",
|
||||
# "GTJA_alpha057",
|
||||
# "GTJA_alpha058",
|
||||
# "GTJA_alpha059",
|
||||
# "GTJA_alpha060",
|
||||
# "GTJA_alpha061",
|
||||
# "GTJA_alpha062",
|
||||
# "GTJA_alpha063",
|
||||
# "GTJA_alpha064",
|
||||
# "GTJA_alpha065",
|
||||
# "GTJA_alpha066",
|
||||
# "GTJA_alpha067",
|
||||
# "GTJA_alpha068",
|
||||
# "GTJA_alpha070",
|
||||
# "GTJA_alpha071",
|
||||
# "GTJA_alpha072",
|
||||
# "GTJA_alpha073",
|
||||
# "GTJA_alpha074",
|
||||
# "GTJA_alpha076",
|
||||
# "GTJA_alpha077",
|
||||
# "GTJA_alpha078",
|
||||
# "GTJA_alpha079",
|
||||
# "GTJA_alpha080",
|
||||
# "GTJA_alpha081",
|
||||
# "GTJA_alpha082",
|
||||
# "GTJA_alpha083",
|
||||
# "GTJA_alpha084",
|
||||
# "GTJA_alpha085",
|
||||
# "GTJA_alpha086",
|
||||
# "GTJA_alpha087",
|
||||
# "GTJA_alpha088",
|
||||
# "GTJA_alpha089",
|
||||
# "GTJA_alpha090",
|
||||
# "GTJA_alpha091",
|
||||
# "GTJA_alpha092",
|
||||
# "GTJA_alpha093",
|
||||
# "GTJA_alpha094",
|
||||
# "GTJA_alpha095",
|
||||
# "GTJA_alpha096",
|
||||
# "GTJA_alpha097",
|
||||
# "GTJA_alpha098",
|
||||
# "GTJA_alpha099",
|
||||
# "GTJA_alpha100",
|
||||
# "GTJA_alpha101",
|
||||
# "GTJA_alpha102",
|
||||
# "GTJA_alpha103",
|
||||
# "GTJA_alpha104",
|
||||
# "GTJA_alpha105",
|
||||
# "GTJA_alpha106",
|
||||
# "GTJA_alpha107",
|
||||
# "GTJA_alpha108",
|
||||
# "GTJA_alpha109",
|
||||
# "GTJA_alpha110",
|
||||
# "GTJA_alpha111",
|
||||
# "GTJA_alpha112",
|
||||
# # "GTJA_alpha113",
|
||||
# "GTJA_alpha114",
|
||||
# "GTJA_alpha115",
|
||||
# "GTJA_alpha117",
|
||||
# "GTJA_alpha118",
|
||||
# "GTJA_alpha119",
|
||||
# "GTJA_alpha120",
|
||||
# # "GTJA_alpha121",
|
||||
# "GTJA_alpha122",
|
||||
# "GTJA_alpha123",
|
||||
# "GTJA_alpha124",
|
||||
# "GTJA_alpha125",
|
||||
# "GTJA_alpha126",
|
||||
# "GTJA_alpha127",
|
||||
# "GTJA_alpha128",
|
||||
# "GTJA_alpha129",
|
||||
# "GTJA_alpha130",
|
||||
# "GTJA_alpha131",
|
||||
# "GTJA_alpha132",
|
||||
# "GTJA_alpha133",
|
||||
# "GTJA_alpha134",
|
||||
# "GTJA_alpha135",
|
||||
# "GTJA_alpha136",
|
||||
# # "GTJA_alpha138",
|
||||
# "GTJA_alpha139",
|
||||
# # "GTJA_alpha140",
|
||||
# "GTJA_alpha141",
|
||||
# "GTJA_alpha142",
|
||||
# "GTJA_alpha145",
|
||||
# # "GTJA_alpha146",
|
||||
# "GTJA_alpha148",
|
||||
# "GTJA_alpha150",
|
||||
# "GTJA_alpha151",
|
||||
# "GTJA_alpha152",
|
||||
# "GTJA_alpha153",
|
||||
# "GTJA_alpha154",
|
||||
# "GTJA_alpha155",
|
||||
# "GTJA_alpha156",
|
||||
# "GTJA_alpha157",
|
||||
# "GTJA_alpha158",
|
||||
# "GTJA_alpha159",
|
||||
# "GTJA_alpha160",
|
||||
# "GTJA_alpha161",
|
||||
# "GTJA_alpha162",
|
||||
# "GTJA_alpha163",
|
||||
# "GTJA_alpha164",
|
||||
# # "GTJA_alpha165",
|
||||
# "GTJA_alpha166",
|
||||
# "GTJA_alpha167",
|
||||
# "GTJA_alpha168",
|
||||
# "GTJA_alpha169",
|
||||
# "GTJA_alpha170",
|
||||
# "GTJA_alpha171",
|
||||
# "GTJA_alpha173",
|
||||
# "GTJA_alpha174",
|
||||
# "GTJA_alpha175",
|
||||
# "GTJA_alpha176",
|
||||
# "GTJA_alpha177",
|
||||
# "GTJA_alpha178",
|
||||
# "GTJA_alpha179",
|
||||
# "GTJA_alpha180",
|
||||
# # "GTJA_alpha183",
|
||||
# "GTJA_alpha184",
|
||||
# "GTJA_alpha185",
|
||||
# "GTJA_alpha187",
|
||||
# "GTJA_alpha188",
|
||||
# "GTJA_alpha189",
|
||||
# "GTJA_alpha191",
|
||||
# "chip_dispersion_90",
|
||||
# "chip_dispersion_70",
|
||||
# "cost_skewness",
|
||||
# "dispersion_change_20",
|
||||
# "price_to_avg_cost",
|
||||
# "price_to_median_cost",
|
||||
# "mean_median_dev",
|
||||
# "trap_pressure",
|
||||
# "bottom_profit",
|
||||
# "history_position",
|
||||
# "winner_rate_surge_5",
|
||||
# "winner_rate_cs_rank",
|
||||
# "winner_rate_dev_20",
|
||||
# "winner_rate_volatility",
|
||||
# "smart_money_accumulation",
|
||||
# "winner_vol_corr_20",
|
||||
# "cost_base_momentum",
|
||||
# "bottom_cost_stability",
|
||||
# "pivot_reversion",
|
||||
# "chip_transition",
|
||||
]
|
||||
|
||||
# TabM 模型参数配置(来自用户提供的示例代码)
|
||||
MODEL_PARAMS = {
|
||||
# ==================== MLP 结构 ====================
|
||||
"n_blocks": 3, # MLP 层数
|
||||
"d_block": 256, # 每层神经元数
|
||||
"dropout": 0.3, # Dropout 率
|
||||
# ==================== 集成机制 ====================
|
||||
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
|
||||
# ==================== 训练参数 ====================
|
||||
"batch_size": 1024, # 批次大小
|
||||
"learning_rate": 1e-3, # 学习率
|
||||
"weight_decay": 1e-5, # 权重衰减
|
||||
"epochs": 100, # 训练轮数
|
||||
# ==================== 早停 ====================
|
||||
"early_stopping_patience": 30, # 早停耐心值
|
||||
}
|
||||
|
||||
# 日期范围配置
|
||||
date_range = {
|
||||
"train": (TRAIN_START, TRAIN_END),
|
||||
"val": (VAL_START, VAL_END),
|
||||
"test": (TEST_START, TEST_END),
|
||||
}
|
||||
|
||||
# 输出配置
|
||||
output_config = {
|
||||
"output_dir": OUTPUT_DIR,
|
||||
"output_filename": "tabm_regression_output.csv",
|
||||
"save_predictions": SAVE_PREDICTIONS,
|
||||
"save_model": SAVE_MODEL,
|
||||
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
||||
"top_n": TOP_N,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TabM 回归模型训练")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 创建 FactorManager
|
||||
print("\n[2] 创建 FactorManager")
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
# 3. 创建 DataPipeline
|
||||
# 【关键】TabM 需要标准化输入,使用 StandardScaler
|
||||
# 处理顺序:NullFiller -> Winsorizer -> StandardScaler
|
||||
print("\n[3] 创建 DataPipeline")
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布
|
||||
(StandardScaler, {}), # TabM 需要标准化输入
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 4. 创建 TabMRegressionTask
|
||||
print("\n[4] 创建 TabMRegressionTask")
|
||||
task = TabMRegressionTask(
|
||||
model_params=MODEL_PARAMS,
|
||||
label_name=LABEL_NAME,
|
||||
)
|
||||
|
||||
# 5. 创建 Trainer
|
||||
print("\n[5] 创建 Trainer")
|
||||
trainer = Trainer(
|
||||
data_pipeline=pipeline,
|
||||
task=task,
|
||||
output_config=output_config,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 6. 执行训练
|
||||
print("\n[6] 执行训练")
|
||||
results = trainer.run(engine=engine, date_range=date_range)
|
||||
|
||||
# 7. 绘制训练曲线
|
||||
print("\n[7] 绘制训练曲线")
|
||||
task.plot_training_metrics(
|
||||
output_path=os.path.join(OUTPUT_DIR, "tabm_training_curve.png")
|
||||
)
|
||||
|
||||
# 8. 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n[8] 保存模型和因子信息")
|
||||
save_model_with_factors(
|
||||
model=task.get_model(),
|
||||
model_path=output_config["model_save_path"],
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=pipeline.get_fitted_processors(),
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("TabM 训练流程完成!")
|
||||
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_regression_output.csv')}")
|
||||
print("=" * 80)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
425
src/experiment/tabpfn_regression.py
Normal file
425
src/experiment/tabpfn_regression.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# %% md
|
||||
# # TabPFN 回归训练流程
|
||||
#
|
||||
# 使用 TabPFN (Prior-Data Fitted Network) 进行回归预测。
|
||||
# TabPFN 通过上下文学习进行预测,无需传统梯度下降训练过程。
|
||||
# %% md
|
||||
# ## 1. 导入依赖
|
||||
# %%
|
||||
import os
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.core.trainer_v2 import Trainer
|
||||
from src.training.components.filters import STFilter
|
||||
from src.training.components.models import TabPFNModel
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
TRAIN_START,
|
||||
TRAIN_END,
|
||||
VAL_START,
|
||||
VAL_END,
|
||||
TEST_START,
|
||||
TEST_END,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabpfn"
|
||||
|
||||
# %% md
|
||||
# ## 2. 训练特定配置
|
||||
# %%
|
||||
# Label 配置(从 common.py 统一导入)
|
||||
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
|
||||
|
||||
# 排除的因子列表(与 regression.py 保持一致)
|
||||
EXCLUDED_FACTORS = [
|
||||
"GTJA_alpha001",
|
||||
"GTJA_alpha002",
|
||||
"GTJA_alpha003",
|
||||
"GTJA_alpha004",
|
||||
"GTJA_alpha005",
|
||||
"GTJA_alpha006",
|
||||
"GTJA_alpha007",
|
||||
"GTJA_alpha008",
|
||||
"GTJA_alpha009",
|
||||
"GTJA_alpha010",
|
||||
"GTJA_alpha011",
|
||||
"GTJA_alpha012",
|
||||
"GTJA_alpha013",
|
||||
"GTJA_alpha014",
|
||||
"GTJA_alpha015",
|
||||
"GTJA_alpha016",
|
||||
"GTJA_alpha017",
|
||||
"GTJA_alpha018",
|
||||
"GTJA_alpha019",
|
||||
"GTJA_alpha020",
|
||||
"GTJA_alpha022",
|
||||
"GTJA_alpha023",
|
||||
"GTJA_alpha024",
|
||||
"GTJA_alpha025",
|
||||
"GTJA_alpha026",
|
||||
"GTJA_alpha027",
|
||||
"GTJA_alpha028",
|
||||
"GTJA_alpha029",
|
||||
"GTJA_alpha031",
|
||||
"GTJA_alpha032",
|
||||
"GTJA_alpha033",
|
||||
"GTJA_alpha034",
|
||||
"GTJA_alpha035",
|
||||
"GTJA_alpha036",
|
||||
"GTJA_alpha037",
|
||||
# "GTJA_alpha038",
|
||||
"GTJA_alpha039",
|
||||
"GTJA_alpha040",
|
||||
"GTJA_alpha041",
|
||||
"GTJA_alpha042",
|
||||
"GTJA_alpha043",
|
||||
"GTJA_alpha044",
|
||||
"GTJA_alpha045",
|
||||
"GTJA_alpha046",
|
||||
"GTJA_alpha047",
|
||||
"GTJA_alpha048",
|
||||
"GTJA_alpha049",
|
||||
"GTJA_alpha050",
|
||||
"GTJA_alpha051",
|
||||
"GTJA_alpha052",
|
||||
"GTJA_alpha053",
|
||||
"GTJA_alpha054",
|
||||
"GTJA_alpha056",
|
||||
"GTJA_alpha057",
|
||||
"GTJA_alpha058",
|
||||
"GTJA_alpha059",
|
||||
"GTJA_alpha060",
|
||||
"GTJA_alpha061",
|
||||
"GTJA_alpha062",
|
||||
"GTJA_alpha063",
|
||||
"GTJA_alpha064",
|
||||
"GTJA_alpha065",
|
||||
"GTJA_alpha066",
|
||||
"GTJA_alpha067",
|
||||
"GTJA_alpha068",
|
||||
"GTJA_alpha070",
|
||||
"GTJA_alpha071",
|
||||
"GTJA_alpha072",
|
||||
"GTJA_alpha073",
|
||||
"GTJA_alpha074",
|
||||
"GTJA_alpha076",
|
||||
"GTJA_alpha077",
|
||||
"GTJA_alpha078",
|
||||
"GTJA_alpha079",
|
||||
"GTJA_alpha080",
|
||||
"GTJA_alpha081",
|
||||
"GTJA_alpha082",
|
||||
"GTJA_alpha083",
|
||||
"GTJA_alpha084",
|
||||
"GTJA_alpha085",
|
||||
"GTJA_alpha086",
|
||||
"GTJA_alpha087",
|
||||
"GTJA_alpha088",
|
||||
"GTJA_alpha089",
|
||||
"GTJA_alpha090",
|
||||
"GTJA_alpha091",
|
||||
"GTJA_alpha092",
|
||||
"GTJA_alpha093",
|
||||
"GTJA_alpha094",
|
||||
"GTJA_alpha095",
|
||||
"GTJA_alpha096",
|
||||
"GTJA_alpha097",
|
||||
"GTJA_alpha098",
|
||||
"GTJA_alpha099",
|
||||
"GTJA_alpha100",
|
||||
"GTJA_alpha101",
|
||||
"GTJA_alpha102",
|
||||
"GTJA_alpha103",
|
||||
"GTJA_alpha104",
|
||||
"GTJA_alpha105",
|
||||
"GTJA_alpha106",
|
||||
"GTJA_alpha107",
|
||||
"GTJA_alpha108",
|
||||
"GTJA_alpha109",
|
||||
"GTJA_alpha110",
|
||||
"GTJA_alpha111",
|
||||
"GTJA_alpha112",
|
||||
# "GTJA_alpha113",
|
||||
"GTJA_alpha114",
|
||||
"GTJA_alpha115",
|
||||
"GTJA_alpha117",
|
||||
"GTJA_alpha118",
|
||||
"GTJA_alpha119",
|
||||
"GTJA_alpha120",
|
||||
# "GTJA_alpha121",
|
||||
"GTJA_alpha122",
|
||||
"GTJA_alpha123",
|
||||
"GTJA_alpha124",
|
||||
"GTJA_alpha125",
|
||||
"GTJA_alpha126",
|
||||
"GTJA_alpha127",
|
||||
"GTJA_alpha128",
|
||||
"GTJA_alpha129",
|
||||
"GTJA_alpha130",
|
||||
"GTJA_alpha131",
|
||||
"GTJA_alpha132",
|
||||
"GTJA_alpha133",
|
||||
"GTJA_alpha134",
|
||||
"GTJA_alpha135",
|
||||
"GTJA_alpha136",
|
||||
# "GTJA_alpha138",
|
||||
"GTJA_alpha139",
|
||||
# "GTJA_alpha140",
|
||||
"GTJA_alpha141",
|
||||
"GTJA_alpha142",
|
||||
"GTJA_alpha145",
|
||||
# "GTJA_alpha146",
|
||||
"GTJA_alpha148",
|
||||
"GTJA_alpha150",
|
||||
"GTJA_alpha151",
|
||||
"GTJA_alpha152",
|
||||
"GTJA_alpha153",
|
||||
"GTJA_alpha154",
|
||||
"GTJA_alpha155",
|
||||
"GTJA_alpha156",
|
||||
"GTJA_alpha157",
|
||||
"GTJA_alpha158",
|
||||
"GTJA_alpha159",
|
||||
"GTJA_alpha160",
|
||||
"GTJA_alpha161",
|
||||
"GTJA_alpha162",
|
||||
"GTJA_alpha163",
|
||||
"GTJA_alpha164",
|
||||
# "GTJA_alpha165",
|
||||
"GTJA_alpha166",
|
||||
"GTJA_alpha167",
|
||||
"GTJA_alpha168",
|
||||
"GTJA_alpha169",
|
||||
"GTJA_alpha170",
|
||||
"GTJA_alpha171",
|
||||
"GTJA_alpha173",
|
||||
"GTJA_alpha174",
|
||||
"GTJA_alpha175",
|
||||
"GTJA_alpha176",
|
||||
"GTJA_alpha177",
|
||||
"GTJA_alpha178",
|
||||
"GTJA_alpha179",
|
||||
"GTJA_alpha180",
|
||||
# "GTJA_alpha183",
|
||||
"GTJA_alpha184",
|
||||
"GTJA_alpha185",
|
||||
"GTJA_alpha187",
|
||||
"GTJA_alpha188",
|
||||
"GTJA_alpha189",
|
||||
"GTJA_alpha191",
|
||||
"chip_dispersion_90",
|
||||
"chip_dispersion_70",
|
||||
"cost_skewness",
|
||||
"dispersion_change_20",
|
||||
"price_to_avg_cost",
|
||||
"price_to_median_cost",
|
||||
"mean_median_dev",
|
||||
"trap_pressure",
|
||||
"bottom_profit",
|
||||
"history_position",
|
||||
"winner_rate_surge_5",
|
||||
"winner_rate_cs_rank",
|
||||
"winner_rate_dev_20",
|
||||
"winner_rate_volatility",
|
||||
"smart_money_accumulation",
|
||||
"winner_vol_corr_20",
|
||||
"cost_base_momentum",
|
||||
"bottom_cost_stability",
|
||||
"pivot_reversion",
|
||||
"chip_transition",
|
||||
]
|
||||
|
||||
# 模型参数配置
|
||||
MODEL_PARAMS = {
|
||||
# ==================== 设备配置 ====================
|
||||
"device": "cuda", # 计算设备: "cuda" 或 "cpu"(默认 cuda)
|
||||
# ==================== 上下文限制 ====================
|
||||
"max_context_size": 100, # 16GB GPU 建议 1000-3000,32GB 可尝试 5000-8000
|
||||
}
|
||||
|
||||
# 日期范围配置
|
||||
date_range = {
|
||||
"train": (TRAIN_START, TRAIN_END),
|
||||
"val": (VAL_START, VAL_END),
|
||||
"test": (TEST_START, TEST_END),
|
||||
}
|
||||
|
||||
# 输出配置
|
||||
output_config = {
|
||||
"output_dir": OUTPUT_DIR,
|
||||
"output_filename": "tabpfn_output.csv",
|
||||
"save_predictions": SAVE_PREDICTIONS,
|
||||
"save_model": SAVE_MODEL,
|
||||
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
||||
"top_n": TOP_N,
|
||||
}
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 3. 自定义 TabPFN 任务
|
||||
# %%
|
||||
from src.training.tasks import RegressionTask
|
||||
|
||||
|
||||
class TabPFNTask(RegressionTask):
|
||||
"""TabPFN 回归任务
|
||||
|
||||
继承自 RegressionTask,但使用 TabPFNModel 作为模型。
|
||||
TabPFN 不需要传统的训练过程,而是通过上下文学习进行预测。
|
||||
"""
|
||||
|
||||
def __init__(self, model_params: dict, label_name: str):
|
||||
"""初始化 TabPFN 任务
|
||||
|
||||
Args:
|
||||
model_params: TabPFN 参数字典
|
||||
label_name: Label 列名称
|
||||
"""
|
||||
# 不调用父类 __init__,直接初始化以避免创建 LightGBMModel
|
||||
from src.training.tasks.base import BaseTask
|
||||
|
||||
BaseTask.__init__(self, model_params, label_name)
|
||||
self.evals_result: dict | None = None
|
||||
self.model = TabPFNModel(params=model_params)
|
||||
|
||||
def fit(self, train_data: dict, val_data: dict) -> None:
|
||||
"""训练 TabPFN 模型
|
||||
|
||||
TabPFN 通过将训练数据加载到模型上下文中进行"训练",
|
||||
不需要传统的梯度下降优化过程。
|
||||
|
||||
Args:
|
||||
train_data: 训练数据 {"X": DataFrame, "y": Series}
|
||||
val_data: 验证数据,用于评估但不参与训练
|
||||
"""
|
||||
X_train = train_data["X"]
|
||||
y_train = train_data["y"]
|
||||
X_val = val_data.get("X")
|
||||
y_val = val_data.get("y")
|
||||
|
||||
# TabPFN 使用 eval_set 进行验证
|
||||
self.model.fit(
|
||||
X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None
|
||||
)
|
||||
|
||||
def get_model(self) -> TabPFNModel:
|
||||
"""获取训练好的模型实例"""
|
||||
return self.model
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 4. 主函数
|
||||
# %%
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TabPFN 回归模型训练")
|
||||
print("=" * 80)
|
||||
print("\n[说明] TabPFN 使用上下文学习(In-Context Learning),")
|
||||
print(" 训练过程实际是加载数据到模型上下文。")
|
||||
print(" 如果训练数据超过上下文限制,会自动截取最近的数据。")
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 创建 FactorManager
|
||||
print("\n[2] 创建 FactorManager")
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
# 3. 创建 DataPipeline
|
||||
print("\n[3] 创建 DataPipeline")
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(StandardScaler, {}),
|
||||
# (CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
# (StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
)
|
||||
|
||||
# 4. 创建 TabPFNTask
|
||||
print("\n[4] 创建 TabPFNTask")
|
||||
task = TabPFNTask(
|
||||
model_params=MODEL_PARAMS,
|
||||
label_name=LABEL_NAME,
|
||||
)
|
||||
|
||||
# 5. 创建 Trainer
|
||||
print("\n[5] 创建 Trainer")
|
||||
trainer = Trainer(
|
||||
data_pipeline=pipeline,
|
||||
task=task,
|
||||
output_config=output_config,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 6. 执行训练
|
||||
print("\n[6] 执行训练")
|
||||
results = trainer.run(engine=engine, date_range=date_range)
|
||||
|
||||
# 7. 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n[7] 保存模型和因子信息")
|
||||
save_model_with_factors(
|
||||
model=task.get_model(),
|
||||
model_path=output_config["model_save_path"],
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=pipeline.get_fitted_processors(),
|
||||
)
|
||||
|
||||
# 8. 输出 TabPFN 特有指标
|
||||
print("\n" + "=" * 80)
|
||||
print("TabPFN 训练完成!")
|
||||
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabpfn_output.csv')}")
|
||||
|
||||
# 显示验证集评估结果(如果可用)
|
||||
model = task.get_model()
|
||||
best_score = model.get_best_score()
|
||||
if best_score:
|
||||
print("\n[验证集评估指标]")
|
||||
for metric, value in best_score.get("valid_0", {}).items():
|
||||
print(f" - {metric}: {value:.6f}")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
15
src/scripts/check_gpu.py
Normal file
15
src/scripts/check_gpu.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
# 查看 PyTorch 版本(关键!)
|
||||
print(f"PyTorch 版本: {torch.__version__}")
|
||||
# CPU 版本会显示: 2.6.0+cpu
|
||||
# GPU 版本会显示: 2.6.0+cu118 / 2.6.0+cu121 / 2.6.0+cu124 等
|
||||
|
||||
# 检查 CUDA 是否可用
|
||||
print(f"CUDA 可用: {torch.cuda.is_available()}")
|
||||
|
||||
# 如果有 CUDA,查看版本
|
||||
if torch.cuda.is_available():
|
||||
print(f"CUDA 版本: {torch.version.cuda}")
|
||||
print(f"GPU 数量: {torch.cuda.device_count()}")
|
||||
print(f"GPU 名称: {torch.cuda.get_device_name(0)}")
|
||||
@@ -29,7 +29,7 @@ from src.training.components.processors import (
|
||||
)
|
||||
|
||||
# 模型
|
||||
from src.training.components.models import LightGBMModel
|
||||
from src.training.components.models import LightGBMModel, TabMModel
|
||||
|
||||
# 数据过滤器
|
||||
from src.training.components.filters import BaseFilter, STFilter
|
||||
@@ -50,7 +50,7 @@ from src.training.config import TrainingConfig
|
||||
from src.training.factor_manager import FactorManager
|
||||
from src.training.pipeline import DataPipeline
|
||||
from src.training.result_analyzer import ResultAnalyzer
|
||||
from src.training.tasks import BaseTask, RegressionTask, RankTask
|
||||
from src.training.tasks import BaseTask, RegressionTask, RankTask, TabMRegressionTask
|
||||
|
||||
# 从 trainer_v2 导入新 Trainer(推荐)
|
||||
from src.training.core.trainer_v2 import Trainer as TrainerV2
|
||||
@@ -79,6 +79,7 @@ __all__ = [
|
||||
"STFilter",
|
||||
# 模型
|
||||
"LightGBMModel",
|
||||
"TabMModel",
|
||||
# 训练核心(旧版,已废弃)
|
||||
"StockPoolManager",
|
||||
"Trainer",
|
||||
@@ -93,5 +94,6 @@ __all__ = [
|
||||
"BaseTask",
|
||||
"RegressionTask",
|
||||
"RankTask",
|
||||
"TabMRegressionTask",
|
||||
"TrainerV2", # 新的 Trainer(推荐)
|
||||
]
|
||||
|
||||
@@ -5,5 +5,7 @@
|
||||
|
||||
from src.training.components.models.lightgbm import LightGBMModel
|
||||
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
|
||||
from src.training.components.models.tabpfn_model import TabPFNModel
|
||||
from src.training.components.models.tabm_model import TabMModel
|
||||
|
||||
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel"]
|
||||
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel", "TabPFNModel", "TabMModel"]
|
||||
|
||||
368
src/training/components/models/tabm_model.py
Normal file
368
src/training/components/models/tabm_model.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""TabM模型实现
|
||||
|
||||
TabM (Tabular Multilayer Perceptron with Ensembles)
|
||||
基于 rtdl_revisiting_models 的 TabM 模型,支持内置集成。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from tabm import TabM
|
||||
|
||||
from src.training.components.base import BaseModel
|
||||
from src.training.registry import register_model
|
||||
|
||||
|
||||
@register_model("tabm")
|
||||
class TabMModel(BaseModel):
|
||||
"""TabM回归模型
|
||||
|
||||
特点:
|
||||
- 使用MLP架构
|
||||
- 内置集成机制(ensemble_size),显存开销远小于独立模型
|
||||
- 训练时所有集成成员独立优化,保持多样性
|
||||
- 预测时取集成成员均值获得稳定结果
|
||||
|
||||
Attributes:
|
||||
name: 模型名称标识
|
||||
params: 模型参数字典
|
||||
model: TabM模型实例
|
||||
device: 计算设备(cuda/cpu)
|
||||
training_history_: 训练历史记录
|
||||
feature_names_: 特征名称列表
|
||||
"""
|
||||
|
||||
name = "tabm"
|
||||
|
||||
def __init__(self, params: Dict[str, Any]):
|
||||
"""初始化TabM模型
|
||||
|
||||
Args:
|
||||
params: 模型参数字典,包含:
|
||||
- n_blocks: MLP层数 (默认: 3)
|
||||
- d_block: 每层神经元数 (默认: 256)
|
||||
- dropout: Dropout率 (默认: 0.1)
|
||||
- ensemble_size: 集成大小 (默认: 32)
|
||||
- batch_size: 批次大小 (默认: 1024)
|
||||
- learning_rate: 学习率 (默认: 1e-3)
|
||||
- weight_decay: 权重衰减 (默认: 1e-5)
|
||||
- epochs: 训练轮数 (默认: 50)
|
||||
- early_stopping_patience: 早停耐心值 (默认: 10)
|
||||
"""
|
||||
self.params = params
|
||||
self.model = None
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.training_history_: Dict[str, List[float]] = {
|
||||
"train_loss": [],
|
||||
"val_loss": [],
|
||||
}
|
||||
self.feature_names_: Optional[List[str]] = None
|
||||
|
||||
# 损失函数
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
def _make_loader(
|
||||
self, X: np.ndarray, y: Optional[np.ndarray] = None, shuffle: bool = False
|
||||
) -> DataLoader:
|
||||
"""创建DataLoader
|
||||
|
||||
Args:
|
||||
X: 特征数组 [N, n_features]
|
||||
y: 标签数组 [N] 或 None
|
||||
shuffle: 是否打乱数据
|
||||
|
||||
Returns:
|
||||
DataLoader实例
|
||||
"""
|
||||
if y is not None:
|
||||
dataset = TensorDataset(torch.from_numpy(X), torch.from_numpy(y))
|
||||
else:
|
||||
dataset = TensorDataset(torch.from_numpy(X))
|
||||
|
||||
batch_size = self.params.get("batch_size", 1024)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
||||
|
||||
def _validate(self, val_loader: DataLoader) -> float:
|
||||
"""验证模型
|
||||
|
||||
Args:
|
||||
val_loader: 验证数据加载器
|
||||
|
||||
Returns:
|
||||
平均验证损失
|
||||
"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if len(batch) == 2:
|
||||
bx, by = batch
|
||||
bx, by = bx.to(self.device), by.to(self.device)
|
||||
else:
|
||||
bx = batch[0].to(self.device)
|
||||
by = None
|
||||
|
||||
# 预测时取集成成员均值
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
preds = outputs.mean(dim=1).squeeze(-1) # [B]
|
||||
|
||||
if by is not None:
|
||||
loss = self.criterion(preds, by).item()
|
||||
total_loss += loss
|
||||
n_batches += 1
|
||||
|
||||
return total_loss / max(n_batches, 1)
|
||||
|
||||
def fit(
|
||||
self, X: pl.DataFrame, y: pl.Series, eval_set: Optional[tuple] = None
|
||||
) -> "TabMModel":
|
||||
"""训练TabM模型
|
||||
|
||||
训练策略:
|
||||
1. 对所有集成成员独立计算Loss,保持多样性
|
||||
2. 验证和预测时取ensemble成员均值
|
||||
|
||||
Args:
|
||||
X: 训练特征DataFrame
|
||||
y: 训练标签Series
|
||||
eval_set: 验证集元组 (X_val, y_val),可选
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
# 保存特征名称
|
||||
self.feature_names_ = X.columns
|
||||
|
||||
# 【关键】数据类型强制转换为float32
|
||||
# PyTorch对float64支持较差,避免使用Polars/Numpy默认类型
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
y_np = y.to_numpy().astype(np.float32)
|
||||
|
||||
# 创建DataLoader
|
||||
train_loader = self._make_loader(X_np, y_np, shuffle=True)
|
||||
val_loader = None
|
||||
if eval_set is not None:
|
||||
X_val, y_val = eval_set
|
||||
X_val_np = X_val.to_numpy().astype(np.float32)
|
||||
y_val_np = y_val.to_numpy().astype(np.float32)
|
||||
val_loader = self._make_loader(X_val_np, y_val_np, shuffle=False)
|
||||
|
||||
n_features = X_np.shape[1]
|
||||
ensemble_size = self.params.get("ensemble_size", 32)
|
||||
|
||||
# 初始化TabM模型,使用TabM.make()自动填充默认参数
|
||||
self.model = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1, # 回归任务输出维度为1
|
||||
n_blocks=self.params.get("n_blocks", 3),
|
||||
d_block=self.params.get("d_block", 256),
|
||||
dropout=self.params.get("dropout", 0.1),
|
||||
k=ensemble_size, # 集成大小
|
||||
).to(self.device)
|
||||
|
||||
# 优化器
|
||||
optimizer = optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.params.get("learning_rate", 1e-3),
|
||||
weight_decay=self.params.get("weight_decay", 1e-5),
|
||||
)
|
||||
|
||||
# 训练参数
|
||||
epochs = self.params.get("epochs", 50)
|
||||
early_stopping_patience = self.params.get("early_stopping_patience", 10)
|
||||
|
||||
# 训练循环
|
||||
best_val_loss = float("inf")
|
||||
patience_counter = 0
|
||||
|
||||
print(f"[TabM] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 训练阶段
|
||||
self.model.train()
|
||||
train_loss = 0.0
|
||||
n_train_batches = 0
|
||||
|
||||
for bx, by in train_loader:
|
||||
bx, by = bx.to(self.device), by.to(self.device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 前向传播
|
||||
# outputs形状: [Batch, Ensemble, 1]
|
||||
outputs = self.model(bx)
|
||||
outputs_squeezed = outputs.squeeze(-1) # [B, E]
|
||||
|
||||
# 【关键】针对所有集成成员计算Loss
|
||||
# 不先取均值,让每个集成成员独立收敛,保持集成多样性
|
||||
by_expanded = by.unsqueeze(-1).expand(-1, ensemble_size) # [B, E]
|
||||
loss = self.criterion(outputs_squeezed, by_expanded)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
n_train_batches += 1
|
||||
|
||||
avg_train_loss = train_loss / max(n_train_batches, 1)
|
||||
self.training_history_["train_loss"].append(avg_train_loss)
|
||||
|
||||
# 验证阶段
|
||||
if val_loader is not None:
|
||||
val_loss = self._validate(val_loader)
|
||||
self.training_history_["val_loss"].append(val_loss)
|
||||
|
||||
# 早停逻辑
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabM] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Train Loss: {avg_train_loss:.6f} | "
|
||||
f"Val Loss: {val_loss:.6f}"
|
||||
)
|
||||
|
||||
if patience_counter >= early_stopping_patience:
|
||||
print(f"[TabM] 早停触发,epoch {epoch + 1}")
|
||||
break
|
||||
else:
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabM] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Train Loss: {avg_train_loss:.6f}"
|
||||
)
|
||||
|
||||
print(f"[TabM] 训练完成")
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""生成预测
|
||||
|
||||
预测时对ensemble_size个成员取均值,获得稳定结果。
|
||||
|
||||
Args:
|
||||
X: 特征DataFrame
|
||||
|
||||
Returns:
|
||||
预测结果数组 [N]
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,请先调用fit()")
|
||||
|
||||
# 数据类型转换
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
loader = self._make_loader(X_np, shuffle=False)
|
||||
|
||||
self.model.eval()
|
||||
all_preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
bx = batch[0].to(self.device)
|
||||
# 预测时取集成成员均值
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
preds = outputs.mean(dim=1).squeeze(-1) # [B]
|
||||
all_preds.append(preds.cpu().numpy())
|
||||
|
||||
return np.concatenate(all_preds)
|
||||
|
||||
def feature_importance(self) -> Optional[pl.Series]:
|
||||
"""获取特征重要性
|
||||
|
||||
TabM没有内置特征重要性计算,返回None。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""保存模型
|
||||
|
||||
保存模型state_dict和元数据(params, feature_names, training_history)
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,无法保存")
|
||||
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存模型权重
|
||||
model_path = path.with_suffix(".pt")
|
||||
torch.save(self.model.state_dict(), model_path)
|
||||
|
||||
# 保存元数据
|
||||
meta_path = path.with_suffix(".meta")
|
||||
meta = {
|
||||
"params": self.params,
|
||||
"feature_names": self.feature_names_,
|
||||
"training_history": self.training_history_,
|
||||
"device": str(self.device),
|
||||
}
|
||||
with open(meta_path, "wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
|
||||
print(f"[TabM] 模型保存到: {path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "TabMModel":
|
||||
"""加载模型
|
||||
|
||||
Args:
|
||||
path: 模型路径(不含扩展名)
|
||||
|
||||
Returns:
|
||||
加载的TabMModel实例
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
# 加载元数据
|
||||
meta_path = path.with_suffix(".meta")
|
||||
with open(meta_path, "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
|
||||
# 创建实例
|
||||
instance = cls(meta["params"])
|
||||
instance.feature_names_ = meta["feature_names"]
|
||||
instance.training_history_ = meta["training_history"]
|
||||
|
||||
# 重建模型结构
|
||||
if instance.feature_names_ is not None:
|
||||
n_features = len(instance.feature_names_)
|
||||
ensemble_size = instance.params.get("ensemble_size", 32)
|
||||
|
||||
instance.model = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1, # 回归任务输出维度为1
|
||||
n_blocks=instance.params.get("n_blocks", 3),
|
||||
d_block=instance.params.get("d_block", 256),
|
||||
dropout=instance.params.get("dropout", 0.1),
|
||||
k=ensemble_size, # 集成大小
|
||||
).to(instance.device)
|
||||
|
||||
# 加载权重
|
||||
model_path = path.with_suffix(".pt")
|
||||
instance.model.load_state_dict(
|
||||
torch.load(model_path, map_location=instance.device)
|
||||
)
|
||||
|
||||
print(f"[TabM] 模型从 {path} 加载完成")
|
||||
return instance
|
||||
296
src/training/components/models/tabpfn_model.py
Normal file
296
src/training/components/models/tabpfn_model.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""TabPFN 模型实现
|
||||
|
||||
基于 TabPFN (Prior-Data Fitted Network) 的回归模型实现。
|
||||
TabPFN 利用预训练的 Transformer 网络,通过上下文学习(in-context learning)
|
||||
进行快速的小样本/中样本回归预测,无需传统训练过程。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
from scipy.stats import spearmanr
|
||||
|
||||
from src.training.components.base import BaseModel
|
||||
from src.training.registry import register_model
|
||||
|
||||
os.environ["HF_TOKEN"] = "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"
|
||||
|
||||
|
||||
@register_model("tabpfn")
|
||||
class TabPFNModel(BaseModel):
|
||||
"""TabPFN 回归模型
|
||||
|
||||
使用 TabPFN 库实现基于 Prior-Data Fitted Network 的回归预测。
|
||||
该模型通过上下文学习方式进行预测,无需传统梯度下降训练。
|
||||
支持 GPU 加速和自动上下文截断处理。
|
||||
|
||||
Attributes:
|
||||
name: 模型名称 "tabpfn"
|
||||
params: TabPFN 参数字典
|
||||
model: TabPFNRegressor 实例
|
||||
feature_names_: 特征名称列表
|
||||
evals_result_: 训练评估结果
|
||||
best_score_: 最佳评估指标
|
||||
"""
|
||||
|
||||
name = "tabpfn"
|
||||
|
||||
# TabPFN 官方限制(最大样本数),可通过 ignore_pretraining_limits=True 扩展
|
||||
MAX_CONTEXT_SIZE = 10000
|
||||
|
||||
def __init__(self, params: Optional[dict] = None):
|
||||
"""初始化 TabPFN 模型
|
||||
|
||||
Args:
|
||||
params: TabPFN 参数字典,支持以下参数:
|
||||
- device: 计算设备,'cuda' 或 'cpu'(默认 'cpu')
|
||||
- model_path: 本地模型权重文件路径(可选)
|
||||
- N_ensemble: 集成数量,用于降低预测方差(默认 1)
|
||||
- max_context_size: 最大上下文样本数(默认 50000)
|
||||
|
||||
Examples:
|
||||
>>> model = TabPFNModel(params={
|
||||
... "device": "cuda",
|
||||
... "N_ensemble": 5,
|
||||
... })
|
||||
"""
|
||||
self.params = dict(params) if params is not None else {}
|
||||
self.model = None
|
||||
self.feature_names_: Optional[list] = None
|
||||
self.evals_result_: Optional[dict] = None
|
||||
self.best_score_: Optional[dict] = None
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
eval_set: Optional[tuple] = None,
|
||||
) -> "TabPFNModel":
|
||||
"""训练/加载 TabPFN 模型
|
||||
|
||||
TabPFN 采用上下文学习,"fit" 操作实际上是加载训练数据到模型上下文。
|
||||
如果训练数据超过上下文限制,会自动截取最近的数据。
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
y: 目标变量 (Polars Series)
|
||||
eval_set: 验证集元组 (X_val, y_val),用于评估模型性能
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 tabpfn
|
||||
RuntimeError: 模型初始化或加载失败
|
||||
"""
|
||||
from tabpfn import TabPFNRegressor
|
||||
|
||||
self.feature_names_ = X.columns
|
||||
|
||||
# 转换为 numpy 数组
|
||||
X_np = X.to_numpy()
|
||||
y_np = y.to_numpy()
|
||||
|
||||
# 处理上下文大小限制
|
||||
max_context = self.params.get("max_context_size", self.MAX_CONTEXT_SIZE)
|
||||
if len(X_np) > max_context:
|
||||
print(
|
||||
f"[TabPFN] 训练数据 {len(X_np)} 超过上下文限制 {max_context},截取最近数据"
|
||||
)
|
||||
X_np = X_np[-max_context:]
|
||||
y_np = y_np[-max_context:]
|
||||
|
||||
# 初始化模型
|
||||
# TabPFNRegressor 需要设置 ignore_pretraining_limits=True 以支持超过 10,000 样本
|
||||
device = self.params.get("device", "cuda")
|
||||
ignore_limits = self.params.get("ignore_pretraining_limits", True)
|
||||
self.model = TabPFNRegressor(
|
||||
device=device,
|
||||
ignore_pretraining_limits=ignore_limits,
|
||||
n_estimators=1
|
||||
)
|
||||
|
||||
# 加载上下文(TabPFN 的 "fit" 是加载上下文)
|
||||
print("[TabPFN] 加载训练数据到上下文...")
|
||||
self.model.fit(X_np, y_np)
|
||||
|
||||
# 评估验证集
|
||||
if eval_set is not None:
|
||||
X_val, y_val = eval_set
|
||||
val_preds = self.predict(X_val)
|
||||
y_val_np = y_val.to_numpy()
|
||||
|
||||
# 计算评估指标
|
||||
mse = np.mean((y_val_np - val_preds) ** 2)
|
||||
rank_ic, p_value = spearmanr(val_preds, y_val_np)
|
||||
|
||||
self.evals_result_ = {
|
||||
"valid_0": {
|
||||
"mse": [mse],
|
||||
"rank_ic": [rank_ic],
|
||||
}
|
||||
}
|
||||
self.best_score_ = {
|
||||
"valid_0": {
|
||||
"mse": mse,
|
||||
"rank_ic": rank_ic,
|
||||
"rank_ic_pvalue": p_value,
|
||||
}
|
||||
}
|
||||
|
||||
print(f"[TabPFN] 验证集 MSE: {mse:.6f}, Rank IC: {rank_ic:.4f}")
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
预测结果 (numpy ndarray)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未初始化时调用
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未初始化,请先调用 fit()")
|
||||
|
||||
X_np = X.to_numpy()
|
||||
result = self.model.predict(X_np)
|
||||
return np.asarray(result)
|
||||
|
||||
def predict_with_uncertainty(
|
||||
self, X: pl.DataFrame
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""预测并返回不确定性估计
|
||||
|
||||
利用 N_ensemble 预测的标准差作为不确定性估计。
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
(predictions, uncertainties) 元组,均为 numpy ndarray
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未初始化,请先调用 fit()")
|
||||
|
||||
X_np = X.to_numpy()
|
||||
predictions = self.model.predict(X_np)
|
||||
|
||||
# 如果使用了 ensemble,可以通过多次预测计算标准差
|
||||
# 注意:这需要修改 TabPFNRegressor 的使用方式
|
||||
# 这里返回预测值的零不确定性作为默认行为
|
||||
uncertainties = np.zeros_like(predictions)
|
||||
|
||||
return np.asarray(predictions), np.asarray(uncertainties)
|
||||
|
||||
def get_evals_result(self) -> Optional[dict]:
|
||||
"""获取训练评估结果
|
||||
|
||||
Returns:
|
||||
评估结果字典,如果未进行评估返回 None
|
||||
"""
|
||||
return self.evals_result_
|
||||
|
||||
def get_best_score(self) -> Optional[dict]:
|
||||
"""获取最佳评分
|
||||
|
||||
Returns:
|
||||
最佳评分字典,如果未进行评估返回 None
|
||||
"""
|
||||
return self.best_score_
|
||||
|
||||
def evaluate(self, X: pl.DataFrame, y: pl.Series) -> dict[str, float]:
|
||||
"""评估模型性能
|
||||
|
||||
计算回归任务常用指标:MSE 和 Rank IC。
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
y: 真实目标值
|
||||
|
||||
Returns:
|
||||
评估指标字典,包含 mse 和 rank_ic
|
||||
"""
|
||||
preds = self.predict(X)
|
||||
y_np = y.to_numpy()
|
||||
|
||||
mse = float(np.mean((y_np - preds) ** 2))
|
||||
rank_ic_result = spearmanr(preds, y_np)
|
||||
rank_ic = float(rank_ic_result.correlation)
|
||||
p_value = float(rank_ic_result.pvalue)
|
||||
|
||||
return {
|
||||
"mse": mse,
|
||||
"rank_ic": rank_ic,
|
||||
"rank_ic_pvalue": p_value,
|
||||
}
|
||||
|
||||
def feature_importance(self) -> None:
|
||||
"""TabPFN 不支持传统特征重要性
|
||||
|
||||
TabPFN 是基于 Transformer 的上下文学习模型,
|
||||
不提供类似决策树的特征重要性指标。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""保存模型元数据和配置
|
||||
|
||||
TabPFN 模型本身不支持序列化保存,因此只保存:
|
||||
- 模型参数配置
|
||||
- 特征名称列表
|
||||
- 上下文数据摘要(样本数、特征数)
|
||||
|
||||
注意:实际使用时需要重新 fit 来加载上下文。
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
"""
|
||||
save_data = {
|
||||
"model_type": self.name,
|
||||
"params": self.params,
|
||||
"feature_names": self.feature_names_,
|
||||
"evals_result": self.evals_result_,
|
||||
"best_score": self.best_score_,
|
||||
}
|
||||
|
||||
# 保存为 JSON
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(save_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "TabPFNModel":
|
||||
"""加载模型配置
|
||||
|
||||
注意:TabPFN 模型需要重新 fit 才能使用,
|
||||
此方法仅恢复模型参数配置。
|
||||
|
||||
Args:
|
||||
path: 配置文件路径
|
||||
|
||||
Returns:
|
||||
配置恢复的 TabPFNModel 实例(未 fit)
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
save_data = json.load(f)
|
||||
|
||||
instance = cls(params=save_data.get("params", {}))
|
||||
instance.feature_names_ = save_data.get("feature_names")
|
||||
instance.evals_result_ = save_data.get("evals_result")
|
||||
instance.best_score_ = save_data.get("best_score")
|
||||
|
||||
print(f"[TabPFN] 已加载模型配置,需要调用 fit() 重新加载上下文")
|
||||
return instance
|
||||
@@ -3,6 +3,7 @@
|
||||
包含标准化、缩尾、缺失值填充等数据处理器。
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import polars as pl
|
||||
@@ -88,7 +89,9 @@ class NullFiller(BaseProcessor):
|
||||
"""
|
||||
if not self.by_date and self.strategy in ("mean", "median"):
|
||||
for col in self.feature_cols:
|
||||
if col in X.columns and X[col].dtype.is_numeric():
|
||||
if col in X.columns and (
|
||||
X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean
|
||||
):
|
||||
if self.strategy == "mean":
|
||||
self.stats_[col] = X[col].mean() or 0.0
|
||||
else: # median
|
||||
@@ -119,11 +122,14 @@ class NullFiller(BaseProcessor):
|
||||
raise ValueError(f"未知的填充策略: {self.strategy}")
|
||||
|
||||
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用0填充缺失值"""
|
||||
"""使用0填充缺失值(同时处理 NaN 和 null)"""
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
expr = pl.col(col).fill_null(0).alias(col)
|
||||
if col in self.feature_cols and (
|
||||
X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean
|
||||
):
|
||||
# 先 fill_nan 再 fill_null,确保两种缺失值都被处理
|
||||
expr = pl.col(col).fill_nan(0).fill_null(0).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -131,11 +137,19 @@ class NullFiller(BaseProcessor):
|
||||
return X.select(expressions)
|
||||
|
||||
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用指定值填充缺失值"""
|
||||
"""使用指定值填充缺失值(同时处理 NaN 和 null)"""
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
expr = pl.col(col).fill_null(self.fill_value).alias(col)
|
||||
if col in self.feature_cols and (
|
||||
X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean
|
||||
):
|
||||
# 先 fill_nan 再 fill_null
|
||||
expr = (
|
||||
pl.col(col)
|
||||
.fill_nan(self.fill_value)
|
||||
.fill_null(self.fill_value)
|
||||
.alias(col)
|
||||
)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -143,12 +157,13 @@ class NullFiller(BaseProcessor):
|
||||
return X.select(expressions)
|
||||
|
||||
def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用全局统计量填充(训练集学到的统计量)"""
|
||||
"""使用全局统计量填充(训练集学到的统计量,同时处理 NaN 和 null)"""
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in self.stats_:
|
||||
fill_val = self.stats_[col]
|
||||
expr = pl.col(col).fill_null(fill_val).alias(col)
|
||||
# 先 fill_nan 再 fill_null
|
||||
expr = pl.col(col).fill_nan(fill_val).fill_null(fill_val).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -156,8 +171,9 @@ class NullFiller(BaseProcessor):
|
||||
return X.select(expressions)
|
||||
|
||||
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用每天截面统计量填充"""
|
||||
# 确定需要处理的数值列
|
||||
"""使用每天截面统计量填充(同时处理 NaN 和 null)"""
|
||||
# 确定需要处理的列(仅 numeric 类型,排除 boolean)
|
||||
# 注意:boolean 类型没有 NaN 概念,fill_nan 会报错
|
||||
target_cols = [
|
||||
col
|
||||
for col in self.feature_cols
|
||||
@@ -180,10 +196,20 @@ class NullFiller(BaseProcessor):
|
||||
result = X.with_columns(stat_exprs)
|
||||
|
||||
# 使用统计量填充缺失值
|
||||
# 注意:如果某天某列全为null,统计量也会为null,所以需要链式填充
|
||||
# 同时处理 NaN 和 null
|
||||
fill_exprs = []
|
||||
for col in X.columns:
|
||||
if col in target_cols:
|
||||
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
|
||||
# 先用当天统计量填充 NaN 和 null,如果统计量也是null则用0填充
|
||||
expr = (
|
||||
pl.col(col)
|
||||
.fill_nan(pl.col(f"{col}_stat"))
|
||||
.fill_null(pl.col(f"{col}_stat"))
|
||||
.fill_nan(0) # 如果统计量是 NaN,再用 0 填充
|
||||
.fill_null(0) # 如果统计量是 null,再用 0 填充
|
||||
.alias(col)
|
||||
)
|
||||
fill_exprs.append(expr)
|
||||
else:
|
||||
fill_exprs.append(pl.col(col))
|
||||
@@ -230,17 +256,40 @@ class StandardScaler(BaseProcessor):
|
||||
self
|
||||
"""
|
||||
for col in self.feature_cols:
|
||||
# 仅处理数值类型,排除布尔类型(标准化布尔类型语义不明确)
|
||||
if col in X.columns and X[col].dtype.is_numeric():
|
||||
col_mean = X[col].mean()
|
||||
col_std = X[col].std()
|
||||
if col_mean is not None and col_std is not None:
|
||||
# 关键修复:检查是否为 None 且不是 NaN
|
||||
# 注意:使用 try-except 处理类型转换,避免 LSP 类型检查错误
|
||||
try:
|
||||
mean_is_valid = (
|
||||
col_mean is not None
|
||||
and isinstance(col_mean, (int, float))
|
||||
and not math.isnan(col_mean)
|
||||
)
|
||||
std_is_valid = (
|
||||
col_std is not None
|
||||
and isinstance(col_std, (int, float))
|
||||
and not math.isnan(col_std)
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
mean_is_valid = False
|
||||
std_is_valid = False
|
||||
|
||||
if mean_is_valid and std_is_valid:
|
||||
self.mean_[col] = col_mean
|
||||
self.std_[col] = col_std
|
||||
else:
|
||||
# 如果统计量无效,使用默认值(mean=0, std=1)
|
||||
# 防止 transform 时产生更多 NaN
|
||||
self.mean_[col] = 0.0
|
||||
self.std_[col] = 1.0
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""标准化(使用训练集学到的参数)
|
||||
"""标准化(使用训练集学到的参数,增加 NaN 保护)
|
||||
|
||||
Args:
|
||||
X: 待转换数据
|
||||
@@ -253,7 +302,18 @@ class StandardScaler(BaseProcessor):
|
||||
if col in self.mean_ and col in self.std_:
|
||||
# 避免除以0
|
||||
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
|
||||
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
|
||||
# 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN
|
||||
expr = (
|
||||
((pl.col(col) - self.mean_[col]) / std_val)
|
||||
.fill_nan(0)
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但未学习到统计量的列
|
||||
# 统一转换为float并同时处理 NaN 和 null
|
||||
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -308,13 +368,24 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
||||
# 构建表达式列表
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
# 仅处理数值类型,排除布尔类型(标准化布尔类型语义不明确)
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
# 截面标准化:每天独立计算均值和标准差
|
||||
# 避免除以0,当std为0时设为1
|
||||
# 关键修复:先 fill_nan 再 fill_null,防止计算产生的 NaN
|
||||
expr = (
|
||||
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||
).alias(col)
|
||||
(
|
||||
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||
)
|
||||
.fill_nan(0)
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但类型不匹配的列,转换为float并同时处理 NaN 和 null
|
||||
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -384,6 +455,7 @@ class Winsorizer(BaseProcessor):
|
||||
"""
|
||||
if not self.by_date:
|
||||
for col in self.feature_cols:
|
||||
# 仅处理数值类型,排除布尔类型(quantile 不支持布尔类型)
|
||||
if col in X.columns and X[col].dtype.is_numeric():
|
||||
self.bounds_[col] = {
|
||||
"lower": X[col].quantile(self.lower),
|
||||
@@ -414,13 +486,19 @@ class Winsorizer(BaseProcessor):
|
||||
upper = self.bounds_[col]["upper"]
|
||||
expr = pl.col(col).clip(lower, upper).alias(col)
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但未学习到边界的列(如全为NaN、布尔列等)
|
||||
# 统一转换为float并填充0
|
||||
expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
return X.select(expressions)
|
||||
|
||||
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""每日独立缩尾"""
|
||||
# 确定需要处理的数值列
|
||||
# 确定需要处理的列(仅 numeric 类型,排除 boolean)
|
||||
# 注意:quantile 操作不支持布尔类型
|
||||
target_cols = [
|
||||
col
|
||||
for col in self.feature_cols
|
||||
@@ -444,9 +522,11 @@ class Winsorizer(BaseProcessor):
|
||||
clip_exprs = []
|
||||
for col in X.columns:
|
||||
if col in target_cols:
|
||||
# 先用当天分位数缩尾,如果分位数是null(该日全为NaN)则填充0
|
||||
clipped = (
|
||||
pl.col(col)
|
||||
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
clip_exprs.append(clipped)
|
||||
|
||||
@@ -95,6 +95,27 @@ class Trainer:
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
# Step 1.5: 数据质量分析
|
||||
if self.verbose:
|
||||
print("\n[Step 1.5/7] 数据质量分析...")
|
||||
|
||||
try:
|
||||
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
# 获取特征列名(从训练集)
|
||||
feature_cols = data["train"].get("feature_cols", [])
|
||||
label_name = self.task.label_name
|
||||
|
||||
analyzer = DataQualityAnalyzer(
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_name,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
analyzer.analyze(data)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f" [警告] 数据质量分析失败: {e}")
|
||||
|
||||
# Step 2: 处理标签
|
||||
if self.verbose:
|
||||
print("\n[Step 2/7] 处理标签...")
|
||||
|
||||
@@ -43,6 +43,7 @@ class DataPipeline:
|
||||
label_processor_configs: Optional[
|
||||
List[Tuple[Type[BaseProcessor], Dict[str, Any]]]
|
||||
] = None,
|
||||
train_skip_days: int = 252,
|
||||
):
|
||||
"""初始化数据流水线
|
||||
|
||||
@@ -55,6 +56,7 @@ class DataPipeline:
|
||||
stock_pool_required_columns: 股票池筛选所需的额外列
|
||||
label_processor_configs: Label 数据处理器配置列表,格式与 processor_configs 相同
|
||||
例如:[(Winsorizer, {"lower": 0.01, "upper": 0.99})] 用于对 label 进行缩尾处理
|
||||
train_skip_days: 训练数据跳过前n天,用于避免训练初期数据不足的问题,默认252天
|
||||
"""
|
||||
self.factor_manager = factor_manager
|
||||
self.processor_configs = processor_configs or []
|
||||
@@ -64,6 +66,7 @@ class DataPipeline:
|
||||
self.fitted_processors: List[BaseProcessor] = []
|
||||
self.label_processor_configs = label_processor_configs or []
|
||||
self.fitted_label_processors: List[BaseProcessor] = []
|
||||
self.train_skip_days = train_skip_days
|
||||
|
||||
def prepare_data(
|
||||
self,
|
||||
@@ -220,6 +223,8 @@ class DataPipeline:
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""划分数据集
|
||||
|
||||
对于训练集,会根据 train_skip_days 参数跳过前n个交易日的数据。
|
||||
|
||||
Args:
|
||||
data: 完整数据
|
||||
date_range: 日期范围字典
|
||||
@@ -236,6 +241,33 @@ class DataPipeline:
|
||||
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
|
||||
split_df = data.filter(mask)
|
||||
|
||||
# 对训练集跳过前n天数据
|
||||
if split_name == "train" and self.train_skip_days > 0:
|
||||
original_count = len(split_df)
|
||||
# 获取唯一的交易日列表并按日期排序
|
||||
unique_dates = split_df["trade_date"].unique().sort()
|
||||
if len(unique_dates) > self.train_skip_days:
|
||||
# 跳过前n个交易日
|
||||
start_date = unique_dates[self.train_skip_days]
|
||||
split_df = split_df.filter(pl.col("trade_date") >= start_date)
|
||||
skipped_count = original_count - len(split_df)
|
||||
if verbose:
|
||||
print(
|
||||
f" {split_name}: {len(split_df)} 条记录"
|
||||
f" (跳过前{self.train_skip_days}天,减少{skipped_count}条)"
|
||||
)
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
f" [警告] 训练数据交易日数量({len(unique_dates)})"
|
||||
f"少于跳过天数({self.train_skip_days}),未进行过滤"
|
||||
)
|
||||
if verbose:
|
||||
print(f" {split_name}: {len(split_df)} 条记录")
|
||||
else:
|
||||
if verbose:
|
||||
print(f" {split_name}: {len(split_df)} 条记录")
|
||||
|
||||
result[split_name] = {
|
||||
"X": split_df.select(feature_cols),
|
||||
"y": split_df[label_name],
|
||||
@@ -243,9 +275,6 @@ class DataPipeline:
|
||||
"feature_cols": feature_cols,
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f" {split_name}: {len(split_df)} 条记录")
|
||||
|
||||
return result
|
||||
|
||||
def _preprocess(
|
||||
@@ -345,6 +374,30 @@ class DataPipeline:
|
||||
split_data[split_name]["X"] = split_df.select(feature_cols)
|
||||
split_data[split_name]["y"] = split_df[label_name]
|
||||
|
||||
# 删除标签为 NaN 的行
|
||||
for split_name in ["train", "val", "test"]:
|
||||
if split_name in split_data:
|
||||
y_series = split_data[split_name]["y"]
|
||||
y_nan_count = y_series.null_count()
|
||||
|
||||
if y_nan_count > 0:
|
||||
if verbose:
|
||||
print(f" 删除 {split_name} 集中 {y_nan_count} 个标签为NaN的行")
|
||||
|
||||
# 创建有效标签的mask
|
||||
valid_mask = y_series.is_not_null()
|
||||
|
||||
# 过滤所有相关数据
|
||||
split_data[split_name]["raw_data"] = split_data[split_name][
|
||||
"raw_data"
|
||||
].filter(valid_mask)
|
||||
split_data[split_name]["X"] = split_data[split_name]["X"].filter(
|
||||
valid_mask
|
||||
)
|
||||
split_data[split_name]["y"] = split_data[split_name]["y"].filter(
|
||||
valid_mask
|
||||
)
|
||||
|
||||
return split_data
|
||||
|
||||
def get_fitted_processors(self) -> List[BaseProcessor]:
|
||||
|
||||
@@ -6,9 +6,11 @@
|
||||
from src.training.tasks.base import BaseTask
|
||||
from src.training.tasks.regression_task import RegressionTask
|
||||
from src.training.tasks.rank_task import RankTask
|
||||
from src.training.tasks.tabm_regression_task import TabMRegressionTask
|
||||
|
||||
__all__ = [
|
||||
"BaseTask",
|
||||
"RegressionTask",
|
||||
"RankTask",
|
||||
"TabMRegressionTask",
|
||||
]
|
||||
|
||||
165
src/training/tasks/tabm_regression_task.py
Normal file
165
src/training/tasks/tabm_regression_task.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""TabM回归任务
|
||||
|
||||
TabM模型的回归训练任务实现。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from src.training.tasks.base import BaseTask
|
||||
from src.training.components.models.tabm_model import TabMModel
|
||||
|
||||
# Type alias for model type
|
||||
TabMModelType = TabMModel
|
||||
|
||||
|
||||
class TabMRegressionTask(BaseTask):
|
||||
"""TabM回归任务
|
||||
|
||||
使用TabM模型进行回归训练,支持:
|
||||
- 内置集成训练(ensemble_size)
|
||||
- 早停机制
|
||||
- 训练曲线绘制
|
||||
|
||||
Attributes:
|
||||
model_params: 模型参数字典
|
||||
label_name: 目标列名称
|
||||
model: TabMModel实例
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_params: Dict[str, Any],
|
||||
label_name: str = "future_return_5",
|
||||
):
|
||||
"""初始化TabM回归任务
|
||||
|
||||
Args:
|
||||
model_params: TabM模型参数,包含:
|
||||
- n_blocks: MLP层数
|
||||
- d_block: 每层神经元数
|
||||
- dropout: Dropout率
|
||||
- ensemble_size: 集成大小
|
||||
- batch_size: 批次大小
|
||||
- learning_rate: 学习率
|
||||
- weight_decay: 权重衰减
|
||||
- epochs: 训练轮数
|
||||
- early_stopping_patience: 早停耐心值
|
||||
label_name: 目标列名称
|
||||
"""
|
||||
super().__init__(model_params, label_name)
|
||||
self.model_params = model_params
|
||||
self.label_name = label_name
|
||||
self.model = None # type: Optional[TabMModelType]
|
||||
|
||||
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
|
||||
"""准备标签
|
||||
|
||||
回归任务不需要转换标签,直接返回原始数据。
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
未修改的数据字典
|
||||
"""
|
||||
# 回归任务:标签已经是连续值,无需转换
|
||||
return data
|
||||
|
||||
def fit(
|
||||
self,
|
||||
train_data: Dict[str, Any],
|
||||
val_data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""训练TabM模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据字典,包含:
|
||||
- X: 特征DataFrame
|
||||
- y: 标签Series
|
||||
val_data: 验证数据字典,包含:
|
||||
- X: 特征DataFrame
|
||||
- y: 标签Series
|
||||
"""
|
||||
print("\n[TabMRegressionTask] 开始训练...")
|
||||
|
||||
# 创建模型实例
|
||||
self.model = TabMModel(self.model_params)
|
||||
|
||||
# 提取训练数据
|
||||
X_train = train_data["X"]
|
||||
y_train = train_data["y"]
|
||||
X_val = val_data["X"]
|
||||
y_val = val_data["y"]
|
||||
|
||||
# 训练模型
|
||||
self.model.fit(X=X_train, y=y_train, eval_set=(X_val, y_val))
|
||||
|
||||
print("[TabMRegressionTask] 训练完成")
|
||||
|
||||
def predict(self, test_data: Dict[str, Any]) -> np.ndarray:
|
||||
"""生成预测
|
||||
|
||||
Args:
|
||||
test_data: 测试数据字典,包含:
|
||||
- X: 特征DataFrame
|
||||
|
||||
Returns:
|
||||
预测结果数组
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,请先调用fit()")
|
||||
|
||||
X_test = test_data["X"]
|
||||
return self.model.predict(X_test)
|
||||
|
||||
def get_model(self) -> Any:
|
||||
"""获取训练好的模型
|
||||
|
||||
Returns:
|
||||
TabMModel实例或None
|
||||
"""
|
||||
return self.model
|
||||
|
||||
def plot_training_metrics(self, output_path: Optional[str] = None) -> None:
|
||||
"""绘制训练指标
|
||||
|
||||
绘制训练和验证损失曲线。
|
||||
|
||||
Args:
|
||||
output_path: 图表保存路径,None则显示图表
|
||||
"""
|
||||
if self.model is None or not self.model.training_history_["train_loss"]:
|
||||
print("[TabMRegressionTask] 无训练历史可绘制")
|
||||
return
|
||||
|
||||
history = self.model.training_history_
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
epochs = range(1, len(history["train_loss"]) + 1)
|
||||
ax.plot(epochs, history["train_loss"], "b-", label="Train Loss", linewidth=2)
|
||||
|
||||
if history["val_loss"]:
|
||||
ax.plot(epochs, history["val_loss"], "r-", label="Val Loss", linewidth=2)
|
||||
|
||||
ax.set_xlabel("Epoch", fontsize=12)
|
||||
ax.set_ylabel("Loss (MSE)", fontsize=12)
|
||||
ax.set_title("TabM Training History", fontsize=14)
|
||||
ax.legend(fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if output_path:
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
print(f"[TabMRegressionTask] 训练曲线保存到: {output_path}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
Reference in New Issue
Block a user