Skip to content

Commit

Permalink
train speed trial 1
Browse files Browse the repository at this point in the history
  • Loading branch information
YousefMetwally committed Jun 11, 2024
1 parent 616300c commit f08c3fb
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tomotwin/modules/training/torchtrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,14 @@ def classification_f1_score(self, test_loader: DataLoader) -> float:
anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
positive_vol = batch["positive"].to(self.device, non_blocking=True)
negative_vol = batch["negative"].to(self.device, non_blocking=True)
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0)
filenames = batch["filenames"]
with autocast():
# TODO: Probably concat anchor, positive and vol into one batch and run only one forward pass is enough.
anchor_out = self.model.forward(anchor_vol)
positive_out = self.model.forward(positive_vol)
negative_out = self.model.forward(negative_vol)
out = self.model.forward(full_input)
out = torch.split(out, anchor_vol.shape[0], dim=0)
anchor_out = out[0]
positive_out = out[1]
negative_out = out[2]

anchor_out_np = anchor_out.cpu().detach().numpy()
for i, anchor_filename in enumerate(filenames[0]):
Expand Down Expand Up @@ -258,16 +260,14 @@ def run_batch(self, batch: Dict):
anchor_vol = batch["anchor"].to(self.device, non_blocking=True)
positive_vol = batch["positive"].to(self.device, non_blocking=True)
negative_vol = batch["negative"].to(self.device, non_blocking=True)
full_input = torch.cat((anchor_vol,positive_vol,negative_vol), dim=0)
with autocast():
# TODO: Probably concat anchor, positive and vol into one batch and run only on forward pass is enough.
anchor_out = self.model.forward(anchor_vol)
positive_out = self.model.forward(positive_vol)
negative_out = self.model.forward(negative_vol)

out = self.model.forward(full_input)
out = torch.split(out, anchor_vol.shape[0], dim=0)
loss = self.criterion(
anchor_out,
positive_out,
negative_out,
out[0],
out[1],
out[2],
label_anchor=batch["label_anchor"],
label_positive=batch["label_positive"],
label_negative=batch["label_negative"],
Expand Down

0 comments on commit f08c3fb

Please sign in to comment.