refactor: 代码审查修复 - 日期过滤、性能优化、数据泄露防护

- 修复 data_loader.py 财务数据日期过滤,支持按范围加载
- 优化 MADClipper 使用窗口函数替代 join,提升性能
- 修复训练日期边界问题,添加1天间隔避免数据泄露
- 新增 .gitignore 规则忽略训练输出目录
This commit is contained in:
2026-02-25 21:11:19 +08:00
parent 593ec99466
commit a9e4746239
24 changed files with 3597 additions and 56 deletions

View File

@@ -30,10 +30,12 @@ def prepare_data(
data_dir: str = "data",
train_start: str = "20180101",
train_end: str = "20230101",
test_start: str = "20230101",
val_start: str = "20230101",
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101",
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""准备训练和测试数据
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
"""准备训练、验证和测试数据
从DuckDB加载原始日线数据计算所需因子并生成标签。
@@ -41,11 +43,13 @@ def prepare_data(
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
Returns:
(train_data, test_data): 训练集和测试集的DataFrame
(train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame
"""
from src.data.storage import Storage
@@ -56,47 +60,56 @@ def prepare_data(
lookback_days = 20 # 足够计算MA10和5日收益率
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
# 查询训练集数据
# 查询全部数据包含train、val、test然后再拆分
# 注意DuckDB 中 trade_date 是 DATE 类型,需要转换
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
end_dt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
train_query = f"""
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
all_query = f"""
SELECT ts_code, trade_date, close, pre_close
FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
ORDER BY ts_code, trade_date
"""
train_raw = storage._connection.sql(train_query).pl()
all_raw = storage._connection.sql(all_query).pl()
# 转换 trade_date 为字符串格式
train_raw = train_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 查询测试集数据(也需要历史数据计算因子)
test_start_dt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
test_end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
test_query = f"""
SELECT ts_code, trade_date, close, pre_close
FROM daily
WHERE trade_date >= '{test_start_dt}' AND trade_date <= '{test_end_dt}'
ORDER BY ts_code, trade_date
"""
test_raw = storage._connection.sql(test_query).pl()
# 转换 trade_date 为字符串格式
test_raw = test_raw.with_columns(
all_raw = all_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 过滤不符合条件的股票
train_raw = _filter_invalid_stocks(train_raw)
test_raw = _filter_invalid_stocks(test_raw)
print(f"[PrepareData] After filtering: train={len(train_raw)}, test={len(test_raw)}")
all_raw = _filter_invalid_stocks(all_raw)
print(f"[PrepareData] After filtering: total={len(all_raw)}")
# 计算因子和标签
train_data = _compute_features_and_label(train_raw, train_start, train_end)
test_data = _compute_features_and_label(test_raw, test_start, test_end)
# 计算因子和标签(需要全局数据一次性计算)
all_data = _compute_features_and_label(
all_raw,
start_date=train_start,
end_date=test_end
)
return train_data, test_data
# 转换日期格式用于比较
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
val_start_fmt = f"{val_start[:4]}-{val_start[4:6]}-{val_start[6:8]}"
val_end_fmt = f"{val_end[:4]}-{val_end[4:6]}-{val_end[6:8]}"
test_start_fmt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
test_end_fmt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
# 拆分数据
train_data = all_data.filter(
(pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt)
)
val_data = all_data.filter(
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
)
test_data = all_data.filter(
(pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt)
)
print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}")
return train_data, val_data, test_data
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
@@ -254,6 +267,7 @@ def create_pipeline() -> ProcessingPipeline:
def train_model(
train_data: pl.DataFrame,
val_data: Optional[pl.DataFrame],
feature_cols: List[str],
label_col: str = "label",
model_params: Optional[dict] = None,
@@ -262,6 +276,7 @@ def train_model(
Args:
train_data: 训练数据
val_data: 验证数据(用于早停)
feature_cols: 特征列名列表
label_col: 标签列名
model_params: 模型参数字典
@@ -273,21 +288,39 @@ def train_model(
pipeline = create_pipeline()
print("[TrainModel] Pipeline created: FillNA(0)")
# 准备特征和标签
# 准备训练特征和标签
X_train = train_data.select(feature_cols)
y_train = train_data[label_col]
print(f"[TrainModel] Raw samples: {len(X_train)}, features: {feature_cols}")
print(f"[TrainModel] Train samples: {len(X_train)}, features: {feature_cols}")
# 处理数据
# 处理训练数据
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
# 过滤有效标签(排除-1等无效值
# 过滤训练集有效标签(排除-1等无效值
valid_mask = y_train.is_in([0, 1])
X_train_processed = X_train_processed.filter(valid_mask)
y_train = y_train.filter(valid_mask)
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
print(f"[TrainModel] Label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
# 准备验证集
X_val_processed = None
y_val = None
if val_data is not None and len(val_data) > 0:
X_val = val_data.select(feature_cols)
y_val = val_data[label_col]
print(f"[TrainModel] Val samples: {len(X_val)}")
# 处理验证集数据(使用训练集的参数)
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
# 过滤验证集有效标签
val_valid_mask = y_val.is_in([0, 1])
X_val_processed = X_val_processed.filter(val_valid_mask)
y_val = y_val.filter(val_valid_mask)
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}")
# 创建模型
params = model_params or {
@@ -302,9 +335,13 @@ def train_model(
params=params,
)
# 训练模型
# 训练模型(使用验证集早停)
print("[TrainModel] Training LightGBM...")
model.fit(X_train_processed, y_train)
if X_val_processed is not None and y_val is not None:
print("[TrainModel] Using validation set for early stopping")
model.fit(X_train_processed, y_train, X_val_processed, y_val)
else:
model.fit(X_train_processed, y_train)
print("[TrainModel] Training completed!")
return model, pipeline
@@ -382,7 +419,9 @@ def run_training(
output_path: str = "output/top_stocks.tsv",
train_start: str = "20180101",
train_end: str = "20230101",
test_start: str = "20230101",
val_start: str = "20230101",
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101",
top_n: int = 5,
) -> pl.DataFrame:
@@ -393,6 +432,8 @@ def run_training(
output_path: 输出文件路径
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
top_n: 每日选股数量
@@ -402,18 +443,22 @@ def run_training(
"""
print(f"[Training] Starting training pipeline...")
print(f"[Training] Train period: {train_start} -> {train_end}")
print(f"[Training] Val period: {val_start} -> {val_end}")
print(f"[Training] Test period: {test_start} -> {test_end}")
# 1. 准备数据
print("[Training] Preparing data...")
train_data, test_data = prepare_data(
train_data, val_data, test_data = prepare_data(
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
print(f"[Training] Train samples: {len(train_data)}")
print(f"[Training] Val samples: {len(val_data)}")
print(f"[Training] Test samples: {len(test_data)}")
# 2. 定义特征列
@@ -424,6 +469,7 @@ def run_training(
print("[Training] Training model...")
model, pipeline = train_model(
train_data=train_data,
val_data=val_data,
feature_cols=feature_cols,
label_col=label_col,
)