"""测试 DateSplitter 数据划分器 验证一次性日期划分功能。 """ import pytest import polars as pl from src.training.components.splitters import DateSplitter class TestDateSplitter: """DateSplitter 测试类""" def test_initialization_success(self): """测试正常初始化""" splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", ) assert splitter.train_start == "20200101" assert splitter.train_end == "20221231" assert splitter.test_start == "20230101" assert splitter.test_end == "20231231" def test_invalid_date_format(self): """测试无效的日期格式""" with pytest.raises(ValueError, match="必须是格式为 'YYYYMMDD' 的8位字符串"): DateSplitter( train_start="2020-01-01", # 错误格式 train_end="20221231", test_start="20230101", test_end="20231231", ) def test_train_start_after_train_end(self): """测试训练集开始日期晚于结束日期""" with pytest.raises(ValueError, match="train_start.*必须早于或等于 train_end"): DateSplitter( train_start="20231231", train_end="20200101", test_start="20230101", test_end="20231231", ) def test_test_start_after_test_end(self): """测试测试集开始日期晚于结束日期""" with pytest.raises(ValueError, match="test_start.*必须早于或等于 test_end"): DateSplitter( train_start="20200101", train_end="20221231", test_start="20231231", test_end="20230101", ) def test_overlapping_dates(self): """测试训练集和测试集日期重叠""" with pytest.raises(ValueError, match="必须晚于训练集结束日期"): DateSplitter( train_start="20200101", train_end="20221231", test_start="20220601", # 在训练集范围内 test_end="20231231", ) def test_split_success(self): """测试正常划分数据""" # 创建测试数据 data = pl.DataFrame( { "ts_code": [ "000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ", "000005.SZ", "000006.SZ", ], "trade_date": [ "20200101", "20211231", "20221231", "20230101", "20230601", "20231231", ], "value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], } ) splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", ) train_data, test_data = splitter.split(data) # 验证训练集 assert len(train_data) == 3 assert train_data["trade_date"].to_list() == [ "20200101", "20211231", "20221231", ] # 验证测试集 assert len(test_data) == 3 assert test_data["trade_date"].to_list() == ["20230101", "20230601", "20231231"] def test_split_no_matching_train_data(self): """测试训练集无匹配数据""" data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20230101", "20231231"], "value": [1.0, 2.0], } ) splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", ) train_data, test_data = splitter.split(data) # 训练集应该为空 assert len(train_data) == 0 # 测试集应该有数据 assert len(test_data) == 2 def test_split_no_matching_test_data(self): """测试测试集无匹配数据""" data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20200101", "20211231"], "value": [1.0, 2.0], } ) splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", ) train_data, test_data = splitter.split(data) # 训练集应该有数据 assert len(train_data) == 2 # 测试集应该为空 assert len(test_data) == 0 def test_split_with_custom_date_col(self): """测试使用自定义日期列名""" data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "date": ["20200101", "20211231", "20230101"], "value": [1.0, 2.0, 3.0], } ) splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", ) train_data, test_data = splitter.split(data, date_col="date") assert len(train_data) == 2 assert len(test_data) == 1 def test_split_missing_date_column(self): """测试数据缺少日期列""" data = pl.DataFrame( { "ts_code": ["000001.SZ"], "value": [1.0], } ) splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", ) with pytest.raises(ValueError, match="数据中不包含列 'trade_date'"): splitter.split(data) def test_repr(self): """测试 __repr__ 方法""" splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", ) repr_str = repr(splitter) assert "DateSplitter" in repr_str assert "20200101" in repr_str assert "20221231" in repr_str assert "20230101" in repr_str assert "20231231" in repr_str def test_edge_case_same_day_train(self): """测试训练集为单日""" data = pl.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20200101"], "value": [1.0], } ) splitter = DateSplitter( train_start="20200101", train_end="20200101", test_start="20200102", test_end="20200102", ) train_data, test_data = splitter.split(data) assert len(train_data) == 1 assert len(test_data) == 0 if __name__ == "__main__": pytest.main([__file__, "-v"])