Skip to content

Commit

Permalink
✅ [Pass] loss function test, with new return shape
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed May 28, 2024
1 parent 12dfccf commit da24bd9
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/test_utils/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ def loss_function(cfg) -> YOLOLoss:
@pytest.fixture
def data():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
targets = torch.zeros(20, 6, device=device)
targets = torch.zeros(1, 20, 5, device=device)
predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
return predicts, targets


def test_yolo_loss(loss_function, data):
predicts, targets = data
loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets)
loss, (loss_iou, loss_dfl, loss_cls) = loss_function(predicts, targets)
assert torch.isnan(loss)
assert torch.isnan(loss_iou)
assert torch.isnan(loss_dfl)
assert torch.isinf(loss_cls)

0 comments on commit da24bd9

Please sign in to comment.