- 添加 TabMModel、TabPFNModel 深度学习模型实现 - 新增 DataQualityAnalyzer 进行训练前数据质量诊断 - 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性 - 支持 train_skip_days 参数跳过训练初期数据不足期 - Pipeline 自动清理标签为 NaN 的样本
520 lines
19 KiB
Python
520 lines
19 KiB
Python
"""TabM NaN 问题诊断测试 + train_skip_days 功能验证
|
||
|
||
诊断 loss 为 nan 的根因:
|
||
1. 标签中是否有 NaN 或极端值
|
||
2. 标准化后是否有 NaN
|
||
3. 是否有方差为0或接近0的列
|
||
4. train_skip_days 功能是否正常工作
|
||
"""
|
||
|
||
import numpy as np
|
||
import polars as pl
|
||
import pytest
|
||
|
||
from src.factors import FactorEngine
|
||
from src.training import (
|
||
FactorManager,
|
||
DataPipeline,
|
||
TabMRegressionTask,
|
||
NullFiller,
|
||
Winsorizer,
|
||
StandardScaler,
|
||
)
|
||
from src.training.components.filters import STFilter
|
||
from src.experiment.common import (
|
||
SELECTED_FACTORS,
|
||
FACTOR_DEFINITIONS,
|
||
LABEL_NAME,
|
||
LABEL_FACTOR,
|
||
stock_pool_filter,
|
||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||
)
|
||
|
||
|
||
# TabM 模型参数(简化用于快速测试)
|
||
MODEL_PARAMS = {
|
||
"n_blocks": 2,
|
||
"d_block": 128,
|
||
"dropout": 0.1,
|
||
"ensemble_size": 4, # 简化
|
||
"batch_size": 256,
|
||
"learning_rate": 1e-4, # 降低学习率
|
||
"weight_decay": 1e-5,
|
||
"epochs": 3,
|
||
"early_stopping_patience": 5,
|
||
}
|
||
|
||
# 测试日期范围
|
||
TEST_DATE_RANGE = {
|
||
"train": ("20200101", "20200630"),
|
||
"val": ("20200701", "20200731"),
|
||
"test": ("20200801", "20200831"),
|
||
}
|
||
|
||
# 小范围测试用的因子排除列表
|
||
EXCLUDED_FACTORS = [
|
||
"GTJA_alpha001",
|
||
"GTJA_alpha002",
|
||
"GTJA_alpha003",
|
||
"GTJA_alpha004",
|
||
"GTJA_alpha005",
|
||
"GTJA_alpha006",
|
||
"GTJA_alpha007",
|
||
"GTJA_alpha008",
|
||
"GTJA_alpha009",
|
||
"GTJA_alpha010",
|
||
]
|
||
|
||
|
||
class TestTabMNanDebug:
|
||
"""TabM NaN 问题诊断测试类(使用 DataPipeline)"""
|
||
|
||
@pytest.fixture(scope="class")
|
||
def engine_and_factor_manager(self):
|
||
"""准备 FactorEngine 和 FactorManager"""
|
||
engine = FactorEngine()
|
||
|
||
factor_manager = FactorManager(
|
||
selected_factors=SELECTED_FACTORS,
|
||
factor_definitions=FACTOR_DEFINITIONS,
|
||
label_factor=LABEL_FACTOR,
|
||
excluded_factors=EXCLUDED_FACTORS,
|
||
)
|
||
|
||
return {
|
||
"engine": engine,
|
||
"factor_manager": factor_manager,
|
||
}
|
||
|
||
@pytest.fixture(scope="class")
|
||
def pipeline_data(self, engine_and_factor_manager):
|
||
"""使用 DataPipeline 准备数据,并过滤掉全为NaN的列"""
|
||
engine = engine_and_factor_manager["engine"]
|
||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||
|
||
# 创建 DataPipeline(使用 train_skip_days=0 进行基础测试)
|
||
pipeline = DataPipeline(
|
||
factor_manager=factor_manager,
|
||
processor_configs=[
|
||
(NullFiller, {"strategy": "mean"}),
|
||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||
(StandardScaler, {}),
|
||
],
|
||
filters=[STFilter(data_router=engine.router)],
|
||
stock_pool_filter_func=stock_pool_filter,
|
||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||
train_skip_days=0, # 先不跳过天数进行测试
|
||
)
|
||
|
||
# 准备数据
|
||
data = pipeline.prepare_data(
|
||
engine=engine,
|
||
date_range=TEST_DATE_RANGE,
|
||
label_name=LABEL_NAME,
|
||
verbose=True,
|
||
)
|
||
|
||
# 获取特征列
|
||
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
|
||
|
||
# 过滤掉全为NaN的列(这些因子计算失败,没有修复意义)
|
||
print("\n[DataPipeline] 检查并过滤全为NaN的特征列...")
|
||
for split_name in ["train", "val", "test"]:
|
||
X_df = data[split_name]["X"]
|
||
y_series = data[split_name]["y"]
|
||
raw_data = data[split_name]["raw_data"]
|
||
|
||
# 删除标签为NaN的行
|
||
y_nan_count = y_series.null_count()
|
||
if y_nan_count > 0:
|
||
print(f" {split_name}: 发现 {y_nan_count} 个标签为NaN的行,将被删除")
|
||
# 创建有效标签的mask
|
||
valid_mask = y_series.is_not_null()
|
||
# 过滤所有相关数据
|
||
X_df = X_df.filter(valid_mask)
|
||
y_series = y_series.filter(valid_mask)
|
||
raw_data = raw_data.filter(valid_mask)
|
||
# 更新数据
|
||
data[split_name]["X"] = X_df
|
||
data[split_name]["y"] = y_series
|
||
data[split_name]["raw_data"] = raw_data
|
||
|
||
# 检查每列的NaN数量
|
||
nan_counts = {col: X_df[col].null_count() for col in X_df.columns}
|
||
total_rows = len(X_df)
|
||
|
||
# 找出全为NaN的列
|
||
all_nan_cols = [
|
||
col for col, count in nan_counts.items() if count == total_rows
|
||
]
|
||
|
||
if all_nan_cols:
|
||
print(
|
||
f" {split_name}: 发现 {len(all_nan_cols)} 个全为NaN的列,将被删除"
|
||
)
|
||
print(
|
||
f" 列名: {all_nan_cols[:5]}{'...' if len(all_nan_cols) > 5 else ''}"
|
||
)
|
||
|
||
# 更新feature_cols
|
||
feature_cols = [c for c in feature_cols if c not in all_nan_cols]
|
||
|
||
# 从X中删除这些列
|
||
X_df = X_df.select(feature_cols)
|
||
|
||
# 从raw_data中也删除这些列(保留原始特征列以外的列如trade_date, ts_code等)
|
||
raw_cols_to_keep = [
|
||
c for c in raw_data.columns if c not in all_nan_cols
|
||
]
|
||
raw_data = raw_data.select(raw_cols_to_keep)
|
||
|
||
# 更新数据
|
||
data[split_name]["X"] = X_df
|
||
data[split_name]["raw_data"] = raw_data
|
||
data[split_name]["feature_cols"] = feature_cols
|
||
|
||
# 验证没有NaN了
|
||
X_np = X_df.to_numpy()
|
||
nan_count = np.isnan(X_np).sum()
|
||
assert nan_count == 0, f"{split_name} 中仍有 {nan_count} 个NaN"
|
||
|
||
print(f" 过滤后特征数: {len(feature_cols)}")
|
||
print(" [通过] 所有特征列均无NaN")
|
||
|
||
return {
|
||
"pipeline": pipeline,
|
||
"data": data,
|
||
"feature_cols": feature_cols,
|
||
"engine": engine,
|
||
}
|
||
|
||
def test_pipeline_data_structure(self, pipeline_data):
|
||
"""诊断0: DataPipeline 返回的数据结构检查"""
|
||
data = pipeline_data["data"]
|
||
feature_cols = pipeline_data["feature_cols"]
|
||
|
||
print("\n[诊断0] DataPipeline 数据结构检查:")
|
||
|
||
# 检查数据结构
|
||
assert "train" in data, "缺少 train 数据"
|
||
assert "val" in data, "缺少 val 数据"
|
||
assert "test" in data, "缺少 test 数据"
|
||
|
||
for split_name in ["train", "val", "test"]:
|
||
split_data = data[split_name]
|
||
assert "X" in split_data, f"{split_name} 缺少 X"
|
||
assert "y" in split_data, f"{split_name} 缺少 y"
|
||
assert "raw_data" in split_data, f"{split_name} 缺少 raw_data"
|
||
assert "feature_cols" in split_data, f"{split_name} 缺少 feature_cols"
|
||
|
||
X = split_data["X"]
|
||
y = split_data["y"]
|
||
|
||
print(f" {split_name}:")
|
||
print(f" X 形状: {X.shape}")
|
||
print(f" y 形状: {len(y)}")
|
||
print(f" 特征数: {len(split_data['feature_cols'])}")
|
||
|
||
# 检查维度一致性
|
||
assert X.shape[0] == len(y), f"{split_name} X 和 y 行数不匹配"
|
||
assert X.shape[1] == len(feature_cols), f"{split_name} 特征数不匹配"
|
||
|
||
print(" [通过] 数据结构正确")
|
||
|
||
def test_label_quality_with_pipeline(self, pipeline_data):
|
||
"""诊断1: 使用 DataPipeline 后的标签数据质量检查"""
|
||
data = pipeline_data["data"]
|
||
|
||
# 检查所有数据集的标签质量
|
||
for split_name in ["train", "val", "test"]:
|
||
y = data[split_name]["y"]
|
||
y_np = y.to_numpy()
|
||
|
||
print(f"\n[诊断1-{split_name}] 标签数据质量:")
|
||
print(f" 总数: {len(y)}")
|
||
print(f" NaN数量: {y.null_count()}")
|
||
print(f" 均值: {y.mean():.6f}")
|
||
print(f" 标准差: {y.std():.6f}")
|
||
print(f" 最小值: {y.min():.6f}")
|
||
print(f" 最大值: {y.max():.6f}")
|
||
|
||
inf_count = np.isinf(y_np).sum()
|
||
nan_count = np.isnan(y_np).sum()
|
||
print(f" inf数量: {inf_count}")
|
||
print(f" nan数量: {nan_count}")
|
||
|
||
assert inf_count == 0, f"{split_name} 标签含 inf: {inf_count}"
|
||
assert nan_count == 0, f"{split_name} 标签含 nan: {nan_count}"
|
||
|
||
def test_processed_data_quality_with_pipeline(self, pipeline_data):
|
||
"""诊断2: 使用 DataPipeline 处理后的特征数据质量"""
|
||
data = pipeline_data["data"]
|
||
feature_cols = pipeline_data["feature_cols"]
|
||
|
||
for split_name in ["train", "val", "test"]:
|
||
X = data[split_name]["X"]
|
||
y = data[split_name]["y"]
|
||
|
||
X_np = X.to_numpy().astype(np.float32)
|
||
y_np = y.to_numpy().astype(np.float32)
|
||
|
||
print(f"\n[诊断2-{split_name}] 处理后数据质量:")
|
||
print(f" X 形状: {X_np.shape}, dtype: {X_np.dtype}")
|
||
print(f" y 形状: {y_np.shape}, dtype: {y_np.dtype}")
|
||
print(f" X中NaN: {np.isnan(X_np).sum()}")
|
||
print(f" X中Inf: {np.isinf(X_np).sum()}")
|
||
print(f" y中NaN: {np.isnan(y_np).sum()}")
|
||
print(f" y中Inf: {np.isinf(y_np).sum()}")
|
||
|
||
assert np.isnan(X_np).sum() == 0, f"{split_name} X含NaN"
|
||
assert np.isnan(y_np).sum() == 0, f"{split_name} y含NaN"
|
||
assert np.isinf(X_np).sum() == 0, f"{split_name} X含Inf"
|
||
assert np.isinf(y_np).sum() == 0, f"{split_name} y含Inf"
|
||
|
||
def test_train_skip_days_functionality(self, engine_and_factor_manager):
|
||
"""诊断3: train_skip_days 功能验证"""
|
||
engine = engine_and_factor_manager["engine"]
|
||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||
|
||
print("\n[诊断3] train_skip_days 功能验证:")
|
||
|
||
# 创建一个跳过50天数据的 pipeline
|
||
skip_days = 50
|
||
pipeline_with_skip = DataPipeline(
|
||
factor_manager=factor_manager,
|
||
processor_configs=[
|
||
(NullFiller, {"strategy": "mean"}),
|
||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||
(StandardScaler, {}),
|
||
],
|
||
filters=[STFilter(data_router=engine.router)],
|
||
stock_pool_filter_func=stock_pool_filter,
|
||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||
train_skip_days=skip_days,
|
||
)
|
||
|
||
# 准备数据
|
||
data_with_skip = pipeline_with_skip.prepare_data(
|
||
engine=engine,
|
||
date_range=TEST_DATE_RANGE,
|
||
label_name=LABEL_NAME,
|
||
verbose=True,
|
||
)
|
||
|
||
# 获取原始数据(不跳过天数)用于对比
|
||
pipeline_no_skip = DataPipeline(
|
||
factor_manager=factor_manager,
|
||
processor_configs=[
|
||
(NullFiller, {"strategy": "mean"}),
|
||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||
(StandardScaler, {}),
|
||
],
|
||
filters=[STFilter(data_router=engine.router)],
|
||
stock_pool_filter_func=stock_pool_filter,
|
||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||
train_skip_days=0,
|
||
)
|
||
|
||
data_no_skip = pipeline_no_skip.prepare_data(
|
||
engine=engine,
|
||
date_range=TEST_DATE_RANGE,
|
||
label_name=LABEL_NAME,
|
||
verbose=False,
|
||
)
|
||
|
||
# 验证训练数据减少了
|
||
train_with_skip = data_with_skip["train"]["raw_data"]
|
||
train_no_skip = data_no_skip["train"]["raw_data"]
|
||
|
||
print(f"\n 对比结果:")
|
||
print(f" 不跳过时训练数据: {len(train_no_skip)} 条")
|
||
print(f" 跳过{skip_days}天后: {len(train_with_skip)} 条")
|
||
print(f" 减少: {len(train_no_skip) - len(train_with_skip)} 条")
|
||
|
||
# 验证验证集和测试集不受影响
|
||
assert len(data_with_skip["val"]["raw_data"]) == len(
|
||
data_no_skip["val"]["raw_data"]
|
||
), "val 数据不应该受 train_skip_days 影响"
|
||
assert len(data_with_skip["test"]["raw_data"]) == len(
|
||
data_no_skip["test"]["raw_data"]
|
||
), "test 数据不应该受 train_skip_days 影响"
|
||
|
||
# 验证日期确实被跳过了
|
||
if len(train_no_skip) > 0:
|
||
dates_no_skip = sorted(train_no_skip["trade_date"].unique())
|
||
dates_with_skip = sorted(train_with_skip["trade_date"].unique())
|
||
|
||
print(f"\n 日期对比:")
|
||
print(
|
||
f" 不跳过 - 最早日期: {dates_no_skip[0]}, 共 {len(dates_no_skip)} 个交易日"
|
||
)
|
||
print(
|
||
f" 跳过 - 最早日期: {dates_with_skip[0]}, 共 {len(dates_with_skip)} 个交易日"
|
||
)
|
||
|
||
# 验证跳过的数据确实从更晚的日期开始
|
||
if len(dates_no_skip) > skip_days:
|
||
expected_start_date = dates_no_skip[skip_days]
|
||
assert dates_with_skip[0] == expected_start_date, (
|
||
f"预期从 {expected_start_date} 开始,实际从 {dates_with_skip[0]} 开始"
|
||
)
|
||
print(f" [通过] 正确跳过前 {skip_days} 个交易日")
|
||
|
||
def test_training_with_pipeline(self, pipeline_data):
|
||
"""诊断4: 使用 DataPipeline 处理后的数据进行训练测试"""
|
||
import torch
|
||
from torch.utils.data import DataLoader, TensorDataset
|
||
from tabm import TabM
|
||
|
||
data = pipeline_data["data"]
|
||
feature_cols = pipeline_data["feature_cols"]
|
||
|
||
# 获取训练数据
|
||
X_train_df = data["train"]["X"]
|
||
y_train_series = data["train"]["y"]
|
||
|
||
# 删除标签为NaN的行
|
||
valid_mask = y_train_series.is_not_null()
|
||
y_nan_count = y_train_series.null_count()
|
||
if y_nan_count > 0:
|
||
print(f" 发现 {y_nan_count} 个标签为NaN的行,将被删除")
|
||
X_train_df = X_train_df.filter(valid_mask)
|
||
y_train_series = y_train_series.filter(valid_mask)
|
||
|
||
X_train = X_train_df.to_numpy().astype(np.float32)
|
||
y_train = y_train_series.to_numpy().astype(np.float32)
|
||
|
||
# 只取前1000条加速测试
|
||
X_train = X_train[:1000]
|
||
y_train = y_train[:1000]
|
||
|
||
print(f"\n[诊断4] DataPipeline 数据训练测试:")
|
||
print(f" X 形状: {X_train.shape}")
|
||
print(f" y 形状: {y_train.shape}")
|
||
print(f" X中NaN: {np.isnan(X_train).sum()}, Inf: {np.isinf(X_train).sum()}")
|
||
print(f" y中NaN: {np.isnan(y_train).sum()}, Inf: {np.isinf(y_train).sum()}")
|
||
|
||
# 创建 TabM 模型
|
||
n_features = X_train.shape[1]
|
||
model = TabM.make(
|
||
n_num_features=n_features,
|
||
cat_cardinalities=[],
|
||
d_out=1,
|
||
n_blocks=2,
|
||
d_block=128,
|
||
dropout=0.1,
|
||
k=4,
|
||
)
|
||
|
||
# 训练一个 epoch
|
||
dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
|
||
loader = DataLoader(dataset, batch_size=256, shuffle=True)
|
||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||
criterion = torch.nn.MSELoss()
|
||
|
||
model.train()
|
||
losses = []
|
||
|
||
for batch_idx, (bx, by) in enumerate(loader):
|
||
optimizer.zero_grad()
|
||
outputs = model(bx) # [B, E, 1]
|
||
outputs_squeezed = outputs.squeeze(-1) # [B, E]
|
||
by_expanded = by.unsqueeze(-1).expand(-1, 4) # [B, E]
|
||
loss = criterion(outputs_squeezed, by_expanded)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
losses.append(loss.item())
|
||
print(f" 批次 {batch_idx + 1} loss: {loss.item():.6f}")
|
||
|
||
assert not torch.isnan(loss), f"批次 {batch_idx + 1} loss 为 nan!"
|
||
|
||
if batch_idx >= 2: # 只测试前3个批次
|
||
break
|
||
|
||
print(f" [通过] 所有批次 loss 正常,无 NaN")
|
||
|
||
|
||
class TestTrainSkipDaysEdgeCases:
|
||
"""train_skip_days 边界情况测试"""
|
||
|
||
@pytest.fixture(scope="class")
|
||
def engine_and_factor_manager(self):
|
||
"""准备 FactorEngine 和 FactorManager"""
|
||
engine = FactorEngine()
|
||
|
||
factor_manager = FactorManager(
|
||
selected_factors=SELECTED_FACTORS,
|
||
factor_definitions=FACTOR_DEFINITIONS,
|
||
label_factor=LABEL_FACTOR,
|
||
excluded_factors=EXCLUDED_FACTORS,
|
||
)
|
||
|
||
return {
|
||
"engine": engine,
|
||
"factor_manager": factor_manager,
|
||
}
|
||
|
||
def test_skip_more_than_available_days(self, engine_and_factor_manager):
|
||
"""测试跳过天数超过可用天数的情况"""
|
||
engine = engine_and_factor_manager["engine"]
|
||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||
|
||
print("\n[边界测试] 跳过天数超过可用天数:")
|
||
|
||
# 使用一个很大的跳过天数
|
||
pipeline = DataPipeline(
|
||
factor_manager=factor_manager,
|
||
processor_configs=[(NullFiller, {"strategy": "mean"})],
|
||
filters=[STFilter(data_router=engine.router)],
|
||
stock_pool_filter_func=stock_pool_filter,
|
||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||
train_skip_days=1000, # 超过测试期间的交易日数
|
||
)
|
||
|
||
# 准备数据(应该能正常运行并发出警告)
|
||
data = pipeline.prepare_data(
|
||
engine=engine,
|
||
date_range=TEST_DATE_RANGE,
|
||
label_name=LABEL_NAME,
|
||
verbose=True,
|
||
)
|
||
|
||
# 即使跳过天数很多,也应该有数据返回
|
||
# 如果交易日数少于跳过天数,应该保留所有数据(只发出警告)
|
||
train_data = data["train"]["raw_data"]
|
||
print(f" 训练数据量: {len(train_data)} 条")
|
||
print(f" [通过] 程序未崩溃")
|
||
|
||
def test_skip_zero_days(self, engine_and_factor_manager):
|
||
"""测试跳过0天(即不跳过)"""
|
||
engine = engine_and_factor_manager["engine"]
|
||
factor_manager = engine_and_factor_manager["factor_manager"]
|
||
|
||
print("\n[边界测试] 跳过0天:")
|
||
|
||
pipeline = DataPipeline(
|
||
factor_manager=factor_manager,
|
||
processor_configs=[(NullFiller, {"strategy": "mean"})],
|
||
filters=[STFilter(data_router=engine.router)],
|
||
stock_pool_filter_func=stock_pool_filter,
|
||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||
train_skip_days=0,
|
||
)
|
||
|
||
data = pipeline.prepare_data(
|
||
engine=engine,
|
||
date_range=TEST_DATE_RANGE,
|
||
label_name=LABEL_NAME,
|
||
verbose=False,
|
||
)
|
||
|
||
train_data = data["train"]["raw_data"]
|
||
print(f" 训练数据量: {len(train_data)} 条")
|
||
print(f" [通过] skip=0 时数据正常")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v", "-s"])
|