Skip to content

Commit 1a2f51b

Browse files
committed
Fix autopara GPU test
1 parent 0c8222e commit 1a2f51b

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

tests/local_scripts/complete_pytest.tin

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ module load compiler/gnu python python_extras/structure python_extras/quippy lap
77
module load python_extras/wif
88
module load python_extras/torch/cpu
99

10+
if [ ! -z "$WFL_PYTEST_POST_MODULE_COMMANDS" ]; then
11+
echo "Using WFL_PYTEST_POST_MODULE_COMMANDS '$WFL_PYTEST_POST_MODULE_COMMANDS'" 1>&2
12+
eval $WFL_PYTEST_POST_MODULE_COMMANDS
13+
else
14+
echo "Using no WFL_PYTEST_POST_MODULE_COMMANDS" 1>&2
15+
fi
16+
1017
if [ -z "$WFL_PYTEST_EXPYRE_INFO" ]; then
1118
echo "To override partition used, set WFL_PYTEST_EXPYRE_INFO='{\"resources\" : {\"partitions\": \"DESIRED_PARTITION\"}}'" 1>&2
1219
fi
@@ -22,11 +29,13 @@ print(json.dumps(i))
2229
EOF
2330
)
2431
export WFL_PYTEST_EXPYRE_INFO
25-
echo Using WFL_PYTEST_EXPYRE_INFO \'$WFL_PYTEST_EXPYRE_INFO\' 1>&2
32+
echo "Using WFL_PYTEST_EXPYRE_INFO '$WFL_PYTEST_EXPYRE_INFO'" 1>&2
2633

2734
if [ ! -z $WFL_PYTHONPATH_EXTRA ]; then
2835
echo "Adding WFL_PYTHONPATH_EXTRA '$WFL_PYTHONPATH_EXTRA'" 1>&2
2936
export PYTHONPATH=${WFL_PYTHONPATH_EXTRA}:${PYTHONPATH}
37+
else
38+
echo "Adding no WFL_PYTHONPATH_EXTRA" 1>&2
3039
fi
3140

3241
export JULIA_PROJECT=${PWD}/tests/assets/julia
@@ -51,6 +60,8 @@ export PYTEST_VASP_POTCAR_DIR=$VASP_PATH/pot/rev_54/PBE
5160
module load dft/pwscf
5261
# no ORCA
5362

63+
module list
64+
5465
export OPENBLAS_NUM_THREADS=1
5566
export MKL_NUM_THREADS=1
5667
# required for descriptor calc to not hang

tests/test_autoparallelize.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_pool_speedup():
6666
assert dt_2 / dt_1 < 0.75
6767

6868

69-
@pytest.mark.skipif(torch is None or not torch.cuda.is_available(), reason="No torch CUDA devices available")
69+
@pytest.mark.skipif(torch is None or not torch.cuda.is_available() or os.environ.get("WFL_TORCH_N_GPUS") is None, reason="No torch CUDA devices available, or WFL_TORCH_N_GPUS isn't set")
7070
@pytest.mark.perf
7171
def test_pool_speedup_GPU(monkeypatch):
7272
np.random.seed(5)
@@ -81,20 +81,29 @@ def test_pool_speedup_GPU(monkeypatch):
8181

8282
calc = (mace_mp, ["small-omat-0"], {"device": "cuda"})
8383

84+
req_n_gpus = os.environ["WFL_TORCH_N_GPUS"]
85+
if len(req_n_gpus) == 0:
86+
req_n_gpus = str(len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
87+
88+
if "WFL_TORCH_N_GPUS" in os.environ:
89+
monkeypatch.delenv("WFL_TORCH_N_GPUS")
90+
8491
t0 = time.time()
8592
co = generic.calculate(ConfigSet(ats), OutputSpec(), calc, output_prefix="_auto_",
8693
autopara_info=AutoparaInfo(num_python_subprocesses=1,
8794
num_inputs_per_python_subprocess=30))
8895
dt_1 = time.time() - t0
8996

97+
monkeypatch.setenv("WFL_TORCH_N_GPUS", req_n_gpus)
98+
9099
t0 = time.time()
91-
monkeypatch.setenv("WFL_TORCH_N_GPUS", str(len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))))
92100
co = generic.calculate(ConfigSet(ats), OutputSpec(), calc, output_prefix="_auto_",
93101
autopara_info=AutoparaInfo(num_python_subprocesses=2,
94102
num_inputs_per_python_subprocess=30))
95-
monkeypatch.delenv("WFL_TORCH_N_GPUS")
96103
dt_2 = time.time() - t0
97104

105+
monkeypatch.delenv("WFL_TORCH_N_GPUS")
106+
98107
print("time ratio", dt_2 / dt_1)
99108
assert dt_2 / dt_1 < 0.75
100109

wfl/autoparallelize/pool.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,19 @@
1717
# pass
1818

1919
# https://docs.pytorch.org/docs/stable/notes/multiprocessing.html#poison-fork-in-multiprocessing
20+
# But, only use forkserver if needed because it has a lot of overhead
21+
try:
22+
import torch
23+
except:
24+
torch = None
2025
if os.environ.get("WFL_TORCH_N_GPUS") is not None:
26+
if not torch:
27+
raise RuntimeError(f"Got WFL_TORCH_N_GPUS '{WFL_TORCH_N_GPUS}' but torch module is not available")
2128
try:
22-
import torch
2329
import multiprocessing
2430
multiprocessing.set_start_method('forkserver')
25-
except (ImportError, RuntimeError) as exc:
31+
except RuntimeError as exc:
32+
# ignore complains about setting start method more than once
2633
pass
2734
from multiprocessing.pool import Pool
2835
# to help keep track of distinct GPU for each python subprocess

0 commit comments

Comments
 (0)