diff --git a/src/training/pipeline.py b/src/training/pipeline.py index cb23715..ec10e67 100644 --- a/src/training/pipeline.py +++ b/src/training/pipeline.py @@ -374,8 +374,8 @@ class DataPipeline: split_data[split_name]["X"] = split_df.select(feature_cols) split_data[split_name]["y"] = split_df[label_name] - # 删除标签为 NaN 的行 - for split_name in ["train", "val", "test"]: + # 删除标签为 NaN 的行(仅在 train/val 上执行,test 集保留用于预测) + for split_name in ["train", "val"]: if split_name in split_data: y_series = split_data[split_name]["y"] y_nan_count = y_series.null_count()