Compare commits
2 Commits
1fa4ff9544
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 161b7cc690 | |||
| ad8ba8f6ec |
@@ -191,9 +191,8 @@ class DataRouter:
|
|||||||
Returns:
|
Returns:
|
||||||
过滤后的 DataFrame
|
过滤后的 DataFrame
|
||||||
"""
|
"""
|
||||||
cache_key = (
|
cols_key = ",".join(sorted(spec.columns)) if spec.columns else "*"
|
||||||
f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}"
|
cache_key = f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}_{cols_key}"
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if cache_key in self._cache:
|
if cache_key in self._cache:
|
||||||
@@ -259,7 +258,8 @@ class DataRouter:
|
|||||||
Returns:
|
Returns:
|
||||||
过滤后的 DataFrame
|
过滤后的 DataFrame
|
||||||
"""
|
"""
|
||||||
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
|
cols_key = ",".join(sorted(columns)) if columns else "*"
|
||||||
|
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}_{cols_key}"
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if cache_key in self._cache:
|
if cache_key in self._cache:
|
||||||
|
|||||||
@@ -374,8 +374,8 @@ class DataPipeline:
|
|||||||
split_data[split_name]["X"] = split_df.select(feature_cols)
|
split_data[split_name]["X"] = split_df.select(feature_cols)
|
||||||
split_data[split_name]["y"] = split_df[label_name]
|
split_data[split_name]["y"] = split_df[label_name]
|
||||||
|
|
||||||
# 删除标签为 NaN 的行
|
# 删除标签为 NaN 的行(仅在 train/val 上执行,test 集保留用于预测)
|
||||||
for split_name in ["train", "val", "test"]:
|
for split_name in ["train", "val"]:
|
||||||
if split_name in split_data:
|
if split_name in split_data:
|
||||||
y_series = split_data[split_name]["y"]
|
y_series = split_data[split_name]["y"]
|
||||||
y_nan_count = y_series.null_count()
|
y_nan_count = y_series.null_count()
|
||||||
|
|||||||
Reference in New Issue
Block a user