Files
ProStock/tests/test_tabm_nan_debug.py
liaozhaorun 36a3ccbcc8 feat(training): 新增 TabM 模型支持及数据质量优化
- 添加 TabMModel、TabPFNModel 深度学习模型实现
- 新增 DataQualityAnalyzer 进行训练前数据质量诊断
- 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性
- 支持 train_skip_days 参数跳过训练初期数据不足期
- Pipeline 自动清理标签为 NaN 的样本
2026-03-31 23:11:21 +08:00

520 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""TabM NaN 问题诊断测试 + train_skip_days 功能验证
诊断 loss 为 nan 的根因:
1. 标签中是否有 NaN 或极端值
2. 标准化后是否有 NaN
3. 是否有方差为0或接近0的列
4. train_skip_days 功能是否正常工作
"""
import numpy as np
import polars as pl
import pytest
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
TabMRegressionTask,
NullFiller,
Winsorizer,
StandardScaler,
)
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
)
# TabM 模型参数(简化用于快速测试)
MODEL_PARAMS = {
"n_blocks": 2,
"d_block": 128,
"dropout": 0.1,
"ensemble_size": 4, # 简化
"batch_size": 256,
"learning_rate": 1e-4, # 降低学习率
"weight_decay": 1e-5,
"epochs": 3,
"early_stopping_patience": 5,
}
# 测试日期范围
TEST_DATE_RANGE = {
"train": ("20200101", "20200630"),
"val": ("20200701", "20200731"),
"test": ("20200801", "20200831"),
}
# 小范围测试用的因子排除列表
EXCLUDED_FACTORS = [
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
]
class TestTabMNanDebug:
"""TabM NaN 问题诊断测试类(使用 DataPipeline"""
@pytest.fixture(scope="class")
def engine_and_factor_manager(self):
"""准备 FactorEngine 和 FactorManager"""
engine = FactorEngine()
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
return {
"engine": engine,
"factor_manager": factor_manager,
}
@pytest.fixture(scope="class")
def pipeline_data(self, engine_and_factor_manager):
"""使用 DataPipeline 准备数据并过滤掉全为NaN的列"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
# 创建 DataPipeline使用 train_skip_days=0 进行基础测试)
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=0, # 先不跳过天数进行测试
)
# 准备数据
data = pipeline.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=True,
)
# 获取特征列
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
# 过滤掉全为NaN的列这些因子计算失败没有修复意义
print("\n[DataPipeline] 检查并过滤全为NaN的特征列...")
for split_name in ["train", "val", "test"]:
X_df = data[split_name]["X"]
y_series = data[split_name]["y"]
raw_data = data[split_name]["raw_data"]
# 删除标签为NaN的行
y_nan_count = y_series.null_count()
if y_nan_count > 0:
print(f" {split_name}: 发现 {y_nan_count} 个标签为NaN的行将被删除")
# 创建有效标签的mask
valid_mask = y_series.is_not_null()
# 过滤所有相关数据
X_df = X_df.filter(valid_mask)
y_series = y_series.filter(valid_mask)
raw_data = raw_data.filter(valid_mask)
# 更新数据
data[split_name]["X"] = X_df
data[split_name]["y"] = y_series
data[split_name]["raw_data"] = raw_data
# 检查每列的NaN数量
nan_counts = {col: X_df[col].null_count() for col in X_df.columns}
total_rows = len(X_df)
# 找出全为NaN的列
all_nan_cols = [
col for col, count in nan_counts.items() if count == total_rows
]
if all_nan_cols:
print(
f" {split_name}: 发现 {len(all_nan_cols)} 个全为NaN的列将被删除"
)
print(
f" 列名: {all_nan_cols[:5]}{'...' if len(all_nan_cols) > 5 else ''}"
)
# 更新feature_cols
feature_cols = [c for c in feature_cols if c not in all_nan_cols]
# 从X中删除这些列
X_df = X_df.select(feature_cols)
# 从raw_data中也删除这些列保留原始特征列以外的列如trade_date, ts_code等
raw_cols_to_keep = [
c for c in raw_data.columns if c not in all_nan_cols
]
raw_data = raw_data.select(raw_cols_to_keep)
# 更新数据
data[split_name]["X"] = X_df
data[split_name]["raw_data"] = raw_data
data[split_name]["feature_cols"] = feature_cols
# 验证没有NaN了
X_np = X_df.to_numpy()
nan_count = np.isnan(X_np).sum()
assert nan_count == 0, f"{split_name} 中仍有 {nan_count} 个NaN"
print(f" 过滤后特征数: {len(feature_cols)}")
print(" [通过] 所有特征列均无NaN")
return {
"pipeline": pipeline,
"data": data,
"feature_cols": feature_cols,
"engine": engine,
}
def test_pipeline_data_structure(self, pipeline_data):
"""诊断0: DataPipeline 返回的数据结构检查"""
data = pipeline_data["data"]
feature_cols = pipeline_data["feature_cols"]
print("\n[诊断0] DataPipeline 数据结构检查:")
# 检查数据结构
assert "train" in data, "缺少 train 数据"
assert "val" in data, "缺少 val 数据"
assert "test" in data, "缺少 test 数据"
for split_name in ["train", "val", "test"]:
split_data = data[split_name]
assert "X" in split_data, f"{split_name} 缺少 X"
assert "y" in split_data, f"{split_name} 缺少 y"
assert "raw_data" in split_data, f"{split_name} 缺少 raw_data"
assert "feature_cols" in split_data, f"{split_name} 缺少 feature_cols"
X = split_data["X"]
y = split_data["y"]
print(f" {split_name}:")
print(f" X 形状: {X.shape}")
print(f" y 形状: {len(y)}")
print(f" 特征数: {len(split_data['feature_cols'])}")
# 检查维度一致性
assert X.shape[0] == len(y), f"{split_name} X 和 y 行数不匹配"
assert X.shape[1] == len(feature_cols), f"{split_name} 特征数不匹配"
print(" [通过] 数据结构正确")
def test_label_quality_with_pipeline(self, pipeline_data):
"""诊断1: 使用 DataPipeline 后的标签数据质量检查"""
data = pipeline_data["data"]
# 检查所有数据集的标签质量
for split_name in ["train", "val", "test"]:
y = data[split_name]["y"]
y_np = y.to_numpy()
print(f"\n[诊断1-{split_name}] 标签数据质量:")
print(f" 总数: {len(y)}")
print(f" NaN数量: {y.null_count()}")
print(f" 均值: {y.mean():.6f}")
print(f" 标准差: {y.std():.6f}")
print(f" 最小值: {y.min():.6f}")
print(f" 最大值: {y.max():.6f}")
inf_count = np.isinf(y_np).sum()
nan_count = np.isnan(y_np).sum()
print(f" inf数量: {inf_count}")
print(f" nan数量: {nan_count}")
assert inf_count == 0, f"{split_name} 标签含 inf: {inf_count}"
assert nan_count == 0, f"{split_name} 标签含 nan: {nan_count}"
def test_processed_data_quality_with_pipeline(self, pipeline_data):
"""诊断2: 使用 DataPipeline 处理后的特征数据质量"""
data = pipeline_data["data"]
feature_cols = pipeline_data["feature_cols"]
for split_name in ["train", "val", "test"]:
X = data[split_name]["X"]
y = data[split_name]["y"]
X_np = X.to_numpy().astype(np.float32)
y_np = y.to_numpy().astype(np.float32)
print(f"\n[诊断2-{split_name}] 处理后数据质量:")
print(f" X 形状: {X_np.shape}, dtype: {X_np.dtype}")
print(f" y 形状: {y_np.shape}, dtype: {y_np.dtype}")
print(f" X中NaN: {np.isnan(X_np).sum()}")
print(f" X中Inf: {np.isinf(X_np).sum()}")
print(f" y中NaN: {np.isnan(y_np).sum()}")
print(f" y中Inf: {np.isinf(y_np).sum()}")
assert np.isnan(X_np).sum() == 0, f"{split_name} X含NaN"
assert np.isnan(y_np).sum() == 0, f"{split_name} y含NaN"
assert np.isinf(X_np).sum() == 0, f"{split_name} X含Inf"
assert np.isinf(y_np).sum() == 0, f"{split_name} y含Inf"
def test_train_skip_days_functionality(self, engine_and_factor_manager):
"""诊断3: train_skip_days 功能验证"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
print("\n[诊断3] train_skip_days 功能验证:")
# 创建一个跳过50天数据的 pipeline
skip_days = 50
pipeline_with_skip = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=skip_days,
)
# 准备数据
data_with_skip = pipeline_with_skip.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=True,
)
# 获取原始数据(不跳过天数)用于对比
pipeline_no_skip = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=0,
)
data_no_skip = pipeline_no_skip.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=False,
)
# 验证训练数据减少了
train_with_skip = data_with_skip["train"]["raw_data"]
train_no_skip = data_no_skip["train"]["raw_data"]
print(f"\n 对比结果:")
print(f" 不跳过时训练数据: {len(train_no_skip)}")
print(f" 跳过{skip_days}天后: {len(train_with_skip)}")
print(f" 减少: {len(train_no_skip) - len(train_with_skip)}")
# 验证验证集和测试集不受影响
assert len(data_with_skip["val"]["raw_data"]) == len(
data_no_skip["val"]["raw_data"]
), "val 数据不应该受 train_skip_days 影响"
assert len(data_with_skip["test"]["raw_data"]) == len(
data_no_skip["test"]["raw_data"]
), "test 数据不应该受 train_skip_days 影响"
# 验证日期确实被跳过了
if len(train_no_skip) > 0:
dates_no_skip = sorted(train_no_skip["trade_date"].unique())
dates_with_skip = sorted(train_with_skip["trade_date"].unique())
print(f"\n 日期对比:")
print(
f" 不跳过 - 最早日期: {dates_no_skip[0]}, 共 {len(dates_no_skip)} 个交易日"
)
print(
f" 跳过 - 最早日期: {dates_with_skip[0]}, 共 {len(dates_with_skip)} 个交易日"
)
# 验证跳过的数据确实从更晚的日期开始
if len(dates_no_skip) > skip_days:
expected_start_date = dates_no_skip[skip_days]
assert dates_with_skip[0] == expected_start_date, (
f"预期从 {expected_start_date} 开始,实际从 {dates_with_skip[0]} 开始"
)
print(f" [通过] 正确跳过前 {skip_days} 个交易日")
def test_training_with_pipeline(self, pipeline_data):
"""诊断4: 使用 DataPipeline 处理后的数据进行训练测试"""
import torch
from torch.utils.data import DataLoader, TensorDataset
from tabm import TabM
data = pipeline_data["data"]
feature_cols = pipeline_data["feature_cols"]
# 获取训练数据
X_train_df = data["train"]["X"]
y_train_series = data["train"]["y"]
# 删除标签为NaN的行
valid_mask = y_train_series.is_not_null()
y_nan_count = y_train_series.null_count()
if y_nan_count > 0:
print(f" 发现 {y_nan_count} 个标签为NaN的行将被删除")
X_train_df = X_train_df.filter(valid_mask)
y_train_series = y_train_series.filter(valid_mask)
X_train = X_train_df.to_numpy().astype(np.float32)
y_train = y_train_series.to_numpy().astype(np.float32)
# 只取前1000条加速测试
X_train = X_train[:1000]
y_train = y_train[:1000]
print(f"\n[诊断4] DataPipeline 数据训练测试:")
print(f" X 形状: {X_train.shape}")
print(f" y 形状: {y_train.shape}")
print(f" X中NaN: {np.isnan(X_train).sum()}, Inf: {np.isinf(X_train).sum()}")
print(f" y中NaN: {np.isnan(y_train).sum()}, Inf: {np.isinf(y_train).sum()}")
# 创建 TabM 模型
n_features = X_train.shape[1]
model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1,
n_blocks=2,
d_block=128,
dropout=0.1,
k=4,
)
# 训练一个 epoch
dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
loader = DataLoader(dataset, batch_size=256, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()
model.train()
losses = []
for batch_idx, (bx, by) in enumerate(loader):
optimizer.zero_grad()
outputs = model(bx) # [B, E, 1]
outputs_squeezed = outputs.squeeze(-1) # [B, E]
by_expanded = by.unsqueeze(-1).expand(-1, 4) # [B, E]
loss = criterion(outputs_squeezed, by_expanded)
loss.backward()
optimizer.step()
losses.append(loss.item())
print(f" 批次 {batch_idx + 1} loss: {loss.item():.6f}")
assert not torch.isnan(loss), f"批次 {batch_idx + 1} loss 为 nan!"
if batch_idx >= 2: # 只测试前3个批次
break
print(f" [通过] 所有批次 loss 正常,无 NaN")
class TestTrainSkipDaysEdgeCases:
"""train_skip_days 边界情况测试"""
@pytest.fixture(scope="class")
def engine_and_factor_manager(self):
"""准备 FactorEngine 和 FactorManager"""
engine = FactorEngine()
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
return {
"engine": engine,
"factor_manager": factor_manager,
}
def test_skip_more_than_available_days(self, engine_and_factor_manager):
"""测试跳过天数超过可用天数的情况"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
print("\n[边界测试] 跳过天数超过可用天数:")
# 使用一个很大的跳过天数
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[(NullFiller, {"strategy": "mean"})],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=1000, # 超过测试期间的交易日数
)
# 准备数据(应该能正常运行并发出警告)
data = pipeline.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=True,
)
# 即使跳过天数很多,也应该有数据返回
# 如果交易日数少于跳过天数,应该保留所有数据(只发出警告)
train_data = data["train"]["raw_data"]
print(f" 训练数据量: {len(train_data)}")
print(f" [通过] 程序未崩溃")
def test_skip_zero_days(self, engine_and_factor_manager):
"""测试跳过0天即不跳过"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
print("\n[边界测试] 跳过0天:")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[(NullFiller, {"strategy": "mean"})],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=0,
)
data = pipeline.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=False,
)
train_data = data["train"]["raw_data"]
print(f" 训练数据量: {len(train_data)}")
print(f" [通过] skip=0 时数据正常")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])