Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing to boolean argument in parallel_assert #11

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions .github/etc/test_environment_mpich.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---

name: MPICH_test_env
channels:
- conda-forge
dependencies:
- pytest
- mpich
- mpi4py
36 changes: 36 additions & 0 deletions .github/workflows/ci_pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---

name: CI pipeline for mpi-pytest

on:
push:
pull_request:
schedule:
- cron: '1 5 * * 1'

jobs:

tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.10', '3.11', '3.12', '3.13']
defaults:
run:
shell: bash -l {0}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Conda environment with Micromamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: ".github/etc/test_environment_mpich.yml"
create-args: >-
python=${{ matrix.python }}
- name: Install mpi-pytest as a package in the current environment
run: |
pip install --no-deps -e .
- name: Run tests
run: |
pytest --continue-on-collection-errors -v tests
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ from pytest_mpi import parallel_assert
@pytest.mark.parallel(2)
def test_something():
# this will fail on *all* ranks
parallel_assert(lambda: COMM_WORLD.rank == 0)
parallel_assert(COMM_WORLD.rank == 0)
...
```

Expand Down
29 changes: 21 additions & 8 deletions pytest_mpi/parallel_assert.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from collections.abc import Callable
import warnings

from mpi4py import MPI


def parallel_assert(assertion: Callable, participating: bool = True, msg: str = "") -> None:
def parallel_assert(assertion: bool, participating: bool = True, msg: str = "") -> None:
"""Make an assertion across ``COMM_WORLD``.

Parameters
----------
assertion :
Callable that will be tested for truthyness (usually evaluates some assertion).
This should be the same across all ranks.
If this is `False` on any participating task, an `AssertionError` will
be raised.
participating :
Whether the given rank should evaluate the assertion.
msg :
Expand All @@ -19,24 +19,37 @@ def parallel_assert(assertion: Callable, participating: bool = True, msg: str =
Notes
-----
It is very important that ``parallel_assert`` is called collectively on all
ranks simulataneously.
ranks simultaneously.
This function allows passing a callable instead of a boolean for the
`assertion` argument. This is useful in rare circumstances, such as when
the assertion is not defined on all tasks, but is not recommended.


Example
-------
Where in serial code one would have previously written:
```python
x = f()
assert x < 5
assert x < 5, "x is too large"
```

Now write:
```python
x = f()
parallel_assert(lambda: x < 5)
parallel_assert(x < 5, "x is too large")
```

"""
result = assertion() if participating else True
if participating:
if callable(assertion):
warnings.warn('Passing callables to parallel_assert is no longer recommended.'
'Please pass booleans instead.')
result = assertion()
else:
result = assertion
else:
result = True

all_results = MPI.COMM_WORLD.allgather(result)
if not min(all_results):
raise AssertionError(
Expand Down
64 changes: 64 additions & 0 deletions tests/test_parallel_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
from pytest_mpi.parallel_assert import parallel_assert


@pytest.mark.parametrize('expression', [True, False])
def test_parallel_assert_equivalent_to_assert_in_serial(expression):
raised_exception = True

try:
parallel_assert(expression)
raised_exception = False
except AssertionError:
try:
assert expression
except AssertionError:
pass

if not raised_exception:
assert expression


@pytest.mark.parallel([1, 2, 3])
def test_parallel_assert_all_tasks():
from mpi4py import MPI
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2
raised_exception = False

try:
parallel_assert(expression)
except AssertionError:
raised_exception = True

assert raised_exception, f'No exception raised on rank {comm.rank}!'


@pytest.mark.parallel([1, 2, 3])
def test_parallel_assert_participating_tasks_only():
from mpi4py import MPI
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2
raised_exception = False

try:
parallel_assert(expression, participating=expression)
except AssertionError:
raised_exception = True

assert not raised_exception, f'Exception raised on rank {comm.rank}!'


@pytest.mark.parallel([1, 2, 3])
def test_legacy_parallel_assert():
from mpi4py import MPI
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2
raised_exception = False

try:
parallel_assert(lambda: expression, participating=expression)
except AssertionError:
raised_exception = True

assert not raised_exception, f'Exception raised on rank {comm.rank}!'