feat(training): 新增 TabM 模型支持及数据质量优化
- 添加 TabMModel、TabPFNModel 深度学习模型实现 - 新增 DataQualityAnalyzer 进行训练前数据质量诊断 - 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性 - 支持 train_skip_days 参数跳过训练初期数据不足期 - Pipeline 自动清理标签为 NaN 的样本
This commit is contained in:
492
tests/test_nan_step_by_step.py
Normal file
492
tests/test_nan_step_by_step.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""NaN 问题逐步诊断测试 - 精确定位问题环节
|
||||
|
||||
此测试会逐步检查 DataPipeline 的每个处理步骤,精确定位 NaN 产生的位置。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
)
|
||||
from src.training.components.filters import STFilter
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
)
|
||||
|
||||
# 测试配置
|
||||
EXCLUDED_FACTORS = [f"GTJA_alpha{i:03d}" for i in range(1, 50)] # 排除前50个加速测试
|
||||
TEST_DATE_RANGE = {
|
||||
"train": ("20200101", "20201231"), # 一整年数据
|
||||
"val": ("20210101", "20210331"),
|
||||
"test": ("20210401", "20210630"),
|
||||
}
|
||||
|
||||
|
||||
class TestNaNStepByStep:
|
||||
"""逐步诊断 NaN 问题的测试类"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def base_data(self):
|
||||
"""准备基础数据(未经过任何处理)"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[Fixture] 准备基础数据...")
|
||||
|
||||
engine = FactorEngine()
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
# 注册因子
|
||||
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
|
||||
print(f" 特征数: {len(feature_cols)}")
|
||||
|
||||
# 计算完整日期范围
|
||||
all_start = min(
|
||||
TEST_DATE_RANGE["train"][0],
|
||||
TEST_DATE_RANGE["val"][0],
|
||||
TEST_DATE_RANGE["test"][0],
|
||||
)
|
||||
all_end = max(
|
||||
TEST_DATE_RANGE["train"][1],
|
||||
TEST_DATE_RANGE["val"][1],
|
||||
TEST_DATE_RANGE["test"][1],
|
||||
)
|
||||
|
||||
# 计算因子
|
||||
raw_data = engine.compute(
|
||||
factor_names=feature_cols + [LABEL_NAME],
|
||||
start_date=all_start,
|
||||
end_date=all_end,
|
||||
)
|
||||
print(f" 原始数据形状: {raw_data.shape}")
|
||||
|
||||
return {
|
||||
"engine": engine,
|
||||
"factor_manager": factor_manager,
|
||||
"feature_cols": feature_cols,
|
||||
"raw_data": raw_data,
|
||||
}
|
||||
|
||||
def test_step_0_raw_data(self, base_data):
|
||||
"""步骤0: 检查原始数据中的 NaN"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤0] 检查原始数据中的 NaN")
|
||||
|
||||
raw_data = base_data["raw_data"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
|
||||
nan_stats = self._check_nan_in_df(raw_data, feature_cols, "原始数据")
|
||||
|
||||
# 记录有 NaN 的列
|
||||
print(f" 含 NaN 的特征列数: {len(nan_stats['cols_with_nan'])}")
|
||||
if nan_stats["cols_with_nan"]:
|
||||
print(f" 示例: {nan_stats['cols_with_nan'][:5]}")
|
||||
|
||||
return nan_stats
|
||||
|
||||
def test_step_1_after_st_filter(self, base_data):
|
||||
"""步骤1: 检查 STFilter 后的 NaN"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤1] 检查 STFilter 后的 NaN")
|
||||
|
||||
raw_data = base_data["raw_data"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
engine = base_data["engine"]
|
||||
|
||||
st_filter = STFilter(data_router=engine.router)
|
||||
filtered_data = st_filter.filter(raw_data)
|
||||
|
||||
print(f" 过滤后数据形状: {filtered_data.shape}")
|
||||
print(f" 删除记录数: {len(raw_data) - len(filtered_data)}")
|
||||
|
||||
nan_stats = self._check_nan_in_df(filtered_data, feature_cols, "STFilter后")
|
||||
|
||||
# 对比步骤0,看是否有新增 NaN
|
||||
step0_nan = self.test_step_0_raw_data(base_data)
|
||||
if nan_stats["total_nan"] != step0_nan["total_nan"]:
|
||||
print(
|
||||
f" [警告] NaN 数量变化: {step0_nan['total_nan']} -> {nan_stats['total_nan']}"
|
||||
)
|
||||
|
||||
return nan_stats
|
||||
|
||||
def test_step_2_after_stock_pool(self, base_data):
|
||||
"""步骤2: 检查股票池筛选后的 NaN"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤2] 检查股票池筛选后的 NaN")
|
||||
|
||||
raw_data = base_data["raw_data"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
engine = base_data["engine"]
|
||||
|
||||
# 先应用 STFilter
|
||||
st_filter = STFilter(data_router=engine.router)
|
||||
filtered_data = st_filter.filter(raw_data)
|
||||
|
||||
# 再应用股票池筛选
|
||||
from src.training.core.stock_pool_manager import StockPoolManager
|
||||
|
||||
pool_manager = StockPoolManager(
|
||||
filter_func=stock_pool_filter,
|
||||
required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
data_router=engine.router,
|
||||
)
|
||||
pool_data = pool_manager.filter_and_select_daily(filtered_data)
|
||||
|
||||
print(f" 筛选后数据形状: {pool_data.shape}")
|
||||
print(f" 删除记录数: {len(filtered_data) - len(pool_data)}")
|
||||
|
||||
nan_stats = self._check_nan_in_df(pool_data, feature_cols, "股票池筛选后")
|
||||
return nan_stats
|
||||
|
||||
def test_step_3_train_split_without_skip(self, base_data):
|
||||
"""步骤3: 检查训练集划分后的 NaN(不跳过天数)"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤3] 检查训练集划分后的 NaN(不跳过天数)")
|
||||
|
||||
raw_data = base_data["raw_data"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
engine = base_data["engine"]
|
||||
|
||||
# 应用过滤器
|
||||
st_filter = STFilter(data_router=engine.router)
|
||||
filtered_data = st_filter.filter(raw_data)
|
||||
|
||||
from src.training.core.stock_pool_manager import StockPoolManager
|
||||
|
||||
pool_manager = StockPoolManager(
|
||||
filter_func=stock_pool_filter,
|
||||
required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
data_router=engine.router,
|
||||
)
|
||||
pool_data = pool_manager.filter_and_select_daily(filtered_data)
|
||||
|
||||
# 划分训练集
|
||||
train_start, train_end = TEST_DATE_RANGE["train"]
|
||||
train_mask = (pool_data["trade_date"] >= train_start) & (
|
||||
pool_data["trade_date"] <= train_end
|
||||
)
|
||||
train_df = pool_data.filter(train_mask)
|
||||
|
||||
print(f" 训练集形状: {train_df.shape}")
|
||||
|
||||
# 统计交易日数量
|
||||
unique_dates = train_df["trade_date"].unique().sort()
|
||||
print(f" 训练集交易日数量: {len(unique_dates)}")
|
||||
print(f" 日期范围: {unique_dates[0]} ~ {unique_dates[-1]}")
|
||||
|
||||
nan_stats = self._check_nan_in_df(train_df, feature_cols, "训练集(不跳过)")
|
||||
|
||||
# 返回训练集供后续测试使用
|
||||
return {
|
||||
"nan_stats": nan_stats,
|
||||
"train_df": train_df,
|
||||
"unique_dates": unique_dates,
|
||||
}
|
||||
|
||||
def test_step_4_train_split_with_skip(self, base_data):
|
||||
"""步骤4: 检查训练集划分后的 NaN(跳过前252天)"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤4] 检查训练集划分后的 NaN(跳过前252天)")
|
||||
|
||||
step3_result = self.test_step_3_train_split_without_skip(base_data)
|
||||
train_df = step3_result["train_df"]
|
||||
unique_dates = step3_result["unique_dates"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
|
||||
# 跳过前252天
|
||||
skip_days = 252
|
||||
if len(unique_dates) > skip_days:
|
||||
start_date = unique_dates[skip_days]
|
||||
train_df_skipped = train_df.filter(pl.col("trade_date") >= start_date)
|
||||
print(f" 跳过前{skip_days}天后,从 {start_date} 开始")
|
||||
print(f" 跳过后训练集形状: {train_df_skipped.shape}")
|
||||
print(f" 跳过记录数: {len(train_df) - len(train_df_skipped)}")
|
||||
else:
|
||||
train_df_skipped = train_df
|
||||
print(
|
||||
f" [警告] 训练集交易日数({len(unique_dates)})少于跳过天数({skip_days}),未跳过"
|
||||
)
|
||||
|
||||
nan_stats = self._check_nan_in_df(
|
||||
train_df_skipped, feature_cols, "训练集(跳过252天)"
|
||||
)
|
||||
|
||||
return {
|
||||
"nan_stats": nan_stats,
|
||||
"train_df": train_df_skipped,
|
||||
}
|
||||
|
||||
def test_step_5_after_null_filler(self, base_data):
|
||||
"""步骤5: 检查 NullFiller 后的 NaN"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤5] 检查 NullFiller 后的 NaN")
|
||||
|
||||
step4_result = self.test_step_4_train_split_with_skip(base_data)
|
||||
train_df = step4_result["train_df"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
|
||||
print(f" 处理前数据形状: {train_df.shape}")
|
||||
|
||||
# 应用 NullFiller
|
||||
null_filler = NullFiller(
|
||||
feature_cols=feature_cols, strategy="mean", by_date=True
|
||||
)
|
||||
after_null = null_filler.fit_transform(train_df)
|
||||
|
||||
print(f" 处理后数据形状: {after_null.shape}")
|
||||
|
||||
nan_stats = self._check_nan_in_df(after_null, feature_cols, "NullFiller后")
|
||||
|
||||
# 检查哪些列还有 NaN
|
||||
if nan_stats["cols_with_nan"]:
|
||||
print(
|
||||
f" [错误] NullFiller 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:"
|
||||
)
|
||||
for col in nan_stats["cols_with_nan"][:10]:
|
||||
count = after_null[col].null_count()
|
||||
dtype = after_null[col].dtype
|
||||
print(f" {col}: {count} 个 NaN, dtype={dtype}")
|
||||
|
||||
return {
|
||||
"nan_stats": nan_stats,
|
||||
"after_null": after_null,
|
||||
}
|
||||
|
||||
def test_step_6_after_winsorizer(self, base_data):
|
||||
"""步骤6: 检查 Winsorizer 后的 NaN"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤6] 检查 Winsorizer 后的 NaN")
|
||||
|
||||
step5_result = self.test_step_5_after_null_filler(base_data)
|
||||
after_null = step5_result["after_null"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
|
||||
# 应用 Winsorizer
|
||||
winsorizer = Winsorizer(
|
||||
feature_cols=feature_cols, lower=0.01, upper=0.99, by_date=False
|
||||
)
|
||||
after_winsor = winsorizer.fit_transform(after_null)
|
||||
|
||||
nan_stats = self._check_nan_in_df(after_winsor, feature_cols, "Winsorizer后")
|
||||
|
||||
# 检查哪些列还有 NaN
|
||||
if nan_stats["cols_with_nan"]:
|
||||
print(
|
||||
f" [错误] Winsorizer 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:"
|
||||
)
|
||||
for col in nan_stats["cols_with_nan"][:10]:
|
||||
count = after_winsor[col].null_count()
|
||||
dtype = after_winsor[col].dtype
|
||||
print(f" {col}: {count} 个 NaN, dtype={dtype}")
|
||||
|
||||
return {
|
||||
"nan_stats": nan_stats,
|
||||
"after_winsor": after_winsor,
|
||||
}
|
||||
|
||||
def test_step_7_after_standard_scaler(self, base_data):
|
||||
"""步骤7: 检查 StandardScaler 后的 NaN"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤7] 检查 StandardScaler 后的 NaN")
|
||||
|
||||
step6_result = self.test_step_6_after_winsorizer(base_data)
|
||||
after_winsor = step6_result["after_winsor"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
|
||||
# 在应用 StandardScaler 之前,检查那些后来出问题的列
|
||||
print("\n [预检查] StandardScaler 前,检查关键列...")
|
||||
problematic_cols = [
|
||||
"GTJA_alpha062",
|
||||
"GTJA_alpha073",
|
||||
"GTJA_alpha085",
|
||||
"GTJA_alpha087",
|
||||
"GTJA_alpha092",
|
||||
"GTJA_alpha103",
|
||||
"GTJA_alpha104",
|
||||
"GTJA_alpha117",
|
||||
"GTJA_alpha124",
|
||||
"GTJA_alpha131",
|
||||
]
|
||||
for col in problematic_cols:
|
||||
if col in after_winsor.columns:
|
||||
null_count = after_winsor[col].null_count()
|
||||
dtype = after_winsor[col].dtype
|
||||
min_val = after_winsor[col].min()
|
||||
max_val = after_winsor[col].max()
|
||||
print(
|
||||
f" {col}: null={null_count}, dtype={dtype}, min={min_val}, max={max_val}"
|
||||
)
|
||||
|
||||
# 应用 StandardScaler
|
||||
scaler = StandardScaler(feature_cols=feature_cols)
|
||||
after_scaler = scaler.fit_transform(after_winsor)
|
||||
|
||||
# 检查 StandardScaler 学到的统计量
|
||||
print("\n [统计量检查] StandardScaler 学到的统计量...")
|
||||
for col in problematic_cols:
|
||||
if col in scaler.mean_:
|
||||
print(f" {col}: mean={scaler.mean_[col]}, std={scaler.std_[col]}")
|
||||
else:
|
||||
print(f" {col}: [未学到统计量]")
|
||||
|
||||
nan_stats = self._check_nan_in_df(
|
||||
after_scaler, feature_cols, "StandardScaler后"
|
||||
)
|
||||
|
||||
# 检查哪些列还有 NaN
|
||||
if nan_stats["cols_with_nan"]:
|
||||
print(
|
||||
f" [错误] StandardScaler 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:"
|
||||
)
|
||||
for col in nan_stats["cols_with_nan"][:10]:
|
||||
count = after_scaler[col].null_count()
|
||||
dtype = after_scaler[col].dtype
|
||||
print(f" {col}: {count} 个 NaN, dtype={dtype}")
|
||||
|
||||
# 检查这列是否学到了统计量
|
||||
if col in scaler.mean_:
|
||||
print(
|
||||
f" mean={scaler.mean_[col]:.4f}, std={scaler.std_[col]:.4f}"
|
||||
)
|
||||
else:
|
||||
print(f" [警告] 未学到统计量!")
|
||||
|
||||
return {
|
||||
"nan_stats": nan_stats,
|
||||
"after_scaler": after_scaler,
|
||||
"scaler": scaler,
|
||||
}
|
||||
|
||||
def test_step_8_extract_X(self, base_data):
|
||||
"""步骤8: 检查提取 X 后的 NaN(转换为 numpy)"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[步骤8] 检查提取 X 后的 NaN")
|
||||
|
||||
step7_result = self.test_step_7_after_standard_scaler(base_data)
|
||||
after_scaler = step7_result["after_scaler"]
|
||||
feature_cols = base_data["feature_cols"]
|
||||
|
||||
# 提取 X
|
||||
X_df = after_scaler.select(feature_cols)
|
||||
print(f" X DataFrame 形状: {X_df.shape}")
|
||||
|
||||
# 对比 DataFrame 和 select 后的 null 数量
|
||||
print("\n [对比] DataFrame vs select 后的 null 数量:")
|
||||
mismatched = []
|
||||
for col in feature_cols[:20]: # 只检查前20个
|
||||
null_in_df = after_scaler[col].null_count()
|
||||
null_in_x = X_df[col].null_count()
|
||||
if null_in_df != null_in_x:
|
||||
mismatched.append((col, null_in_df, null_in_x))
|
||||
|
||||
if mismatched:
|
||||
print(f" [警告] 发现 {len(mismatched)} 列不匹配:")
|
||||
for col, df_null, x_null in mismatched[:10]:
|
||||
print(f" {col}: DataFrame={df_null}, X={x_null}")
|
||||
else:
|
||||
print(f" [通过] 所有列的 null 数量一致")
|
||||
|
||||
# 转换为 numpy
|
||||
X_np = X_df.to_numpy()
|
||||
print(f"\n X numpy 形状: {X_np.shape}")
|
||||
|
||||
nan_count = np.isnan(X_np).sum()
|
||||
print(f" X 中 NaN 总数: {nan_count}")
|
||||
|
||||
if nan_count > 0:
|
||||
# 找出哪些列有 NaN
|
||||
nan_by_col = []
|
||||
for i, col in enumerate(feature_cols):
|
||||
col_nan = np.isnan(X_np[:, i]).sum()
|
||||
if col_nan > 0:
|
||||
nan_by_col.append((col, col_nan))
|
||||
|
||||
print(f"\n [错误] 含 NaN 的特征列数: {len(nan_by_col)}")
|
||||
for col, count in nan_by_col[:10]:
|
||||
# 检查原始 DataFrame 中的情况
|
||||
df_null = after_scaler[col].null_count()
|
||||
dtype = after_scaler[col].dtype
|
||||
|
||||
# 检查是否有 Infinity
|
||||
inf_count_pos = (after_scaler[col] == float("inf")).sum()
|
||||
inf_count_neg = (after_scaler[col] == float("-inf")).sum()
|
||||
|
||||
# 检查 min/max
|
||||
col_min = after_scaler[col].min()
|
||||
col_max = after_scaler[col].max()
|
||||
|
||||
print(
|
||||
f" {col}: numpy中{count}个NaN, DataFrame中{df_null}个null, dtype={dtype}"
|
||||
)
|
||||
print(f" min={col_min}, max={col_max}")
|
||||
print(f" +inf={inf_count_pos}, -inf={inf_count_neg}")
|
||||
|
||||
# 如果有 inf,显示一些样本值
|
||||
if inf_count_pos > 0 or inf_count_neg > 0:
|
||||
sample_vals = after_scaler[col].drop_nulls().tail(5).to_list()
|
||||
print(f" 样本值: {sample_vals}")
|
||||
|
||||
# 断言失败,显示详细信息
|
||||
assert False, f"X 中含 {nan_count} 个 NaN,涉及 {len(nan_by_col)} 个特征列"
|
||||
else:
|
||||
print("\n [通过] X 中无 NaN!")
|
||||
|
||||
def _check_nan_in_df(
|
||||
self, df: pl.DataFrame, feature_cols: list, step_name: str
|
||||
) -> dict:
|
||||
"""检查 DataFrame 中的 NaN 统计信息
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'total_nan': 总NaN数,
|
||||
'cols_with_nan': 含NaN的列名列表,
|
||||
'nan_by_col': {列名: NaN数} 的字典
|
||||
}
|
||||
"""
|
||||
nan_by_col = {}
|
||||
total_nan = 0
|
||||
|
||||
for col in feature_cols:
|
||||
null_count = df[col].null_count()
|
||||
if null_count > 0:
|
||||
nan_by_col[col] = null_count
|
||||
total_nan += null_count
|
||||
|
||||
cols_with_nan = list(nan_by_col.keys())
|
||||
|
||||
print(f" {step_name}:")
|
||||
print(f" 总记录数: {len(df)}")
|
||||
print(f" 特征列数: {len(feature_cols)}")
|
||||
print(f" 总NaN数: {total_nan}")
|
||||
print(f" 含NaN的列数: {len(cols_with_nan)}")
|
||||
|
||||
if cols_with_nan and len(cols_with_nan) <= 5:
|
||||
print(f" 含NaN的列: {cols_with_nan}")
|
||||
elif cols_with_nan:
|
||||
print(f" 含NaN的列(前5): {cols_with_nan[:5]}...")
|
||||
|
||||
return {
|
||||
"total_nan": total_nan,
|
||||
"cols_with_nan": cols_with_nan,
|
||||
"nan_by_col": nan_by_col,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user