feat(training): 实现 DateSplitter 数据划分器
- 新增 DateSplitter 类,支持基于日期范围的一次性训练/测试划分 - 实现日期格式验证和日期范围逻辑检查 - 支持自定义日期列名参数 - 添加完整的单元测试(12个测试用例) - 在 components 模块导出 DateSplitter
This commit is contained in:
@@ -6,7 +6,11 @@
|
|||||||
# 基础抽象类
|
# 基础抽象类
|
||||||
from src.training.components.base import BaseModel, BaseProcessor
|
from src.training.components.base import BaseModel, BaseProcessor
|
||||||
|
|
||||||
|
# 数据划分器
|
||||||
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
"BaseProcessor",
|
"BaseProcessor",
|
||||||
|
"DateSplitter",
|
||||||
]
|
]
|
||||||
|
|||||||
122
src/training/components/splitters.py
Normal file
122
src/training/components/splitters.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""数据划分器
|
||||||
|
|
||||||
|
提供基于日期范围的数据划分功能,支持一次性训练/测试划分。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
class DateSplitter:
|
||||||
|
"""基于日期范围的一次性划分
|
||||||
|
|
||||||
|
将数据按日期划分为训练集和测试集,不滚动。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
train_start: "20200101", train_end: "20221231" (训练集:3年)
|
||||||
|
test_start: "20230101", test_end: "20231231" (测试集:1年)
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 一次性划分,不滚动
|
||||||
|
- 训练集和测试集互不重叠
|
||||||
|
- 基于实际日期范围,而非行数
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
train_start: 训练期开始日期,格式 "YYYYMMDD"
|
||||||
|
train_end: 训练期结束日期,格式 "YYYYMMDD"
|
||||||
|
test_start: 测试期开始日期,格式 "YYYYMMDD"
|
||||||
|
test_end: 测试期结束日期,格式 "YYYYMMDD"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train_start: str,
|
||||||
|
train_end: str,
|
||||||
|
test_start: str,
|
||||||
|
test_end: str,
|
||||||
|
):
|
||||||
|
"""初始化日期划分器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_start: 训练期开始日期 "YYYYMMDD"
|
||||||
|
train_end: 训练期结束日期 "YYYYMMDD"
|
||||||
|
test_start: 测试期开始日期 "YYYYMMDD"
|
||||||
|
test_end: 测试期结束日期 "YYYYMMDD"
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 日期格式错误或日期范围无效
|
||||||
|
"""
|
||||||
|
# 验证日期格式(简单的长度检查)
|
||||||
|
for name, value in [
|
||||||
|
("train_start", train_start),
|
||||||
|
("train_end", train_end),
|
||||||
|
("test_start", test_start),
|
||||||
|
("test_end", test_end),
|
||||||
|
]:
|
||||||
|
if not isinstance(value, str) or len(value) != 8:
|
||||||
|
raise ValueError(
|
||||||
|
f"{name} 必须是格式为 'YYYYMMDD' 的8位字符串,得到: {value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证日期范围逻辑
|
||||||
|
if train_start > train_end:
|
||||||
|
raise ValueError(
|
||||||
|
f"train_start ({train_start}) 必须早于或等于 train_end ({train_end})"
|
||||||
|
)
|
||||||
|
if test_start > test_end:
|
||||||
|
raise ValueError(
|
||||||
|
f"test_start ({test_start}) 必须早于或等于 test_end ({test_end})"
|
||||||
|
)
|
||||||
|
if test_start <= train_end:
|
||||||
|
raise ValueError(
|
||||||
|
f"测试集开始日期 ({test_start}) 必须晚于训练集结束日期 ({train_end}),"
|
||||||
|
"以确保训练集和测试集不重叠"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train_start = train_start
|
||||||
|
self.train_end = train_end
|
||||||
|
self.test_start = test_start
|
||||||
|
self.test_end = test_end
|
||||||
|
|
||||||
|
def split(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||||
|
"""划分数据为训练集和测试集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据,必须包含日期列
|
||||||
|
date_col: 日期列名,默认为 "trade_date"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(train_data, test_data) 元组
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 数据中不包含指定的日期列
|
||||||
|
"""
|
||||||
|
if date_col not in data.columns:
|
||||||
|
raise ValueError(f"数据中不包含列 '{date_col}',可用列: {data.columns}")
|
||||||
|
|
||||||
|
# 筛选训练集数据
|
||||||
|
train_data = data.filter(
|
||||||
|
(pl.col(date_col) >= self.train_start)
|
||||||
|
& (pl.col(date_col) <= self.train_end)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 筛选测试集数据
|
||||||
|
test_data = data.filter(
|
||||||
|
(pl.col(date_col) >= self.test_start) & (pl.col(date_col) <= self.test_end)
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_data, test_data
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""返回划分器的字符串表示"""
|
||||||
|
return (
|
||||||
|
f"DateSplitter("
|
||||||
|
f"train_start='{self.train_start}', "
|
||||||
|
f"train_end='{self.train_end}', "
|
||||||
|
f"test_start='{self.test_start}', "
|
||||||
|
f"test_end='{self.test_end}'"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
244
tests/training/test_splitters.py
Normal file
244
tests/training/test_splitters.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""测试 DateSplitter 数据划分器
|
||||||
|
|
||||||
|
验证一次性日期划分功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
|
|
||||||
|
class TestDateSplitter:
|
||||||
|
"""DateSplitter 测试类"""
|
||||||
|
|
||||||
|
def test_initialization_success(self):
|
||||||
|
"""测试正常初始化"""
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert splitter.train_start == "20200101"
|
||||||
|
assert splitter.train_end == "20221231"
|
||||||
|
assert splitter.test_start == "20230101"
|
||||||
|
assert splitter.test_end == "20231231"
|
||||||
|
|
||||||
|
def test_invalid_date_format(self):
|
||||||
|
"""测试无效的日期格式"""
|
||||||
|
with pytest.raises(ValueError, match="必须是格式为 'YYYYMMDD' 的8位字符串"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="2020-01-01", # 错误格式
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_train_start_after_train_end(self):
|
||||||
|
"""测试训练集开始日期晚于结束日期"""
|
||||||
|
with pytest.raises(ValueError, match="train_start.*必须早于或等于 train_end"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="20231231",
|
||||||
|
train_end="20200101",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_test_start_after_test_end(self):
|
||||||
|
"""测试测试集开始日期晚于结束日期"""
|
||||||
|
with pytest.raises(ValueError, match="test_start.*必须早于或等于 test_end"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20231231",
|
||||||
|
test_end="20230101",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_overlapping_dates(self):
|
||||||
|
"""测试训练集和测试集日期重叠"""
|
||||||
|
with pytest.raises(ValueError, match="必须晚于训练集结束日期"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20220601", # 在训练集范围内
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_split_success(self):
|
||||||
|
"""测试正常划分数据"""
|
||||||
|
# 创建测试数据
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": [
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ",
|
||||||
|
"000003.SZ",
|
||||||
|
"000004.SZ",
|
||||||
|
"000005.SZ",
|
||||||
|
"000006.SZ",
|
||||||
|
],
|
||||||
|
"trade_date": [
|
||||||
|
"20200101",
|
||||||
|
"20211231",
|
||||||
|
"20221231",
|
||||||
|
"20230101",
|
||||||
|
"20230601",
|
||||||
|
"20231231",
|
||||||
|
],
|
||||||
|
"value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
# 验证训练集
|
||||||
|
assert len(train_data) == 3
|
||||||
|
assert train_data["trade_date"].to_list() == [
|
||||||
|
"20200101",
|
||||||
|
"20211231",
|
||||||
|
"20221231",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 验证测试集
|
||||||
|
assert len(test_data) == 3
|
||||||
|
assert test_data["trade_date"].to_list() == ["20230101", "20230601", "20231231"]
|
||||||
|
|
||||||
|
def test_split_no_matching_train_data(self):
|
||||||
|
"""测试训练集无匹配数据"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20230101", "20231231"],
|
||||||
|
"value": [1.0, 2.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
# 训练集应该为空
|
||||||
|
assert len(train_data) == 0
|
||||||
|
# 测试集应该有数据
|
||||||
|
assert len(test_data) == 2
|
||||||
|
|
||||||
|
def test_split_no_matching_test_data(self):
|
||||||
|
"""测试测试集无匹配数据"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20200101", "20211231"],
|
||||||
|
"value": [1.0, 2.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
# 训练集应该有数据
|
||||||
|
assert len(train_data) == 2
|
||||||
|
# 测试集应该为空
|
||||||
|
assert len(test_data) == 0
|
||||||
|
|
||||||
|
def test_split_with_custom_date_col(self):
|
||||||
|
"""测试使用自定义日期列名"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"date": ["20200101", "20211231", "20230101"],
|
||||||
|
"value": [1.0, 2.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data, date_col="date")
|
||||||
|
|
||||||
|
assert len(train_data) == 2
|
||||||
|
assert len(test_data) == 1
|
||||||
|
|
||||||
|
def test_split_missing_date_column(self):
|
||||||
|
"""测试数据缺少日期列"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"value": [1.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="数据中不包含列 'trade_date'"):
|
||||||
|
splitter.split(data)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
"""测试 __repr__ 方法"""
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(splitter)
|
||||||
|
assert "DateSplitter" in repr_str
|
||||||
|
assert "20200101" in repr_str
|
||||||
|
assert "20221231" in repr_str
|
||||||
|
assert "20230101" in repr_str
|
||||||
|
assert "20231231" in repr_str
|
||||||
|
|
||||||
|
def test_edge_case_same_day_train(self):
|
||||||
|
"""测试训练集为单日"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20200101"],
|
||||||
|
"value": [1.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20200101",
|
||||||
|
test_start="20200102",
|
||||||
|
test_end="20200102",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
assert len(train_data) == 1
|
||||||
|
assert len(test_data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user