"""TabM NaN 问题诊断测试 + train_skip_days 功能验证 诊断 loss 为 nan 的根因: 1. 标签中是否有 NaN 或极端值 2. 标准化后是否有 NaN 3. 是否有方差为0或接近0的列 4. train_skip_days 功能是否正常工作 """ import numpy as np import polars as pl import pytest from src.factors import FactorEngine from src.training import ( FactorManager, DataPipeline, TabMRegressionTask, NullFiller, Winsorizer, StandardScaler, ) from src.training.components.filters import STFilter from src.experiment.common import ( SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_NAME, LABEL_FACTOR, stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, ) # TabM 模型参数(简化用于快速测试) MODEL_PARAMS = { "n_blocks": 2, "d_block": 128, "dropout": 0.1, "ensemble_size": 4, # 简化 "batch_size": 256, "learning_rate": 1e-4, # 降低学习率 "weight_decay": 1e-5, "epochs": 3, "early_stopping_patience": 5, } # 测试日期范围 TEST_DATE_RANGE = { "train": ("20200101", "20200630"), "val": ("20200701", "20200731"), "test": ("20200801", "20200831"), } # 小范围测试用的因子排除列表 EXCLUDED_FACTORS = [ "GTJA_alpha001", "GTJA_alpha002", "GTJA_alpha003", "GTJA_alpha004", "GTJA_alpha005", "GTJA_alpha006", "GTJA_alpha007", "GTJA_alpha008", "GTJA_alpha009", "GTJA_alpha010", ] class TestTabMNanDebug: """TabM NaN 问题诊断测试类(使用 DataPipeline)""" @pytest.fixture(scope="class") def engine_and_factor_manager(self): """准备 FactorEngine 和 FactorManager""" engine = FactorEngine() factor_manager = FactorManager( selected_factors=SELECTED_FACTORS, factor_definitions=FACTOR_DEFINITIONS, label_factor=LABEL_FACTOR, excluded_factors=EXCLUDED_FACTORS, ) return { "engine": engine, "factor_manager": factor_manager, } @pytest.fixture(scope="class") def pipeline_data(self, engine_and_factor_manager): """使用 DataPipeline 准备数据,并过滤掉全为NaN的列""" engine = engine_and_factor_manager["engine"] factor_manager = engine_and_factor_manager["factor_manager"] # 创建 DataPipeline(使用 train_skip_days=0 进行基础测试) pipeline = DataPipeline( factor_manager=factor_manager, processor_configs=[ (NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99}), (StandardScaler, {}), ], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, train_skip_days=0, # 先不跳过天数进行测试 ) # 准备数据 data = pipeline.prepare_data( engine=engine, date_range=TEST_DATE_RANGE, label_name=LABEL_NAME, verbose=True, ) # 获取特征列 feature_cols = factor_manager.register_to_engine(engine, verbose=False) # 过滤掉全为NaN的列(这些因子计算失败,没有修复意义) print("\n[DataPipeline] 检查并过滤全为NaN的特征列...") for split_name in ["train", "val", "test"]: X_df = data[split_name]["X"] y_series = data[split_name]["y"] raw_data = data[split_name]["raw_data"] # 删除标签为NaN的行 y_nan_count = y_series.null_count() if y_nan_count > 0: print(f" {split_name}: 发现 {y_nan_count} 个标签为NaN的行,将被删除") # 创建有效标签的mask valid_mask = y_series.is_not_null() # 过滤所有相关数据 X_df = X_df.filter(valid_mask) y_series = y_series.filter(valid_mask) raw_data = raw_data.filter(valid_mask) # 更新数据 data[split_name]["X"] = X_df data[split_name]["y"] = y_series data[split_name]["raw_data"] = raw_data # 检查每列的NaN数量 nan_counts = {col: X_df[col].null_count() for col in X_df.columns} total_rows = len(X_df) # 找出全为NaN的列 all_nan_cols = [ col for col, count in nan_counts.items() if count == total_rows ] if all_nan_cols: print( f" {split_name}: 发现 {len(all_nan_cols)} 个全为NaN的列,将被删除" ) print( f" 列名: {all_nan_cols[:5]}{'...' if len(all_nan_cols) > 5 else ''}" ) # 更新feature_cols feature_cols = [c for c in feature_cols if c not in all_nan_cols] # 从X中删除这些列 X_df = X_df.select(feature_cols) # 从raw_data中也删除这些列(保留原始特征列以外的列如trade_date, ts_code等) raw_cols_to_keep = [ c for c in raw_data.columns if c not in all_nan_cols ] raw_data = raw_data.select(raw_cols_to_keep) # 更新数据 data[split_name]["X"] = X_df data[split_name]["raw_data"] = raw_data data[split_name]["feature_cols"] = feature_cols # 验证没有NaN了 X_np = X_df.to_numpy() nan_count = np.isnan(X_np).sum() assert nan_count == 0, f"{split_name} 中仍有 {nan_count} 个NaN" print(f" 过滤后特征数: {len(feature_cols)}") print(" [通过] 所有特征列均无NaN") return { "pipeline": pipeline, "data": data, "feature_cols": feature_cols, "engine": engine, } def test_pipeline_data_structure(self, pipeline_data): """诊断0: DataPipeline 返回的数据结构检查""" data = pipeline_data["data"] feature_cols = pipeline_data["feature_cols"] print("\n[诊断0] DataPipeline 数据结构检查:") # 检查数据结构 assert "train" in data, "缺少 train 数据" assert "val" in data, "缺少 val 数据" assert "test" in data, "缺少 test 数据" for split_name in ["train", "val", "test"]: split_data = data[split_name] assert "X" in split_data, f"{split_name} 缺少 X" assert "y" in split_data, f"{split_name} 缺少 y" assert "raw_data" in split_data, f"{split_name} 缺少 raw_data" assert "feature_cols" in split_data, f"{split_name} 缺少 feature_cols" X = split_data["X"] y = split_data["y"] print(f" {split_name}:") print(f" X 形状: {X.shape}") print(f" y 形状: {len(y)}") print(f" 特征数: {len(split_data['feature_cols'])}") # 检查维度一致性 assert X.shape[0] == len(y), f"{split_name} X 和 y 行数不匹配" assert X.shape[1] == len(feature_cols), f"{split_name} 特征数不匹配" print(" [通过] 数据结构正确") def test_label_quality_with_pipeline(self, pipeline_data): """诊断1: 使用 DataPipeline 后的标签数据质量检查""" data = pipeline_data["data"] # 检查所有数据集的标签质量 for split_name in ["train", "val", "test"]: y = data[split_name]["y"] y_np = y.to_numpy() print(f"\n[诊断1-{split_name}] 标签数据质量:") print(f" 总数: {len(y)}") print(f" NaN数量: {y.null_count()}") print(f" 均值: {y.mean():.6f}") print(f" 标准差: {y.std():.6f}") print(f" 最小值: {y.min():.6f}") print(f" 最大值: {y.max():.6f}") inf_count = np.isinf(y_np).sum() nan_count = np.isnan(y_np).sum() print(f" inf数量: {inf_count}") print(f" nan数量: {nan_count}") assert inf_count == 0, f"{split_name} 标签含 inf: {inf_count}" assert nan_count == 0, f"{split_name} 标签含 nan: {nan_count}" def test_processed_data_quality_with_pipeline(self, pipeline_data): """诊断2: 使用 DataPipeline 处理后的特征数据质量""" data = pipeline_data["data"] feature_cols = pipeline_data["feature_cols"] for split_name in ["train", "val", "test"]: X = data[split_name]["X"] y = data[split_name]["y"] X_np = X.to_numpy().astype(np.float32) y_np = y.to_numpy().astype(np.float32) print(f"\n[诊断2-{split_name}] 处理后数据质量:") print(f" X 形状: {X_np.shape}, dtype: {X_np.dtype}") print(f" y 形状: {y_np.shape}, dtype: {y_np.dtype}") print(f" X中NaN: {np.isnan(X_np).sum()}") print(f" X中Inf: {np.isinf(X_np).sum()}") print(f" y中NaN: {np.isnan(y_np).sum()}") print(f" y中Inf: {np.isinf(y_np).sum()}") assert np.isnan(X_np).sum() == 0, f"{split_name} X含NaN" assert np.isnan(y_np).sum() == 0, f"{split_name} y含NaN" assert np.isinf(X_np).sum() == 0, f"{split_name} X含Inf" assert np.isinf(y_np).sum() == 0, f"{split_name} y含Inf" def test_train_skip_days_functionality(self, engine_and_factor_manager): """诊断3: train_skip_days 功能验证""" engine = engine_and_factor_manager["engine"] factor_manager = engine_and_factor_manager["factor_manager"] print("\n[诊断3] train_skip_days 功能验证:") # 创建一个跳过50天数据的 pipeline skip_days = 50 pipeline_with_skip = DataPipeline( factor_manager=factor_manager, processor_configs=[ (NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99}), (StandardScaler, {}), ], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, train_skip_days=skip_days, ) # 准备数据 data_with_skip = pipeline_with_skip.prepare_data( engine=engine, date_range=TEST_DATE_RANGE, label_name=LABEL_NAME, verbose=True, ) # 获取原始数据(不跳过天数)用于对比 pipeline_no_skip = DataPipeline( factor_manager=factor_manager, processor_configs=[ (NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99}), (StandardScaler, {}), ], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, train_skip_days=0, ) data_no_skip = pipeline_no_skip.prepare_data( engine=engine, date_range=TEST_DATE_RANGE, label_name=LABEL_NAME, verbose=False, ) # 验证训练数据减少了 train_with_skip = data_with_skip["train"]["raw_data"] train_no_skip = data_no_skip["train"]["raw_data"] print(f"\n 对比结果:") print(f" 不跳过时训练数据: {len(train_no_skip)} 条") print(f" 跳过{skip_days}天后: {len(train_with_skip)} 条") print(f" 减少: {len(train_no_skip) - len(train_with_skip)} 条") # 验证验证集和测试集不受影响 assert len(data_with_skip["val"]["raw_data"]) == len( data_no_skip["val"]["raw_data"] ), "val 数据不应该受 train_skip_days 影响" assert len(data_with_skip["test"]["raw_data"]) == len( data_no_skip["test"]["raw_data"] ), "test 数据不应该受 train_skip_days 影响" # 验证日期确实被跳过了 if len(train_no_skip) > 0: dates_no_skip = sorted(train_no_skip["trade_date"].unique()) dates_with_skip = sorted(train_with_skip["trade_date"].unique()) print(f"\n 日期对比:") print( f" 不跳过 - 最早日期: {dates_no_skip[0]}, 共 {len(dates_no_skip)} 个交易日" ) print( f" 跳过 - 最早日期: {dates_with_skip[0]}, 共 {len(dates_with_skip)} 个交易日" ) # 验证跳过的数据确实从更晚的日期开始 if len(dates_no_skip) > skip_days: expected_start_date = dates_no_skip[skip_days] assert dates_with_skip[0] == expected_start_date, ( f"预期从 {expected_start_date} 开始,实际从 {dates_with_skip[0]} 开始" ) print(f" [通过] 正确跳过前 {skip_days} 个交易日") def test_training_with_pipeline(self, pipeline_data): """诊断4: 使用 DataPipeline 处理后的数据进行训练测试""" import torch from torch.utils.data import DataLoader, TensorDataset from tabm import TabM data = pipeline_data["data"] feature_cols = pipeline_data["feature_cols"] # 获取训练数据 X_train_df = data["train"]["X"] y_train_series = data["train"]["y"] # 删除标签为NaN的行 valid_mask = y_train_series.is_not_null() y_nan_count = y_train_series.null_count() if y_nan_count > 0: print(f" 发现 {y_nan_count} 个标签为NaN的行,将被删除") X_train_df = X_train_df.filter(valid_mask) y_train_series = y_train_series.filter(valid_mask) X_train = X_train_df.to_numpy().astype(np.float32) y_train = y_train_series.to_numpy().astype(np.float32) # 只取前1000条加速测试 X_train = X_train[:1000] y_train = y_train[:1000] print(f"\n[诊断4] DataPipeline 数据训练测试:") print(f" X 形状: {X_train.shape}") print(f" y 形状: {y_train.shape}") print(f" X中NaN: {np.isnan(X_train).sum()}, Inf: {np.isinf(X_train).sum()}") print(f" y中NaN: {np.isnan(y_train).sum()}, Inf: {np.isinf(y_train).sum()}") # 创建 TabM 模型 n_features = X_train.shape[1] model = TabM.make( n_num_features=n_features, cat_cardinalities=[], d_out=1, n_blocks=2, d_block=128, dropout=0.1, k=4, ) # 训练一个 epoch dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)) loader = DataLoader(dataset, batch_size=256, shuffle=True) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) criterion = torch.nn.MSELoss() model.train() losses = [] for batch_idx, (bx, by) in enumerate(loader): optimizer.zero_grad() outputs = model(bx) # [B, E, 1] outputs_squeezed = outputs.squeeze(-1) # [B, E] by_expanded = by.unsqueeze(-1).expand(-1, 4) # [B, E] loss = criterion(outputs_squeezed, by_expanded) loss.backward() optimizer.step() losses.append(loss.item()) print(f" 批次 {batch_idx + 1} loss: {loss.item():.6f}") assert not torch.isnan(loss), f"批次 {batch_idx + 1} loss 为 nan!" if batch_idx >= 2: # 只测试前3个批次 break print(f" [通过] 所有批次 loss 正常,无 NaN") class TestTrainSkipDaysEdgeCases: """train_skip_days 边界情况测试""" @pytest.fixture(scope="class") def engine_and_factor_manager(self): """准备 FactorEngine 和 FactorManager""" engine = FactorEngine() factor_manager = FactorManager( selected_factors=SELECTED_FACTORS, factor_definitions=FACTOR_DEFINITIONS, label_factor=LABEL_FACTOR, excluded_factors=EXCLUDED_FACTORS, ) return { "engine": engine, "factor_manager": factor_manager, } def test_skip_more_than_available_days(self, engine_and_factor_manager): """测试跳过天数超过可用天数的情况""" engine = engine_and_factor_manager["engine"] factor_manager = engine_and_factor_manager["factor_manager"] print("\n[边界测试] 跳过天数超过可用天数:") # 使用一个很大的跳过天数 pipeline = DataPipeline( factor_manager=factor_manager, processor_configs=[(NullFiller, {"strategy": "mean"})], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, train_skip_days=1000, # 超过测试期间的交易日数 ) # 准备数据(应该能正常运行并发出警告) data = pipeline.prepare_data( engine=engine, date_range=TEST_DATE_RANGE, label_name=LABEL_NAME, verbose=True, ) # 即使跳过天数很多,也应该有数据返回 # 如果交易日数少于跳过天数,应该保留所有数据(只发出警告) train_data = data["train"]["raw_data"] print(f" 训练数据量: {len(train_data)} 条") print(f" [通过] 程序未崩溃") def test_skip_zero_days(self, engine_and_factor_manager): """测试跳过0天(即不跳过)""" engine = engine_and_factor_manager["engine"] factor_manager = engine_and_factor_manager["factor_manager"] print("\n[边界测试] 跳过0天:") pipeline = DataPipeline( factor_manager=factor_manager, processor_configs=[(NullFiller, {"strategy": "mean"})], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, train_skip_days=0, ) data = pipeline.prepare_data( engine=engine, date_range=TEST_DATE_RANGE, label_name=LABEL_NAME, verbose=False, ) train_data = data["train"]["raw_data"] print(f" 训练数据量: {len(train_data)} 条") print(f" [通过] skip=0 时数据正常") if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])