Skip to content

Commit b94617b

Browse files
Changing to boolean argument in parallel_assert (#11)
* parallel_assert now takes boolean argument and added basic tests * Add some tests * Update README to new syntax * Added CI badge to README --------- Co-authored-by: Connor Ward <[email protected]>
1 parent 48f3092 commit b94617b

File tree

5 files changed

+133
-11
lines changed

5 files changed

+133
-11
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
3+
name: MPICH_test_env
4+
channels:
5+
- conda-forge
6+
dependencies:
7+
- pytest
8+
- mpich
9+
- mpi4py

.github/workflows/ci_pipeline.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
---
2+
3+
name: CI pipeline for mpi-pytest
4+
5+
on:
6+
push:
7+
pull_request:
8+
schedule:
9+
- cron: '1 5 * * 1'
10+
11+
jobs:
12+
13+
tests:
14+
runs-on: ubuntu-latest
15+
strategy:
16+
fail-fast: false
17+
matrix:
18+
python: ['3.9', '3.10', '3.11', '3.12', '3.13']
19+
defaults:
20+
run:
21+
shell: bash -l {0}
22+
steps:
23+
- name: Checkout
24+
uses: actions/checkout@v4
25+
- name: Install Conda environment with Micromamba
26+
uses: mamba-org/setup-micromamba@v1
27+
with:
28+
environment-file: ".github/etc/test_environment_mpich.yml"
29+
create-args: >-
30+
python=${{ matrix.python }}
31+
- name: Install mpi-pytest as a package in the current environment
32+
run: |
33+
pip install --no-deps -e .
34+
- name: Run tests
35+
run: |
36+
pytest --continue-on-collection-errors -v tests

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
[![badge-ga](https://github.com/firedrakeproject/mpi-pytest/actions/workflows/ci_pipeline.yml/badge.svg?branch=master)](https://github.com/firedrakeproject/mpi-pytest/actions/workflows/ci_pipeline.yml)
2+
13
# mpi-pytest
24

35
Pytest plugin that lets you run tests in parallel with MPI.
@@ -129,7 +131,7 @@ from pytest_mpi import parallel_assert
129131
@pytest.mark.parallel(2)
130132
def test_something():
131133
# this will fail on *all* ranks
132-
parallel_assert(lambda: COMM_WORLD.rank == 0)
134+
parallel_assert(COMM_WORLD.rank == 0)
133135
...
134136
```
135137

pytest_mpi/parallel_assert.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,54 @@
1-
from collections.abc import Callable
1+
import warnings
22

33
from mpi4py import MPI
44

55

6-
def parallel_assert(assertion: Callable, participating: bool = True, msg: str = "") -> None:
6+
def parallel_assert(assertion: bool, msg: str = "", *, participating: bool = True) -> None:
77
"""Make an assertion across ``COMM_WORLD``.
88
99
Parameters
1010
----------
1111
assertion :
12-
Callable that will be tested for truthyness (usually evaluates some assertion).
13-
This should be the same across all ranks.
14-
participating :
15-
Whether the given rank should evaluate the assertion.
12+
The assertion to check. If this is `False` on any participating task, an
13+
`AssertionError` will be raised. This argument can also be a callable
14+
that returns a `bool` (deprecated).
1615
msg :
1716
Optional error message to print out on failure.
17+
participating :
18+
Whether the given rank should evaluate the assertion.
1819
1920
Notes
2021
-----
2122
It is very important that ``parallel_assert`` is called collectively on all
22-
ranks simulataneously.
23+
ranks simultaneously.
24+
2325
2426
Example
2527
-------
2628
Where in serial code one would have previously written:
2729
```python
2830
x = f()
29-
assert x < 5
31+
assert x < 5, "x is too large"
3032
```
3133
3234
Now write:
3335
```python
3436
x = f()
35-
parallel_assert(lambda: x < 5)
37+
parallel_assert(x < 5, "x is too large")
3638
```
3739
3840
"""
39-
result = assertion() if participating else True
41+
if participating:
42+
if callable(assertion):
43+
warnings.warn("Passing callables to parallel_assert is no longer"
44+
"recommended. Please pass booleans instead.",
45+
FutureWarning)
46+
result = assertion()
47+
else:
48+
result = assertion
49+
else:
50+
result = True
51+
4052
all_results = MPI.COMM_WORLD.allgather(result)
4153
if not min(all_results):
4254
raise AssertionError(

tests/test_parallel_assert.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import pytest
2+
from mpi4py import MPI
3+
from pytest_mpi.parallel_assert import parallel_assert
4+
5+
6+
@pytest.mark.parametrize('expression', [True, False])
7+
def test_parallel_assert_equivalent_to_assert_in_serial(expression):
8+
try:
9+
parallel_assert(expression)
10+
parallel_raised_exception = False
11+
except AssertionError:
12+
parallel_raised_exception = True
13+
try:
14+
assert expression
15+
serial_raised_exception = False
16+
except AssertionError:
17+
serial_raised_exception = True
18+
19+
assert serial_raised_exception == parallel_raised_exception
20+
21+
22+
@pytest.mark.parallel([1, 2, 3])
23+
def test_parallel_assert_all_tasks():
24+
comm = MPI.COMM_WORLD
25+
expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others
26+
27+
try:
28+
parallel_assert(expression, 'Failed')
29+
raised_exception = False
30+
except AssertionError:
31+
raised_exception = True
32+
33+
assert raised_exception, f'No exception raised on rank {comm.rank}!'
34+
35+
36+
@pytest.mark.parallel([1, 2, 3])
37+
def test_parallel_assert_participating_tasks_only():
38+
comm = MPI.COMM_WORLD
39+
expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others
40+
41+
try:
42+
parallel_assert(expression, participating=expression)
43+
raised_exception = False
44+
except AssertionError:
45+
raised_exception = True
46+
47+
assert not raised_exception, f'Exception raised on rank {comm.rank}!'
48+
49+
50+
@pytest.mark.parallel([1, 2, 3])
51+
def test_legacy_parallel_assert():
52+
comm = MPI.COMM_WORLD
53+
expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others
54+
if expression:
55+
local_expression = expression # This variable is undefined on non-participating tasks
56+
57+
try:
58+
parallel_assert(lambda: local_expression, participating=expression)
59+
raised_exception = False
60+
except AssertionError:
61+
raised_exception = True
62+
63+
assert not raised_exception, f'Exception raised on rank {comm.rank}!'

0 commit comments

Comments
 (0)