From ad8ba8f6ec15967883276e465502b85f4fe84614 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Sun, 5 Apr 2026 23:24:22 +0800 Subject: [PATCH] =?UTF-8?q?fix(training):=20=E4=BF=9D=E7=95=99=20test=20?= =?UTF-8?q?=E9=9B=86=E4=B8=AD=E6=A0=87=E7=AD=BE=E4=B8=BA=20NaN=20=E7=9A=84?= =?UTF-8?q?=E6=A0=B7=E6=9C=AC=E7=94=A8=E4=BA=8E=E9=A2=84=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/training/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/training/pipeline.py b/src/training/pipeline.py index cb23715..ec10e67 100644 --- a/src/training/pipeline.py +++ b/src/training/pipeline.py @@ -374,8 +374,8 @@ class DataPipeline: split_data[split_name]["X"] = split_df.select(feature_cols) split_data[split_name]["y"] = split_df[label_name] - # 删除标签为 NaN 的行 - for split_name in ["train", "val", "test"]: + # 删除标签为 NaN 的行(仅在 train/val 上执行,test 集保留用于预测) + for split_name in ["train", "val"]: if split_name in split_data: y_series = split_data[split_name]["y"] y_nan_count = y_series.null_count()