From 5dc727951b99c4ed8904facbc54e29162ef57276 Mon Sep 17 00:00:00 2001 From: Raimondas Galvelis Date: Fri, 8 Jul 2022 17:02:19 +0200 Subject: [PATCH] Fix interoperability with CustomCVForce (#80) * Add a test with CustomCVForce * Test all the platforms * Add an iteroperability test for TorchANI and NNPOps * Add a missing dependencies * Skip for MacOS * Move imports * Fix import * Retain the primary context * Switch properly the contexts * Set the oldest CUDA to 11.0 * Fix nvcc version * Enable an extra check * Clean up a temporary file * Add more checks * Add comments * Remove a sync and clean up * Move the primary context activation --- .github/workflows/CI.yml | 7 ++- devtools/conda-envs/build-ubuntu-18.04.yml | 4 +- platforms/cuda/src/CudaTorchKernels.cpp | 43 ++++++++++--- platforms/cuda/src/CudaTorchKernels.h | 6 +- python/tests/TestInteroperability.py | 71 ++++++++++++++++++++++ python/tests/TestTorchForce.py | 20 ++++-- 6 files changed, 132 insertions(+), 19 deletions(-) create mode 100644 python/tests/TestInteroperability.py diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5f8a4aae..fecf6b02 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,11 +23,12 @@ jobs: matrix: include: # Oldest supported versions - - name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.11) + # NOTE: renable CUDA 10.2 when it supported by NNPOps (https://github.com/conda-forge/nnpops-feedstock/pull/8) + - name: Linux (CUDA 11.0, Python 3.7, PyTorch 1.11) os: ubuntu-18.04 - cuda-version: "10.2.89" + cuda-version: "11.0.3" gcc-version: "8.5.*" - nvcc-version: "10.2" + nvcc-version: "11.0" python-version: "3.7" pytorch-version: "1.11.*" diff --git a/devtools/conda-envs/build-ubuntu-18.04.yml b/devtools/conda-envs/build-ubuntu-18.04.yml index a7f608a0..40271720 100644 --- a/devtools/conda-envs/build-ubuntu-18.04.yml +++ b/devtools/conda-envs/build-ubuntu-18.04.yml @@ -6,6 +6,7 @@ dependencies: - cudatoolkit @CUDATOOLKIT_VERSION@ - gxx_linux-64 @GCC_VERSION@ - make + - nnpops - nvcc_linux-64 @NVCC_VERSION@ - ocl-icd - openmm >=7.7 @@ -15,4 +16,5 @@ dependencies: - python - pytorch-gpu @PYTORCH_VERSION@ - swig - - sysroot_linux-64 2.17 \ No newline at end of file + - sysroot_linux-64 2.17 + - torchani \ No newline at end of file diff --git a/platforms/cuda/src/CudaTorchKernels.cpp b/platforms/cuda/src/CudaTorchKernels.cpp index 31274997..42e23d2c 100644 --- a/platforms/cuda/src/CudaTorchKernels.cpp +++ b/platforms/cuda/src/CudaTorchKernels.cpp @@ -49,7 +49,14 @@ if (result != CUDA_SUCCESS) { \ throw OpenMMException(m.str());\ } +CudaCalcTorchForceKernel::CudaCalcTorchForceKernel(string name, const Platform& platform, CudaContext& cu) : + CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) { + // Explicitly activate the primary context + CHECK_RESULT(cuDevicePrimaryCtxRetain(&primaryContext, cu.getDevice()), "Failed to retain the primary context"); +} + CudaCalcTorchForceKernel::~CudaCalcTorchForceKernel() { + cuDevicePrimaryCtxRelease(cu.getDevice()); } void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce& force, torch::jit::script::Module& module) { @@ -60,6 +67,11 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce globalNames.push_back(force.getGlobalParameterName(i)); int numParticles = system.getNumParticles(); + // Push the PyTorch context + // NOTE: Pytorch is always using the primary context. + // It makes the primary context current, if it is not a case. + CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); + // Initialize CUDA objects for PyTorch const torch::Device device(torch::kCUDA, cu.getDeviceIndex()); // This implicitly initialize PyTorch module.to(device); @@ -69,8 +81,13 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce posTensor = torch::empty({numParticles, 3}, options.requires_grad(!outputsForces)); boxTensor = torch::empty({3, 3}, options); + // Pop the PyToch context + CUcontext ctx; + CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); + assert(primaryContext == ctx); // Check that PyTorch haven't messed up the context stack + // Initialize CUDA objects for OpenMM-Torch - ContextSelector selector(cu); + ContextSelector selector(cu); // Switch to the OpenMM context map defines; CUmodule program = cu.createModule(CudaTorchKernelSources::torchForce, defines); copyInputsKernel = cu.getKernel(program, "copyInputs"); @@ -80,6 +97,9 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { int numParticles = cu.getNumAtoms(); + // Push to the PyTorch context + CHECK_RESULT(cuCtxPushCurrent(primaryContext), "Failed to push the CUDA context"); + // Get pointers to the atomic positions and simulation box void* posData; void* boxData; @@ -94,11 +114,11 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce // Copy the atomic positions and simulation box to PyTorch tensors { - ContextSelector selector(cu); + ContextSelector selector(cu); // Switch to the OpenMM context void* inputArgs[] = {&posData, &boxData, &cu.getPosq().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer()}; cu.executeKernel(copyInputsKernel, inputArgs, numParticles); - CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the PyTorch context + CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context } // Prepare the input of the PyTorch model @@ -138,21 +158,30 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce forceTensor = forceTensor.to(torch::kFloat32); forceData = forceTensor.data_ptr(); } - CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the OpenMM context + CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the OpenMM context // Add the computed forces to the total atomic forces { - ContextSelector selector(cu); + ContextSelector selector(cu); // Switch to the OpenMM context int paddedNumAtoms = cu.getPaddedNumAtoms(); int forceSign = (outputsForces ? 1 : -1); void* forceArgs[] = {&forceData, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms, &forceSign}; cu.executeKernel(addForcesKernel, forceArgs, numParticles); - CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context"); // Synchronize before switching to the PyTorch context + CHECK_RESULT(cuCtxSynchronize(), "Failed to synchronize the CUDA context"); // Synchronize before switching to the PyTorch context } // Reset the forces if (!outputsForces) posTensor.grad().zero_(); } - return energyTensor.item(); // This implicitly synchronize the PyTorch context + + // Get energy + const double energy = energyTensor.item(); // This implicitly synchronizes the PyTorch context + + // Pop to the PyTorch context + CUcontext ctx; + CHECK_RESULT(cuCtxPopCurrent(&ctx), "Failed to pop the CUDA context"); + assert(primaryContext == ctx); // Check that the correct context was popped + + return energy; } diff --git a/platforms/cuda/src/CudaTorchKernels.h b/platforms/cuda/src/CudaTorchKernels.h index 62d0cee2..23a6238c 100644 --- a/platforms/cuda/src/CudaTorchKernels.h +++ b/platforms/cuda/src/CudaTorchKernels.h @@ -34,7 +34,6 @@ #include "TorchKernels.h" #include "openmm/cuda/CudaContext.h" -#include "openmm/cuda/CudaArray.h" namespace TorchPlugin { @@ -43,9 +42,7 @@ namespace TorchPlugin { */ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { public: - CudaCalcTorchForceKernel(std::string name, const OpenMM::Platform& platform, OpenMM::CudaContext& cu) : - CalcTorchForceKernel(name, platform), hasInitializedKernel(false), cu(cu) { - } + CudaCalcTorchForceKernel(std::string name, const OpenMM::Platform& platform, OpenMM::CudaContext& cu); ~CudaCalcTorchForceKernel(); /** * Initialize the kernel. @@ -72,6 +69,7 @@ class CudaCalcTorchForceKernel : public CalcTorchForceKernel { std::vector globalNames; bool usePeriodic, outputsForces; CUfunction copyInputsKernel, addForcesKernel; + CUcontext primaryContext; }; } // namespace TorchPlugin diff --git a/python/tests/TestInteroperability.py b/python/tests/TestInteroperability.py new file mode 100644 index 00000000..1c67c8a1 --- /dev/null +++ b/python/tests/TestInteroperability.py @@ -0,0 +1,71 @@ +import openmm as mm +import openmm.unit as unit +import openmmtorch as ot +import platform +import pytest +from tempfile import NamedTemporaryFile +import torch as pt + + +@pytest.mark.skipif(platform.system() == 'Darwin', reason='There is no NNPOps package for MacOS') +@pytest.mark.parametrize('use_cv_force', [True, False]) +@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL']) +def testTorchANI(use_cv_force, platform): + + if pt.cuda.device_count() < 1 and platform == 'CUDA': + pytest.skip('A CUDA device is not available') + + import NNPOps # There is no NNPOps package for MacOS + import torchani + + class Model(pt.nn.Module): + + def __init__(self): + super().__init__() + self.register_buffer('atomic_numbers', pt.tensor([[1, 1]])) + self.model = torchani.models.ANI2x(periodic_table_index=True) + self.model = NNPOps.OptimizedTorchANI(self.model, self.atomic_numbers) + + def forward(self, positions): + positions = positions.float().unsqueeze(0) * 10 # nm --> Ang + return self.model((self.atomic_numbers, positions)).energies[0] * 2625.5 # Hartree --> kJ/mol + + # Create a system + system = mm.System() + for _ in range(2): + system.addParticle(1.0) + positions = pt.tensor([[-5, 0.0, 0.0], [5, 0.0, 0.0]], requires_grad=True) + + with NamedTemporaryFile() as model_file: + + # Save the model + pt.jit.script(Model()).save(model_file.name) + + # Compute reference energy and forces + model = pt.jit.load(model_file) + ref_energy = model(positions) + ref_energy.backward() + ref_forces = positions.grad + + # Create a force + force = ot.TorchForce(model_file.name) + if use_cv_force: + # Wrap TorchForce into CustomCVForce + cv_force = mm.CustomCVForce('force') + cv_force.addCollectiveVariable('force', force) + system.addForce(cv_force) + else: + system.addForce(force) + + # Compute energy and forces + integ = mm.VerletIntegrator(1.0) + platform = mm.Platform.getPlatformByName(platform) + context = mm.Context(system, integ, platform) + context.setPositions(positions.detach().numpy()) + state = context.getState(getEnergy=True, getForces=True) + energy = state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole) + forces = state.getForces(asNumpy=True).value_in_unit(unit.kilojoules_per_mole/unit.nanometers) + + # Check energy and forces + assert pt.allclose(ref_energy, pt.tensor(energy, dtype=ref_energy.dtype)) + assert pt.allclose(ref_forces, pt.tensor(forces, dtype=ref_forces.dtype)) \ No newline at end of file diff --git a/python/tests/TestTorchForce.py b/python/tests/TestTorchForce.py index 380ca260..c7e7708e 100644 --- a/python/tests/TestTorchForce.py +++ b/python/tests/TestTorchForce.py @@ -9,13 +9,18 @@ @pytest.mark.parametrize('model_file, output_forces,', [('../../tests/central.pt', False), ('../../tests/forces.pt', True)]) -def testForce(model_file, output_forces): +@pytest.mark.parametrize('use_cv_force', [True, False]) +@pytest.mark.parametrize('platform', ['Reference', 'CPU', 'CUDA', 'OpenCL']) +def testForce(model_file, output_forces, use_cv_force, platform): + + if pt.cuda.device_count() < 1 and platform == 'CUDA': + pytest.skip('A CUDA device is not available') # Create a random cloud of particles. numParticles = 10 system = mm.System() positions = np.random.rand(numParticles, 3) - for i in range(numParticles): + for _ in range(numParticles): system.addParticle(1.0) # Create a force @@ -23,11 +28,18 @@ def testForce(model_file, output_forces): assert not force.getOutputsForces() # Check the default force.setOutputsForces(output_forces) assert force.getOutputsForces() == output_forces - system.addForce(force) + if use_cv_force: + # Wrap TorchForce into CustomCVForce + cv_force = mm.CustomCVForce('force') + cv_force.addCollectiveVariable('force', force) + system.addForce(cv_force) + else: + system.addForce(force) # Compute the forces and energy. integ = mm.VerletIntegrator(1.0) - context = mm.Context(system, integ, mm.Platform.getPlatformByName('Reference')) + platform = mm.Platform.getPlatformByName(platform) + context = mm.Context(system, integ, platform) context.setPositions(positions) state = context.getState(getEnergy=True, getForces=True)