From f48b307ad2accb829d3c743e5cfcd6053da11ac5 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Tue, 3 Mar 2026 22:07:45 +0800 Subject: [PATCH] =?UTF-8?q?feat(training):=20=E5=AE=9E=E7=8E=B0=20DateSpli?= =?UTF-8?q?tter=20=E6=95=B0=E6=8D=AE=E5=88=92=E5=88=86=E5=99=A8=20-=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20DateSplitter=20=E7=B1=BB=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=9F=BA=E4=BA=8E=E6=97=A5=E6=9C=9F=E8=8C=83=E5=9B=B4?= =?UTF-8?q?=E7=9A=84=E4=B8=80=E6=AC=A1=E6=80=A7=E8=AE=AD=E7=BB=83/?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E5=88=92=E5=88=86=20-=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=97=A5=E6=9C=9F=E6=A0=BC=E5=BC=8F=E9=AA=8C=E8=AF=81=E5=92=8C?= =?UTF-8?q?=E6=97=A5=E6=9C=9F=E8=8C=83=E5=9B=B4=E9=80=BB=E8=BE=91=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=20-=20=E6=94=AF=E6=8C=81=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E6=97=A5=E6=9C=9F=E5=88=97=E5=90=8D=E5=8F=82=E6=95=B0=20-=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AE=8C=E6=95=B4=E7=9A=84=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=EF=BC=8812=E4=B8=AA=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=EF=BC=89=20-=20=E5=9C=A8=20components=20?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E5=AF=BC=E5=87=BA=20DateSplitter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/training/components/__init__.py | 4 + src/training/components/splitters.py | 122 ++++++++++++++ tests/training/test_splitters.py | 244 +++++++++++++++++++++++++++ 3 files changed, 370 insertions(+) create mode 100644 src/training/components/splitters.py create mode 100644 tests/training/test_splitters.py diff --git a/src/training/components/__init__.py b/src/training/components/__init__.py index 96f3916..762d8d2 100644 --- a/src/training/components/__init__.py +++ b/src/training/components/__init__.py @@ -6,7 +6,11 @@ # 基础抽象类 from src.training.components.base import BaseModel, BaseProcessor +# 数据划分器 +from src.training.components.splitters import DateSplitter + __all__ = [ "BaseModel", "BaseProcessor", + "DateSplitter", ] diff --git a/src/training/components/splitters.py b/src/training/components/splitters.py new file mode 100644 index 0000000..5d29969 --- /dev/null +++ b/src/training/components/splitters.py @@ -0,0 +1,122 @@ +"""数据划分器 + +提供基于日期范围的数据划分功能,支持一次性训练/测试划分。 +""" + +from typing import Tuple + +import polars as pl + + +class DateSplitter: + """基于日期范围的一次性划分 + + 将数据按日期划分为训练集和测试集,不滚动。 + + 示例: + train_start: "20200101", train_end: "20221231" (训练集:3年) + test_start: "20230101", test_end: "20231231" (测试集:1年) + + 特点: + - 一次性划分,不滚动 + - 训练集和测试集互不重叠 + - 基于实际日期范围,而非行数 + + Attributes: + train_start: 训练期开始日期,格式 "YYYYMMDD" + train_end: 训练期结束日期,格式 "YYYYMMDD" + test_start: 测试期开始日期,格式 "YYYYMMDD" + test_end: 测试期结束日期,格式 "YYYYMMDD" + """ + + def __init__( + self, + train_start: str, + train_end: str, + test_start: str, + test_end: str, + ): + """初始化日期划分器 + + Args: + train_start: 训练期开始日期 "YYYYMMDD" + train_end: 训练期结束日期 "YYYYMMDD" + test_start: 测试期开始日期 "YYYYMMDD" + test_end: 测试期结束日期 "YYYYMMDD" + + Raises: + ValueError: 日期格式错误或日期范围无效 + """ + # 验证日期格式(简单的长度检查) + for name, value in [ + ("train_start", train_start), + ("train_end", train_end), + ("test_start", test_start), + ("test_end", test_end), + ]: + if not isinstance(value, str) or len(value) != 8: + raise ValueError( + f"{name} 必须是格式为 'YYYYMMDD' 的8位字符串,得到: {value}" + ) + + # 验证日期范围逻辑 + if train_start > train_end: + raise ValueError( + f"train_start ({train_start}) 必须早于或等于 train_end ({train_end})" + ) + if test_start > test_end: + raise ValueError( + f"test_start ({test_start}) 必须早于或等于 test_end ({test_end})" + ) + if test_start <= train_end: + raise ValueError( + f"测试集开始日期 ({test_start}) 必须晚于训练集结束日期 ({train_end})," + "以确保训练集和测试集不重叠" + ) + + self.train_start = train_start + self.train_end = train_end + self.test_start = test_start + self.test_end = test_end + + def split( + self, data: pl.DataFrame, date_col: str = "trade_date" + ) -> Tuple[pl.DataFrame, pl.DataFrame]: + """划分数据为训练集和测试集 + + Args: + data: 输入数据,必须包含日期列 + date_col: 日期列名,默认为 "trade_date" + + Returns: + (train_data, test_data) 元组 + + Raises: + ValueError: 数据中不包含指定的日期列 + """ + if date_col not in data.columns: + raise ValueError(f"数据中不包含列 '{date_col}',可用列: {data.columns}") + + # 筛选训练集数据 + train_data = data.filter( + (pl.col(date_col) >= self.train_start) + & (pl.col(date_col) <= self.train_end) + ) + + # 筛选测试集数据 + test_data = data.filter( + (pl.col(date_col) >= self.test_start) & (pl.col(date_col) <= self.test_end) + ) + + return train_data, test_data + + def __repr__(self) -> str: + """返回划分器的字符串表示""" + return ( + f"DateSplitter(" + f"train_start='{self.train_start}', " + f"train_end='{self.train_end}', " + f"test_start='{self.test_start}', " + f"test_end='{self.test_end}'" + f")" + ) diff --git a/tests/training/test_splitters.py b/tests/training/test_splitters.py new file mode 100644 index 0000000..09c72c0 --- /dev/null +++ b/tests/training/test_splitters.py @@ -0,0 +1,244 @@ +"""测试 DateSplitter 数据划分器 + +验证一次性日期划分功能。 +""" + +import pytest +import polars as pl + +from src.training.components.splitters import DateSplitter + + +class TestDateSplitter: + """DateSplitter 测试类""" + + def test_initialization_success(self): + """测试正常初始化""" + splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + assert splitter.train_start == "20200101" + assert splitter.train_end == "20221231" + assert splitter.test_start == "20230101" + assert splitter.test_end == "20231231" + + def test_invalid_date_format(self): + """测试无效的日期格式""" + with pytest.raises(ValueError, match="必须是格式为 'YYYYMMDD' 的8位字符串"): + DateSplitter( + train_start="2020-01-01", # 错误格式 + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + def test_train_start_after_train_end(self): + """测试训练集开始日期晚于结束日期""" + with pytest.raises(ValueError, match="train_start.*必须早于或等于 train_end"): + DateSplitter( + train_start="20231231", + train_end="20200101", + test_start="20230101", + test_end="20231231", + ) + + def test_test_start_after_test_end(self): + """测试测试集开始日期晚于结束日期""" + with pytest.raises(ValueError, match="test_start.*必须早于或等于 test_end"): + DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20231231", + test_end="20230101", + ) + + def test_overlapping_dates(self): + """测试训练集和测试集日期重叠""" + with pytest.raises(ValueError, match="必须晚于训练集结束日期"): + DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20220601", # 在训练集范围内 + test_end="20231231", + ) + + def test_split_success(self): + """测试正常划分数据""" + # 创建测试数据 + data = pl.DataFrame( + { + "ts_code": [ + "000001.SZ", + "000002.SZ", + "000003.SZ", + "000004.SZ", + "000005.SZ", + "000006.SZ", + ], + "trade_date": [ + "20200101", + "20211231", + "20221231", + "20230101", + "20230601", + "20231231", + ], + "value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + } + ) + + splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + train_data, test_data = splitter.split(data) + + # 验证训练集 + assert len(train_data) == 3 + assert train_data["trade_date"].to_list() == [ + "20200101", + "20211231", + "20221231", + ] + + # 验证测试集 + assert len(test_data) == 3 + assert test_data["trade_date"].to_list() == ["20230101", "20230601", "20231231"] + + def test_split_no_matching_train_data(self): + """测试训练集无匹配数据""" + data = pl.DataFrame( + { + "ts_code": ["000001.SZ", "000002.SZ"], + "trade_date": ["20230101", "20231231"], + "value": [1.0, 2.0], + } + ) + + splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + train_data, test_data = splitter.split(data) + + # 训练集应该为空 + assert len(train_data) == 0 + # 测试集应该有数据 + assert len(test_data) == 2 + + def test_split_no_matching_test_data(self): + """测试测试集无匹配数据""" + data = pl.DataFrame( + { + "ts_code": ["000001.SZ", "000002.SZ"], + "trade_date": ["20200101", "20211231"], + "value": [1.0, 2.0], + } + ) + + splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + train_data, test_data = splitter.split(data) + + # 训练集应该有数据 + assert len(train_data) == 2 + # 测试集应该为空 + assert len(test_data) == 0 + + def test_split_with_custom_date_col(self): + """测试使用自定义日期列名""" + data = pl.DataFrame( + { + "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], + "date": ["20200101", "20211231", "20230101"], + "value": [1.0, 2.0, 3.0], + } + ) + + splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + train_data, test_data = splitter.split(data, date_col="date") + + assert len(train_data) == 2 + assert len(test_data) == 1 + + def test_split_missing_date_column(self): + """测试数据缺少日期列""" + data = pl.DataFrame( + { + "ts_code": ["000001.SZ"], + "value": [1.0], + } + ) + + splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + with pytest.raises(ValueError, match="数据中不包含列 'trade_date'"): + splitter.split(data) + + def test_repr(self): + """测试 __repr__ 方法""" + splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + ) + + repr_str = repr(splitter) + assert "DateSplitter" in repr_str + assert "20200101" in repr_str + assert "20221231" in repr_str + assert "20230101" in repr_str + assert "20231231" in repr_str + + def test_edge_case_same_day_train(self): + """测试训练集为单日""" + data = pl.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20200101"], + "value": [1.0], + } + ) + + splitter = DateSplitter( + train_start="20200101", + train_end="20200101", + test_start="20200102", + test_end="20200102", + ) + + train_data, test_data = splitter.split(data) + + assert len(train_data) == 1 + assert len(test_data) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])