Files
ProStock/tests/training/test_ensemble_quant_loss.py
liaozhaorun c143815443 feat(training): TabM模型量化交易优化
- 新增 CrossSectionSampler 支持截面数据采样(按交易日批处理)
- 新增 EnsembleQuantLoss (Huber + IC) 替代 MSE 作为损失函数
- 重构 TabMModel 支持量化场景:Rank IC 作为验证指标、CosineAnnealingLR学习率调度、梯度裁剪
- 支持 date_col 参数和特征对齐
- 更新实验配置 batch_size 2048 和 weight_decay 等超参数
2026-04-01 00:20:05 +08:00

99 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""EnsembleQuantLoss 单元测试"""
import numpy as np
import pytest
import torch
import torch.nn as nn
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
class TestEnsembleQuantLoss:
"""EnsembleQuantLoss 单元测试"""
def test_initialization(self):
"""测试初始化"""
loss_fn = EnsembleQuantLoss(alpha=0.7, ensemble_size=16)
assert loss_fn.alpha == 0.7
assert loss_fn.ensemble_size == 16
assert isinstance(loss_fn.huber, nn.HuberLoss)
def test_output_shape(self):
"""测试输出形状和类型"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
# 创建模拟数据: 20只股票, 4个集成成员
preds = torch.randn(20, 4)
target = torch.randn(20)
loss = loss_fn(preds, target)
# 验证输出是标量
assert loss.shape == torch.Size([])
assert isinstance(loss.item(), float)
def test_small_batch_fallback(self):
"""测试小批次回退到 Huber"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
# 少于10只股票的批次
preds = torch.randn(5, 4)
target = torch.randn(5)
loss = loss_fn(preds, target)
# 应该正常返回loss
assert not torch.isnan(loss)
assert loss.item() > 0
def test_huber_component(self):
"""测试 Huber 损失组件"""
loss_fn = EnsembleQuantLoss(alpha=1.0, ensemble_size=4) # 纯 Huber
preds = torch.randn(50, 4)
target = torch.randn(50)
loss = loss_fn(preds, target)
# 手动计算期望的 Huber 损失
huber = nn.HuberLoss(reduction="mean")
expected_loss = 0
for i in range(4):
expected_loss += huber(preds[:, i], target)
expected_loss /= 4
assert torch.allclose(loss, expected_loss, rtol=1e-5)
def test_ic_component(self):
"""测试 IC 损失组件"""
loss_fn = EnsembleQuantLoss(alpha=0.0, ensemble_size=1) # 纯 IC
# 创建完全相关的预测和目标
target = torch.randn(50)
preds = target.unsqueeze(1) # 完美相关
loss = loss_fn(preds, target)
# 完美相关时 IC=1所以 IC loss = 0
# 但由于 std 计算和数值精度可能不完全为0
assert loss.item() < 0.1
def test_gradient_flow(self):
"""测试梯度可以正常回传"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
preds = torch.randn(50, 4, requires_grad=True)
target = torch.randn(50)
loss = loss_fn(preds, target)
loss.backward()
# 验证梯度存在且非零
assert preds.grad is not None
assert not torch.all(preds.grad == 0)
if __name__ == "__main__":
pytest.main([__file__, "-v"])