99 lines
2.8 KiB
Python
99 lines
2.8 KiB
Python
|
|
"""EnsembleQuantLoss 单元测试"""
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import pytest
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
|
|||
|
|
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestEnsembleQuantLoss:
|
|||
|
|
"""EnsembleQuantLoss 单元测试"""
|
|||
|
|
|
|||
|
|
def test_initialization(self):
|
|||
|
|
"""测试初始化"""
|
|||
|
|
loss_fn = EnsembleQuantLoss(alpha=0.7, ensemble_size=16)
|
|||
|
|
|
|||
|
|
assert loss_fn.alpha == 0.7
|
|||
|
|
assert loss_fn.ensemble_size == 16
|
|||
|
|
assert isinstance(loss_fn.huber, nn.HuberLoss)
|
|||
|
|
|
|||
|
|
def test_output_shape(self):
|
|||
|
|
"""测试输出形状和类型"""
|
|||
|
|
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
|
|||
|
|
|
|||
|
|
# 创建模拟数据: 20只股票, 4个集成成员
|
|||
|
|
preds = torch.randn(20, 4)
|
|||
|
|
target = torch.randn(20)
|
|||
|
|
|
|||
|
|
loss = loss_fn(preds, target)
|
|||
|
|
|
|||
|
|
# 验证输出是标量
|
|||
|
|
assert loss.shape == torch.Size([])
|
|||
|
|
assert isinstance(loss.item(), float)
|
|||
|
|
|
|||
|
|
def test_small_batch_fallback(self):
|
|||
|
|
"""测试小批次回退到 Huber"""
|
|||
|
|
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
|
|||
|
|
|
|||
|
|
# 少于10只股票的批次
|
|||
|
|
preds = torch.randn(5, 4)
|
|||
|
|
target = torch.randn(5)
|
|||
|
|
|
|||
|
|
loss = loss_fn(preds, target)
|
|||
|
|
|
|||
|
|
# 应该正常返回loss
|
|||
|
|
assert not torch.isnan(loss)
|
|||
|
|
assert loss.item() > 0
|
|||
|
|
|
|||
|
|
def test_huber_component(self):
|
|||
|
|
"""测试 Huber 损失组件"""
|
|||
|
|
loss_fn = EnsembleQuantLoss(alpha=1.0, ensemble_size=4) # 纯 Huber
|
|||
|
|
|
|||
|
|
preds = torch.randn(50, 4)
|
|||
|
|
target = torch.randn(50)
|
|||
|
|
|
|||
|
|
loss = loss_fn(preds, target)
|
|||
|
|
|
|||
|
|
# 手动计算期望的 Huber 损失
|
|||
|
|
huber = nn.HuberLoss(reduction="mean")
|
|||
|
|
expected_loss = 0
|
|||
|
|
for i in range(4):
|
|||
|
|
expected_loss += huber(preds[:, i], target)
|
|||
|
|
expected_loss /= 4
|
|||
|
|
|
|||
|
|
assert torch.allclose(loss, expected_loss, rtol=1e-5)
|
|||
|
|
|
|||
|
|
def test_ic_component(self):
|
|||
|
|
"""测试 IC 损失组件"""
|
|||
|
|
loss_fn = EnsembleQuantLoss(alpha=0.0, ensemble_size=1) # 纯 IC
|
|||
|
|
|
|||
|
|
# 创建完全相关的预测和目标
|
|||
|
|
target = torch.randn(50)
|
|||
|
|
preds = target.unsqueeze(1) # 完美相关
|
|||
|
|
|
|||
|
|
loss = loss_fn(preds, target)
|
|||
|
|
|
|||
|
|
# 完美相关时 IC=1,所以 IC loss = 0
|
|||
|
|
# 但由于 std 计算和数值精度,可能不完全为0
|
|||
|
|
assert loss.item() < 0.1
|
|||
|
|
|
|||
|
|
def test_gradient_flow(self):
|
|||
|
|
"""测试梯度可以正常回传"""
|
|||
|
|
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
|
|||
|
|
|
|||
|
|
preds = torch.randn(50, 4, requires_grad=True)
|
|||
|
|
target = torch.randn(50)
|
|||
|
|
|
|||
|
|
loss = loss_fn(preds, target)
|
|||
|
|
loss.backward()
|
|||
|
|
|
|||
|
|
# 验证梯度存在且非零
|
|||
|
|
assert preds.grad is not None
|
|||
|
|
assert not torch.all(preds.grad == 0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
pytest.main([__file__, "-v"])
|