Skip to content

Commit 6c8d739

Browse files
ajmssctomaarsen
andauthored
Fix bug where SetFitHead not moved to non-cuda devices on init (#518)
Co-authored-by: Tom Aarsen <[email protected]>
1 parent 6d4010d commit 6c8d739

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/setfit/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
self.temperature = temperature
9292
self.eps = eps
9393
self.bias = bias
94-
self._device = device or "cuda" if torch.cuda.is_available() else "cpu"
94+
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
9595
self.multitarget = multitarget
9696

9797
self.to(self._device)

0 commit comments

Comments
 (0)