Files
ProStock/tests/training/test_cross_section_sampler.py
liaozhaorun c143815443 feat(training): TabM模型量化交易优化
- 新增 CrossSectionSampler 支持截面数据采样(按交易日批处理)
- 新增 EnsembleQuantLoss (Huber + IC) 替代 MSE 作为损失函数
- 重构 TabMModel 支持量化场景:Rank IC 作为验证指标、CosineAnnealingLR学习率调度、梯度裁剪
- 支持 date_col 参数和特征对齐
- 更新实验配置 batch_size 2048 和 weight_decay 等超参数
2026-04-01 00:20:05 +08:00

80 lines
2.6 KiB
Python

"""截面数据采样器单元测试"""
import numpy as np
import pytest
import torch
from torch.utils.data import TensorDataset
from src.training.components.models.cross_section_sampler import CrossSectionSampler
class TestCrossSectionSampler:
"""截面采样器单元测试"""
def test_basic_functionality(self):
"""测试基本功能:按日期分组"""
dates = np.array(["20240101", "20240101", "20240102", "20240102", "20240103"])
sampler = CrossSectionSampler(dates, shuffle_days=False)
# 应该有3个日期批次
assert len(sampler) == 3
# 获取所有批次
batches = list(sampler)
# 验证每个批次包含同一天的数据
for batch in batches:
batch_dates = [dates[i] for i in batch]
assert len(set(batch_dates)) == 1, "批次内日期不一致"
def test_shuffle_days(self):
"""测试日期打乱功能"""
np.random.seed(42)
dates = np.array(["20240101"] * 5 + ["20240102"] * 5 + ["20240103"] * 5)
# 多次采样,验证日期顺序会变化
orders = []
for _ in range(10):
batches = list(CrossSectionSampler(dates, shuffle_days=True))
date_order = [dates[batch[0]] for batch in batches]
orders.append(tuple(date_order))
# 应该有不同的顺序出现
assert len(set(orders)) > 1, "日期顺序未被打乱"
def test_internal_shuffle(self):
"""测试截面内股票顺序打乱"""
np.random.seed(42)
dates = np.array(["20240101"] * 10)
# 多次获取同一批次
indices_list = []
for _ in range(5):
sampler = CrossSectionSampler(dates, shuffle_days=False)
batch = next(iter(sampler))
indices_list.append(list(batch))
# 应该有不同顺序
assert len(set(tuple(x) for x in indices_list)) > 1, "截面内顺序未被打乱"
def test_with_dataloader(self):
"""测试与 DataLoader 集成"""
dates = np.array(["20240101", "20240101", "20240102", "20240102"])
X = torch.randn(4, 5)
y = torch.randn(4)
dataset = TensorDataset(X, y)
sampler = CrossSectionSampler(dates, shuffle_days=False)
loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
batches = list(loader)
assert len(batches) == 2 # 2个日期
for bx, by in batches:
assert bx.shape[0] == 2 # 每个日期2个样本
assert by.shape[0] == 2
if __name__ == "__main__":
pytest.main([__file__, "-v"])