diff --git a/intel_npu_acceleration_library/nn/autograd.py b/intel_npu_acceleration_library/nn/autograd.py index 5f5f5ca..2211343 100644 --- a/intel_npu_acceleration_library/nn/autograd.py +++ b/intel_npu_acceleration_library/nn/autograd.py @@ -63,6 +63,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Iterable[Union[torch.Tensor, Non dl_dx = run_matmul(grad_output, torch.transpose(w, -1, -2)) dl_dw = run_matmul( - torch.transpose(grad_output, -1, -2), torch.transpose(x, -1, -2) + torch.transpose(grad_output, -1, -2), + torch.transpose(x, -1, -2).to(torch.float16), ) return dl_dx, dl_dw, None