feat(training): 实现 train/val/test 三分法并添加训练指标可视化

- DateSplitter 支持三分法划分,修复 test 数据泄露问题
- 添加训练指标曲线绘制和100轮早停
This commit is contained in:
2026-03-08 01:09:47 +08:00
parent 85044a74c6
commit 592126c376
2 changed files with 551 additions and 226 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1,30 +1,39 @@
"""数据划分器 """数据划分器
提供基于日期范围的数据划分功能,支持一次性训练/测试划分 提供基于日期范围的数据划分功能,支持 train/val/test 三分法
""" """
from typing import Tuple from typing import Tuple, Optional
import polars as pl import polars as pl
class DateSplitter: class DateSplitter:
"""基于日期范围的一次性划分 """基于日期范围的一次性划分(支持 train/val/test 三分法)
将数据按日期划分为训练集和测试集,不滚动。 将数据按日期划分为训练集、验证集和测试集,不滚动。
正确的三分法:
- Train: 用于训练模型参数
- Val: 用于验证/早停/调参(从训练时间后切出)
- Test: 仅用于最终评估,完全独立于训练过程
示例: 示例:
train_start: "20200101", train_end: "20221231" (训练集:3年) train_start: "20200101", train_end: "20211231" (训练集:2年)
val_start: "20220101", val_end: "20221231" (验证集1年)
test_start: "20230101", test_end: "20231231" (测试集1年) test_start: "20230101", test_end: "20231231" (测试集1年)
特点: 特点:
- 一次性划分,不滚动 - 一次性划分,不滚动
- 训练集测试集互不重叠 - 训练集、验证集、测试集三者互不重叠
- 验证集和测试集按时间顺序位于训练集之后
- 基于实际日期范围,而非行数 - 基于实际日期范围,而非行数
Attributes: Attributes:
train_start: 训练期开始日期,格式 "YYYYMMDD" train_start: 训练期开始日期,格式 "YYYYMMDD"
train_end: 训练期结束日期,格式 "YYYYMMDD" train_end: 训练期结束日期,格式 "YYYYMMDD"
val_start: 验证期开始日期,格式 "YYYYMMDD"(可选)
val_end: 验证期结束日期,格式 "YYYYMMDD"(可选)
test_start: 测试期开始日期,格式 "YYYYMMDD" test_start: 测试期开始日期,格式 "YYYYMMDD"
test_end: 测试期结束日期,格式 "YYYYMMDD" test_end: 测试期结束日期,格式 "YYYYMMDD"
""" """
@@ -35,6 +44,8 @@ class DateSplitter:
train_end: str, train_end: str,
test_start: str, test_start: str,
test_end: str, test_end: str,
val_start: Optional[str] = None,
val_end: Optional[str] = None,
): ):
"""初始化日期划分器 """初始化日期划分器
@@ -43,17 +54,31 @@ class DateSplitter:
train_end: 训练期结束日期 "YYYYMMDD" train_end: 训练期结束日期 "YYYYMMDD"
test_start: 测试期开始日期 "YYYYMMDD" test_start: 测试期开始日期 "YYYYMMDD"
test_end: 测试期结束日期 "YYYYMMDD" test_end: 测试期结束日期 "YYYYMMDD"
val_start: 验证期开始日期 "YYYYMMDD"(可选,如果不提供则从 train 中划分)
val_end: 验证期结束日期 "YYYYMMDD"(可选,如果不提供则从 train 中划分)
Raises: Raises:
ValueError: 日期格式错误或日期范围无效 ValueError: 日期格式错误或日期范围无效
Note:
正确的三分法:
- Train: 用于训练模型参数
- Val: 用于验证/早停/调参(必须位于 train 之后、test 之前)
- Test: 仅用于最终评估,完全独立于训练过程
""" """
# 验证日期格式(简单的长度检查) # 验证日期格式(简单的长度检查)
for name, value in [ dates_to_check = [
("train_start", train_start), ("train_start", train_start),
("train_end", train_end), ("train_end", train_end),
("test_start", test_start), ("test_start", test_start),
("test_end", test_end), ("test_end", test_end),
]: ]
if val_start is not None:
dates_to_check.append(("val_start", val_start))
if val_end is not None:
dates_to_check.append(("val_end", val_end))
for name, value in dates_to_check:
if not isinstance(value, str) or len(value) != 8: if not isinstance(value, str) or len(value) != 8:
raise ValueError( raise ValueError(
f"{name} 必须是格式为 'YYYYMMDD' 的8位字符串得到: {value}" f"{name} 必须是格式为 'YYYYMMDD' 的8位字符串得到: {value}"
@@ -68,31 +93,83 @@ class DateSplitter:
raise ValueError( raise ValueError(
f"test_start ({test_start}) 必须早于或等于 test_end ({test_end})" f"test_start ({test_start}) 必须早于或等于 test_end ({test_end})"
) )
if test_start <= train_end:
# 验证 val 日期(如果提供了)
if val_start is not None and val_end is not None:
if val_start > val_end:
raise ValueError( raise ValueError(
f"测试集开始日期 ({test_start}) 必须晚于训练集结束日期 ({train_end})" f"val_start ({val_start}) 必须早于或等于 val_end ({val_end})"
"以确保训练集和测试集不重叠"
) )
if val_start <= train_end:
raise ValueError(
f"验证集开始日期 ({val_start}) 必须晚于训练集结束日期 ({train_end})"
"以确保验证集在训练集之后"
)
if test_start <= val_end:
raise ValueError(
f"测试集开始日期 ({test_start}) 必须晚于验证集结束日期 ({val_end})"
"以确保测试集在验证集之后"
)
elif val_start is not None or val_end is not None:
raise ValueError("val_start 和 val_end 必须同时提供或同时省略")
# 如果没有提供 val 日期,自动从 train 后划分一段作为 val
# 默认取 train 结束后的 20% 时间作为 val但必须确保在 test 之前
if val_start is None:
# 计算 train 时间跨度(天数近似)
from datetime import datetime
train_start_dt = datetime.strptime(train_start, "%Y%m%d")
train_end_dt = datetime.strptime(train_end, "%Y%m%d")
test_start_dt = datetime.strptime(test_start, "%Y%m%d")
train_days = (train_end_dt - train_start_dt).days
val_duration = max(int(train_days * 0.2), 30) # 至少30天
val_start_dt = train_end_dt + __import__("datetime").timedelta(days=1)
val_end_dt = val_start_dt + __import__("datetime").timedelta(
days=val_duration
)
# 确保 val 在 test 之前
if val_end_dt >= test_start_dt:
# 取 train 和 test 之间的中点
gap_days = (test_start_dt - train_end_dt).days
val_end_dt = train_end_dt + __import__("datetime").timedelta(
days=gap_days // 2
)
val_start_dt = train_end_dt + __import__("datetime").timedelta(days=1)
val_start = val_start_dt.strftime("%Y%m%d")
val_end = min(val_end_dt.strftime("%Y%m%d"), test_start)
self.train_start = train_start self.train_start = train_start
self.train_end = train_end self.train_end = train_end
self.val_start = val_start
self.val_end = val_end
self.test_start = test_start self.test_start = test_start
self.test_end = test_end self.test_end = test_end
def split( def split(
self, data: pl.DataFrame, date_col: str = "trade_date" self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Tuple[pl.DataFrame, pl.DataFrame]: ) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
"""划分数据为训练集和测试集 """划分数据为训练集、验证集和测试集
Args: Args:
data: 输入数据,必须包含日期列 data: 输入数据,必须包含日期列
date_col: 日期列名,默认为 "trade_date" date_col: 日期列名,默认为 "trade_date"
Returns: Returns:
(train_data, test_data) 元组 (train_data, val_data, test_data) 元组
Raises: Raises:
ValueError: 数据中不包含指定的日期列 ValueError: 数据中不包含指定的日期列
Note:
正确的三分法:
- train_data: 用于训练模型参数
- val_data: 用于验证/早停/调参
- test_data: 仅用于最终评估,完全独立于训练过程
""" """
if date_col not in data.columns: if date_col not in data.columns:
raise ValueError(f"数据中不包含列 '{date_col}',可用列: {data.columns}") raise ValueError(f"数据中不包含列 '{date_col}',可用列: {data.columns}")
@@ -103,20 +180,43 @@ class DateSplitter:
& (pl.col(date_col) <= self.train_end) & (pl.col(date_col) <= self.train_end)
) )
# 筛选验证集数据
val_data = data.filter(
(pl.col(date_col) >= self.val_start) & (pl.col(date_col) <= self.val_end)
)
# 筛选测试集数据 # 筛选测试集数据
test_data = data.filter( test_data = data.filter(
(pl.col(date_col) >= self.test_start) & (pl.col(date_col) <= self.test_end) (pl.col(date_col) >= self.test_start) & (pl.col(date_col) <= self.test_end)
) )
return train_data, test_data return train_data, val_data, test_data
def split_train_test(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""划分数据为训练集和测试集(验证集合并到训练集)
适用于不需要验证集的场景,或者使用交叉验证的场景。
Args:
data: 输入数据,必须包含日期列
date_col: 日期列名,默认为 "trade_date"
Returns:
(train_val_data, test_data) 元组,其中 train_val_data 包含 train + val
"""
train_data, val_data, test_data = self.split(data, date_col)
# 合并 train 和 val
train_val_data = pl.concat([train_data, val_data])
return train_val_data, test_data
def __repr__(self) -> str: def __repr__(self) -> str:
"""返回划分器的字符串表示""" """返回划分器的字符串表示"""
return ( return (
f"DateSplitter(" f"DateSplitter("
f"train_start='{self.train_start}', " f"train='{self.train_start}-{self.train_end}', "
f"train_end='{self.train_end}', " f"val='{self.val_start}-{self.val_end}', "
f"test_start='{self.test_start}', " f"test='{self.test_start}-{self.test_end}'"
f"test_end='{self.test_end}'"
f")" f")"
) )