From 86e4e1e8a9ac393a2c084e03d8095149690f48a1 Mon Sep 17 00:00:00 2001 From: SarahByrneIntel <135850186+SarahByrneIntel@users.noreply.github.com> Date: Tue, 23 Jul 2024 10:07:01 +0100 Subject: [PATCH] Dtype fix for backprop matmul (#104) Co-authored-by: SarahByrneIntel --- intel_npu_acceleration_library/nn/autograd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/intel_npu_acceleration_library/nn/autograd.py b/intel_npu_acceleration_library/nn/autograd.py index 5f5f5ca..2211343 100644 --- a/intel_npu_acceleration_library/nn/autograd.py +++ b/intel_npu_acceleration_library/nn/autograd.py @@ -63,6 +63,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Iterable[Union[torch.Tensor, Non dl_dx = run_matmul(grad_output, torch.transpose(w, -1, -2)) dl_dw = run_matmul( - torch.transpose(grad_output, -1, -2), torch.transpose(x, -1, -2) + torch.transpose(grad_output, -1, -2), + torch.transpose(x, -1, -2).to(torch.float16), ) return dl_dx, dl_dw, None