feat(training): 实现 DateSplitter 数据划分器

- 新增 DateSplitter 类,支持基于日期范围的一次性训练/测试划分
- 实现日期格式验证和日期范围逻辑检查
- 支持自定义日期列名参数
- 添加完整的单元测试(12个测试用例)
- 在 components 模块导出 DateSplitter
This commit is contained in:
2026-03-03 22:07:45 +08:00
parent 317ecd87e7
commit f48b307ad2
3 changed files with 370 additions and 0 deletions

View File

@@ -0,0 +1,244 @@
"""测试 DateSplitter 数据划分器
验证一次性日期划分功能。
"""
import pytest
import polars as pl
from src.training.components.splitters import DateSplitter
class TestDateSplitter:
"""DateSplitter 测试类"""
def test_initialization_success(self):
"""测试正常初始化"""
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
assert splitter.train_start == "20200101"
assert splitter.train_end == "20221231"
assert splitter.test_start == "20230101"
assert splitter.test_end == "20231231"
def test_invalid_date_format(self):
"""测试无效的日期格式"""
with pytest.raises(ValueError, match="必须是格式为 'YYYYMMDD' 的8位字符串"):
DateSplitter(
train_start="2020-01-01", # 错误格式
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
def test_train_start_after_train_end(self):
"""测试训练集开始日期晚于结束日期"""
with pytest.raises(ValueError, match="train_start.*必须早于或等于 train_end"):
DateSplitter(
train_start="20231231",
train_end="20200101",
test_start="20230101",
test_end="20231231",
)
def test_test_start_after_test_end(self):
"""测试测试集开始日期晚于结束日期"""
with pytest.raises(ValueError, match="test_start.*必须早于或等于 test_end"):
DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20231231",
test_end="20230101",
)
def test_overlapping_dates(self):
"""测试训练集和测试集日期重叠"""
with pytest.raises(ValueError, match="必须晚于训练集结束日期"):
DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20220601", # 在训练集范围内
test_end="20231231",
)
def test_split_success(self):
"""测试正常划分数据"""
# 创建测试数据
data = pl.DataFrame(
{
"ts_code": [
"000001.SZ",
"000002.SZ",
"000003.SZ",
"000004.SZ",
"000005.SZ",
"000006.SZ",
],
"trade_date": [
"20200101",
"20211231",
"20221231",
"20230101",
"20230601",
"20231231",
],
"value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data)
# 验证训练集
assert len(train_data) == 3
assert train_data["trade_date"].to_list() == [
"20200101",
"20211231",
"20221231",
]
# 验证测试集
assert len(test_data) == 3
assert test_data["trade_date"].to_list() == ["20230101", "20230601", "20231231"]
def test_split_no_matching_train_data(self):
"""测试训练集无匹配数据"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20230101", "20231231"],
"value": [1.0, 2.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data)
# 训练集应该为空
assert len(train_data) == 0
# 测试集应该有数据
assert len(test_data) == 2
def test_split_no_matching_test_data(self):
"""测试测试集无匹配数据"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20200101", "20211231"],
"value": [1.0, 2.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data)
# 训练集应该有数据
assert len(train_data) == 2
# 测试集应该为空
assert len(test_data) == 0
def test_split_with_custom_date_col(self):
"""测试使用自定义日期列名"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"date": ["20200101", "20211231", "20230101"],
"value": [1.0, 2.0, 3.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data, date_col="date")
assert len(train_data) == 2
assert len(test_data) == 1
def test_split_missing_date_column(self):
"""测试数据缺少日期列"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ"],
"value": [1.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
with pytest.raises(ValueError, match="数据中不包含列 'trade_date'"):
splitter.split(data)
def test_repr(self):
"""测试 __repr__ 方法"""
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
repr_str = repr(splitter)
assert "DateSplitter" in repr_str
assert "20200101" in repr_str
assert "20221231" in repr_str
assert "20230101" in repr_str
assert "20231231" in repr_str
def test_edge_case_same_day_train(self):
"""测试训练集为单日"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20200101"],
"value": [1.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20200101",
test_start="20200102",
test_end="20200102",
)
train_data, test_data = splitter.split(data)
assert len(train_data) == 1
assert len(test_data) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])