feat(training): 新增 TabM 模型支持及数据质量优化

- 添加 TabMModel、TabPFNModel 深度学习模型实现
- 新增 DataQualityAnalyzer 进行训练前数据质量诊断
- 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性
- 支持 train_skip_days 参数跳过训练初期数据不足期
- Pipeline 自动清理标签为 NaN 的样本
This commit is contained in:
2026-03-31 23:11:21 +08:00
parent 9e0114c745
commit 36a3ccbcc8
22 changed files with 4421 additions and 204 deletions

View File

@@ -14,7 +14,7 @@ from src.factors import FactorEngine
# =============================================================================
# 日期范围配置(正确的 train/val/test 三分法)
# =============================================================================
TRAIN_START = "20200101"
TRAIN_START = "20190101"
TRAIN_END = "20231231"
VAL_START = "20240101"
VAL_END = "20241231"
@@ -79,8 +79,8 @@ SELECTED_FACTORS = [
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
# "GTJA_alpha005",
# "GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
@@ -123,133 +123,133 @@ SELECTED_FACTORS = [
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
# "GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
# "GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
# "GTJA_alpha138",
"GTJA_alpha139",
# "GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
# "GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
"GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
# "GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
# "GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
# "GTJA_alpha051",
# "GTJA_alpha052",
# "GTJA_alpha053",
# "GTJA_alpha054",
# "GTJA_alpha056",
# "GTJA_alpha057",
# "GTJA_alpha058",
# "GTJA_alpha059",
# "GTJA_alpha060",
# "GTJA_alpha061",
# "GTJA_alpha062",
# "GTJA_alpha063",
# "GTJA_alpha064",
# "GTJA_alpha065",
# "GTJA_alpha066",
# "GTJA_alpha067",
# "GTJA_alpha068",
# "GTJA_alpha070",
# "GTJA_alpha071",
# "GTJA_alpha072",
# "GTJA_alpha073",
# "GTJA_alpha074",
# "GTJA_alpha076",
# "GTJA_alpha077",
# "GTJA_alpha078",
# "GTJA_alpha079",
# "GTJA_alpha080",
# "GTJA_alpha081",
# "GTJA_alpha082",
# "GTJA_alpha083",
# "GTJA_alpha084",
# "GTJA_alpha085",
# "GTJA_alpha086",
# "GTJA_alpha087",
# "GTJA_alpha088",
# "GTJA_alpha089",
# "GTJA_alpha090",
# "GTJA_alpha091",
# "GTJA_alpha092",
# "GTJA_alpha093",
# "GTJA_alpha094",
# "GTJA_alpha095",
# "GTJA_alpha096",
# "GTJA_alpha097",
# "GTJA_alpha098",
# "GTJA_alpha099",
# "GTJA_alpha100",
# "GTJA_alpha101",
# "GTJA_alpha102",
# "GTJA_alpha103",
# "GTJA_alpha104",
# "GTJA_alpha105",
# "GTJA_alpha106",
# "GTJA_alpha107",
# "GTJA_alpha108",
# "GTJA_alpha109",
# "GTJA_alpha110",
# "GTJA_alpha111",
# "GTJA_alpha112",
# # "GTJA_alpha113",
# "GTJA_alpha114",
# "GTJA_alpha115",
# "GTJA_alpha117",
# "GTJA_alpha118",
# "GTJA_alpha119",
# "GTJA_alpha120",
# # "GTJA_alpha121",
# "GTJA_alpha122",
# "GTJA_alpha123",
# "GTJA_alpha124",
# "GTJA_alpha125",
# "GTJA_alpha126",
# "GTJA_alpha127",
# "GTJA_alpha128",
# "GTJA_alpha129",
# "GTJA_alpha130",
# "GTJA_alpha131",
# "GTJA_alpha132",
# "GTJA_alpha133",
# "GTJA_alpha134",
# "GTJA_alpha135",
# "GTJA_alpha136",
# # "GTJA_alpha138",
# "GTJA_alpha139",
# # "GTJA_alpha140",
# "GTJA_alpha141",
# "GTJA_alpha142",
# "GTJA_alpha145",
# # "GTJA_alpha146",
# "GTJA_alpha148",
# "GTJA_alpha150",
# "GTJA_alpha151",
# "GTJA_alpha152",
# "GTJA_alpha153",
# "GTJA_alpha154",
# "GTJA_alpha155",
# "GTJA_alpha156",
# "GTJA_alpha157",
# "GTJA_alpha158",
# "GTJA_alpha159",
# "GTJA_alpha160",
# "GTJA_alpha161",
# # "GTJA_alpha162",
# "GTJA_alpha163",
# "GTJA_alpha164",
# # "GTJA_alpha165",
# "GTJA_alpha166",
# "GTJA_alpha167",
# "GTJA_alpha168",
# "GTJA_alpha169",
# "GTJA_alpha170",
# "GTJA_alpha171",
# "GTJA_alpha173",
# "GTJA_alpha174",
# "GTJA_alpha175",
# "GTJA_alpha176",
# "GTJA_alpha177",
# "GTJA_alpha178",
# "GTJA_alpha179",
# "GTJA_alpha180",
# # "GTJA_alpha183",
# "GTJA_alpha184",
# "GTJA_alpha185",
# "GTJA_alpha187",
# "GTJA_alpha188",
# "GTJA_alpha189",
# "GTJA_alpha191",
"chip_dispersion_90",
"chip_dispersion_70",
"cost_skewness",
@@ -488,6 +488,9 @@ MODEL_SAVE_DIR = "models" # 模型保存目录
# Top N 配置:每日推荐股票数量
TOP_N = 5 # 可调整为 10, 20 等
# 训练数据跳过天数配置
TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据避免训练初期数据不足的问题
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
"""生成输出文件路径。

View File

@@ -0,0 +1,514 @@
"""数据质量分析模块
提供数据质量检查功能,包括:
- 缺失值统计
- 零值统计
- 按日期检查全空列
"""
from typing import Any, Dict, List, Optional
import polars as pl
import numpy as np
class DataQualityAnalyzer:
"""数据质量分析器
用于分析训练数据的质量问题,帮助识别数据异常。
Attributes:
feature_cols: 特征列名列表
label_col: 标签列名
verbose: 是否打印详细信息
"""
def __init__(
self,
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
verbose: bool = True,
):
"""初始化数据质量分析器
Args:
feature_cols: 特征列名列表
label_col: 标签列名
verbose: 是否打印详细信息
"""
self.feature_cols = feature_cols or []
self.label_col = label_col
self.verbose = verbose
self.analysis_results: Dict[str, Any] = {}
def set_columns(self, feature_cols: List[str], label_col: str) -> None:
"""设置要分析的列
Args:
feature_cols: 特征列名列表
label_col: 标签列名
"""
self.feature_cols = feature_cols
self.label_col = label_col
def analyze(
self,
data: Dict[str, Dict[str, Any]],
split_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""执行完整的数据质量分析
Args:
data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}}
split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"]
Returns:
分析结果字典
"""
if not split_names:
split_names = ["train", "val", "test"]
if self.verbose:
print("\n" + "=" * 80)
print("数据质量分析报告")
print("=" * 80)
self.analysis_results = {}
for split_name in split_names:
if split_name not in data:
continue
split_data = data[split_name]
raw_df = split_data.get("raw_data")
if raw_df is None:
continue
if self.verbose:
print(f"\n[{split_name.upper()} 数据集]")
print("-" * 40)
split_results = self._analyze_split(raw_df, split_name)
self.analysis_results[split_name] = split_results
if self.verbose:
print("\n" + "=" * 80)
return self.analysis_results
def _analyze_split(
self,
df: pl.DataFrame,
split_name: str,
) -> Dict[str, Any]:
"""分析单个数据集划分
Args:
df: 数据框
split_name: 划分名称
Returns:
分析结果字典
"""
results = {
"total_rows": len(df),
"total_cols": len(df.columns),
"feature_cols": self.feature_cols,
"label_col": self.label_col,
"null_analysis": {},
"zero_analysis": {},
"all_null_by_date": {},
}
# 1. 分析特征列的缺失值
null_stats = self._analyze_null_values(df, self.feature_cols)
results["null_analysis"] = null_stats
if self.verbose:
self._print_null_analysis(null_stats)
# 2. 分析特征列的零值
zero_stats = self._analyze_zero_values(df, self.feature_cols)
results["zero_analysis"] = zero_stats
if self.verbose:
self._print_zero_analysis(zero_stats)
# 3. 检查是否存在某天某列全为空的情况
all_null_by_date = self._check_all_null_by_date(df, self.feature_cols)
results["all_null_by_date"] = all_null_by_date
if self.verbose:
self._print_all_null_by_date(all_null_by_date)
# 4. 分析标签列
if self.label_col and self.label_col in df.columns:
label_stats = self._analyze_label(df, self.label_col)
results["label_analysis"] = label_stats
if self.verbose:
self._print_label_analysis(label_stats)
return results
def _analyze_null_values(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""分析缺失值
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
缺失值统计字典
"""
stats = {
"total_cells": len(df) * len(cols),
"null_counts": {},
"null_percentages": {},
"columns_with_null": [],
"total_null_cells": 0,
}
for col in cols:
if col not in df.columns:
continue
null_count = df[col].null_count()
if null_count > 0:
null_pct = null_count / len(df) * 100
stats["null_counts"][col] = null_count
stats["null_percentages"][col] = null_pct
stats["columns_with_null"].append(col)
stats["total_null_cells"] += null_count
return stats
def _analyze_zero_values(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""分析零值
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
零值统计字典
"""
stats = {
"total_cells": len(df) * len(cols),
"zero_counts": {},
"zero_percentages": {},
"columns_with_zero": [],
"total_zero_cells": 0,
}
for col in cols:
if col not in df.columns:
continue
# 计算零值数量(排除空值)
non_null_series = df[col].drop_nulls()
if len(non_null_series) == 0:
continue
zero_count = (non_null_series == 0).sum()
if zero_count > 0:
zero_pct = zero_count / len(df) * 100
stats["zero_counts"][col] = int(zero_count)
stats["zero_percentages"][col] = zero_pct
stats["columns_with_zero"].append(col)
stats["total_zero_cells"] += int(zero_count)
return stats
def _check_all_null_by_date(
self,
df: pl.DataFrame,
cols: List[str],
) -> Dict[str, Any]:
"""检查是否存在某天某列全为空的情况
使用 polars lazy frame 进行内存安全的高效计算。
Args:
df: 数据框
cols: 要分析的列名列表
Returns:
全空检查结果字典
"""
results = {
"issues_found": False,
"issues": [],
}
if "trade_date" not in df.columns:
return results
# 过滤掉不在表中的列
valid_cols = [c for c in cols if c in df.columns]
if not valid_cols:
return results
# 使用 lazy frame 进行查询优化
lf = df.lazy()
# 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小)
# 为每个列创建单独的 null_count 聚合表达式
agg_exprs = [
pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols
]
agg_exprs.append(pl.len().alias("total_rows"))
agg_lf = lf.group_by("trade_date").agg(agg_exprs)
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
agg_df = agg_lf.collect()
# 在这个已经"脱水"的小表上进行逻辑检查
issues = []
for col in valid_cols:
null_col = f"{col}_nulls"
# 找出 null 数量等于总行数的日期
bad_dates = agg_df.filter(
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
).select(["trade_date", "total_rows"])
if not bad_dates.is_empty():
for row in bad_dates.to_dicts():
issues.append(
{
"date": row["trade_date"],
"column": col,
"total_rows": row["total_rows"],
}
)
if issues:
results["issues_found"] = True
results["issues"] = issues
return results
def _analyze_label(
self,
df: pl.DataFrame,
label_col: str,
) -> Dict[str, Any]:
"""分析标签列
Args:
df: 数据框
label_col: 标签列名
Returns:
标签分析字典
"""
stats = {
"total_count": len(df),
"null_count": 0,
"null_percentage": 0.0,
"zero_count": 0,
"zero_percentage": 0.0,
"min": None,
"max": None,
"mean": None,
"std": None,
}
if label_col not in df.columns:
return stats
series = df[label_col]
# 缺失值统计
null_count = series.null_count()
stats["null_count"] = null_count
stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0
# 零值统计
non_null_series = series.drop_nulls()
if len(non_null_series) > 0:
zero_count = (non_null_series == 0).sum()
stats["zero_count"] = int(zero_count)
stats["zero_percentage"] = zero_count / len(df) * 100
# 基本统计量
stats["min"] = float(non_null_series.min())
stats["max"] = float(non_null_series.max())
stats["mean"] = float(non_null_series.mean())
stats["std"] = float(non_null_series.std())
return stats
def _print_null_analysis(self, stats: Dict[str, Any]) -> None:
"""打印缺失值分析结果
Args:
stats: 缺失值统计字典
"""
total_cells = stats["total_cells"]
total_null = stats["total_null_cells"]
null_cols = stats["columns_with_null"]
print(f" 缺失值统计:")
print(f" 总单元格数: {total_cells:,}")
print(
f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)"
)
print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}")
if null_cols:
print(f" 缺失值最多的5个特征:")
sorted_cols = sorted(
stats["null_counts"].items(),
key=lambda x: x[1],
reverse=True,
)[:5]
for col, count in sorted_cols:
pct = stats["null_percentages"][col]
print(f" {col}: {count:,} ({pct:.2f}%)")
def _print_zero_analysis(self, stats: Dict[str, Any]) -> None:
"""打印零值分析结果
Args:
stats: 零值统计字典
"""
total_cells = stats["total_cells"]
total_zero = stats["total_zero_cells"]
zero_cols = stats["columns_with_zero"]
print(f" 零值统计:")
print(f" 总单元格数: {total_cells:,}")
print(
f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)"
)
print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}")
if zero_cols:
print(f" 零值最多的5个特征:")
sorted_cols = sorted(
stats["zero_counts"].items(),
key=lambda x: x[1],
reverse=True,
)[:5]
for col, count in sorted_cols:
pct = stats["zero_percentages"][col]
print(f" {col}: {count:,} ({pct:.2f}%)")
def _print_all_null_by_date(self, results: Dict[str, Any]) -> None:
"""打印按日期全空检查结果
Args:
results: 全空检查结果字典
"""
issues = results["issues"]
print(f" 按日期全空检查:")
if results["issues_found"]:
print(f" [警告] 发现 {len(issues)} 个问题:")
# 按日期分组显示
by_date = {}
for issue in issues:
date = issue["date"]
if date not in by_date:
by_date[date] = []
by_date[date].append(issue["column"])
for date in sorted(by_date.keys())[:5]: # 只显示前5个日期
cols = by_date[date]
print(f" 日期 {date}: {len(cols)} 列全为空")
if len(cols) <= 3:
print(f" 列名: {', '.join(cols)}")
if len(by_date) > 5:
print(f" ... 还有 {len(by_date) - 5} 个日期存在问题")
else:
print(f" [正常] 未发现某天某列全为空的情况")
def _print_label_analysis(self, stats: Dict[str, Any]) -> None:
"""打印标签分析结果
Args:
stats: 标签分析字典
"""
print(f" 标签列统计 ({self.label_col}):")
print(f" 总数: {stats['total_count']:,}")
print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)")
print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)")
if stats["mean"] is not None:
print(f" 最小值: {stats['min']:.6f}")
print(f" 最大值: {stats['max']:.6f}")
print(f" 均值: {stats['mean']:.6f}")
print(f" 标准差: {stats['std']:.6f}")
def get_summary(self) -> str:
"""获取分析结果摘要
Returns:
摘要字符串
"""
if not self.analysis_results:
return "尚未执行分析"
lines = ["数据质量分析摘要", "=" * 40]
for split_name, results in self.analysis_results.items():
lines.append(f"\n[{split_name.upper()}]")
lines.append(f" 总行数: {results['total_rows']:,}")
null_stats = results.get("null_analysis", {})
if null_stats.get("columns_with_null"):
lines.append(
f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, "
f"{len(null_stats['columns_with_null'])} 列受影响"
)
zero_stats = results.get("zero_analysis", {})
if zero_stats.get("columns_with_zero"):
lines.append(
f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, "
f"{len(zero_stats['columns_with_zero'])} 列受影响"
)
all_null = results.get("all_null_by_date", {})
if all_null.get("issues_found"):
lines.append(
f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题"
)
return "\n".join(lines)
def analyze_data_quality(
data: Dict[str, Dict[str, Any]],
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
verbose: bool = True,
) -> Dict[str, Any]:
"""便捷函数:执行数据质量分析
Args:
data: 数据字典
feature_cols: 特征列名列表
label_col: 标签列名
verbose: 是否打印详细信息
Returns:
分析结果字典
"""
analyzer = DataQualityAnalyzer(
feature_cols=feature_cols,
label_col=label_col,
verbose=verbose,
)
return analyzer.analyze(data)

View File

@@ -37,6 +37,7 @@ from src.experiment.common import (
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
@@ -155,6 +156,7 @@ def main():
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 RankTask

View File

@@ -38,6 +38,7 @@ from src.experiment.common import (
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
@@ -51,55 +52,55 @@ TRAINING_TYPE = "regression"
# 排除的因子列表
EXCLUDED_FACTORS = [
'GTJA_alpha036',
'GTJA_alpha032',
'GTJA_alpha010',
'GTJA_alpha005',
'CP',
'BP',
'debt_to_equity',
'current_ratio',
'GTJA_alpha002',
'GTJA_alpha027',
'GTJA_alpha064',
'GTJA_alpha062',
'GTJA_alpha043',
'GTJA_alpha044',
'GTJA_alpha120',
'GTJA_alpha117',
'GTJA_alpha103',
'GTJA_alpha104',
'GTJA_alpha105',
'GTJA_alpha073',
'GTJA_alpha077',
'GTJA_alpha085',
'GTJA_alpha090',
'GTJA_alpha087',
'GTJA_alpha083',
'GTJA_alpha092',
'GTJA_alpha133',
'GTJA_alpha131',
'GTJA_alpha126',
'GTJA_alpha124',
'GTJA_alpha162',
'GTJA_alpha164',
'GTJA_alpha157',
'GTJA_alpha177',
'price_to_avg_cost',
'cost_skewness',
'GTJA_alpha191',
'GTJA_alpha180',
'history_position',
'bottom_profit',
'mean_median_dev',
'smart_money_accumulation',
'GTJA_alpha013',
'GTJA_alpha099',
'GTJA_alpha107',
'GTJA_alpha119',
'GTJA_alpha141',
'GTJA_alpha130',
'GTJA_alpha173',
"GTJA_alpha036",
"GTJA_alpha032",
"GTJA_alpha010",
"GTJA_alpha005",
"CP",
"BP",
"debt_to_equity",
"current_ratio",
"GTJA_alpha002",
"GTJA_alpha027",
"GTJA_alpha064",
"GTJA_alpha062",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha120",
"GTJA_alpha117",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha073",
"GTJA_alpha077",
"GTJA_alpha085",
"GTJA_alpha090",
"GTJA_alpha087",
"GTJA_alpha083",
"GTJA_alpha092",
"GTJA_alpha133",
"GTJA_alpha131",
"GTJA_alpha126",
"GTJA_alpha124",
"GTJA_alpha162",
"GTJA_alpha164",
"GTJA_alpha157",
"GTJA_alpha177",
"price_to_avg_cost",
"cost_skewness",
"GTJA_alpha191",
"GTJA_alpha180",
"history_position",
"bottom_profit",
"mean_median_dev",
"smart_money_accumulation",
"GTJA_alpha013",
"GTJA_alpha099",
"GTJA_alpha107",
"GTJA_alpha119",
"GTJA_alpha141",
"GTJA_alpha130",
"GTJA_alpha173",
]
# 模型参数配置
@@ -184,6 +185,7 @@ def main():
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 RegressionTask

View File

@@ -0,0 +1,375 @@
# %% md
# # TabM 回归训练流程
#
# 使用 TabM (Tabular MLP with Ensembles) 模型进行回归训练。
# TabM 通过内置集成机制ensemble_size=32实现高效的多模型集成。
# %% md
# ## 1. 导入依赖
# %%
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
TabMRegressionTask,
NullFiller,
Winsorizer,
StandardScaler,
CrossSectionalStandardScaler,
)
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
# 训练类型标识
TRAINING_TYPE = "tabm_regression"
# %% md
# ## 2. 训练特定配置
# %%
# Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 排除的因子列表(与 LightGBM 回归保持一致)
EXCLUDED_FACTORS = [
# "GTJA_alpha001",
# "GTJA_alpha002",
# "GTJA_alpha003",
# "GTJA_alpha004",
# "GTJA_alpha005",
# "GTJA_alpha006",
# "GTJA_alpha007",
# "GTJA_alpha008",
# "GTJA_alpha009",
# "GTJA_alpha010",
# "GTJA_alpha011",
# "GTJA_alpha012",
# "GTJA_alpha013",
# "GTJA_alpha014",
# "GTJA_alpha015",
# "GTJA_alpha016",
# "GTJA_alpha017",
# "GTJA_alpha018",
# "GTJA_alpha019",
# "GTJA_alpha020",
# "GTJA_alpha022",
# "GTJA_alpha023",
# "GTJA_alpha024",
# "GTJA_alpha025",
# "GTJA_alpha026",
# "GTJA_alpha027",
# "GTJA_alpha028",
# "GTJA_alpha029",
# "GTJA_alpha031",
# "GTJA_alpha032",
# "GTJA_alpha033",
# "GTJA_alpha034",
# "GTJA_alpha035",
# "GTJA_alpha036",
# "GTJA_alpha037",
# # "GTJA_alpha038",
# "GTJA_alpha039",
# "GTJA_alpha040",
# "GTJA_alpha041",
# "GTJA_alpha042",
# "GTJA_alpha043",
# "GTJA_alpha044",
# "GTJA_alpha045",
# "GTJA_alpha046",
# "GTJA_alpha047",
# "GTJA_alpha048",
# "GTJA_alpha049",
# "GTJA_alpha050",
# "GTJA_alpha051",
# "GTJA_alpha052",
# "GTJA_alpha053",
# "GTJA_alpha054",
# "GTJA_alpha056",
# "GTJA_alpha057",
# "GTJA_alpha058",
# "GTJA_alpha059",
# "GTJA_alpha060",
# "GTJA_alpha061",
# "GTJA_alpha062",
# "GTJA_alpha063",
# "GTJA_alpha064",
# "GTJA_alpha065",
# "GTJA_alpha066",
# "GTJA_alpha067",
# "GTJA_alpha068",
# "GTJA_alpha070",
# "GTJA_alpha071",
# "GTJA_alpha072",
# "GTJA_alpha073",
# "GTJA_alpha074",
# "GTJA_alpha076",
# "GTJA_alpha077",
# "GTJA_alpha078",
# "GTJA_alpha079",
# "GTJA_alpha080",
# "GTJA_alpha081",
# "GTJA_alpha082",
# "GTJA_alpha083",
# "GTJA_alpha084",
# "GTJA_alpha085",
# "GTJA_alpha086",
# "GTJA_alpha087",
# "GTJA_alpha088",
# "GTJA_alpha089",
# "GTJA_alpha090",
# "GTJA_alpha091",
# "GTJA_alpha092",
# "GTJA_alpha093",
# "GTJA_alpha094",
# "GTJA_alpha095",
# "GTJA_alpha096",
# "GTJA_alpha097",
# "GTJA_alpha098",
# "GTJA_alpha099",
# "GTJA_alpha100",
# "GTJA_alpha101",
# "GTJA_alpha102",
# "GTJA_alpha103",
# "GTJA_alpha104",
# "GTJA_alpha105",
# "GTJA_alpha106",
# "GTJA_alpha107",
# "GTJA_alpha108",
# "GTJA_alpha109",
# "GTJA_alpha110",
# "GTJA_alpha111",
# "GTJA_alpha112",
# # "GTJA_alpha113",
# "GTJA_alpha114",
# "GTJA_alpha115",
# "GTJA_alpha117",
# "GTJA_alpha118",
# "GTJA_alpha119",
# "GTJA_alpha120",
# # "GTJA_alpha121",
# "GTJA_alpha122",
# "GTJA_alpha123",
# "GTJA_alpha124",
# "GTJA_alpha125",
# "GTJA_alpha126",
# "GTJA_alpha127",
# "GTJA_alpha128",
# "GTJA_alpha129",
# "GTJA_alpha130",
# "GTJA_alpha131",
# "GTJA_alpha132",
# "GTJA_alpha133",
# "GTJA_alpha134",
# "GTJA_alpha135",
# "GTJA_alpha136",
# # "GTJA_alpha138",
# "GTJA_alpha139",
# # "GTJA_alpha140",
# "GTJA_alpha141",
# "GTJA_alpha142",
# "GTJA_alpha145",
# # "GTJA_alpha146",
# "GTJA_alpha148",
# "GTJA_alpha150",
# "GTJA_alpha151",
# "GTJA_alpha152",
# "GTJA_alpha153",
# "GTJA_alpha154",
# "GTJA_alpha155",
# "GTJA_alpha156",
# "GTJA_alpha157",
# "GTJA_alpha158",
# "GTJA_alpha159",
# "GTJA_alpha160",
# "GTJA_alpha161",
# "GTJA_alpha162",
# "GTJA_alpha163",
# "GTJA_alpha164",
# # "GTJA_alpha165",
# "GTJA_alpha166",
# "GTJA_alpha167",
# "GTJA_alpha168",
# "GTJA_alpha169",
# "GTJA_alpha170",
# "GTJA_alpha171",
# "GTJA_alpha173",
# "GTJA_alpha174",
# "GTJA_alpha175",
# "GTJA_alpha176",
# "GTJA_alpha177",
# "GTJA_alpha178",
# "GTJA_alpha179",
# "GTJA_alpha180",
# # "GTJA_alpha183",
# "GTJA_alpha184",
# "GTJA_alpha185",
# "GTJA_alpha187",
# "GTJA_alpha188",
# "GTJA_alpha189",
# "GTJA_alpha191",
# "chip_dispersion_90",
# "chip_dispersion_70",
# "cost_skewness",
# "dispersion_change_20",
# "price_to_avg_cost",
# "price_to_median_cost",
# "mean_median_dev",
# "trap_pressure",
# "bottom_profit",
# "history_position",
# "winner_rate_surge_5",
# "winner_rate_cs_rank",
# "winner_rate_dev_20",
# "winner_rate_volatility",
# "smart_money_accumulation",
# "winner_vol_corr_20",
# "cost_base_momentum",
# "bottom_cost_stability",
# "pivot_reversion",
# "chip_transition",
]
# TabM 模型参数配置(来自用户提供的示例代码)
MODEL_PARAMS = {
# ==================== MLP 结构 ====================
"n_blocks": 3, # MLP 层数
"d_block": 256, # 每层神经元数
"dropout": 0.3, # Dropout 率
# ==================== 集成机制 ====================
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
# ==================== 训练参数 ====================
"batch_size": 1024, # 批次大小
"learning_rate": 1e-3, # 学习率
"weight_decay": 1e-5, # 权重衰减
"epochs": 100, # 训练轮数
# ==================== 早停 ====================
"early_stopping_patience": 30, # 早停耐心值
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabm_regression_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabM 回归模型训练")
print("=" * 80)
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
# 【关键】TabM 需要标准化输入,使用 StandardScaler
# 处理顺序NullFiller -> Winsorizer -> StandardScaler
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布
(StandardScaler, {}), # TabM 需要标准化输入
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 TabMRegressionTask
print("\n[4] 创建 TabMRegressionTask")
task = TabMRegressionTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 绘制训练曲线
print("\n[7] 绘制训练曲线")
task.plot_training_metrics(
output_path=os.path.join(OUTPUT_DIR, "tabm_training_curve.png")
)
# 8. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[8] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
print("\n" + "=" * 80)
print("TabM 训练流程完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_regression_output.csv')}")
print("=" * 80)
return results
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,425 @@
# %% md
# # TabPFN 回归训练流程
#
# 使用 TabPFN (Prior-Data Fitted Network) 进行回归预测。
# TabPFN 通过上下文学习进行预测,无需传统梯度下降训练过程。
# %% md
# ## 1. 导入依赖
# %%
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
NullFiller,
Winsorizer,
StandardScaler,
CrossSectionalStandardScaler,
)
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.training.components.models import TabPFNModel
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
)
# 训练类型标识
TRAINING_TYPE = "tabpfn"
# %% md
# ## 2. 训练特定配置
# %%
# Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 排除的因子列表(与 regression.py 保持一致)
EXCLUDED_FACTORS = [
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
"GTJA_alpha011",
"GTJA_alpha012",
"GTJA_alpha013",
"GTJA_alpha014",
"GTJA_alpha015",
"GTJA_alpha016",
"GTJA_alpha017",
"GTJA_alpha018",
"GTJA_alpha019",
"GTJA_alpha020",
"GTJA_alpha022",
"GTJA_alpha023",
"GTJA_alpha024",
"GTJA_alpha025",
"GTJA_alpha026",
"GTJA_alpha027",
"GTJA_alpha028",
"GTJA_alpha029",
"GTJA_alpha031",
"GTJA_alpha032",
"GTJA_alpha033",
"GTJA_alpha034",
"GTJA_alpha035",
"GTJA_alpha036",
"GTJA_alpha037",
# "GTJA_alpha038",
"GTJA_alpha039",
"GTJA_alpha040",
"GTJA_alpha041",
"GTJA_alpha042",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha045",
"GTJA_alpha046",
"GTJA_alpha047",
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
# "GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
# "GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
# "GTJA_alpha138",
"GTJA_alpha139",
# "GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
# "GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
"GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
# "GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
# "GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
"chip_dispersion_90",
"chip_dispersion_70",
"cost_skewness",
"dispersion_change_20",
"price_to_avg_cost",
"price_to_median_cost",
"mean_median_dev",
"trap_pressure",
"bottom_profit",
"history_position",
"winner_rate_surge_5",
"winner_rate_cs_rank",
"winner_rate_dev_20",
"winner_rate_volatility",
"smart_money_accumulation",
"winner_vol_corr_20",
"cost_base_momentum",
"bottom_cost_stability",
"pivot_reversion",
"chip_transition",
]
# 模型参数配置
MODEL_PARAMS = {
# ==================== 设备配置 ====================
"device": "cuda", # 计算设备: "cuda" 或 "cpu"(默认 cuda
# ==================== 上下文限制 ====================
"max_context_size": 100, # 16GB GPU 建议 1000-300032GB 可尝试 5000-8000
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabpfn_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
# %% md
# ## 3. 自定义 TabPFN 任务
# %%
from src.training.tasks import RegressionTask
class TabPFNTask(RegressionTask):
"""TabPFN 回归任务
继承自 RegressionTask但使用 TabPFNModel 作为模型。
TabPFN 不需要传统的训练过程,而是通过上下文学习进行预测。
"""
def __init__(self, model_params: dict, label_name: str):
"""初始化 TabPFN 任务
Args:
model_params: TabPFN 参数字典
label_name: Label 列名称
"""
# 不调用父类 __init__直接初始化以避免创建 LightGBMModel
from src.training.tasks.base import BaseTask
BaseTask.__init__(self, model_params, label_name)
self.evals_result: dict | None = None
self.model = TabPFNModel(params=model_params)
def fit(self, train_data: dict, val_data: dict) -> None:
"""训练 TabPFN 模型
TabPFN 通过将训练数据加载到模型上下文中进行"训练"
不需要传统的梯度下降优化过程。
Args:
train_data: 训练数据 {"X": DataFrame, "y": Series}
val_data: 验证数据,用于评估但不参与训练
"""
X_train = train_data["X"]
y_train = train_data["y"]
X_val = val_data.get("X")
y_val = val_data.get("y")
# TabPFN 使用 eval_set 进行验证
self.model.fit(
X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None
)
def get_model(self) -> TabPFNModel:
"""获取训练好的模型实例"""
return self.model
# %% md
# ## 4. 主函数
# %%
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabPFN 回归模型训练")
print("=" * 80)
print("\n[说明] TabPFN 使用上下文学习(In-Context Learning)")
print(" 训练过程实际是加载数据到模型上下文。")
print(" 如果训练数据超过上下文限制,会自动截取最近的数据。")
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
# (CrossSectionalStandardScaler, {}),
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
# (StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
)
# 4. 创建 TabPFNTask
print("\n[4] 创建 TabPFNTask")
task = TabPFNTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[7] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
# 8. 输出 TabPFN 特有指标
print("\n" + "=" * 80)
print("TabPFN 训练完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabpfn_output.csv')}")
# 显示验证集评估结果(如果可用)
model = task.get_model()
best_score = model.get_best_score()
if best_score:
print("\n[验证集评估指标]")
for metric, value in best_score.get("valid_0", {}).items():
print(f" - {metric}: {value:.6f}")
print("=" * 80)
return results
if __name__ == "__main__":
main()

15
src/scripts/check_gpu.py Normal file
View File

@@ -0,0 +1,15 @@
import torch
# 查看 PyTorch 版本(关键!)
print(f"PyTorch 版本: {torch.__version__}")
# CPU 版本会显示: 2.6.0+cpu
# GPU 版本会显示: 2.6.0+cu118 / 2.6.0+cu121 / 2.6.0+cu124 等
# 检查 CUDA 是否可用
print(f"CUDA 可用: {torch.cuda.is_available()}")
# 如果有 CUDA查看版本
if torch.cuda.is_available():
print(f"CUDA 版本: {torch.version.cuda}")
print(f"GPU 数量: {torch.cuda.device_count()}")
print(f"GPU 名称: {torch.cuda.get_device_name(0)}")

View File

@@ -29,7 +29,7 @@ from src.training.components.processors import (
)
# 模型
from src.training.components.models import LightGBMModel
from src.training.components.models import LightGBMModel, TabMModel
# 数据过滤器
from src.training.components.filters import BaseFilter, STFilter
@@ -50,7 +50,7 @@ from src.training.config import TrainingConfig
from src.training.factor_manager import FactorManager
from src.training.pipeline import DataPipeline
from src.training.result_analyzer import ResultAnalyzer
from src.training.tasks import BaseTask, RegressionTask, RankTask
from src.training.tasks import BaseTask, RegressionTask, RankTask, TabMRegressionTask
# 从 trainer_v2 导入新 Trainer推荐
from src.training.core.trainer_v2 import Trainer as TrainerV2
@@ -79,6 +79,7 @@ __all__ = [
"STFilter",
# 模型
"LightGBMModel",
"TabMModel",
# 训练核心(旧版,已废弃)
"StockPoolManager",
"Trainer",
@@ -93,5 +94,6 @@ __all__ = [
"BaseTask",
"RegressionTask",
"RankTask",
"TabMRegressionTask",
"TrainerV2", # 新的 Trainer推荐
]

View File

@@ -5,5 +5,7 @@
from src.training.components.models.lightgbm import LightGBMModel
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
from src.training.components.models.tabpfn_model import TabPFNModel
from src.training.components.models.tabm_model import TabMModel
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel"]
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel", "TabPFNModel", "TabMModel"]

View File

@@ -0,0 +1,368 @@
"""TabM模型实现
TabM (Tabular Multilayer Perceptron with Ensembles)
基于 rtdl_revisiting_models 的 TabM 模型,支持内置集成。
"""
from typing import Dict, Any, List, Optional
from pathlib import Path
import pickle
import numpy as np
import polars as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tabm import TabM
from src.training.components.base import BaseModel
from src.training.registry import register_model
@register_model("tabm")
class TabMModel(BaseModel):
"""TabM回归模型
特点:
- 使用MLP架构
- 内置集成机制(ensemble_size),显存开销远小于独立模型
- 训练时所有集成成员独立优化,保持多样性
- 预测时取集成成员均值获得稳定结果
Attributes:
name: 模型名称标识
params: 模型参数字典
model: TabM模型实例
device: 计算设备(cuda/cpu)
training_history_: 训练历史记录
feature_names_: 特征名称列表
"""
name = "tabm"
def __init__(self, params: Dict[str, Any]):
"""初始化TabM模型
Args:
params: 模型参数字典,包含:
- n_blocks: MLP层数 (默认: 3)
- d_block: 每层神经元数 (默认: 256)
- dropout: Dropout率 (默认: 0.1)
- ensemble_size: 集成大小 (默认: 32)
- batch_size: 批次大小 (默认: 1024)
- learning_rate: 学习率 (默认: 1e-3)
- weight_decay: 权重衰减 (默认: 1e-5)
- epochs: 训练轮数 (默认: 50)
- early_stopping_patience: 早停耐心值 (默认: 10)
"""
self.params = params
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.training_history_: Dict[str, List[float]] = {
"train_loss": [],
"val_loss": [],
}
self.feature_names_: Optional[List[str]] = None
# 损失函数
self.criterion = nn.MSELoss()
def _make_loader(
self, X: np.ndarray, y: Optional[np.ndarray] = None, shuffle: bool = False
) -> DataLoader:
"""创建DataLoader
Args:
X: 特征数组 [N, n_features]
y: 标签数组 [N] 或 None
shuffle: 是否打乱数据
Returns:
DataLoader实例
"""
if y is not None:
dataset = TensorDataset(torch.from_numpy(X), torch.from_numpy(y))
else:
dataset = TensorDataset(torch.from_numpy(X))
batch_size = self.params.get("batch_size", 1024)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
def _validate(self, val_loader: DataLoader) -> float:
"""验证模型
Args:
val_loader: 验证数据加载器
Returns:
平均验证损失
"""
self.model.eval()
total_loss = 0.0
n_batches = 0
with torch.no_grad():
for batch in val_loader:
if len(batch) == 2:
bx, by = batch
bx, by = bx.to(self.device), by.to(self.device)
else:
bx = batch[0].to(self.device)
by = None
# 预测时取集成成员均值
outputs = self.model(bx) # [B, E, 1]
preds = outputs.mean(dim=1).squeeze(-1) # [B]
if by is not None:
loss = self.criterion(preds, by).item()
total_loss += loss
n_batches += 1
return total_loss / max(n_batches, 1)
def fit(
self, X: pl.DataFrame, y: pl.Series, eval_set: Optional[tuple] = None
) -> "TabMModel":
"""训练TabM模型
训练策略:
1. 对所有集成成员独立计算Loss保持多样性
2. 验证和预测时取ensemble成员均值
Args:
X: 训练特征DataFrame
y: 训练标签Series
eval_set: 验证集元组 (X_val, y_val),可选
Returns:
self
"""
# 保存特征名称
self.feature_names_ = X.columns
# 【关键】数据类型强制转换为float32
# PyTorch对float64支持较差避免使用Polars/Numpy默认类型
X_np = X.to_numpy().astype(np.float32)
y_np = y.to_numpy().astype(np.float32)
# 创建DataLoader
train_loader = self._make_loader(X_np, y_np, shuffle=True)
val_loader = None
if eval_set is not None:
X_val, y_val = eval_set
X_val_np = X_val.to_numpy().astype(np.float32)
y_val_np = y_val.to_numpy().astype(np.float32)
val_loader = self._make_loader(X_val_np, y_val_np, shuffle=False)
n_features = X_np.shape[1]
ensemble_size = self.params.get("ensemble_size", 32)
# 初始化TabM模型使用TabM.make()自动填充默认参数
self.model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1, # 回归任务输出维度为1
n_blocks=self.params.get("n_blocks", 3),
d_block=self.params.get("d_block", 256),
dropout=self.params.get("dropout", 0.1),
k=ensemble_size, # 集成大小
).to(self.device)
# 优化器
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.params.get("learning_rate", 1e-3),
weight_decay=self.params.get("weight_decay", 1e-5),
)
# 训练参数
epochs = self.params.get("epochs", 50)
early_stopping_patience = self.params.get("early_stopping_patience", 10)
# 训练循环
best_val_loss = float("inf")
patience_counter = 0
print(f"[TabM] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}")
for epoch in range(epochs):
# 训练阶段
self.model.train()
train_loss = 0.0
n_train_batches = 0
for bx, by in train_loader:
bx, by = bx.to(self.device), by.to(self.device)
optimizer.zero_grad()
# 前向传播
# outputs形状: [Batch, Ensemble, 1]
outputs = self.model(bx)
outputs_squeezed = outputs.squeeze(-1) # [B, E]
# 【关键】针对所有集成成员计算Loss
# 不先取均值,让每个集成成员独立收敛,保持集成多样性
by_expanded = by.unsqueeze(-1).expand(-1, ensemble_size) # [B, E]
loss = self.criterion(outputs_squeezed, by_expanded)
loss.backward()
optimizer.step()
train_loss += loss.item()
n_train_batches += 1
avg_train_loss = train_loss / max(n_train_batches, 1)
self.training_history_["train_loss"].append(avg_train_loss)
# 验证阶段
if val_loader is not None:
val_loss = self._validate(val_loader)
self.training_history_["val_loss"].append(val_loss)
# 早停逻辑
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabM] Epoch {epoch + 1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {val_loss:.6f}"
)
if patience_counter >= early_stopping_patience:
print(f"[TabM] 早停触发epoch {epoch + 1}")
break
else:
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabM] Epoch {epoch + 1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f}"
)
print(f"[TabM] 训练完成")
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""生成预测
预测时对ensemble_size个成员取均值获得稳定结果。
Args:
X: 特征DataFrame
Returns:
预测结果数组 [N]
"""
if self.model is None:
raise RuntimeError("模型未训练请先调用fit()")
# 数据类型转换
X_np = X.to_numpy().astype(np.float32)
loader = self._make_loader(X_np, shuffle=False)
self.model.eval()
all_preds = []
with torch.no_grad():
for batch in loader:
bx = batch[0].to(self.device)
# 预测时取集成成员均值
outputs = self.model(bx) # [B, E, 1]
preds = outputs.mean(dim=1).squeeze(-1) # [B]
all_preds.append(preds.cpu().numpy())
return np.concatenate(all_preds)
def feature_importance(self) -> Optional[pl.Series]:
"""获取特征重要性
TabM没有内置特征重要性计算返回None。
Returns:
None
"""
return None
def save(self, path: str | Path) -> None:
"""保存模型
保存模型state_dict和元数据params, feature_names, training_history
Args:
path: 保存路径
"""
if self.model is None:
raise RuntimeError("模型未训练,无法保存")
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# 保存模型权重
model_path = path.with_suffix(".pt")
torch.save(self.model.state_dict(), model_path)
# 保存元数据
meta_path = path.with_suffix(".meta")
meta = {
"params": self.params,
"feature_names": self.feature_names_,
"training_history": self.training_history_,
"device": str(self.device),
}
with open(meta_path, "wb") as f:
pickle.dump(meta, f)
print(f"[TabM] 模型保存到: {path}")
@classmethod
def load(cls, path: str) -> "TabMModel":
"""加载模型
Args:
path: 模型路径(不含扩展名)
Returns:
加载的TabMModel实例
"""
path = Path(path)
# 加载元数据
meta_path = path.with_suffix(".meta")
with open(meta_path, "rb") as f:
meta = pickle.load(f)
# 创建实例
instance = cls(meta["params"])
instance.feature_names_ = meta["feature_names"]
instance.training_history_ = meta["training_history"]
# 重建模型结构
if instance.feature_names_ is not None:
n_features = len(instance.feature_names_)
ensemble_size = instance.params.get("ensemble_size", 32)
instance.model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1, # 回归任务输出维度为1
n_blocks=instance.params.get("n_blocks", 3),
d_block=instance.params.get("d_block", 256),
dropout=instance.params.get("dropout", 0.1),
k=ensemble_size, # 集成大小
).to(instance.device)
# 加载权重
model_path = path.with_suffix(".pt")
instance.model.load_state_dict(
torch.load(model_path, map_location=instance.device)
)
print(f"[TabM] 模型从 {path} 加载完成")
return instance

View File

@@ -0,0 +1,296 @@
"""TabPFN 模型实现
基于 TabPFN (Prior-Data Fitted Network) 的回归模型实现。
TabPFN 利用预训练的 Transformer 网络,通过上下文学习(in-context learning)
进行快速的小样本/中样本回归预测,无需传统训练过程。
"""
import json
import os
from pathlib import Path
from typing import Any, Optional
import numpy as np
import pandas as pd
import polars as pl
from scipy.stats import spearmanr
from src.training.components.base import BaseModel
from src.training.registry import register_model
os.environ["HF_TOKEN"] = "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"
@register_model("tabpfn")
class TabPFNModel(BaseModel):
"""TabPFN 回归模型
使用 TabPFN 库实现基于 Prior-Data Fitted Network 的回归预测。
该模型通过上下文学习方式进行预测,无需传统梯度下降训练。
支持 GPU 加速和自动上下文截断处理。
Attributes:
name: 模型名称 "tabpfn"
params: TabPFN 参数字典
model: TabPFNRegressor 实例
feature_names_: 特征名称列表
evals_result_: 训练评估结果
best_score_: 最佳评估指标
"""
name = "tabpfn"
# TabPFN 官方限制(最大样本数),可通过 ignore_pretraining_limits=True 扩展
MAX_CONTEXT_SIZE = 10000
def __init__(self, params: Optional[dict] = None):
"""初始化 TabPFN 模型
Args:
params: TabPFN 参数字典,支持以下参数:
- device: 计算设备,'cuda''cpu'(默认 'cpu'
- model_path: 本地模型权重文件路径(可选)
- N_ensemble: 集成数量,用于降低预测方差(默认 1
- max_context_size: 最大上下文样本数(默认 50000
Examples:
>>> model = TabPFNModel(params={
... "device": "cuda",
... "N_ensemble": 5,
... })
"""
self.params = dict(params) if params is not None else {}
self.model = None
self.feature_names_: Optional[list] = None
self.evals_result_: Optional[dict] = None
self.best_score_: Optional[dict] = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
eval_set: Optional[tuple] = None,
) -> "TabPFNModel":
"""训练/加载 TabPFN 模型
TabPFN 采用上下文学习,"fit" 操作实际上是加载训练数据到模型上下文。
如果训练数据超过上下文限制,会自动截取最近的数据。
Args:
X: 特征矩阵 (Polars DataFrame)
y: 目标变量 (Polars Series)
eval_set: 验证集元组 (X_val, y_val),用于评估模型性能
Returns:
self (支持链式调用)
Raises:
ImportError: 未安装 tabpfn
RuntimeError: 模型初始化或加载失败
"""
from tabpfn import TabPFNRegressor
self.feature_names_ = X.columns
# 转换为 numpy 数组
X_np = X.to_numpy()
y_np = y.to_numpy()
# 处理上下文大小限制
max_context = self.params.get("max_context_size", self.MAX_CONTEXT_SIZE)
if len(X_np) > max_context:
print(
f"[TabPFN] 训练数据 {len(X_np)} 超过上下文限制 {max_context},截取最近数据"
)
X_np = X_np[-max_context:]
y_np = y_np[-max_context:]
# 初始化模型
# TabPFNRegressor 需要设置 ignore_pretraining_limits=True 以支持超过 10,000 样本
device = self.params.get("device", "cuda")
ignore_limits = self.params.get("ignore_pretraining_limits", True)
self.model = TabPFNRegressor(
device=device,
ignore_pretraining_limits=ignore_limits,
n_estimators=1
)
# 加载上下文TabPFN 的 "fit" 是加载上下文)
print("[TabPFN] 加载训练数据到上下文...")
self.model.fit(X_np, y_np)
# 评估验证集
if eval_set is not None:
X_val, y_val = eval_set
val_preds = self.predict(X_val)
y_val_np = y_val.to_numpy()
# 计算评估指标
mse = np.mean((y_val_np - val_preds) ** 2)
rank_ic, p_value = spearmanr(val_preds, y_val_np)
self.evals_result_ = {
"valid_0": {
"mse": [mse],
"rank_ic": [rank_ic],
}
}
self.best_score_ = {
"valid_0": {
"mse": mse,
"rank_ic": rank_ic,
"rank_ic_pvalue": p_value,
}
}
print(f"[TabPFN] 验证集 MSE: {mse:.6f}, Rank IC: {rank_ic:.4f}")
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征矩阵 (Polars DataFrame)
Returns:
预测结果 (numpy ndarray)
Raises:
RuntimeError: 模型未初始化时调用
"""
if self.model is None:
raise RuntimeError("模型尚未初始化,请先调用 fit()")
X_np = X.to_numpy()
result = self.model.predict(X_np)
return np.asarray(result)
def predict_with_uncertainty(
self, X: pl.DataFrame
) -> tuple[np.ndarray, np.ndarray]:
"""预测并返回不确定性估计
利用 N_ensemble 预测的标准差作为不确定性估计。
Args:
X: 特征矩阵 (Polars DataFrame)
Returns:
(predictions, uncertainties) 元组,均为 numpy ndarray
"""
if self.model is None:
raise RuntimeError("模型尚未初始化,请先调用 fit()")
X_np = X.to_numpy()
predictions = self.model.predict(X_np)
# 如果使用了 ensemble可以通过多次预测计算标准差
# 注意:这需要修改 TabPFNRegressor 的使用方式
# 这里返回预测值的零不确定性作为默认行为
uncertainties = np.zeros_like(predictions)
return np.asarray(predictions), np.asarray(uncertainties)
def get_evals_result(self) -> Optional[dict]:
"""获取训练评估结果
Returns:
评估结果字典,如果未进行评估返回 None
"""
return self.evals_result_
def get_best_score(self) -> Optional[dict]:
"""获取最佳评分
Returns:
最佳评分字典,如果未进行评估返回 None
"""
return self.best_score_
def evaluate(self, X: pl.DataFrame, y: pl.Series) -> dict[str, float]:
"""评估模型性能
计算回归任务常用指标MSE 和 Rank IC。
Args:
X: 特征矩阵
y: 真实目标值
Returns:
评估指标字典,包含 mse 和 rank_ic
"""
preds = self.predict(X)
y_np = y.to_numpy()
mse = float(np.mean((y_np - preds) ** 2))
rank_ic_result = spearmanr(preds, y_np)
rank_ic = float(rank_ic_result.correlation)
p_value = float(rank_ic_result.pvalue)
return {
"mse": mse,
"rank_ic": rank_ic,
"rank_ic_pvalue": p_value,
}
def feature_importance(self) -> None:
"""TabPFN 不支持传统特征重要性
TabPFN 是基于 Transformer 的上下文学习模型,
不提供类似决策树的特征重要性指标。
Returns:
None
"""
return None
def save(self, path: str) -> None:
"""保存模型元数据和配置
TabPFN 模型本身不支持序列化保存,因此只保存:
- 模型参数配置
- 特征名称列表
- 上下文数据摘要(样本数、特征数)
注意:实际使用时需要重新 fit 来加载上下文。
Args:
path: 保存路径
"""
save_data = {
"model_type": self.name,
"params": self.params,
"feature_names": self.feature_names_,
"evals_result": self.evals_result_,
"best_score": self.best_score_,
}
# 保存为 JSON
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
@classmethod
def load(cls, path: str) -> "TabPFNModel":
"""加载模型配置
注意TabPFN 模型需要重新 fit 才能使用,
此方法仅恢复模型参数配置。
Args:
path: 配置文件路径
Returns:
配置恢复的 TabPFNModel 实例(未 fit
"""
with open(path, "r", encoding="utf-8") as f:
save_data = json.load(f)
instance = cls(params=save_data.get("params", {}))
instance.feature_names_ = save_data.get("feature_names")
instance.evals_result_ = save_data.get("evals_result")
instance.best_score_ = save_data.get("best_score")
print(f"[TabPFN] 已加载模型配置,需要调用 fit() 重新加载上下文")
return instance

View File

@@ -3,6 +3,7 @@
包含标准化、缩尾、缺失值填充等数据处理器。
"""
import math
from typing import List, Literal, Optional, Union
import polars as pl
@@ -88,7 +89,9 @@ class NullFiller(BaseProcessor):
"""
if not self.by_date and self.strategy in ("mean", "median"):
for col in self.feature_cols:
if col in X.columns and X[col].dtype.is_numeric():
if col in X.columns and (
X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean
):
if self.strategy == "mean":
self.stats_[col] = X[col].mean() or 0.0
else: # median
@@ -119,11 +122,14 @@ class NullFiller(BaseProcessor):
raise ValueError(f"未知的填充策略: {self.strategy}")
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用0填充缺失值"""
"""使用0填充缺失值(同时处理 NaN 和 null"""
expressions = []
for col in X.columns:
if col in self.feature_cols and X[col].dtype.is_numeric():
expr = pl.col(col).fill_null(0).alias(col)
if col in self.feature_cols and (
X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean
):
# 先 fill_nan 再 fill_null确保两种缺失值都被处理
expr = pl.col(col).fill_nan(0).fill_null(0).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -131,11 +137,19 @@ class NullFiller(BaseProcessor):
return X.select(expressions)
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用指定值填充缺失值"""
"""使用指定值填充缺失值(同时处理 NaN 和 null"""
expressions = []
for col in X.columns:
if col in self.feature_cols and X[col].dtype.is_numeric():
expr = pl.col(col).fill_null(self.fill_value).alias(col)
if col in self.feature_cols and (
X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean
):
# 先 fill_nan 再 fill_null
expr = (
pl.col(col)
.fill_nan(self.fill_value)
.fill_null(self.fill_value)
.alias(col)
)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -143,12 +157,13 @@ class NullFiller(BaseProcessor):
return X.select(expressions)
def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用全局统计量填充(训练集学到的统计量)"""
"""使用全局统计量填充(训练集学到的统计量,同时处理 NaN 和 null"""
expressions = []
for col in X.columns:
if col in self.stats_:
fill_val = self.stats_[col]
expr = pl.col(col).fill_null(fill_val).alias(col)
# 先 fill_nan 再 fill_null
expr = pl.col(col).fill_nan(fill_val).fill_null(fill_val).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -156,8 +171,9 @@ class NullFiller(BaseProcessor):
return X.select(expressions)
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用每天截面统计量填充"""
# 确定需要处理的数值
"""使用每天截面统计量填充(同时处理 NaN 和 null"""
# 确定需要处理的列(仅 numeric 类型,排除 boolean
# 注意boolean 类型没有 NaN 概念fill_nan 会报错
target_cols = [
col
for col in self.feature_cols
@@ -180,10 +196,20 @@ class NullFiller(BaseProcessor):
result = X.with_columns(stat_exprs)
# 使用统计量填充缺失值
# 注意如果某天某列全为null统计量也会为null所以需要链式填充
# 同时处理 NaN 和 null
fill_exprs = []
for col in X.columns:
if col in target_cols:
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
# 先用当天统计量填充 NaN 和 null如果统计量也是null则用0填充
expr = (
pl.col(col)
.fill_nan(pl.col(f"{col}_stat"))
.fill_null(pl.col(f"{col}_stat"))
.fill_nan(0) # 如果统计量是 NaN再用 0 填充
.fill_null(0) # 如果统计量是 null再用 0 填充
.alias(col)
)
fill_exprs.append(expr)
else:
fill_exprs.append(pl.col(col))
@@ -230,17 +256,40 @@ class StandardScaler(BaseProcessor):
self
"""
for col in self.feature_cols:
# 仅处理数值类型,排除布尔类型(标准化布尔类型语义不明确)
if col in X.columns and X[col].dtype.is_numeric():
col_mean = X[col].mean()
col_std = X[col].std()
if col_mean is not None and col_std is not None:
# 关键修复:检查是否为 None 且不是 NaN
# 注意:使用 try-except 处理类型转换,避免 LSP 类型检查错误
try:
mean_is_valid = (
col_mean is not None
and isinstance(col_mean, (int, float))
and not math.isnan(col_mean)
)
std_is_valid = (
col_std is not None
and isinstance(col_std, (int, float))
and not math.isnan(col_std)
)
except (TypeError, ValueError):
mean_is_valid = False
std_is_valid = False
if mean_is_valid and std_is_valid:
self.mean_[col] = col_mean
self.std_[col] = col_std
else:
# 如果统计量无效使用默认值mean=0, std=1
# 防止 transform 时产生更多 NaN
self.mean_[col] = 0.0
self.std_[col] = 1.0
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""标准化(使用训练集学到的参数)
"""标准化(使用训练集学到的参数,增加 NaN 保护
Args:
X: 待转换数据
@@ -253,7 +302,18 @@ class StandardScaler(BaseProcessor):
if col in self.mean_ and col in self.std_:
# 避免除以0
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
# 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN
expr = (
((pl.col(col) - self.mean_[col]) / std_val)
.fill_nan(0)
.fill_null(0)
.alias(col)
)
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但未学习到统计量的列
# 统一转换为float并同时处理 NaN 和 null
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -308,13 +368,24 @@ class CrossSectionalStandardScaler(BaseProcessor):
# 构建表达式列表
expressions = []
for col in X.columns:
# 仅处理数值类型,排除布尔类型(标准化布尔类型语义不明确)
if col in self.feature_cols and X[col].dtype.is_numeric():
# 截面标准化:每天独立计算均值和标准差
# 避免除以0当std为0时设为1
# 关键修复:先 fill_nan 再 fill_null防止计算产生的 NaN
expr = (
(pl.col(col) - pl.col(col).mean().over(self.date_col))
/ (pl.col(col).std().over(self.date_col) + 1e-10)
).alias(col)
(
(pl.col(col) - pl.col(col).mean().over(self.date_col))
/ (pl.col(col).std().over(self.date_col) + 1e-10)
)
.fill_nan(0)
.fill_null(0)
.alias(col)
)
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但类型不匹配的列转换为float并同时处理 NaN 和 null
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -384,6 +455,7 @@ class Winsorizer(BaseProcessor):
"""
if not self.by_date:
for col in self.feature_cols:
# 仅处理数值类型排除布尔类型quantile 不支持布尔类型)
if col in X.columns and X[col].dtype.is_numeric():
self.bounds_[col] = {
"lower": X[col].quantile(self.lower),
@@ -414,13 +486,19 @@ class Winsorizer(BaseProcessor):
upper = self.bounds_[col]["upper"]
expr = pl.col(col).clip(lower, upper).alias(col)
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但未学习到边界的列如全为NaN、布尔列等
# 统一转换为float并填充0
expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
"""每日独立缩尾"""
# 确定需要处理的数值
# 确定需要处理的列(仅 numeric 类型,排除 boolean
# 注意quantile 操作不支持布尔类型
target_cols = [
col
for col in self.feature_cols
@@ -444,9 +522,11 @@ class Winsorizer(BaseProcessor):
clip_exprs = []
for col in X.columns:
if col in target_cols:
# 先用当天分位数缩尾如果分位数是null该日全为NaN则填充0
clipped = (
pl.col(col)
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
.fill_null(0)
.alias(col)
)
clip_exprs.append(clipped)

View File

@@ -95,6 +95,27 @@ class Trainer:
verbose=self.verbose,
)
# Step 1.5: 数据质量分析
if self.verbose:
print("\n[Step 1.5/7] 数据质量分析...")
try:
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
# 获取特征列名(从训练集)
feature_cols = data["train"].get("feature_cols", [])
label_name = self.task.label_name
analyzer = DataQualityAnalyzer(
feature_cols=feature_cols,
label_col=label_name,
verbose=self.verbose,
)
analyzer.analyze(data)
except Exception as e:
if self.verbose:
print(f" [警告] 数据质量分析失败: {e}")
# Step 2: 处理标签
if self.verbose:
print("\n[Step 2/7] 处理标签...")

View File

@@ -43,6 +43,7 @@ class DataPipeline:
label_processor_configs: Optional[
List[Tuple[Type[BaseProcessor], Dict[str, Any]]]
] = None,
train_skip_days: int = 252,
):
"""初始化数据流水线
@@ -55,6 +56,7 @@ class DataPipeline:
stock_pool_required_columns: 股票池筛选所需的额外列
label_processor_configs: Label 数据处理器配置列表,格式与 processor_configs 相同
例如:[(Winsorizer, {"lower": 0.01, "upper": 0.99})] 用于对 label 进行缩尾处理
train_skip_days: 训练数据跳过前n天用于避免训练初期数据不足的问题默认252天
"""
self.factor_manager = factor_manager
self.processor_configs = processor_configs or []
@@ -64,6 +66,7 @@ class DataPipeline:
self.fitted_processors: List[BaseProcessor] = []
self.label_processor_configs = label_processor_configs or []
self.fitted_label_processors: List[BaseProcessor] = []
self.train_skip_days = train_skip_days
def prepare_data(
self,
@@ -220,6 +223,8 @@ class DataPipeline:
) -> Dict[str, Dict[str, Any]]:
"""划分数据集
对于训练集,会根据 train_skip_days 参数跳过前n个交易日的数据。
Args:
data: 完整数据
date_range: 日期范围字典
@@ -236,6 +241,33 @@ class DataPipeline:
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
split_df = data.filter(mask)
# 对训练集跳过前n天数据
if split_name == "train" and self.train_skip_days > 0:
original_count = len(split_df)
# 获取唯一的交易日列表并按日期排序
unique_dates = split_df["trade_date"].unique().sort()
if len(unique_dates) > self.train_skip_days:
# 跳过前n个交易日
start_date = unique_dates[self.train_skip_days]
split_df = split_df.filter(pl.col("trade_date") >= start_date)
skipped_count = original_count - len(split_df)
if verbose:
print(
f" {split_name}: {len(split_df)} 条记录"
f" (跳过前{self.train_skip_days}天,减少{skipped_count}条)"
)
else:
if verbose:
print(
f" [警告] 训练数据交易日数量({len(unique_dates)})"
f"少于跳过天数({self.train_skip_days}),未进行过滤"
)
if verbose:
print(f" {split_name}: {len(split_df)} 条记录")
else:
if verbose:
print(f" {split_name}: {len(split_df)} 条记录")
result[split_name] = {
"X": split_df.select(feature_cols),
"y": split_df[label_name],
@@ -243,9 +275,6 @@ class DataPipeline:
"feature_cols": feature_cols,
}
if verbose:
print(f" {split_name}: {len(split_df)} 条记录")
return result
def _preprocess(
@@ -345,6 +374,30 @@ class DataPipeline:
split_data[split_name]["X"] = split_df.select(feature_cols)
split_data[split_name]["y"] = split_df[label_name]
# 删除标签为 NaN 的行
for split_name in ["train", "val", "test"]:
if split_name in split_data:
y_series = split_data[split_name]["y"]
y_nan_count = y_series.null_count()
if y_nan_count > 0:
if verbose:
print(f" 删除 {split_name} 集中 {y_nan_count} 个标签为NaN的行")
# 创建有效标签的mask
valid_mask = y_series.is_not_null()
# 过滤所有相关数据
split_data[split_name]["raw_data"] = split_data[split_name][
"raw_data"
].filter(valid_mask)
split_data[split_name]["X"] = split_data[split_name]["X"].filter(
valid_mask
)
split_data[split_name]["y"] = split_data[split_name]["y"].filter(
valid_mask
)
return split_data
def get_fitted_processors(self) -> List[BaseProcessor]:

View File

@@ -6,9 +6,11 @@
from src.training.tasks.base import BaseTask
from src.training.tasks.regression_task import RegressionTask
from src.training.tasks.rank_task import RankTask
from src.training.tasks.tabm_regression_task import TabMRegressionTask
__all__ = [
"BaseTask",
"RegressionTask",
"RankTask",
"TabMRegressionTask",
]

View File

@@ -0,0 +1,165 @@
"""TabM回归任务
TabM模型的回归训练任务实现。
"""
from typing import Dict, Any, Optional
from pathlib import Path
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
from src.training.tasks.base import BaseTask
from src.training.components.models.tabm_model import TabMModel
# Type alias for model type
TabMModelType = TabMModel
class TabMRegressionTask(BaseTask):
"""TabM回归任务
使用TabM模型进行回归训练支持:
- 内置集成训练ensemble_size
- 早停机制
- 训练曲线绘制
Attributes:
model_params: 模型参数字典
label_name: 目标列名称
model: TabMModel实例
"""
def __init__(
self,
model_params: Dict[str, Any],
label_name: str = "future_return_5",
):
"""初始化TabM回归任务
Args:
model_params: TabM模型参数包含:
- n_blocks: MLP层数
- d_block: 每层神经元数
- dropout: Dropout率
- ensemble_size: 集成大小
- batch_size: 批次大小
- learning_rate: 学习率
- weight_decay: 权重衰减
- epochs: 训练轮数
- early_stopping_patience: 早停耐心值
label_name: 目标列名称
"""
super().__init__(model_params, label_name)
self.model_params = model_params
self.label_name = label_name
self.model = None # type: Optional[TabMModelType]
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签
回归任务不需要转换标签,直接返回原始数据。
Args:
data: 数据字典
Returns:
未修改的数据字典
"""
# 回归任务:标签已经是连续值,无需转换
return data
def fit(
self,
train_data: Dict[str, Any],
val_data: Dict[str, Any],
) -> None:
"""训练TabM模型
Args:
train_data: 训练数据字典,包含:
- X: 特征DataFrame
- y: 标签Series
val_data: 验证数据字典,包含:
- X: 特征DataFrame
- y: 标签Series
"""
print("\n[TabMRegressionTask] 开始训练...")
# 创建模型实例
self.model = TabMModel(self.model_params)
# 提取训练数据
X_train = train_data["X"]
y_train = train_data["y"]
X_val = val_data["X"]
y_val = val_data["y"]
# 训练模型
self.model.fit(X=X_train, y=y_train, eval_set=(X_val, y_val))
print("[TabMRegressionTask] 训练完成")
def predict(self, test_data: Dict[str, Any]) -> np.ndarray:
"""生成预测
Args:
test_data: 测试数据字典,包含:
- X: 特征DataFrame
Returns:
预测结果数组
"""
if self.model is None:
raise RuntimeError("模型未训练请先调用fit()")
X_test = test_data["X"]
return self.model.predict(X_test)
def get_model(self) -> Any:
"""获取训练好的模型
Returns:
TabMModel实例或None
"""
return self.model
def plot_training_metrics(self, output_path: Optional[str] = None) -> None:
"""绘制训练指标
绘制训练和验证损失曲线。
Args:
output_path: 图表保存路径None则显示图表
"""
if self.model is None or not self.model.training_history_["train_loss"]:
print("[TabMRegressionTask] 无训练历史可绘制")
return
history = self.model.training_history_
fig, ax = plt.subplots(figsize=(10, 6))
epochs = range(1, len(history["train_loss"]) + 1)
ax.plot(epochs, history["train_loss"], "b-", label="Train Loss", linewidth=2)
if history["val_loss"]:
ax.plot(epochs, history["val_loss"], "r-", label="Val Loss", linewidth=2)
ax.set_xlabel("Epoch", fontsize=12)
ax.set_ylabel("Loss (MSE)", fontsize=12)
ax.set_title("TabM Training History", fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
if output_path:
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f"[TabMRegressionTask] 训练曲线保存到: {output_path}")
else:
plt.show()
plt.close()