Skip to content

Commit b63fc70

Browse files
authored
Support for pytorch 2.0 (#94)
* Update environments to support torch 2.0 * Make latest CUDA version in CI be 11.7, since the CUDA workflow provider does not include 11.8 * Update ci * Update ci * Update ci * Update ci * Add latest torchani, compatible with pytorch 2 * update ci * update ci * Update ci * Address raimis comments * Fix cuda * Remove = in environment.yml
1 parent b27ec97 commit b63fc70

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
gcc: "8.5.*"
3030
nvcc: "10.2"
3131
python: "3.8.*"
32+
torchani: "2.2.*"
3233
pytorch: "1.11.*"
3334

3435
# Older supported versions
@@ -38,33 +39,36 @@ jobs:
3839
gcc: "10.3.*"
3940
nvcc: "11.2"
4041
python: "3.9.*"
42+
torchani: "2.2.*"
4143
pytorch: "1.12.*"
4244

4345
# Latest supported versions (with CUDA)
44-
- name: Linux (CUDA 11.7, Python 3.10, PyTorch 1.13)
46+
- name: Linux (CUDA 11.8, Python 3.10, PyTorch 2.0)
4547
enable_cuda: true
46-
cuda: "11.7.0"
48+
cuda: "11.8.0"
4749
gcc: "10.3.*"
48-
nvcc: "11.7"
50+
nvcc: "11.8"
4951
python: "3.10.*"
50-
pytorch: "1.13.*"
52+
torchani: "2.2.*"
53+
pytorch: "2.0.*"
5154

5255
# Latest supported versions (without CUDA)
53-
- name: Linux (no CUDA, Python 3.10, PyTorch 1.13)
56+
- name: Linux (no CUDA, Python 3.10, PyTorch 2.0)
5457
enable_cuda: false
5558
gcc: "10.3.*"
5659
python: "3.10.*"
57-
pytorch: "1.13.*"
60+
pytorch: "2.0.*"
61+
torchani: "2.2.*"
5862

5963
steps:
6064
- name: Check out
6165
uses: actions/checkout@v2
6266

6367
- name: Install CUDA Toolkit
64-
uses: Jimver/[email protected].8
68+
uses: Jimver/[email protected].10
6569
with:
6670
cuda: ${{ matrix.cuda }}
67-
linux-local-args: '["--toolkit", "--override"]' # Need to install CUDA 10.2
71+
linux-local-args: '["--toolkit", "--override"]'
6872
if: ${{ matrix.enable_cuda }}
6973

7074
- name: Install Miniconda
@@ -79,6 +83,7 @@ jobs:
7983
run: |
8084
sed -i -e "/cudatoolkit/c\ - cudatoolkit ${{ matrix.cuda }}" \
8185
-e "/gxx_linux-64/c\ - gxx_linux-64 ${{ matrix.gcc }}" \
86+
-e "/torchani/c\ - torchani ${{ matrix.torchani }}" \
8287
-e "/nvcc_linux-64/c\ - nvcc_linux-64 ${{ matrix.nvcc }}" \
8388
-e "/python/c\ - python ${{ matrix.python }}" \
8489
-e "/pytorch-gpu/c\ - pytorch-gpu ${{ matrix.pytorch }}" \
@@ -89,6 +94,7 @@ jobs:
8994
run: |
9095
sed -i -e "/cudatoolkit/c\ # - cudatoolkit" \
9196
-e "/gxx_linux-64/c\ - gxx_linux-64 ${{ matrix.gcc }}" \
97+
-e "/torchani/c\ - torchani ${{ matrix.torchani }}" \
9298
-e "/nvcc_linux-64/c\ # - nvcc_linux-64" \
9399
-e "/python/c\ - python ${{ matrix.python }}" \
94100
-e "/pytorch-gpu/c\ - pytorch-cpu ${{ matrix.pytorch }}" \

environment.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ channels:
22
- conda-forge
33
dependencies:
44
- cmake >=3.20
5-
- cudatoolkit 11.2.2
5+
- cudatoolkit 11.8.*
66
- gxx_linux-64 10.3.*
77
- make
88
- mdtraj
9-
- nvcc_linux-64 11.2
10-
- torchani 2.2.2
9+
- nvcc_linux-64 11.8
10+
- torchani 2.2.*
1111
- pytest
1212
- python 3.10.*
13-
- pytorch-gpu 1.12.*
14-
- sysroot_linux-64 2.17
13+
- pytorch-gpu 2.0.*
14+
- sysroot_linux-64 2.17

0 commit comments

Comments
 (0)