feat(training): 实现 DateSplitter 数据划分器

- 新增 DateSplitter 类,支持基于日期范围的一次性训练/测试划分
- 实现日期格式验证和日期范围逻辑检查
- 支持自定义日期列名参数
- 添加完整的单元测试(12个测试用例)
- 在 components 模块导出 DateSplitter
This commit is contained in:
2026-03-03 22:07:45 +08:00
parent 317ecd87e7
commit f48b307ad2
3 changed files with 370 additions and 0 deletions

View File

@@ -6,7 +6,11 @@
# 基础抽象类
from src.training.components.base import BaseModel, BaseProcessor
# 数据划分器
from src.training.components.splitters import DateSplitter
__all__ = [
"BaseModel",
"BaseProcessor",
"DateSplitter",
]

View File

@@ -0,0 +1,122 @@
"""数据划分器
提供基于日期范围的数据划分功能,支持一次性训练/测试划分。
"""
from typing import Tuple
import polars as pl
class DateSplitter:
"""基于日期范围的一次性划分
将数据按日期划分为训练集和测试集,不滚动。
示例:
train_start: "20200101", train_end: "20221231" (训练集3年)
test_start: "20230101", test_end: "20231231" (测试集1年)
特点:
- 一次性划分,不滚动
- 训练集和测试集互不重叠
- 基于实际日期范围,而非行数
Attributes:
train_start: 训练期开始日期,格式 "YYYYMMDD"
train_end: 训练期结束日期,格式 "YYYYMMDD"
test_start: 测试期开始日期,格式 "YYYYMMDD"
test_end: 测试期结束日期,格式 "YYYYMMDD"
"""
def __init__(
self,
train_start: str,
train_end: str,
test_start: str,
test_end: str,
):
"""初始化日期划分器
Args:
train_start: 训练期开始日期 "YYYYMMDD"
train_end: 训练期结束日期 "YYYYMMDD"
test_start: 测试期开始日期 "YYYYMMDD"
test_end: 测试期结束日期 "YYYYMMDD"
Raises:
ValueError: 日期格式错误或日期范围无效
"""
# 验证日期格式(简单的长度检查)
for name, value in [
("train_start", train_start),
("train_end", train_end),
("test_start", test_start),
("test_end", test_end),
]:
if not isinstance(value, str) or len(value) != 8:
raise ValueError(
f"{name} 必须是格式为 'YYYYMMDD' 的8位字符串得到: {value}"
)
# 验证日期范围逻辑
if train_start > train_end:
raise ValueError(
f"train_start ({train_start}) 必须早于或等于 train_end ({train_end})"
)
if test_start > test_end:
raise ValueError(
f"test_start ({test_start}) 必须早于或等于 test_end ({test_end})"
)
if test_start <= train_end:
raise ValueError(
f"测试集开始日期 ({test_start}) 必须晚于训练集结束日期 ({train_end})"
"以确保训练集和测试集不重叠"
)
self.train_start = train_start
self.train_end = train_end
self.test_start = test_start
self.test_end = test_end
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""划分数据为训练集和测试集
Args:
data: 输入数据,必须包含日期列
date_col: 日期列名,默认为 "trade_date"
Returns:
(train_data, test_data) 元组
Raises:
ValueError: 数据中不包含指定的日期列
"""
if date_col not in data.columns:
raise ValueError(f"数据中不包含列 '{date_col}',可用列: {data.columns}")
# 筛选训练集数据
train_data = data.filter(
(pl.col(date_col) >= self.train_start)
& (pl.col(date_col) <= self.train_end)
)
# 筛选测试集数据
test_data = data.filter(
(pl.col(date_col) >= self.test_start) & (pl.col(date_col) <= self.test_end)
)
return train_data, test_data
def __repr__(self) -> str:
"""返回划分器的字符串表示"""
return (
f"DateSplitter("
f"train_start='{self.train_start}', "
f"train_end='{self.train_end}', "
f"test_start='{self.test_start}', "
f"test_end='{self.test_end}'"
f")"
)