Files
ProStock/tests/test_tabm_nan_debug.py

520 lines
19 KiB
Python
Raw Normal View History

"""TabM NaN 问题诊断测试 + train_skip_days 功能验证
诊断 loss nan 的根因
1. 标签中是否有 NaN 或极端值
2. 标准化后是否有 NaN
3. 是否有方差为0或接近0的列
4. train_skip_days 功能是否正常工作
"""
import numpy as np
import polars as pl
import pytest
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
TabMRegressionTask,
NullFiller,
Winsorizer,
StandardScaler,
)
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
)
# TabM 模型参数(简化用于快速测试)
MODEL_PARAMS = {
"n_blocks": 2,
"d_block": 128,
"dropout": 0.1,
"ensemble_size": 4, # 简化
"batch_size": 256,
"learning_rate": 1e-4, # 降低学习率
"weight_decay": 1e-5,
"epochs": 3,
"early_stopping_patience": 5,
}
# 测试日期范围
TEST_DATE_RANGE = {
"train": ("20200101", "20200630"),
"val": ("20200701", "20200731"),
"test": ("20200801", "20200831"),
}
# 小范围测试用的因子排除列表
EXCLUDED_FACTORS = [
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
]
class TestTabMNanDebug:
"""TabM NaN 问题诊断测试类(使用 DataPipeline"""
@pytest.fixture(scope="class")
def engine_and_factor_manager(self):
"""准备 FactorEngine 和 FactorManager"""
engine = FactorEngine()
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
return {
"engine": engine,
"factor_manager": factor_manager,
}
@pytest.fixture(scope="class")
def pipeline_data(self, engine_and_factor_manager):
"""使用 DataPipeline 准备数据并过滤掉全为NaN的列"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
# 创建 DataPipeline使用 train_skip_days=0 进行基础测试)
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=0, # 先不跳过天数进行测试
)
# 准备数据
data = pipeline.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=True,
)
# 获取特征列
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
# 过滤掉全为NaN的列这些因子计算失败没有修复意义
print("\n[DataPipeline] 检查并过滤全为NaN的特征列...")
for split_name in ["train", "val", "test"]:
X_df = data[split_name]["X"]
y_series = data[split_name]["y"]
raw_data = data[split_name]["raw_data"]
# 删除标签为NaN的行
y_nan_count = y_series.null_count()
if y_nan_count > 0:
print(f" {split_name}: 发现 {y_nan_count} 个标签为NaN的行将被删除")
# 创建有效标签的mask
valid_mask = y_series.is_not_null()
# 过滤所有相关数据
X_df = X_df.filter(valid_mask)
y_series = y_series.filter(valid_mask)
raw_data = raw_data.filter(valid_mask)
# 更新数据
data[split_name]["X"] = X_df
data[split_name]["y"] = y_series
data[split_name]["raw_data"] = raw_data
# 检查每列的NaN数量
nan_counts = {col: X_df[col].null_count() for col in X_df.columns}
total_rows = len(X_df)
# 找出全为NaN的列
all_nan_cols = [
col for col, count in nan_counts.items() if count == total_rows
]
if all_nan_cols:
print(
f" {split_name}: 发现 {len(all_nan_cols)} 个全为NaN的列将被删除"
)
print(
f" 列名: {all_nan_cols[:5]}{'...' if len(all_nan_cols) > 5 else ''}"
)
# 更新feature_cols
feature_cols = [c for c in feature_cols if c not in all_nan_cols]
# 从X中删除这些列
X_df = X_df.select(feature_cols)
# 从raw_data中也删除这些列保留原始特征列以外的列如trade_date, ts_code等
raw_cols_to_keep = [
c for c in raw_data.columns if c not in all_nan_cols
]
raw_data = raw_data.select(raw_cols_to_keep)
# 更新数据
data[split_name]["X"] = X_df
data[split_name]["raw_data"] = raw_data
data[split_name]["feature_cols"] = feature_cols
# 验证没有NaN了
X_np = X_df.to_numpy()
nan_count = np.isnan(X_np).sum()
assert nan_count == 0, f"{split_name} 中仍有 {nan_count} 个NaN"
print(f" 过滤后特征数: {len(feature_cols)}")
print(" [通过] 所有特征列均无NaN")
return {
"pipeline": pipeline,
"data": data,
"feature_cols": feature_cols,
"engine": engine,
}
def test_pipeline_data_structure(self, pipeline_data):
"""诊断0: DataPipeline 返回的数据结构检查"""
data = pipeline_data["data"]
feature_cols = pipeline_data["feature_cols"]
print("\n[诊断0] DataPipeline 数据结构检查:")
# 检查数据结构
assert "train" in data, "缺少 train 数据"
assert "val" in data, "缺少 val 数据"
assert "test" in data, "缺少 test 数据"
for split_name in ["train", "val", "test"]:
split_data = data[split_name]
assert "X" in split_data, f"{split_name} 缺少 X"
assert "y" in split_data, f"{split_name} 缺少 y"
assert "raw_data" in split_data, f"{split_name} 缺少 raw_data"
assert "feature_cols" in split_data, f"{split_name} 缺少 feature_cols"
X = split_data["X"]
y = split_data["y"]
print(f" {split_name}:")
print(f" X 形状: {X.shape}")
print(f" y 形状: {len(y)}")
print(f" 特征数: {len(split_data['feature_cols'])}")
# 检查维度一致性
assert X.shape[0] == len(y), f"{split_name} X 和 y 行数不匹配"
assert X.shape[1] == len(feature_cols), f"{split_name} 特征数不匹配"
print(" [通过] 数据结构正确")
def test_label_quality_with_pipeline(self, pipeline_data):
"""诊断1: 使用 DataPipeline 后的标签数据质量检查"""
data = pipeline_data["data"]
# 检查所有数据集的标签质量
for split_name in ["train", "val", "test"]:
y = data[split_name]["y"]
y_np = y.to_numpy()
print(f"\n[诊断1-{split_name}] 标签数据质量:")
print(f" 总数: {len(y)}")
print(f" NaN数量: {y.null_count()}")
print(f" 均值: {y.mean():.6f}")
print(f" 标准差: {y.std():.6f}")
print(f" 最小值: {y.min():.6f}")
print(f" 最大值: {y.max():.6f}")
inf_count = np.isinf(y_np).sum()
nan_count = np.isnan(y_np).sum()
print(f" inf数量: {inf_count}")
print(f" nan数量: {nan_count}")
assert inf_count == 0, f"{split_name} 标签含 inf: {inf_count}"
assert nan_count == 0, f"{split_name} 标签含 nan: {nan_count}"
def test_processed_data_quality_with_pipeline(self, pipeline_data):
"""诊断2: 使用 DataPipeline 处理后的特征数据质量"""
data = pipeline_data["data"]
feature_cols = pipeline_data["feature_cols"]
for split_name in ["train", "val", "test"]:
X = data[split_name]["X"]
y = data[split_name]["y"]
X_np = X.to_numpy().astype(np.float32)
y_np = y.to_numpy().astype(np.float32)
print(f"\n[诊断2-{split_name}] 处理后数据质量:")
print(f" X 形状: {X_np.shape}, dtype: {X_np.dtype}")
print(f" y 形状: {y_np.shape}, dtype: {y_np.dtype}")
print(f" X中NaN: {np.isnan(X_np).sum()}")
print(f" X中Inf: {np.isinf(X_np).sum()}")
print(f" y中NaN: {np.isnan(y_np).sum()}")
print(f" y中Inf: {np.isinf(y_np).sum()}")
assert np.isnan(X_np).sum() == 0, f"{split_name} X含NaN"
assert np.isnan(y_np).sum() == 0, f"{split_name} y含NaN"
assert np.isinf(X_np).sum() == 0, f"{split_name} X含Inf"
assert np.isinf(y_np).sum() == 0, f"{split_name} y含Inf"
def test_train_skip_days_functionality(self, engine_and_factor_manager):
"""诊断3: train_skip_days 功能验证"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
print("\n[诊断3] train_skip_days 功能验证:")
# 创建一个跳过50天数据的 pipeline
skip_days = 50
pipeline_with_skip = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=skip_days,
)
# 准备数据
data_with_skip = pipeline_with_skip.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=True,
)
# 获取原始数据(不跳过天数)用于对比
pipeline_no_skip = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=0,
)
data_no_skip = pipeline_no_skip.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=False,
)
# 验证训练数据减少了
train_with_skip = data_with_skip["train"]["raw_data"]
train_no_skip = data_no_skip["train"]["raw_data"]
print(f"\n 对比结果:")
print(f" 不跳过时训练数据: {len(train_no_skip)}")
print(f" 跳过{skip_days}天后: {len(train_with_skip)}")
print(f" 减少: {len(train_no_skip) - len(train_with_skip)}")
# 验证验证集和测试集不受影响
assert len(data_with_skip["val"]["raw_data"]) == len(
data_no_skip["val"]["raw_data"]
), "val 数据不应该受 train_skip_days 影响"
assert len(data_with_skip["test"]["raw_data"]) == len(
data_no_skip["test"]["raw_data"]
), "test 数据不应该受 train_skip_days 影响"
# 验证日期确实被跳过了
if len(train_no_skip) > 0:
dates_no_skip = sorted(train_no_skip["trade_date"].unique())
dates_with_skip = sorted(train_with_skip["trade_date"].unique())
print(f"\n 日期对比:")
print(
f" 不跳过 - 最早日期: {dates_no_skip[0]}, 共 {len(dates_no_skip)} 个交易日"
)
print(
f" 跳过 - 最早日期: {dates_with_skip[0]}, 共 {len(dates_with_skip)} 个交易日"
)
# 验证跳过的数据确实从更晚的日期开始
if len(dates_no_skip) > skip_days:
expected_start_date = dates_no_skip[skip_days]
assert dates_with_skip[0] == expected_start_date, (
f"预期从 {expected_start_date} 开始,实际从 {dates_with_skip[0]} 开始"
)
print(f" [通过] 正确跳过前 {skip_days} 个交易日")
def test_training_with_pipeline(self, pipeline_data):
"""诊断4: 使用 DataPipeline 处理后的数据进行训练测试"""
import torch
from torch.utils.data import DataLoader, TensorDataset
from tabm import TabM
data = pipeline_data["data"]
feature_cols = pipeline_data["feature_cols"]
# 获取训练数据
X_train_df = data["train"]["X"]
y_train_series = data["train"]["y"]
# 删除标签为NaN的行
valid_mask = y_train_series.is_not_null()
y_nan_count = y_train_series.null_count()
if y_nan_count > 0:
print(f" 发现 {y_nan_count} 个标签为NaN的行将被删除")
X_train_df = X_train_df.filter(valid_mask)
y_train_series = y_train_series.filter(valid_mask)
X_train = X_train_df.to_numpy().astype(np.float32)
y_train = y_train_series.to_numpy().astype(np.float32)
# 只取前1000条加速测试
X_train = X_train[:1000]
y_train = y_train[:1000]
print(f"\n[诊断4] DataPipeline 数据训练测试:")
print(f" X 形状: {X_train.shape}")
print(f" y 形状: {y_train.shape}")
print(f" X中NaN: {np.isnan(X_train).sum()}, Inf: {np.isinf(X_train).sum()}")
print(f" y中NaN: {np.isnan(y_train).sum()}, Inf: {np.isinf(y_train).sum()}")
# 创建 TabM 模型
n_features = X_train.shape[1]
model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1,
n_blocks=2,
d_block=128,
dropout=0.1,
k=4,
)
# 训练一个 epoch
dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
loader = DataLoader(dataset, batch_size=256, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()
model.train()
losses = []
for batch_idx, (bx, by) in enumerate(loader):
optimizer.zero_grad()
outputs = model(bx) # [B, E, 1]
outputs_squeezed = outputs.squeeze(-1) # [B, E]
by_expanded = by.unsqueeze(-1).expand(-1, 4) # [B, E]
loss = criterion(outputs_squeezed, by_expanded)
loss.backward()
optimizer.step()
losses.append(loss.item())
print(f" 批次 {batch_idx + 1} loss: {loss.item():.6f}")
assert not torch.isnan(loss), f"批次 {batch_idx + 1} loss 为 nan!"
if batch_idx >= 2: # 只测试前3个批次
break
print(f" [通过] 所有批次 loss 正常,无 NaN")
class TestTrainSkipDaysEdgeCases:
"""train_skip_days 边界情况测试"""
@pytest.fixture(scope="class")
def engine_and_factor_manager(self):
"""准备 FactorEngine 和 FactorManager"""
engine = FactorEngine()
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
return {
"engine": engine,
"factor_manager": factor_manager,
}
def test_skip_more_than_available_days(self, engine_and_factor_manager):
"""测试跳过天数超过可用天数的情况"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
print("\n[边界测试] 跳过天数超过可用天数:")
# 使用一个很大的跳过天数
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[(NullFiller, {"strategy": "mean"})],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=1000, # 超过测试期间的交易日数
)
# 准备数据(应该能正常运行并发出警告)
data = pipeline.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=True,
)
# 即使跳过天数很多,也应该有数据返回
# 如果交易日数少于跳过天数,应该保留所有数据(只发出警告)
train_data = data["train"]["raw_data"]
print(f" 训练数据量: {len(train_data)}")
print(f" [通过] 程序未崩溃")
def test_skip_zero_days(self, engine_and_factor_manager):
"""测试跳过0天即不跳过"""
engine = engine_and_factor_manager["engine"]
factor_manager = engine_and_factor_manager["factor_manager"]
print("\n[边界测试] 跳过0天:")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[(NullFiller, {"strategy": "mean"})],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=0,
)
data = pipeline.prepare_data(
engine=engine,
date_range=TEST_DATE_RANGE,
label_name=LABEL_NAME,
verbose=False,
)
train_data = data["train"]["raw_data"]
print(f" 训练数据量: {len(train_data)}")
print(f" [通过] skip=0 时数据正常")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])