Files
ProStock/tests/test_nan_step_by_step.py
liaozhaorun 36a3ccbcc8 feat(training): 新增 TabM 模型支持及数据质量优化
- 添加 TabMModel、TabPFNModel 深度学习模型实现
- 新增 DataQualityAnalyzer 进行训练前数据质量诊断
- 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性
- 支持 train_skip_days 参数跳过训练初期数据不足期
- Pipeline 自动清理标签为 NaN 的样本
2026-03-31 23:11:21 +08:00

493 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""NaN 问题逐步诊断测试 - 精确定位问题环节
此测试会逐步检查 DataPipeline 的每个处理步骤,精确定位 NaN 产生的位置。
"""
import numpy as np
import polars as pl
import pytest
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
NullFiller,
Winsorizer,
StandardScaler,
)
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
)
# 测试配置
EXCLUDED_FACTORS = [f"GTJA_alpha{i:03d}" for i in range(1, 50)] # 排除前50个加速测试
TEST_DATE_RANGE = {
"train": ("20200101", "20201231"), # 一整年数据
"val": ("20210101", "20210331"),
"test": ("20210401", "20210630"),
}
class TestNaNStepByStep:
"""逐步诊断 NaN 问题的测试类"""
@pytest.fixture(scope="class")
def base_data(self):
"""准备基础数据(未经过任何处理)"""
print("\n" + "=" * 80)
print("[Fixture] 准备基础数据...")
engine = FactorEngine()
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 注册因子
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
print(f" 特征数: {len(feature_cols)}")
# 计算完整日期范围
all_start = min(
TEST_DATE_RANGE["train"][0],
TEST_DATE_RANGE["val"][0],
TEST_DATE_RANGE["test"][0],
)
all_end = max(
TEST_DATE_RANGE["train"][1],
TEST_DATE_RANGE["val"][1],
TEST_DATE_RANGE["test"][1],
)
# 计算因子
raw_data = engine.compute(
factor_names=feature_cols + [LABEL_NAME],
start_date=all_start,
end_date=all_end,
)
print(f" 原始数据形状: {raw_data.shape}")
return {
"engine": engine,
"factor_manager": factor_manager,
"feature_cols": feature_cols,
"raw_data": raw_data,
}
def test_step_0_raw_data(self, base_data):
"""步骤0: 检查原始数据中的 NaN"""
print("\n" + "=" * 80)
print("[步骤0] 检查原始数据中的 NaN")
raw_data = base_data["raw_data"]
feature_cols = base_data["feature_cols"]
nan_stats = self._check_nan_in_df(raw_data, feature_cols, "原始数据")
# 记录有 NaN 的列
print(f" 含 NaN 的特征列数: {len(nan_stats['cols_with_nan'])}")
if nan_stats["cols_with_nan"]:
print(f" 示例: {nan_stats['cols_with_nan'][:5]}")
return nan_stats
def test_step_1_after_st_filter(self, base_data):
"""步骤1: 检查 STFilter 后的 NaN"""
print("\n" + "=" * 80)
print("[步骤1] 检查 STFilter 后的 NaN")
raw_data = base_data["raw_data"]
feature_cols = base_data["feature_cols"]
engine = base_data["engine"]
st_filter = STFilter(data_router=engine.router)
filtered_data = st_filter.filter(raw_data)
print(f" 过滤后数据形状: {filtered_data.shape}")
print(f" 删除记录数: {len(raw_data) - len(filtered_data)}")
nan_stats = self._check_nan_in_df(filtered_data, feature_cols, "STFilter后")
# 对比步骤0看是否有新增 NaN
step0_nan = self.test_step_0_raw_data(base_data)
if nan_stats["total_nan"] != step0_nan["total_nan"]:
print(
f" [警告] NaN 数量变化: {step0_nan['total_nan']} -> {nan_stats['total_nan']}"
)
return nan_stats
def test_step_2_after_stock_pool(self, base_data):
"""步骤2: 检查股票池筛选后的 NaN"""
print("\n" + "=" * 80)
print("[步骤2] 检查股票池筛选后的 NaN")
raw_data = base_data["raw_data"]
feature_cols = base_data["feature_cols"]
engine = base_data["engine"]
# 先应用 STFilter
st_filter = STFilter(data_router=engine.router)
filtered_data = st_filter.filter(raw_data)
# 再应用股票池筛选
from src.training.core.stock_pool_manager import StockPoolManager
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
data_router=engine.router,
)
pool_data = pool_manager.filter_and_select_daily(filtered_data)
print(f" 筛选后数据形状: {pool_data.shape}")
print(f" 删除记录数: {len(filtered_data) - len(pool_data)}")
nan_stats = self._check_nan_in_df(pool_data, feature_cols, "股票池筛选后")
return nan_stats
def test_step_3_train_split_without_skip(self, base_data):
"""步骤3: 检查训练集划分后的 NaN不跳过天数"""
print("\n" + "=" * 80)
print("[步骤3] 检查训练集划分后的 NaN不跳过天数")
raw_data = base_data["raw_data"]
feature_cols = base_data["feature_cols"]
engine = base_data["engine"]
# 应用过滤器
st_filter = STFilter(data_router=engine.router)
filtered_data = st_filter.filter(raw_data)
from src.training.core.stock_pool_manager import StockPoolManager
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
data_router=engine.router,
)
pool_data = pool_manager.filter_and_select_daily(filtered_data)
# 划分训练集
train_start, train_end = TEST_DATE_RANGE["train"]
train_mask = (pool_data["trade_date"] >= train_start) & (
pool_data["trade_date"] <= train_end
)
train_df = pool_data.filter(train_mask)
print(f" 训练集形状: {train_df.shape}")
# 统计交易日数量
unique_dates = train_df["trade_date"].unique().sort()
print(f" 训练集交易日数量: {len(unique_dates)}")
print(f" 日期范围: {unique_dates[0]} ~ {unique_dates[-1]}")
nan_stats = self._check_nan_in_df(train_df, feature_cols, "训练集(不跳过)")
# 返回训练集供后续测试使用
return {
"nan_stats": nan_stats,
"train_df": train_df,
"unique_dates": unique_dates,
}
def test_step_4_train_split_with_skip(self, base_data):
"""步骤4: 检查训练集划分后的 NaN跳过前252天"""
print("\n" + "=" * 80)
print("[步骤4] 检查训练集划分后的 NaN跳过前252天")
step3_result = self.test_step_3_train_split_without_skip(base_data)
train_df = step3_result["train_df"]
unique_dates = step3_result["unique_dates"]
feature_cols = base_data["feature_cols"]
# 跳过前252天
skip_days = 252
if len(unique_dates) > skip_days:
start_date = unique_dates[skip_days]
train_df_skipped = train_df.filter(pl.col("trade_date") >= start_date)
print(f" 跳过前{skip_days}天后,从 {start_date} 开始")
print(f" 跳过后训练集形状: {train_df_skipped.shape}")
print(f" 跳过记录数: {len(train_df) - len(train_df_skipped)}")
else:
train_df_skipped = train_df
print(
f" [警告] 训练集交易日数({len(unique_dates)})少于跳过天数({skip_days}),未跳过"
)
nan_stats = self._check_nan_in_df(
train_df_skipped, feature_cols, "训练集(跳过252天)"
)
return {
"nan_stats": nan_stats,
"train_df": train_df_skipped,
}
def test_step_5_after_null_filler(self, base_data):
"""步骤5: 检查 NullFiller 后的 NaN"""
print("\n" + "=" * 80)
print("[步骤5] 检查 NullFiller 后的 NaN")
step4_result = self.test_step_4_train_split_with_skip(base_data)
train_df = step4_result["train_df"]
feature_cols = base_data["feature_cols"]
print(f" 处理前数据形状: {train_df.shape}")
# 应用 NullFiller
null_filler = NullFiller(
feature_cols=feature_cols, strategy="mean", by_date=True
)
after_null = null_filler.fit_transform(train_df)
print(f" 处理后数据形状: {after_null.shape}")
nan_stats = self._check_nan_in_df(after_null, feature_cols, "NullFiller后")
# 检查哪些列还有 NaN
if nan_stats["cols_with_nan"]:
print(
f" [错误] NullFiller 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:"
)
for col in nan_stats["cols_with_nan"][:10]:
count = after_null[col].null_count()
dtype = after_null[col].dtype
print(f" {col}: {count} 个 NaN, dtype={dtype}")
return {
"nan_stats": nan_stats,
"after_null": after_null,
}
def test_step_6_after_winsorizer(self, base_data):
"""步骤6: 检查 Winsorizer 后的 NaN"""
print("\n" + "=" * 80)
print("[步骤6] 检查 Winsorizer 后的 NaN")
step5_result = self.test_step_5_after_null_filler(base_data)
after_null = step5_result["after_null"]
feature_cols = base_data["feature_cols"]
# 应用 Winsorizer
winsorizer = Winsorizer(
feature_cols=feature_cols, lower=0.01, upper=0.99, by_date=False
)
after_winsor = winsorizer.fit_transform(after_null)
nan_stats = self._check_nan_in_df(after_winsor, feature_cols, "Winsorizer后")
# 检查哪些列还有 NaN
if nan_stats["cols_with_nan"]:
print(
f" [错误] Winsorizer 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:"
)
for col in nan_stats["cols_with_nan"][:10]:
count = after_winsor[col].null_count()
dtype = after_winsor[col].dtype
print(f" {col}: {count} 个 NaN, dtype={dtype}")
return {
"nan_stats": nan_stats,
"after_winsor": after_winsor,
}
def test_step_7_after_standard_scaler(self, base_data):
"""步骤7: 检查 StandardScaler 后的 NaN"""
print("\n" + "=" * 80)
print("[步骤7] 检查 StandardScaler 后的 NaN")
step6_result = self.test_step_6_after_winsorizer(base_data)
after_winsor = step6_result["after_winsor"]
feature_cols = base_data["feature_cols"]
# 在应用 StandardScaler 之前,检查那些后来出问题的列
print("\n [预检查] StandardScaler 前,检查关键列...")
problematic_cols = [
"GTJA_alpha062",
"GTJA_alpha073",
"GTJA_alpha085",
"GTJA_alpha087",
"GTJA_alpha092",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha117",
"GTJA_alpha124",
"GTJA_alpha131",
]
for col in problematic_cols:
if col in after_winsor.columns:
null_count = after_winsor[col].null_count()
dtype = after_winsor[col].dtype
min_val = after_winsor[col].min()
max_val = after_winsor[col].max()
print(
f" {col}: null={null_count}, dtype={dtype}, min={min_val}, max={max_val}"
)
# 应用 StandardScaler
scaler = StandardScaler(feature_cols=feature_cols)
after_scaler = scaler.fit_transform(after_winsor)
# 检查 StandardScaler 学到的统计量
print("\n [统计量检查] StandardScaler 学到的统计量...")
for col in problematic_cols:
if col in scaler.mean_:
print(f" {col}: mean={scaler.mean_[col]}, std={scaler.std_[col]}")
else:
print(f" {col}: [未学到统计量]")
nan_stats = self._check_nan_in_df(
after_scaler, feature_cols, "StandardScaler后"
)
# 检查哪些列还有 NaN
if nan_stats["cols_with_nan"]:
print(
f" [错误] StandardScaler 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:"
)
for col in nan_stats["cols_with_nan"][:10]:
count = after_scaler[col].null_count()
dtype = after_scaler[col].dtype
print(f" {col}: {count} 个 NaN, dtype={dtype}")
# 检查这列是否学到了统计量
if col in scaler.mean_:
print(
f" mean={scaler.mean_[col]:.4f}, std={scaler.std_[col]:.4f}"
)
else:
print(f" [警告] 未学到统计量!")
return {
"nan_stats": nan_stats,
"after_scaler": after_scaler,
"scaler": scaler,
}
def test_step_8_extract_X(self, base_data):
"""步骤8: 检查提取 X 后的 NaN转换为 numpy"""
print("\n" + "=" * 80)
print("[步骤8] 检查提取 X 后的 NaN")
step7_result = self.test_step_7_after_standard_scaler(base_data)
after_scaler = step7_result["after_scaler"]
feature_cols = base_data["feature_cols"]
# 提取 X
X_df = after_scaler.select(feature_cols)
print(f" X DataFrame 形状: {X_df.shape}")
# 对比 DataFrame 和 select 后的 null 数量
print("\n [对比] DataFrame vs select 后的 null 数量:")
mismatched = []
for col in feature_cols[:20]: # 只检查前20个
null_in_df = after_scaler[col].null_count()
null_in_x = X_df[col].null_count()
if null_in_df != null_in_x:
mismatched.append((col, null_in_df, null_in_x))
if mismatched:
print(f" [警告] 发现 {len(mismatched)} 列不匹配:")
for col, df_null, x_null in mismatched[:10]:
print(f" {col}: DataFrame={df_null}, X={x_null}")
else:
print(f" [通过] 所有列的 null 数量一致")
# 转换为 numpy
X_np = X_df.to_numpy()
print(f"\n X numpy 形状: {X_np.shape}")
nan_count = np.isnan(X_np).sum()
print(f" X 中 NaN 总数: {nan_count}")
if nan_count > 0:
# 找出哪些列有 NaN
nan_by_col = []
for i, col in enumerate(feature_cols):
col_nan = np.isnan(X_np[:, i]).sum()
if col_nan > 0:
nan_by_col.append((col, col_nan))
print(f"\n [错误] 含 NaN 的特征列数: {len(nan_by_col)}")
for col, count in nan_by_col[:10]:
# 检查原始 DataFrame 中的情况
df_null = after_scaler[col].null_count()
dtype = after_scaler[col].dtype
# 检查是否有 Infinity
inf_count_pos = (after_scaler[col] == float("inf")).sum()
inf_count_neg = (after_scaler[col] == float("-inf")).sum()
# 检查 min/max
col_min = after_scaler[col].min()
col_max = after_scaler[col].max()
print(
f" {col}: numpy中{count}个NaN, DataFrame中{df_null}个null, dtype={dtype}"
)
print(f" min={col_min}, max={col_max}")
print(f" +inf={inf_count_pos}, -inf={inf_count_neg}")
# 如果有 inf显示一些样本值
if inf_count_pos > 0 or inf_count_neg > 0:
sample_vals = after_scaler[col].drop_nulls().tail(5).to_list()
print(f" 样本值: {sample_vals}")
# 断言失败,显示详细信息
assert False, f"X 中含 {nan_count} 个 NaN涉及 {len(nan_by_col)} 个特征列"
else:
print("\n [通过] X 中无 NaN")
def _check_nan_in_df(
self, df: pl.DataFrame, feature_cols: list, step_name: str
) -> dict:
"""检查 DataFrame 中的 NaN 统计信息
Returns:
dict: {
'total_nan': 总NaN数,
'cols_with_nan': 含NaN的列名列表,
'nan_by_col': {列名: NaN数} 的字典
}
"""
nan_by_col = {}
total_nan = 0
for col in feature_cols:
null_count = df[col].null_count()
if null_count > 0:
nan_by_col[col] = null_count
total_nan += null_count
cols_with_nan = list(nan_by_col.keys())
print(f" {step_name}:")
print(f" 总记录数: {len(df)}")
print(f" 特征列数: {len(feature_cols)}")
print(f" 总NaN数: {total_nan}")
print(f" 含NaN的列数: {len(cols_with_nan)}")
if cols_with_nan and len(cols_with_nan) <= 5:
print(f" 含NaN的列: {cols_with_nan}")
elif cols_with_nan:
print(f" 含NaN的列(前5): {cols_with_nan[:5]}...")
return {
"total_nan": total_nan,
"cols_with_nan": cols_with_nan,
"nan_by_col": nan_by_col,
}
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])