From 2270256fed21c5ed045cc21b560931b171f5c863 Mon Sep 17 00:00:00 2001
From: Raul <raulppelaez@gmail.com>
Date: Wed, 20 Sep 2023 15:46:14 +0200
Subject: [PATCH] Fix automatic CUDA graphing not working when requiring
 backwards (#120)

* Pass posTensor as input argument to energyTensor.backwards.
This instructs torch to compute gradients only with respect to positions.

* Add comments
---
 platforms/cuda/src/CudaTorchKernels.cpp | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp
index d2d869ae..2cdb1c44 100644
--- a/platforms/cuda/src/CudaTorchKernels.cpp
+++ b/platforms/cuda/src/CudaTorchKernels.cpp
@@ -188,7 +188,11 @@ static void executeGraph(bool outputsForces, bool includeForces, torch::jit::scr
         energyTensor = module.forward(inputs).toTensor();
         // Compute force by backpropagating the PyTorch model
         if (includeForces) {
-            energyTensor.backward();
+            // CUDA graph capture sometimes fails if backwards is not explicitly requested w.r.t positions
+	    // See https://github.com/openmm/openmm-torch/pull/120/
+	    auto none = torch::Tensor();
+	    energyTensor.backward(none, false, false, posTensor);
+	    // This is minus the forces, we change the sign later on
             forceTensor = posTensor.grad().clone();
             // Zero the gradient to avoid accumulating it
             posTensor.grad().zero_();