Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing to boolean argument in parallel_assert #11

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions .github/etc/test_environment_mpich.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---

name: MPICH_test_env
channels:
- conda-forge
dependencies:
- pytest
- mpich
- mpi4py
36 changes: 36 additions & 0 deletions .github/workflows/ci_pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
---

name: CI pipeline for mpi-pytest

on:
push:
pull_request:
schedule:
- cron: '1 5 * * 1'

jobs:

tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python: ['3.9', '3.10', '3.11', '3.12', '3.13']
defaults:
run:
shell: bash -l {0}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Conda environment with Micromamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: ".github/etc/test_environment_mpich.yml"
create-args: >-
python=${{ matrix.python }}
- name: Install mpi-pytest as a package in the current environment
run: |
pip install --no-deps -e .
- name: Run tests
run: |
pytest --continue-on-collection-errors -v tests
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ from pytest_mpi import parallel_assert
@pytest.mark.parallel(2)
def test_something():
# this will fail on *all* ranks
parallel_assert(lambda: COMM_WORLD.rank == 0)
parallel_assert(COMM_WORLD.rank == 0)
...
```

Expand Down
26 changes: 18 additions & 8 deletions pytest_mpi/parallel_assert.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from collections.abc import Callable
import warnings

from mpi4py import MPI


def parallel_assert(assertion: Callable, participating: bool = True, msg: str = "") -> None:
def parallel_assert(assertion: bool, msg: str = "", *, participating: bool = True) -> None:
"""Make an assertion across ``COMM_WORLD``.

Parameters
----------
assertion :
Callable that will be tested for truthyness (usually evaluates some assertion).
This should be the same across all ranks.
The assertion to check. If this is `False` on any participating task, an
`AssertionError` will be raised. This argument can also be callable.
participating :
Whether the given rank should evaluate the assertion.
msg :
Expand All @@ -19,24 +19,34 @@ def parallel_assert(assertion: Callable, participating: bool = True, msg: str =
Notes
-----
It is very important that ``parallel_assert`` is called collectively on all
ranks simulataneously.
ranks simultaneously.


Example
-------
Where in serial code one would have previously written:
```python
x = f()
assert x < 5
assert x < 5, "x is too large"
```

Now write:
```python
x = f()
parallel_assert(lambda: x < 5)
parallel_assert(x < 5, "x is too large")
```

"""
result = assertion() if participating else True
if participating:
if callable(assertion):
warnings.warn("Passing callables to parallel_assert is no longer recommended."
"Please pass booleans instead.", FutureWarning)
result = assertion()
else:
result = assertion
else:
result = True

all_results = MPI.COMM_WORLD.allgather(result)
if not min(all_results):
raise AssertionError(
Expand Down
61 changes: 61 additions & 0 deletions tests/test_parallel_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
from mpi4py import MPI
from pytest_mpi.parallel_assert import parallel_assert


@pytest.mark.parametrize('expression', [True, False])
def test_parallel_assert_equivalent_to_assert_in_serial(expression):
try:
parallel_assert(expression)
parallel_raised_exception = False
except AssertionError:
parallel_raised_exception = True
try:
assert expression
serial_raised_exception = False
except AssertionError:
serial_raised_exception = True

assert serial_raised_exception == parallel_raised_exception


@pytest.mark.parallel([1, 2, 3])
def test_parallel_assert_all_tasks():
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others

try:
parallel_assert(expression, 'Failed')
raised_exception = False
except AssertionError:
raised_exception = True

assert raised_exception, f'No exception raised on rank {comm.rank}!'


@pytest.mark.parallel([1, 2, 3])
def test_parallel_assert_participating_tasks_only():
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others

try:
parallel_assert(expression, participating=expression)
raised_exception = False
except AssertionError:
raised_exception = True

assert not raised_exception, f'Exception raised on rank {comm.rank}!'


@pytest.mark.parallel([1, 2, 3])
def test_legacy_parallel_assert():
comm = MPI.COMM_WORLD
expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others

try:
parallel_assert(lambda: expression, participating=expression)
raised_exception = False
except AssertionError:
raised_exception = True

assert not raised_exception, f'Exception raised on rank {comm.rank}!'