feat(training): 新增 TabM 模型支持及数据质量优化
- 添加 TabMModel、TabPFNModel 深度学习模型实现 - 新增 DataQualityAnalyzer 进行训练前数据质量诊断 - 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性 - 支持 train_skip_days 参数跳过训练初期数据不足期 - Pipeline 自动清理标签为 NaN 的样本
This commit is contained in:
519
tests/test_tabm_nan_debug.py
Normal file
519
tests/test_tabm_nan_debug.py
Normal file
@@ -0,0 +1,519 @@
|
||||
"""TabM NaN 问题诊断测试 + train_skip_days 功能验证
|
||||
|
||||
诊断 loss 为 nan 的根因:
|
||||
1. 标签中是否有 NaN 或极端值
|
||||
2. 标准化后是否有 NaN
|
||||
3. 是否有方差为0或接近0的列
|
||||
4. train_skip_days 功能是否正常工作
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
TabMRegressionTask,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
)
|
||||
from src.training.components.filters import STFilter
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
)
|
||||
|
||||
|
||||
# TabM 模型参数(简化用于快速测试)
|
||||
MODEL_PARAMS = {
|
||||
"n_blocks": 2,
|
||||
"d_block": 128,
|
||||
"dropout": 0.1,
|
||||
"ensemble_size": 4, # 简化
|
||||
"batch_size": 256,
|
||||
"learning_rate": 1e-4, # 降低学习率
|
||||
"weight_decay": 1e-5,
|
||||
"epochs": 3,
|
||||
"early_stopping_patience": 5,
|
||||
}
|
||||
|
||||
# 测试日期范围
|
||||
TEST_DATE_RANGE = {
|
||||
"train": ("20200101", "20200630"),
|
||||
"val": ("20200701", "20200731"),
|
||||
"test": ("20200801", "20200831"),
|
||||
}
|
||||
|
||||
# 小范围测试用的因子排除列表
|
||||
EXCLUDED_FACTORS = [
|
||||
"GTJA_alpha001",
|
||||
"GTJA_alpha002",
|
||||
"GTJA_alpha003",
|
||||
"GTJA_alpha004",
|
||||
"GTJA_alpha005",
|
||||
"GTJA_alpha006",
|
||||
"GTJA_alpha007",
|
||||
"GTJA_alpha008",
|
||||
"GTJA_alpha009",
|
||||
"GTJA_alpha010",
|
||||
]
|
||||
|
||||
|
||||
class TestTabMNanDebug:
|
||||
"""TabM NaN 问题诊断测试类(使用 DataPipeline)"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def engine_and_factor_manager(self):
|
||||
"""准备 FactorEngine 和 FactorManager"""
|
||||
engine = FactorEngine()
|
||||
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
return {
|
||||
"engine": engine,
|
||||
"factor_manager": factor_manager,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def pipeline_data(self, engine_and_factor_manager):
|
||||
"""使用 DataPipeline 准备数据,并过滤掉全为NaN的列"""
|
||||
engine = engine_and_factor_manager["engine"]
|
||||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||||
|
||||
# 创建 DataPipeline(使用 train_skip_days=0 进行基础测试)
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=0, # 先不跳过天数进行测试
|
||||
)
|
||||
|
||||
# 准备数据
|
||||
data = pipeline.prepare_data(
|
||||
engine=engine,
|
||||
date_range=TEST_DATE_RANGE,
|
||||
label_name=LABEL_NAME,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 获取特征列
|
||||
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
|
||||
|
||||
# 过滤掉全为NaN的列(这些因子计算失败,没有修复意义)
|
||||
print("\n[DataPipeline] 检查并过滤全为NaN的特征列...")
|
||||
for split_name in ["train", "val", "test"]:
|
||||
X_df = data[split_name]["X"]
|
||||
y_series = data[split_name]["y"]
|
||||
raw_data = data[split_name]["raw_data"]
|
||||
|
||||
# 删除标签为NaN的行
|
||||
y_nan_count = y_series.null_count()
|
||||
if y_nan_count > 0:
|
||||
print(f" {split_name}: 发现 {y_nan_count} 个标签为NaN的行,将被删除")
|
||||
# 创建有效标签的mask
|
||||
valid_mask = y_series.is_not_null()
|
||||
# 过滤所有相关数据
|
||||
X_df = X_df.filter(valid_mask)
|
||||
y_series = y_series.filter(valid_mask)
|
||||
raw_data = raw_data.filter(valid_mask)
|
||||
# 更新数据
|
||||
data[split_name]["X"] = X_df
|
||||
data[split_name]["y"] = y_series
|
||||
data[split_name]["raw_data"] = raw_data
|
||||
|
||||
# 检查每列的NaN数量
|
||||
nan_counts = {col: X_df[col].null_count() for col in X_df.columns}
|
||||
total_rows = len(X_df)
|
||||
|
||||
# 找出全为NaN的列
|
||||
all_nan_cols = [
|
||||
col for col, count in nan_counts.items() if count == total_rows
|
||||
]
|
||||
|
||||
if all_nan_cols:
|
||||
print(
|
||||
f" {split_name}: 发现 {len(all_nan_cols)} 个全为NaN的列,将被删除"
|
||||
)
|
||||
print(
|
||||
f" 列名: {all_nan_cols[:5]}{'...' if len(all_nan_cols) > 5 else ''}"
|
||||
)
|
||||
|
||||
# 更新feature_cols
|
||||
feature_cols = [c for c in feature_cols if c not in all_nan_cols]
|
||||
|
||||
# 从X中删除这些列
|
||||
X_df = X_df.select(feature_cols)
|
||||
|
||||
# 从raw_data中也删除这些列(保留原始特征列以外的列如trade_date, ts_code等)
|
||||
raw_cols_to_keep = [
|
||||
c for c in raw_data.columns if c not in all_nan_cols
|
||||
]
|
||||
raw_data = raw_data.select(raw_cols_to_keep)
|
||||
|
||||
# 更新数据
|
||||
data[split_name]["X"] = X_df
|
||||
data[split_name]["raw_data"] = raw_data
|
||||
data[split_name]["feature_cols"] = feature_cols
|
||||
|
||||
# 验证没有NaN了
|
||||
X_np = X_df.to_numpy()
|
||||
nan_count = np.isnan(X_np).sum()
|
||||
assert nan_count == 0, f"{split_name} 中仍有 {nan_count} 个NaN"
|
||||
|
||||
print(f" 过滤后特征数: {len(feature_cols)}")
|
||||
print(" [通过] 所有特征列均无NaN")
|
||||
|
||||
return {
|
||||
"pipeline": pipeline,
|
||||
"data": data,
|
||||
"feature_cols": feature_cols,
|
||||
"engine": engine,
|
||||
}
|
||||
|
||||
def test_pipeline_data_structure(self, pipeline_data):
|
||||
"""诊断0: DataPipeline 返回的数据结构检查"""
|
||||
data = pipeline_data["data"]
|
||||
feature_cols = pipeline_data["feature_cols"]
|
||||
|
||||
print("\n[诊断0] DataPipeline 数据结构检查:")
|
||||
|
||||
# 检查数据结构
|
||||
assert "train" in data, "缺少 train 数据"
|
||||
assert "val" in data, "缺少 val 数据"
|
||||
assert "test" in data, "缺少 test 数据"
|
||||
|
||||
for split_name in ["train", "val", "test"]:
|
||||
split_data = data[split_name]
|
||||
assert "X" in split_data, f"{split_name} 缺少 X"
|
||||
assert "y" in split_data, f"{split_name} 缺少 y"
|
||||
assert "raw_data" in split_data, f"{split_name} 缺少 raw_data"
|
||||
assert "feature_cols" in split_data, f"{split_name} 缺少 feature_cols"
|
||||
|
||||
X = split_data["X"]
|
||||
y = split_data["y"]
|
||||
|
||||
print(f" {split_name}:")
|
||||
print(f" X 形状: {X.shape}")
|
||||
print(f" y 形状: {len(y)}")
|
||||
print(f" 特征数: {len(split_data['feature_cols'])}")
|
||||
|
||||
# 检查维度一致性
|
||||
assert X.shape[0] == len(y), f"{split_name} X 和 y 行数不匹配"
|
||||
assert X.shape[1] == len(feature_cols), f"{split_name} 特征数不匹配"
|
||||
|
||||
print(" [通过] 数据结构正确")
|
||||
|
||||
def test_label_quality_with_pipeline(self, pipeline_data):
|
||||
"""诊断1: 使用 DataPipeline 后的标签数据质量检查"""
|
||||
data = pipeline_data["data"]
|
||||
|
||||
# 检查所有数据集的标签质量
|
||||
for split_name in ["train", "val", "test"]:
|
||||
y = data[split_name]["y"]
|
||||
y_np = y.to_numpy()
|
||||
|
||||
print(f"\n[诊断1-{split_name}] 标签数据质量:")
|
||||
print(f" 总数: {len(y)}")
|
||||
print(f" NaN数量: {y.null_count()}")
|
||||
print(f" 均值: {y.mean():.6f}")
|
||||
print(f" 标准差: {y.std():.6f}")
|
||||
print(f" 最小值: {y.min():.6f}")
|
||||
print(f" 最大值: {y.max():.6f}")
|
||||
|
||||
inf_count = np.isinf(y_np).sum()
|
||||
nan_count = np.isnan(y_np).sum()
|
||||
print(f" inf数量: {inf_count}")
|
||||
print(f" nan数量: {nan_count}")
|
||||
|
||||
assert inf_count == 0, f"{split_name} 标签含 inf: {inf_count}"
|
||||
assert nan_count == 0, f"{split_name} 标签含 nan: {nan_count}"
|
||||
|
||||
def test_processed_data_quality_with_pipeline(self, pipeline_data):
|
||||
"""诊断2: 使用 DataPipeline 处理后的特征数据质量"""
|
||||
data = pipeline_data["data"]
|
||||
feature_cols = pipeline_data["feature_cols"]
|
||||
|
||||
for split_name in ["train", "val", "test"]:
|
||||
X = data[split_name]["X"]
|
||||
y = data[split_name]["y"]
|
||||
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
y_np = y.to_numpy().astype(np.float32)
|
||||
|
||||
print(f"\n[诊断2-{split_name}] 处理后数据质量:")
|
||||
print(f" X 形状: {X_np.shape}, dtype: {X_np.dtype}")
|
||||
print(f" y 形状: {y_np.shape}, dtype: {y_np.dtype}")
|
||||
print(f" X中NaN: {np.isnan(X_np).sum()}")
|
||||
print(f" X中Inf: {np.isinf(X_np).sum()}")
|
||||
print(f" y中NaN: {np.isnan(y_np).sum()}")
|
||||
print(f" y中Inf: {np.isinf(y_np).sum()}")
|
||||
|
||||
assert np.isnan(X_np).sum() == 0, f"{split_name} X含NaN"
|
||||
assert np.isnan(y_np).sum() == 0, f"{split_name} y含NaN"
|
||||
assert np.isinf(X_np).sum() == 0, f"{split_name} X含Inf"
|
||||
assert np.isinf(y_np).sum() == 0, f"{split_name} y含Inf"
|
||||
|
||||
def test_train_skip_days_functionality(self, engine_and_factor_manager):
|
||||
"""诊断3: train_skip_days 功能验证"""
|
||||
engine = engine_and_factor_manager["engine"]
|
||||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||||
|
||||
print("\n[诊断3] train_skip_days 功能验证:")
|
||||
|
||||
# 创建一个跳过50天数据的 pipeline
|
||||
skip_days = 50
|
||||
pipeline_with_skip = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=skip_days,
|
||||
)
|
||||
|
||||
# 准备数据
|
||||
data_with_skip = pipeline_with_skip.prepare_data(
|
||||
engine=engine,
|
||||
date_range=TEST_DATE_RANGE,
|
||||
label_name=LABEL_NAME,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 获取原始数据(不跳过天数)用于对比
|
||||
pipeline_no_skip = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=0,
|
||||
)
|
||||
|
||||
data_no_skip = pipeline_no_skip.prepare_data(
|
||||
engine=engine,
|
||||
date_range=TEST_DATE_RANGE,
|
||||
label_name=LABEL_NAME,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# 验证训练数据减少了
|
||||
train_with_skip = data_with_skip["train"]["raw_data"]
|
||||
train_no_skip = data_no_skip["train"]["raw_data"]
|
||||
|
||||
print(f"\n 对比结果:")
|
||||
print(f" 不跳过时训练数据: {len(train_no_skip)} 条")
|
||||
print(f" 跳过{skip_days}天后: {len(train_with_skip)} 条")
|
||||
print(f" 减少: {len(train_no_skip) - len(train_with_skip)} 条")
|
||||
|
||||
# 验证验证集和测试集不受影响
|
||||
assert len(data_with_skip["val"]["raw_data"]) == len(
|
||||
data_no_skip["val"]["raw_data"]
|
||||
), "val 数据不应该受 train_skip_days 影响"
|
||||
assert len(data_with_skip["test"]["raw_data"]) == len(
|
||||
data_no_skip["test"]["raw_data"]
|
||||
), "test 数据不应该受 train_skip_days 影响"
|
||||
|
||||
# 验证日期确实被跳过了
|
||||
if len(train_no_skip) > 0:
|
||||
dates_no_skip = sorted(train_no_skip["trade_date"].unique())
|
||||
dates_with_skip = sorted(train_with_skip["trade_date"].unique())
|
||||
|
||||
print(f"\n 日期对比:")
|
||||
print(
|
||||
f" 不跳过 - 最早日期: {dates_no_skip[0]}, 共 {len(dates_no_skip)} 个交易日"
|
||||
)
|
||||
print(
|
||||
f" 跳过 - 最早日期: {dates_with_skip[0]}, 共 {len(dates_with_skip)} 个交易日"
|
||||
)
|
||||
|
||||
# 验证跳过的数据确实从更晚的日期开始
|
||||
if len(dates_no_skip) > skip_days:
|
||||
expected_start_date = dates_no_skip[skip_days]
|
||||
assert dates_with_skip[0] == expected_start_date, (
|
||||
f"预期从 {expected_start_date} 开始,实际从 {dates_with_skip[0]} 开始"
|
||||
)
|
||||
print(f" [通过] 正确跳过前 {skip_days} 个交易日")
|
||||
|
||||
def test_training_with_pipeline(self, pipeline_data):
|
||||
"""诊断4: 使用 DataPipeline 处理后的数据进行训练测试"""
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from tabm import TabM
|
||||
|
||||
data = pipeline_data["data"]
|
||||
feature_cols = pipeline_data["feature_cols"]
|
||||
|
||||
# 获取训练数据
|
||||
X_train_df = data["train"]["X"]
|
||||
y_train_series = data["train"]["y"]
|
||||
|
||||
# 删除标签为NaN的行
|
||||
valid_mask = y_train_series.is_not_null()
|
||||
y_nan_count = y_train_series.null_count()
|
||||
if y_nan_count > 0:
|
||||
print(f" 发现 {y_nan_count} 个标签为NaN的行,将被删除")
|
||||
X_train_df = X_train_df.filter(valid_mask)
|
||||
y_train_series = y_train_series.filter(valid_mask)
|
||||
|
||||
X_train = X_train_df.to_numpy().astype(np.float32)
|
||||
y_train = y_train_series.to_numpy().astype(np.float32)
|
||||
|
||||
# 只取前1000条加速测试
|
||||
X_train = X_train[:1000]
|
||||
y_train = y_train[:1000]
|
||||
|
||||
print(f"\n[诊断4] DataPipeline 数据训练测试:")
|
||||
print(f" X 形状: {X_train.shape}")
|
||||
print(f" y 形状: {y_train.shape}")
|
||||
print(f" X中NaN: {np.isnan(X_train).sum()}, Inf: {np.isinf(X_train).sum()}")
|
||||
print(f" y中NaN: {np.isnan(y_train).sum()}, Inf: {np.isinf(y_train).sum()}")
|
||||
|
||||
# 创建 TabM 模型
|
||||
n_features = X_train.shape[1]
|
||||
model = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1,
|
||||
n_blocks=2,
|
||||
d_block=128,
|
||||
dropout=0.1,
|
||||
k=4,
|
||||
)
|
||||
|
||||
# 训练一个 epoch
|
||||
dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
|
||||
loader = DataLoader(dataset, batch_size=256, shuffle=True)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
model.train()
|
||||
losses = []
|
||||
|
||||
for batch_idx, (bx, by) in enumerate(loader):
|
||||
optimizer.zero_grad()
|
||||
outputs = model(bx) # [B, E, 1]
|
||||
outputs_squeezed = outputs.squeeze(-1) # [B, E]
|
||||
by_expanded = by.unsqueeze(-1).expand(-1, 4) # [B, E]
|
||||
loss = criterion(outputs_squeezed, by_expanded)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
print(f" 批次 {batch_idx + 1} loss: {loss.item():.6f}")
|
||||
|
||||
assert not torch.isnan(loss), f"批次 {batch_idx + 1} loss 为 nan!"
|
||||
|
||||
if batch_idx >= 2: # 只测试前3个批次
|
||||
break
|
||||
|
||||
print(f" [通过] 所有批次 loss 正常,无 NaN")
|
||||
|
||||
|
||||
class TestTrainSkipDaysEdgeCases:
|
||||
"""train_skip_days 边界情况测试"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def engine_and_factor_manager(self):
|
||||
"""准备 FactorEngine 和 FactorManager"""
|
||||
engine = FactorEngine()
|
||||
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
return {
|
||||
"engine": engine,
|
||||
"factor_manager": factor_manager,
|
||||
}
|
||||
|
||||
def test_skip_more_than_available_days(self, engine_and_factor_manager):
|
||||
"""测试跳过天数超过可用天数的情况"""
|
||||
engine = engine_and_factor_manager["engine"]
|
||||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||||
|
||||
print("\n[边界测试] 跳过天数超过可用天数:")
|
||||
|
||||
# 使用一个很大的跳过天数
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[(NullFiller, {"strategy": "mean"})],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=1000, # 超过测试期间的交易日数
|
||||
)
|
||||
|
||||
# 准备数据(应该能正常运行并发出警告)
|
||||
data = pipeline.prepare_data(
|
||||
engine=engine,
|
||||
date_range=TEST_DATE_RANGE,
|
||||
label_name=LABEL_NAME,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 即使跳过天数很多,也应该有数据返回
|
||||
# 如果交易日数少于跳过天数,应该保留所有数据(只发出警告)
|
||||
train_data = data["train"]["raw_data"]
|
||||
print(f" 训练数据量: {len(train_data)} 条")
|
||||
print(f" [通过] 程序未崩溃")
|
||||
|
||||
def test_skip_zero_days(self, engine_and_factor_manager):
|
||||
"""测试跳过0天(即不跳过)"""
|
||||
engine = engine_and_factor_manager["engine"]
|
||||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||||
|
||||
print("\n[边界测试] 跳过0天:")
|
||||
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[(NullFiller, {"strategy": "mean"})],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=0,
|
||||
)
|
||||
|
||||
data = pipeline.prepare_data(
|
||||
engine=engine,
|
||||
date_range=TEST_DATE_RANGE,
|
||||
label_name=LABEL_NAME,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
train_data = data["train"]["raw_data"]
|
||||
print(f" 训练数据量: {len(train_data)} 条")
|
||||
print(f" [通过] skip=0 时数据正常")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user