refactor: 代码审查修复 - 日期过滤、性能优化、数据泄露防护
- 修复 data_loader.py 财务数据日期过滤,支持按范围加载 - 优化 MADClipper 使用窗口函数替代 join,提升性能 - 修复训练日期边界问题,添加1天间隔避免数据泄露 - 新增 .gitignore 规则忽略训练输出目录
This commit is contained in:
@@ -30,10 +30,12 @@ def prepare_data(
|
||||
data_dir: str = "data",
|
||||
train_start: str = "20180101",
|
||||
train_end: str = "20230101",
|
||||
test_start: str = "20230101",
|
||||
val_start: str = "20230101",
|
||||
val_end: str = "20230601",
|
||||
test_start: str = "20230601",
|
||||
test_end: str = "20240101",
|
||||
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||
"""准备训练和测试数据
|
||||
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
||||
"""准备训练、验证和测试数据
|
||||
|
||||
从DuckDB加载原始日线数据,计算所需因子并生成标签。
|
||||
|
||||
@@ -41,11 +43,13 @@ def prepare_data(
|
||||
data_dir: 数据目录
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
|
||||
Returns:
|
||||
(train_data, test_data): 训练集和测试集的DataFrame
|
||||
(train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame
|
||||
"""
|
||||
from src.data.storage import Storage
|
||||
|
||||
@@ -56,47 +60,56 @@ def prepare_data(
|
||||
lookback_days = 20 # 足够计算MA10和5日收益率
|
||||
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
|
||||
|
||||
# 查询训练集数据
|
||||
# 查询全部数据(包含train、val、test),然后再拆分
|
||||
# 注意:DuckDB 中 trade_date 是 DATE 类型,需要转换
|
||||
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
|
||||
end_dt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
||||
train_query = f"""
|
||||
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
||||
|
||||
all_query = f"""
|
||||
SELECT ts_code, trade_date, close, pre_close
|
||||
FROM daily
|
||||
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
|
||||
ORDER BY ts_code, trade_date
|
||||
"""
|
||||
train_raw = storage._connection.sql(train_query).pl()
|
||||
all_raw = storage._connection.sql(all_query).pl()
|
||||
# 转换 trade_date 为字符串格式
|
||||
train_raw = train_raw.with_columns(
|
||||
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
||||
)
|
||||
|
||||
# 查询测试集数据(也需要历史数据计算因子)
|
||||
test_start_dt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
|
||||
test_end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
||||
test_query = f"""
|
||||
SELECT ts_code, trade_date, close, pre_close
|
||||
FROM daily
|
||||
WHERE trade_date >= '{test_start_dt}' AND trade_date <= '{test_end_dt}'
|
||||
ORDER BY ts_code, trade_date
|
||||
"""
|
||||
test_raw = storage._connection.sql(test_query).pl()
|
||||
# 转换 trade_date 为字符串格式
|
||||
test_raw = test_raw.with_columns(
|
||||
all_raw = all_raw.with_columns(
|
||||
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
||||
)
|
||||
|
||||
# 过滤不符合条件的股票
|
||||
train_raw = _filter_invalid_stocks(train_raw)
|
||||
test_raw = _filter_invalid_stocks(test_raw)
|
||||
print(f"[PrepareData] After filtering: train={len(train_raw)}, test={len(test_raw)}")
|
||||
all_raw = _filter_invalid_stocks(all_raw)
|
||||
print(f"[PrepareData] After filtering: total={len(all_raw)}")
|
||||
|
||||
# 计算因子和标签
|
||||
train_data = _compute_features_and_label(train_raw, train_start, train_end)
|
||||
test_data = _compute_features_and_label(test_raw, test_start, test_end)
|
||||
# 计算因子和标签(需要全局数据一次性计算)
|
||||
all_data = _compute_features_and_label(
|
||||
all_raw,
|
||||
start_date=train_start,
|
||||
end_date=test_end
|
||||
)
|
||||
|
||||
return train_data, test_data
|
||||
# 转换日期格式用于比较
|
||||
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
|
||||
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
||||
val_start_fmt = f"{val_start[:4]}-{val_start[4:6]}-{val_start[6:8]}"
|
||||
val_end_fmt = f"{val_end[:4]}-{val_end[4:6]}-{val_end[6:8]}"
|
||||
test_start_fmt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
|
||||
test_end_fmt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
||||
|
||||
# 拆分数据
|
||||
train_data = all_data.filter(
|
||||
(pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt)
|
||||
)
|
||||
val_data = all_data.filter(
|
||||
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
|
||||
)
|
||||
test_data = all_data.filter(
|
||||
(pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt)
|
||||
)
|
||||
|
||||
print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}")
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
||||
@@ -254,6 +267,7 @@ def create_pipeline() -> ProcessingPipeline:
|
||||
|
||||
def train_model(
|
||||
train_data: pl.DataFrame,
|
||||
val_data: Optional[pl.DataFrame],
|
||||
feature_cols: List[str],
|
||||
label_col: str = "label",
|
||||
model_params: Optional[dict] = None,
|
||||
@@ -262,6 +276,7 @@ def train_model(
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
val_data: 验证数据(用于早停)
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
model_params: 模型参数字典
|
||||
@@ -273,21 +288,39 @@ def train_model(
|
||||
pipeline = create_pipeline()
|
||||
print("[TrainModel] Pipeline created: FillNA(0)")
|
||||
|
||||
# 准备特征和标签
|
||||
# 准备训练特征和标签
|
||||
X_train = train_data.select(feature_cols)
|
||||
y_train = train_data[label_col]
|
||||
print(f"[TrainModel] Raw samples: {len(X_train)}, features: {feature_cols}")
|
||||
print(f"[TrainModel] Train samples: {len(X_train)}, features: {feature_cols}")
|
||||
|
||||
# 处理数据
|
||||
# 处理训练数据
|
||||
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
|
||||
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
|
||||
|
||||
# 过滤有效标签(排除-1等无效值)
|
||||
# 过滤训练集有效标签(排除-1等无效值)
|
||||
valid_mask = y_train.is_in([0, 1])
|
||||
X_train_processed = X_train_processed.filter(valid_mask)
|
||||
y_train = y_train.filter(valid_mask)
|
||||
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
|
||||
print(f"[TrainModel] Label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
|
||||
print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
|
||||
|
||||
# 准备验证集
|
||||
X_val_processed = None
|
||||
y_val = None
|
||||
if val_data is not None and len(val_data) > 0:
|
||||
X_val = val_data.select(feature_cols)
|
||||
y_val = val_data[label_col]
|
||||
print(f"[TrainModel] Val samples: {len(X_val)}")
|
||||
|
||||
# 处理验证集数据(使用训练集的参数)
|
||||
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
|
||||
|
||||
# 过滤验证集有效标签
|
||||
val_valid_mask = y_val.is_in([0, 1])
|
||||
X_val_processed = X_val_processed.filter(val_valid_mask)
|
||||
y_val = y_val.filter(val_valid_mask)
|
||||
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
|
||||
print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}")
|
||||
|
||||
# 创建模型
|
||||
params = model_params or {
|
||||
@@ -302,9 +335,13 @@ def train_model(
|
||||
params=params,
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
# 训练模型(使用验证集早停)
|
||||
print("[TrainModel] Training LightGBM...")
|
||||
model.fit(X_train_processed, y_train)
|
||||
if X_val_processed is not None and y_val is not None:
|
||||
print("[TrainModel] Using validation set for early stopping")
|
||||
model.fit(X_train_processed, y_train, X_val_processed, y_val)
|
||||
else:
|
||||
model.fit(X_train_processed, y_train)
|
||||
print("[TrainModel] Training completed!")
|
||||
|
||||
return model, pipeline
|
||||
@@ -382,7 +419,9 @@ def run_training(
|
||||
output_path: str = "output/top_stocks.tsv",
|
||||
train_start: str = "20180101",
|
||||
train_end: str = "20230101",
|
||||
test_start: str = "20230101",
|
||||
val_start: str = "20230101",
|
||||
val_end: str = "20230601",
|
||||
test_start: str = "20230601",
|
||||
test_end: str = "20240101",
|
||||
top_n: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
@@ -393,6 +432,8 @@ def run_training(
|
||||
output_path: 输出文件路径
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
top_n: 每日选股数量
|
||||
@@ -402,18 +443,22 @@ def run_training(
|
||||
"""
|
||||
print(f"[Training] Starting training pipeline...")
|
||||
print(f"[Training] Train period: {train_start} -> {train_end}")
|
||||
print(f"[Training] Val period: {val_start} -> {val_end}")
|
||||
print(f"[Training] Test period: {test_start} -> {test_end}")
|
||||
|
||||
# 1. 准备数据
|
||||
print("[Training] Preparing data...")
|
||||
train_data, test_data = prepare_data(
|
||||
train_data, val_data, test_data = prepare_data(
|
||||
data_dir=data_dir,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
val_start=val_start,
|
||||
val_end=val_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
)
|
||||
print(f"[Training] Train samples: {len(train_data)}")
|
||||
print(f"[Training] Val samples: {len(val_data)}")
|
||||
print(f"[Training] Test samples: {len(test_data)}")
|
||||
|
||||
# 2. 定义特征列
|
||||
@@ -424,6 +469,7 @@ def run_training(
|
||||
print("[Training] Training model...")
|
||||
model, pipeline = train_model(
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user