"""截面数据采样器单元测试""" import numpy as np import pytest import torch from torch.utils.data import TensorDataset from src.training.components.models.cross_section_sampler import CrossSectionSampler class TestCrossSectionSampler: """截面采样器单元测试""" def test_basic_functionality(self): """测试基本功能:按日期分组""" dates = np.array(["20240101", "20240101", "20240102", "20240102", "20240103"]) sampler = CrossSectionSampler(dates, shuffle_days=False) # 应该有3个日期批次 assert len(sampler) == 3 # 获取所有批次 batches = list(sampler) # 验证每个批次包含同一天的数据 for batch in batches: batch_dates = [dates[i] for i in batch] assert len(set(batch_dates)) == 1, "批次内日期不一致" def test_shuffle_days(self): """测试日期打乱功能""" np.random.seed(42) dates = np.array(["20240101"] * 5 + ["20240102"] * 5 + ["20240103"] * 5) # 多次采样,验证日期顺序会变化 orders = [] for _ in range(10): batches = list(CrossSectionSampler(dates, shuffle_days=True)) date_order = [dates[batch[0]] for batch in batches] orders.append(tuple(date_order)) # 应该有不同的顺序出现 assert len(set(orders)) > 1, "日期顺序未被打乱" def test_internal_shuffle(self): """测试截面内股票顺序打乱""" np.random.seed(42) dates = np.array(["20240101"] * 10) # 多次获取同一批次 indices_list = [] for _ in range(5): sampler = CrossSectionSampler(dates, shuffle_days=False) batch = next(iter(sampler)) indices_list.append(list(batch)) # 应该有不同顺序 assert len(set(tuple(x) for x in indices_list)) > 1, "截面内顺序未被打乱" def test_with_dataloader(self): """测试与 DataLoader 集成""" dates = np.array(["20240101", "20240101", "20240102", "20240102"]) X = torch.randn(4, 5) y = torch.randn(4) dataset = TensorDataset(X, y) sampler = CrossSectionSampler(dates, shuffle_days=False) loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler) batches = list(loader) assert len(batches) == 2 # 2个日期 for bx, by in batches: assert bx.shape[0] == 2 # 每个日期2个样本 assert by.shape[0] == 2 if __name__ == "__main__": pytest.main([__file__, "-v"])