diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c539f8a8..5f8a4aae 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,22 +23,22 @@ jobs: matrix: include: # Oldest supported versions - - name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.7) + - name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.11) os: ubuntu-18.04 cuda-version: "10.2.89" - gcc-version: "9.4.*" + gcc-version: "8.5.*" nvcc-version: "10.2" python-version: "3.7" - pytorch-version: "1.7.*" + pytorch-version: "1.11.*" # Latest supported versions - - name: Linux (CUDA 11.2, Python 3.9, PyTorch 1.10) + - name: Linux (CUDA 11.2, Python 3.10, PyTorch 1.11) os: ubuntu-18.04 cuda-version: "11.2.2" - gcc-version: "11.2.*" + gcc-version: "10.3.*" nvcc-version: "11.2" - python-version: "3.9" - pytorch-version: "1.10.*" + python-version: "3.10" + pytorch-version: "1.11.*" - name: MacOS (Python 3.9, PyTorch 1.9) os: macos-11 diff --git a/tests/central.pt b/tests/central.pt index b190c8f4..f0fbcf3a 100644 Binary files a/tests/central.pt and b/tests/central.pt differ diff --git a/tests/forces.pt b/tests/forces.pt index 2a6866b4..325f4e6b 100644 Binary files a/tests/forces.pt and b/tests/forces.pt differ diff --git a/tests/generate.py b/tests/generate.py new file mode 100644 index 00000000..43e698b2 --- /dev/null +++ b/tests/generate.py @@ -0,0 +1,24 @@ +import torch as pt + +class Central(pt.nn.Module): + def forward(self, pos): + return pos.pow(2).sum() + +class Forces(pt.nn.Module): + def forward(self, pos): + return pos.pow(2).sum(), -2 * pos + +class Global(pt.nn.Module): + def forward(self, pos, k): + return k * pos.pow(2).sum() + +class Periodic(pt.nn.Module): + def forward(self, pos, box): + box = box.diagonal().unsqueeze(0) + pos = pos - (pos / box).floor() * box + return pos.pow(2).sum() + +pt.jit.script(Central()).save('central.pt') +pt.jit.script(Forces()).save('forces.pt') +pt.jit.script(Global()).save('global.pt') +pt.jit.script(Periodic()).save('periodic.pt') \ No newline at end of file diff --git a/tests/global.pt b/tests/global.pt index 3ed9d571..8cc15e63 100644 Binary files a/tests/global.pt and b/tests/global.pt differ diff --git a/tests/periodic.pt b/tests/periodic.pt index 4adc6cac..318ee5af 100644 Binary files a/tests/periodic.pt and b/tests/periodic.pt differ