From b8f3a5ca977fdce575797767b904d16086561a3b Mon Sep 17 00:00:00 2001 From: Sam-Armstrong Date: Wed, 18 Sep 2024 16:30:37 +0100 Subject: [PATCH] fix: unify the torch.Tensor.cuda frontends --- ivy/functional/frontends/torch/tensor.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 3a28b5727c6b..63c08f44f76d 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -126,11 +126,6 @@ def itemsize(self): # Setters # # --------# - @device.setter - def cuda(self, device=None): - self.device = device - return self - @ivy_array.setter def ivy_array(self, array): self._ivy_array = array if isinstance(array, ivy.Array) else ivy.array(array) @@ -728,8 +723,8 @@ def detach_(self): def cpu(self): return ivy.to_device(self.ivy_array, "cpu") - def cuda(self): - return ivy.to_device(self.ivy_array, "gpu:0") + def cuda(self, device=None, non_blocking=False, memory_format=None): + return self.to("cuda" if device is None else device) @with_unsupported_dtypes({"2.2 and below": ("uint16",)}, "torch") @numpy_to_torch_style_args