Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/etc/test_environment_mpich.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---

name: MPICH_test_env
channels:
- conda-forge
dependencies:
- pytest
- mpich
- mpi4py
36 changes: 36 additions & 0 deletions .github/workflows/ci_pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---

name: CI pipeline for mpi-pytest

on:
push:
pull_request:
schedule:
- cron: '1 5 * * 1'

jobs:

tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.10', '3.11', '3.12', '3.13']
defaults:
run:
shell: bash -l {0}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Conda environment with Micromamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: ".github/etc/test_environment_mpich.yml"
create-args: >-
python=${{ matrix.python }}
- name: Install mpi-pytest as a package in the current environment
run: |
pip install --no-deps -e .
- name: Run tests
run: |
pytest --continue-on-collection-errors -v tests
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ from pytest_mpi import parallel_assert
@pytest.mark.parallel(2)
def test_something():
# this will fail on *all* ranks
parallel_assert(lambda: COMM_WORLD.rank == 0)
parallel_assert(COMM_WORLD.rank == 0)
...
```

Expand Down
29 changes: 21 additions & 8 deletions pytest_mpi/parallel_assert.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from collections.abc import Callable
import warnings

from mpi4py import MPI


def parallel_assert(assertion: Callable, participating: bool = True, msg: str = "") -> None:
def parallel_assert(assertion: bool, participating: bool = True, msg: str = "") -> None:
"""Make an assertion across ``COMM_WORLD``.

Parameters
----------
assertion :
Callable that will be tested for truthyness (usually evaluates some assertion).
This should be the same across all ranks.
If this is `False` on any participating task, an `AssertionError` will
be raised.
participating :
Whether the given rank should evaluate the assertion.
msg :
Expand All @@ -19,24 +19,37 @@ def parallel_assert(assertion: Callable, participating: bool = True, msg: str =
Notes
-----
It is very important that ``parallel_assert`` is called collectively on all
ranks simulataneously.
ranks simultaneously.
This function allows passing a callable instead of a boolean for the
`assertion` argument. This is useful in rare circumstances, such as when
the assertion is not defined on all tasks, but is not recommended.


Example
-------
Where in serial code one would have previously written:
```python
x = f()
assert x < 5
assert x < 5, "x is too large"
```

Now write:
```python
x = f()
parallel_assert(lambda: x < 5)
parallel_assert(x < 5, "x is too large")
```

"""
result = assertion() if participating else True
if participating:
if callable(assertion):
warnings.warn('Passing callables to parallel_assert is no longer recommended.'
'Please pass booleans instead.')
result = assertion()
else:
result = assertion
else:
result = True

all_results = MPI.COMM_WORLD.allgather(result)
if not min(all_results):
raise AssertionError(
Expand Down
64 changes: 64 additions & 0 deletions tests/test_parallel_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
from pytest_mpi.parallel_assert import parallel_assert


@pytest.mark.parametrize('expression', [True, False])
def test_parallel_assert_equivalent_to_assert_in_serial(expression):
raised_exception = True

try:
parallel_assert(expression)
raised_exception = False
except AssertionError:
try:
assert expression
except AssertionError:
pass

if not raised_exception:
assert expression


@pytest.mark.parallel([1, 2, 3])
def test_parallel_assert_all_tasks():
from mpi4py import MPI
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2
raised_exception = False

try:
parallel_assert(expression)
except AssertionError:
raised_exception = True

assert raised_exception, f'No exception raised on rank {comm.rank}!'


@pytest.mark.parallel([1, 2, 3])
def test_parallel_assert_participating_tasks_only():
from mpi4py import MPI
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2
raised_exception = False

try:
parallel_assert(expression, participating=expression)
except AssertionError:
raised_exception = True

assert not raised_exception, f'Exception raised on rank {comm.rank}!'


@pytest.mark.parallel([1, 2, 3])
def test_legacy_parallel_assert():
from mpi4py import MPI
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2
raised_exception = False

try:
parallel_assert(lambda: expression, participating=expression)
except AssertionError:
raised_exception = True

assert not raised_exception, f'Exception raised on rank {comm.rank}!'