Compare commits

...

2 Commits

2 changed files with 6 additions and 6 deletions

View File

@@ -191,9 +191,8 @@ class DataRouter:
Returns: Returns:
过滤后的 DataFrame 过滤后的 DataFrame
""" """
cache_key = ( cols_key = ",".join(sorted(spec.columns)) if spec.columns else "*"
f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}" cache_key = f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}_{cols_key}"
)
with self._lock: with self._lock:
if cache_key in self._cache: if cache_key in self._cache:
@@ -259,7 +258,8 @@ class DataRouter:
Returns: Returns:
过滤后的 DataFrame 过滤后的 DataFrame
""" """
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}" cols_key = ",".join(sorted(columns)) if columns else "*"
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}_{cols_key}"
with self._lock: with self._lock:
if cache_key in self._cache: if cache_key in self._cache:

View File

@@ -374,8 +374,8 @@ class DataPipeline:
split_data[split_name]["X"] = split_df.select(feature_cols) split_data[split_name]["X"] = split_df.select(feature_cols)
split_data[split_name]["y"] = split_df[label_name] split_data[split_name]["y"] = split_df[label_name]
# 删除标签为 NaN 的行 # 删除标签为 NaN 的行(仅在 train/val 上执行test 集保留用于预测)
for split_name in ["train", "val", "test"]: for split_name in ["train", "val"]:
if split_name in split_data: if split_name in split_data:
y_series = split_data[split_name]["y"] y_series = split_data[split_name]["y"]
y_nan_count = y_series.null_count() y_nan_count = y_series.null_count()