feat(training): TabM模型量化交易优化

- 新增 CrossSectionSampler 支持截面数据采样(按交易日批处理)
- 新增 EnsembleQuantLoss (Huber + IC) 替代 MSE 作为损失函数
- 重构 TabMModel 支持量化场景:Rank IC 作为验证指标、CosineAnnealingLR学习率调度、梯度裁剪
- 支持 date_col 参数和特征对齐
- 更新实验配置 batch_size 2048 和 weight_decay 等超参数
This commit is contained in:
2026-04-01 00:20:05 +08:00
parent 36a3ccbcc8
commit c143815443
9 changed files with 492 additions and 60 deletions

View File

@@ -0,0 +1,79 @@
"""截面数据采样器单元测试"""
import numpy as np
import pytest
import torch
from torch.utils.data import TensorDataset
from src.training.components.models.cross_section_sampler import CrossSectionSampler
class TestCrossSectionSampler:
"""截面采样器单元测试"""
def test_basic_functionality(self):
"""测试基本功能:按日期分组"""
dates = np.array(["20240101", "20240101", "20240102", "20240102", "20240103"])
sampler = CrossSectionSampler(dates, shuffle_days=False)
# 应该有3个日期批次
assert len(sampler) == 3
# 获取所有批次
batches = list(sampler)
# 验证每个批次包含同一天的数据
for batch in batches:
batch_dates = [dates[i] for i in batch]
assert len(set(batch_dates)) == 1, "批次内日期不一致"
def test_shuffle_days(self):
"""测试日期打乱功能"""
np.random.seed(42)
dates = np.array(["20240101"] * 5 + ["20240102"] * 5 + ["20240103"] * 5)
# 多次采样,验证日期顺序会变化
orders = []
for _ in range(10):
batches = list(CrossSectionSampler(dates, shuffle_days=True))
date_order = [dates[batch[0]] for batch in batches]
orders.append(tuple(date_order))
# 应该有不同的顺序出现
assert len(set(orders)) > 1, "日期顺序未被打乱"
def test_internal_shuffle(self):
"""测试截面内股票顺序打乱"""
np.random.seed(42)
dates = np.array(["20240101"] * 10)
# 多次获取同一批次
indices_list = []
for _ in range(5):
sampler = CrossSectionSampler(dates, shuffle_days=False)
batch = next(iter(sampler))
indices_list.append(list(batch))
# 应该有不同顺序
assert len(set(tuple(x) for x in indices_list)) > 1, "截面内顺序未被打乱"
def test_with_dataloader(self):
"""测试与 DataLoader 集成"""
dates = np.array(["20240101", "20240101", "20240102", "20240102"])
X = torch.randn(4, 5)
y = torch.randn(4)
dataset = TensorDataset(X, y)
sampler = CrossSectionSampler(dates, shuffle_days=False)
loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
batches = list(loader)
assert len(batches) == 2 # 2个日期
for bx, by in batches:
assert bx.shape[0] == 2 # 每个日期2个样本
assert by.shape[0] == 2
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,98 @@
"""EnsembleQuantLoss 单元测试"""
import numpy as np
import pytest
import torch
import torch.nn as nn
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
class TestEnsembleQuantLoss:
"""EnsembleQuantLoss 单元测试"""
def test_initialization(self):
"""测试初始化"""
loss_fn = EnsembleQuantLoss(alpha=0.7, ensemble_size=16)
assert loss_fn.alpha == 0.7
assert loss_fn.ensemble_size == 16
assert isinstance(loss_fn.huber, nn.HuberLoss)
def test_output_shape(self):
"""测试输出形状和类型"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
# 创建模拟数据: 20只股票, 4个集成成员
preds = torch.randn(20, 4)
target = torch.randn(20)
loss = loss_fn(preds, target)
# 验证输出是标量
assert loss.shape == torch.Size([])
assert isinstance(loss.item(), float)
def test_small_batch_fallback(self):
"""测试小批次回退到 Huber"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
# 少于10只股票的批次
preds = torch.randn(5, 4)
target = torch.randn(5)
loss = loss_fn(preds, target)
# 应该正常返回loss
assert not torch.isnan(loss)
assert loss.item() > 0
def test_huber_component(self):
"""测试 Huber 损失组件"""
loss_fn = EnsembleQuantLoss(alpha=1.0, ensemble_size=4) # 纯 Huber
preds = torch.randn(50, 4)
target = torch.randn(50)
loss = loss_fn(preds, target)
# 手动计算期望的 Huber 损失
huber = nn.HuberLoss(reduction="mean")
expected_loss = 0
for i in range(4):
expected_loss += huber(preds[:, i], target)
expected_loss /= 4
assert torch.allclose(loss, expected_loss, rtol=1e-5)
def test_ic_component(self):
"""测试 IC 损失组件"""
loss_fn = EnsembleQuantLoss(alpha=0.0, ensemble_size=1) # 纯 IC
# 创建完全相关的预测和目标
target = torch.randn(50)
preds = target.unsqueeze(1) # 完美相关
loss = loss_fn(preds, target)
# 完美相关时 IC=1所以 IC loss = 0
# 但由于 std 计算和数值精度可能不完全为0
assert loss.item() < 0.1
def test_gradient_flow(self):
"""测试梯度可以正常回传"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
preds = torch.randn(50, 4, requires_grad=True)
target = torch.randn(50)
loss = loss_fn(preds, target)
loss.backward()
# 验证梯度存在且非零
assert preds.grad is not None
assert not torch.all(preds.grad == 0)
if __name__ == "__main__":
pytest.main([__file__, "-v"])