80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
|
|
"""截面数据采样器单元测试"""
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
from torch.utils.data import TensorDataset
|
||
|
|
|
||
|
|
from src.training.components.models.cross_section_sampler import CrossSectionSampler
|
||
|
|
|
||
|
|
|
||
|
|
class TestCrossSectionSampler:
|
||
|
|
"""截面采样器单元测试"""
|
||
|
|
|
||
|
|
def test_basic_functionality(self):
|
||
|
|
"""测试基本功能:按日期分组"""
|
||
|
|
dates = np.array(["20240101", "20240101", "20240102", "20240102", "20240103"])
|
||
|
|
sampler = CrossSectionSampler(dates, shuffle_days=False)
|
||
|
|
|
||
|
|
# 应该有3个日期批次
|
||
|
|
assert len(sampler) == 3
|
||
|
|
|
||
|
|
# 获取所有批次
|
||
|
|
batches = list(sampler)
|
||
|
|
|
||
|
|
# 验证每个批次包含同一天的数据
|
||
|
|
for batch in batches:
|
||
|
|
batch_dates = [dates[i] for i in batch]
|
||
|
|
assert len(set(batch_dates)) == 1, "批次内日期不一致"
|
||
|
|
|
||
|
|
def test_shuffle_days(self):
|
||
|
|
"""测试日期打乱功能"""
|
||
|
|
np.random.seed(42)
|
||
|
|
dates = np.array(["20240101"] * 5 + ["20240102"] * 5 + ["20240103"] * 5)
|
||
|
|
|
||
|
|
# 多次采样,验证日期顺序会变化
|
||
|
|
orders = []
|
||
|
|
for _ in range(10):
|
||
|
|
batches = list(CrossSectionSampler(dates, shuffle_days=True))
|
||
|
|
date_order = [dates[batch[0]] for batch in batches]
|
||
|
|
orders.append(tuple(date_order))
|
||
|
|
|
||
|
|
# 应该有不同的顺序出现
|
||
|
|
assert len(set(orders)) > 1, "日期顺序未被打乱"
|
||
|
|
|
||
|
|
def test_internal_shuffle(self):
|
||
|
|
"""测试截面内股票顺序打乱"""
|
||
|
|
np.random.seed(42)
|
||
|
|
dates = np.array(["20240101"] * 10)
|
||
|
|
|
||
|
|
# 多次获取同一批次
|
||
|
|
indices_list = []
|
||
|
|
for _ in range(5):
|
||
|
|
sampler = CrossSectionSampler(dates, shuffle_days=False)
|
||
|
|
batch = next(iter(sampler))
|
||
|
|
indices_list.append(list(batch))
|
||
|
|
|
||
|
|
# 应该有不同顺序
|
||
|
|
assert len(set(tuple(x) for x in indices_list)) > 1, "截面内顺序未被打乱"
|
||
|
|
|
||
|
|
def test_with_dataloader(self):
|
||
|
|
"""测试与 DataLoader 集成"""
|
||
|
|
dates = np.array(["20240101", "20240101", "20240102", "20240102"])
|
||
|
|
X = torch.randn(4, 5)
|
||
|
|
y = torch.randn(4)
|
||
|
|
|
||
|
|
dataset = TensorDataset(X, y)
|
||
|
|
sampler = CrossSectionSampler(dates, shuffle_days=False)
|
||
|
|
loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
|
||
|
|
|
||
|
|
batches = list(loader)
|
||
|
|
assert len(batches) == 2 # 2个日期
|
||
|
|
|
||
|
|
for bx, by in batches:
|
||
|
|
assert bx.shape[0] == 2 # 每个日期2个样本
|
||
|
|
assert by.shape[0] == 2
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pytest.main([__file__, "-v"])
|