Skip to content

Commit

Permalink
✅ [Pass] Train, Model, Loss Test
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed May 29, 2024
1 parent 73207bd commit 033231b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions tests/test_utils/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ def loss_function(cfg) -> YOLOLoss:
def data():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
targets = torch.zeros(1, 20, 5, device=device)
predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
predicts = [torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]]
return predicts, targets


def test_yolo_loss(loss_function, data):
predicts, targets = data
loss, (loss_iou, loss_dfl, loss_cls) = loss_function(predicts, targets)
assert torch.isnan(loss)
loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets)
assert torch.isnan(loss_iou)
assert torch.isnan(loss_dfl)
assert torch.isinf(loss_cls)
2 changes: 1 addition & 1 deletion yolo/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, cfg: Config) -> None:
self.strides = cfg.model.anchor.strides
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
self.scale_up = torch.tensor(self.image_size * 2, device=device)

self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
Expand Down

0 comments on commit 033231b

Please sign in to comment.