"""EnsembleQuantLoss 单元测试""" import numpy as np import pytest import torch import torch.nn as nn from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss class TestEnsembleQuantLoss: """EnsembleQuantLoss 单元测试""" def test_initialization(self): """测试初始化""" loss_fn = EnsembleQuantLoss(alpha=0.7, ensemble_size=16) assert loss_fn.alpha == 0.7 assert loss_fn.ensemble_size == 16 assert isinstance(loss_fn.huber, nn.HuberLoss) def test_output_shape(self): """测试输出形状和类型""" loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4) # 创建模拟数据: 20只股票, 4个集成成员 preds = torch.randn(20, 4) target = torch.randn(20) loss = loss_fn(preds, target) # 验证输出是标量 assert loss.shape == torch.Size([]) assert isinstance(loss.item(), float) def test_small_batch_fallback(self): """测试小批次回退到 Huber""" loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4) # 少于10只股票的批次 preds = torch.randn(5, 4) target = torch.randn(5) loss = loss_fn(preds, target) # 应该正常返回loss assert not torch.isnan(loss) assert loss.item() > 0 def test_huber_component(self): """测试 Huber 损失组件""" loss_fn = EnsembleQuantLoss(alpha=1.0, ensemble_size=4) # 纯 Huber preds = torch.randn(50, 4) target = torch.randn(50) loss = loss_fn(preds, target) # 手动计算期望的 Huber 损失 huber = nn.HuberLoss(reduction="mean") expected_loss = 0 for i in range(4): expected_loss += huber(preds[:, i], target) expected_loss /= 4 assert torch.allclose(loss, expected_loss, rtol=1e-5) def test_ic_component(self): """测试 IC 损失组件""" loss_fn = EnsembleQuantLoss(alpha=0.0, ensemble_size=1) # 纯 IC # 创建完全相关的预测和目标 target = torch.randn(50) preds = target.unsqueeze(1) # 完美相关 loss = loss_fn(preds, target) # 完美相关时 IC=1,所以 IC loss = 0 # 但由于 std 计算和数值精度,可能不完全为0 assert loss.item() < 0.1 def test_gradient_flow(self): """测试梯度可以正常回传""" loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4) preds = torch.randn(50, 4, requires_grad=True) target = torch.randn(50) loss = loss_fn(preds, target) loss.backward() # 验证梯度存在且非零 assert preds.grad is not None assert not torch.all(preds.grad == 0) if __name__ == "__main__": pytest.main([__file__, "-v"])