diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index aad4bab3eb1..7246ee74b74 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -146,7 +146,15 @@ def quantize_value(self, x: torch.Tensor | float) -> Tensor: x = x.to(torch.float32) if self.per_channel: q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default - args = (x, self.scale, self.zp, self.axis, self.qmin, self.qmax, self.dtype) + args = ( + x, + torch.tensor(self.scale), + torch.tensor(self.zp), + self.axis, + self.qmin, + self.qmax, + self.dtype, + ) else: q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment] @@ -162,8 +170,8 @@ def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default args = ( qx, - self.scale, - self.zp, + torch.tensor(self.scale), + torch.tensor(self.zp), self.axis, self.qmin, self.qmax,