Skip to content

Commit 189111a

Browse files
authored
Merge pull request #363 from libAtoms/autopara_GPU
make wfl pool autoparallelization torch GPU ID aware
2 parents 4d89ec5 + 353a22d commit 189111a

File tree

15 files changed

+302
-93
lines changed

15 files changed

+302
-93
lines changed

docs/source/overview.parallelisation.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ created using `multiprocessing.pool.Pool
6868
The number of threads is controlled by an integer, passed in to the
6969
function as an optional ``num_python_subprocesses`` argument, or stored
7070
in the env var ``WFL_NUM_PYTHON_SUBPROCESSES``. The script should be
71-
started with a normal run of the python executable.
71+
started with a normal run of the python executable. Setting
72+
the ``WFL_TORCH_N_GPUS`` env var to the number of GPUs
73+
causes ``wfl`` to call ``torch.cuda.set_device()`` for each subprocess
74+
ensuring that it gets a unique GPU from the other subprocesses.
7275

7376

7477
========================================

tests/local_scripts/complete_pytest.tin

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,40 @@
22

33
module purge
44
# module load compiler/gnu python/system python_extras/quippy lapack/mkl
5-
module load compiler/gnu python python_extras/quippy lapack/mkl
5+
module load compiler/gnu python python_extras/structure python_extras/quippy lapack/mkl
66
# for wfl dependencies
77
module load python_extras/wif
88
module load python_extras/torch/cpu
99

10+
if [ ! -z "$WFL_PYTEST_POST_MODULE_COMMANDS" ]; then
11+
echo "Using WFL_PYTEST_POST_MODULE_COMMANDS '$WFL_PYTEST_POST_MODULE_COMMANDS'" 1>&2
12+
eval $WFL_PYTEST_POST_MODULE_COMMANDS
13+
else
14+
echo "Using no WFL_PYTEST_POST_MODULE_COMMANDS" 1>&2
15+
fi
16+
1017
if [ -z "$WFL_PYTEST_EXPYRE_INFO" ]; then
1118
echo "To override partition used, set WFL_PYTEST_EXPYRE_INFO='{\"resources\" : {\"partitions\": \"DESIRED_PARTITION\"}}'" 1>&2
1219
fi
20+
WFL_PYTEST_EXPYRE_INFO=$(
21+
cat << EOF | python3
22+
import json, os
23+
i = {"pre_cmds": ["module purge",
24+
"module load compiler/gnu lapack/mkl python python_extras/structure python_extras/quippy python_extras/wif dft/vasp dft/pwscf",
25+
"module list"]}
26+
ienv = json.loads(os.environ.get("WFL_PYTEST_EXPYRE_INFO", "{}"))
27+
i.update(ienv)
28+
print(json.dumps(i))
29+
EOF
30+
)
31+
export WFL_PYTEST_EXPYRE_INFO
32+
echo "Using WFL_PYTEST_EXPYRE_INFO '$WFL_PYTEST_EXPYRE_INFO'" 1>&2
1333

1434
if [ ! -z $WFL_PYTHONPATH_EXTRA ]; then
35+
echo "Adding WFL_PYTHONPATH_EXTRA '$WFL_PYTHONPATH_EXTRA'" 1>&2
1536
export PYTHONPATH=${WFL_PYTHONPATH_EXTRA}:${PYTHONPATH}
37+
else
38+
echo "Adding no WFL_PYTHONPATH_EXTRA" 1>&2
1639
fi
1740

1841
export JULIA_PROJECT=${PWD}/tests/assets/julia
@@ -29,14 +52,16 @@ echo "" >> complete_pytest.tin.out
2952
# buildcell
3053
export WFL_PYTEST_BUILDCELL=$HOME/src/work/AIRSS/airss-0.9.1/src/buildcell/src/buildcell
3154
# VASP
32-
module load dft/vasp
55+
module load dft/vasp/serial
3356
export ASE_VASP_COMMAND=vasp.serial
3457
export ASE_VASP_COMMAND_GAMMA=vasp.gamma.serial
3558
export PYTEST_VASP_POTCAR_DIR=$VASP_PATH/pot/rev_54/PBE
3659
# QE
3760
module load dft/pwscf
3861
# no ORCA
3962

63+
module list
64+
4065
export OPENBLAS_NUM_THREADS=1
4166
export MKL_NUM_THREADS=1
4267
# required for descriptor calc to not hang
@@ -70,7 +95,7 @@ l=`egrep '^=.*(passed|failed|skipped|xfailed|error).* in ' complete_pytest.tin.o
7095
echo "summary line $l"
7196
lp=$( echo $l | sed -E -e 's/ in .*//' -e 's/\s*,\s*/\n/g' )
7297

73-
declare -A expected_n=( ["passed"]="177" ["skipped"]="21" ["warnings"]=823 ["xfailed"]=2 ["xpassed"]=1 )
98+
declare -A expected_n=( ["passed"]="188" ["skipped"]="26" ["warnings"]=1068 ["xfailed"]=1 )
7499
IFS=$'\n'
75100
t_stat=0
76101
for out in $lp; do

tests/test_autoparallelize.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import os
23
import time
34

45
import numpy as np
@@ -12,6 +13,12 @@
1213
from wfl.calculators import generic
1314
from wfl.autoparallelize import AutoparaInfo
1415

16+
try:
17+
import torch
18+
from mace.calculators.foundations_models import mace_mp
19+
except ImportError:
20+
torch = None
21+
1522

1623
def test_empty_iterator(tmp_path):
1724
co = buildcell.buildcell(range(0), OutputSpec(tmp_path / 'dummy.xyz'), buildcell_cmd='dummy', buildcell_input='dummy')
@@ -35,21 +42,71 @@ def test_autopara_info_dict():
3542
def test_pool_speedup():
3643
np.random.seed(5)
3744

45+
rng = np.random.default_rng(5)
3846
ats = []
3947
nconf = 60
48+
at_prim = Atoms('Al', cell=[1, 1, 1], pbc=[True] * 3)
4049
for _ in range(nconf):
41-
ats.append(Atoms(['Al'] * nconf, scaled_positions=np.random.uniform(size=(nconf, 3)), cell=[10, 10, 10], pbc=[True] * 3))
50+
ats.append(at_prim * (4, 4, 4))
51+
ats[-1].rattle(rng=rng)
4252

4353
t0 = time.time()
44-
co = generic.calculate(ConfigSet(ats), OutputSpec(), EMT(), output_prefix="_auto_", autopara_info=AutoparaInfo(num_python_subprocesses=1))
54+
co = generic.calculate(ConfigSet(ats), OutputSpec(), EMT(), output_prefix="_auto_",
55+
autopara_info=AutoparaInfo(num_python_subprocesses=1,
56+
num_inputs_per_python_subprocess=30))
4557
dt_1 = time.time() - t0
4658

4759
t0 = time.time()
48-
co = generic.calculate(ConfigSet(ats), OutputSpec(), EMT(), output_prefix="_auto_", autopara_info=AutoparaInfo(num_python_subprocesses=2))
60+
co = generic.calculate(ConfigSet(ats), OutputSpec(), EMT(), output_prefix="_auto_",
61+
autopara_info=AutoparaInfo(num_python_subprocesses=2,
62+
num_inputs_per_python_subprocess=30))
4963
dt_2 = time.time() - t0
5064

5165
print("time ratio", dt_2 / dt_1)
52-
assert dt_2 < dt_1 * (2/3)
66+
assert dt_2 / dt_1 < 0.75
67+
68+
69+
@pytest.mark.skipif(torch is None or not torch.cuda.is_available() or os.environ.get("WFL_TORCH_N_GPUS") is None, reason="No torch CUDA devices available, or WFL_TORCH_N_GPUS isn't set")
70+
@pytest.mark.perf
71+
def test_pool_speedup_GPU(monkeypatch):
72+
np.random.seed(5)
73+
74+
rng = np.random.default_rng(5)
75+
ats = []
76+
nconf = 60
77+
at_prim = Atoms('Al', cell=[1, 1, 1], pbc=[True] * 3)
78+
for _ in range(nconf):
79+
ats.append(at_prim * (5, 5, 5))
80+
ats[-1].rattle(rng=rng)
81+
82+
calc = (mace_mp, ["small-omat-0"], {"device": "cuda"})
83+
84+
req_n_gpus = os.environ["WFL_TORCH_N_GPUS"]
85+
if len(req_n_gpus) == 0:
86+
req_n_gpus = str(len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
87+
88+
if "WFL_TORCH_N_GPUS" in os.environ:
89+
monkeypatch.delenv("WFL_TORCH_N_GPUS")
90+
91+
t0 = time.time()
92+
co = generic.calculate(ConfigSet(ats), OutputSpec(), calc, output_prefix="_auto_",
93+
autopara_info=AutoparaInfo(num_python_subprocesses=1,
94+
num_inputs_per_python_subprocess=30))
95+
dt_1 = time.time() - t0
96+
97+
monkeypatch.setenv("WFL_TORCH_N_GPUS", req_n_gpus)
98+
99+
t0 = time.time()
100+
co = generic.calculate(ConfigSet(ats), OutputSpec(), calc, output_prefix="_auto_",
101+
autopara_info=AutoparaInfo(num_python_subprocesses=2,
102+
num_inputs_per_python_subprocess=30))
103+
dt_2 = time.time() - t0
104+
105+
monkeypatch.delenv("WFL_TORCH_N_GPUS")
106+
107+
print("time ratio", dt_2 / dt_1)
108+
assert dt_2 / dt_1 < 0.75
109+
53110

54111
def test_outputspec_overwrite(tmp_path):
55112
with open(tmp_path / "ats.xyz", "w") as fout:

tests/test_clean_dir.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from wfl.calculators.utils import clean_rundir
2+
3+
# def clean_rundir(rundir, keep_files, default_keep_files, calculation_succeeded):
4+
5+
all_files = ["a", "aa", "b", "c", "d"]
6+
default_keep_files = ["a*", "b"]
7+
actual_default_keep_files = ["a", "aa", "b"]
8+
9+
def create_files(dir):
10+
for filename in all_files:
11+
with open(dir / filename, "w") as fout:
12+
fout.write("content\n")
13+
14+
def check_dir(dir, files):
15+
if files is None:
16+
# even path doesn't exist
17+
assert not dir.is_dir()
18+
return
19+
20+
files = set(files)
21+
22+
# all expected files are present
23+
for file in files:
24+
assert (dir / file).is_file()
25+
# all present files are expected
26+
for file in dir.iterdir():
27+
assert file.name in files
28+
29+
def test_clean_rundir(tmp_path):
30+
# keep True
31+
# keep all files regardless of success
32+
for succ, files in [(True, all_files), (False, all_files)]:
33+
p = tmp_path / f"True_{succ}"
34+
p.mkdir()
35+
create_files(p)
36+
clean_rundir(p, True, default_keep_files, calculation_succeeded=succ)
37+
check_dir(p, files)
38+
39+
# keep False
40+
# succeeded means keep nothing, failed means keep default
41+
for succ, files in [(True, None), (False, actual_default_keep_files)]:
42+
p = tmp_path / f"False_{succ}"
43+
p.mkdir()
44+
create_files(p)
45+
clean_rundir(p, False, default_keep_files, calculation_succeeded=succ)
46+
check_dir(p, files)
47+
48+
# keep subset of default
49+
# succeeded means keep subset, failed means keep default
50+
for succ, files in [(True, ["a"]), (False, actual_default_keep_files)]:
51+
p = tmp_path / f"a_{succ}"
52+
p.mkdir()
53+
create_files(p)
54+
clean_rundir(p, ["a"], default_keep_files, calculation_succeeded=succ)
55+
check_dir(p, files)
56+
57+
# keep different set from default
58+
# succeeded means keep set, failed means keep union of default and set
59+
for succ, files in [(True, ["a", "c"]), (False, actual_default_keep_files + ["a", "c"])]:
60+
p = tmp_path / f"ac_{succ}"
61+
p.mkdir()
62+
create_files(p)
63+
clean_rundir(p, ["a", "c"], default_keep_files, calculation_succeeded=succ)
64+
check_dir(p, files)

tests/test_md.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from wfl.generate.md.abort import AbortOnCollision, AbortOnLowEnergy
2121

2222
try:
23-
from wif.Langevin_BAOAB import Langevin_BAOAB
23+
from ase.md.langevinbaoab import LangevinBAOAB
2424
except ImportError:
25-
Langevin_BAOAB = None
25+
LangevinBAOAB = None
2626

2727
def select_every_10_steps_for_tests_during(at):
2828
return at.info.get("MD_step", 1) % 10 == 0
@@ -153,14 +153,14 @@ def test_NPT_Berendsen(cu_slab):
153153
assert np.allclose(atoms_traj[0].cell, atoms_traj[-1].cell * cell_f)
154154

155155

156-
@pytest.mark.skipif(Langevin_BAOAB is None, reason="No Langevin_BAOAB available")
157-
def test_NPT_Langevin_BAOAB(cu_slab):
156+
@pytest.mark.skipif(LangevinBAOAB is None, reason="No LangevinBAOAB available")
157+
def test_NPT_LangevinBAOAB(cu_slab):
158158
calc = EMT()
159159

160160
inputs = ConfigSet(cu_slab)
161161
outputs = OutputSpec()
162162

163-
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin_BAOAB", steps=300, dt=1.0,
163+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="LangevinBAOAB", steps=300, dt=1.0,
164164
temperature=500.0, temperature_tau=100/fs, pressure=0.0,
165165
rng=np.random.default_rng(1))
166166

@@ -176,14 +176,14 @@ def test_NPT_Langevin_BAOAB(cu_slab):
176176
assert np.allclose(atoms_traj[0].cell, atoms_traj[-1].cell * cell_f)
177177

178178

179-
@pytest.mark.skipif(Langevin_BAOAB is None, reason="No Langevin_BAOAB available")
180-
def test_NPT_Langevin_BAOAB_hydro_F(cu_slab):
179+
@pytest.mark.skipif(LangevinBAOAB is None, reason="No LangevinBAOAB available")
180+
def test_NPT_LangevinBAOAB_hydro_F(cu_slab):
181181
calc = EMT()
182182

183183
inputs = ConfigSet(cu_slab)
184184
outputs = OutputSpec()
185185

186-
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="Langevin_BAOAB", steps=300, dt=1.0,
186+
atoms_traj = md.md(inputs, outputs, calculator=calc, integrator="LangevinBAOAB", steps=300, dt=1.0,
187187
temperature=500.0, temperature_tau=100/fs, pressure=0.0, hydrostatic=False,
188188
rng=np.random.default_rng(1))
189189

tests/test_phonopy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
from wfl.configset import ConfigSet, OutputSpec
99
from wfl.generate.phonopy import phonopy
10+
try:
11+
import phono3py
12+
except ImportError:
13+
phono3py = None
1014

1115

1216
def test_phonopy(tmp_path):
@@ -33,6 +37,7 @@ def test_phonopy(tmp_path):
3337
for v in at.positions[1:]:
3438
assert min(np.linalg.norm(sc.positions[1:] - v, axis=1)) < 1.0e-7
3539

40+
@pytest.mark.skipif(phono3py is None, reason="No phono3py module")
3641
def test_phono3py(tmp_path):
3742
at0 = Atoms(numbers=[29], cell = [[0, 2, 2], [2, 0, 2], [2, 2, 0]], positions = [[0, 0, 0]], pbc = [True]*3)
3843
at1 = Atoms(numbers=[29], cell = [[0, 1.9, 1.9], [1.9, 0, 1.9], [1.9, 1.9, 0]], positions = [[0, 0, 0]], pbc = [True]*3)
@@ -62,6 +67,7 @@ def test_phono3py(tmp_path):
6267
assert sum([at.info["config_type"] == "phonon_cubic_1" for at in pert]) == 13*2
6368

6469

70+
@pytest.mark.skipif(phono3py is None, reason="No phono3py module")
6571
def test_phono3py_same_supercell(tmp_path):
6672
at0 = Atoms(numbers=[29], cell = [[0, 2, 2], [2, 0, 2], [2, 2, 0]], positions = [[0, 0, 0]], pbc = [True]*3)
6773
at1 = Atoms(numbers=[29], cell = [[0, 1.9, 1.9], [1.9, 0, 1.9], [1.9, 1.9, 0]], positions = [[0, 0, 0]], pbc = [True]*3)

tests/test_remote_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_vasp_fail(tmp_path, expyre_systems, monkeypatch, remoteinfo_env):
6969

7070
def do_vasp_fail(tmp_path, sys_name, monkeypatch, remoteinfo_env):
7171
ri = {'sys_name': sys_name, 'job_name': 'pytest_vasp_'+sys_name,
72-
'env_vars' : ['ASE_VASP_COMMAND=NONE', 'ASE_VASP_COMMAND_GAMMA=NONE'],
72+
'env_vars' : ['ASE_VASP_COMMAND=NO_VASP_FAIL', 'ASE_VASP_COMMAND_GAMMA=NO_VASP_FAIL'],
7373
'input_files' : ['POTCARs'],
7474
'resources': {'max_time': '5m', 'num_nodes': 1},
7575
'num_inputs_per_queued_job': 1, 'check_interval': 10}

wfl/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.4"
1+
__version__ = "0.3.5"

0 commit comments

Comments
 (0)