From 9eb78666207843272e07abd0f27c5e1b630230fa Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 10:49:56 +0100 Subject: [PATCH 01/11] parallel_assert now takes boolean argument and added basic tests --- .github/workflows/ci_pipeline.yml | 36 +++++++++++++++++ etc/test_environment_mpich.yml | 8 ++++ pytest_mpi/parallel_assert.py | 29 ++++++++++---- tests/test_parallel_assert.py | 64 +++++++++++++++++++++++++++++++ 4 files changed, 129 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/ci_pipeline.yml create mode 100644 etc/test_environment_mpich.yml create mode 100644 tests/test_parallel_assert.py diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml new file mode 100644 index 0000000..64b8117 --- /dev/null +++ b/.github/workflows/ci_pipeline.yml @@ -0,0 +1,36 @@ +--- + +name: CI pipeline for mpi-pytest + +on: + push: + pull_request: + schedule: + - cron: '1 5 * * 1' + +jobs: + + tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python: ['3.9', '3.10', '3.11', '3.12', '3.13'] + defaults: + run: + shell: bash -l {0} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Conda environment with Micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: "etc/test_environment_mpich.yml" + create-args: >- + python=${{ matrix.python }} + - name: Install mpi-pytest as a package in the current environment + run: | + pip install --no-deps -e . + - name: Run tests + run: | + pytest --continue-on-collection-errors -v tests diff --git a/etc/test_environment_mpich.yml b/etc/test_environment_mpich.yml new file mode 100644 index 0000000..60eb0b5 --- /dev/null +++ b/etc/test_environment_mpich.yml @@ -0,0 +1,8 @@ +--- + +name: MPICH_test_env +channels: + - conda-forge +dependencies: + - pytest + - mpich diff --git a/pytest_mpi/parallel_assert.py b/pytest_mpi/parallel_assert.py index b89339c..9952496 100644 --- a/pytest_mpi/parallel_assert.py +++ b/pytest_mpi/parallel_assert.py @@ -1,16 +1,16 @@ -from collections.abc import Callable +import warnings from mpi4py import MPI -def parallel_assert(assertion: Callable, participating: bool = True, msg: str = "") -> None: +def parallel_assert(assertion: bool, participating: bool = True, msg: str = "") -> None: """Make an assertion across ``COMM_WORLD``. Parameters ---------- assertion : - Callable that will be tested for truthyness (usually evaluates some assertion). - This should be the same across all ranks. + If this is `False` on any participating task, an `AssertionError` will + be raised. participating : Whether the given rank should evaluate the assertion. msg : @@ -19,24 +19,37 @@ def parallel_assert(assertion: Callable, participating: bool = True, msg: str = Notes ----- It is very important that ``parallel_assert`` is called collectively on all - ranks simulataneously. + ranks simultaneously. + This function allows passing a callable instead of a boolean for the + `assertion` argument. This is useful in rare circumstances, such as when + the assertion is not defined on all tasks, but is not recommended. + Example ------- Where in serial code one would have previously written: ```python x = f() - assert x < 5 + assert x < 5, "x is too large" ``` Now write: ```python x = f() - parallel_assert(lambda: x < 5) + parallel_assert(x < 5, "x is too large") ``` """ - result = assertion() if participating else True + if participating: + if callable(assertion): + warnings.warn('Passing callables to parallel_assert is no longer recommended.' + 'Please pass booleans instead.') + result = assertion() + else: + result = assertion + else: + result = True + all_results = MPI.COMM_WORLD.allgather(result) if not min(all_results): raise AssertionError( diff --git a/tests/test_parallel_assert.py b/tests/test_parallel_assert.py new file mode 100644 index 0000000..02a6b96 --- /dev/null +++ b/tests/test_parallel_assert.py @@ -0,0 +1,64 @@ +import pytest +from pytest_mpi.parallel_assert import parallel_assert + + +@pytest.mark.parametrize('expression', [True, False]) +def test_parallel_assert_equivalent_to_assert_in_serial(expression): + raised_exception = True + + try: + parallel_assert(expression) + raised_exception = False + except AssertionError: + try: + assert expression + except AssertionError: + pass + + if not raised_exception: + assert expression + + +@pytest.mark.parallel([1, 2, 3]) +def test_parallel_assert_all_tasks(): + from mpi4py import MPI + comm = MPI.COMM_WORLD + expression = comm.rank < comm.size // 2 + raised_exception = False + + try: + parallel_assert(expression) + except AssertionError: + raised_exception = True + + assert raised_exception, f'No exception raised on rank {comm.rank}!' + + +@pytest.mark.parallel([1, 2, 3]) +def test_parallel_assert_participating_tasks_only(): + from mpi4py import MPI + comm = MPI.COMM_WORLD + expression = comm.rank < comm.size // 2 + raised_exception = False + + try: + parallel_assert(expression, participating=expression) + except AssertionError: + raised_exception = True + + assert not raised_exception, f'Exception raised on rank {comm.rank}!' + + +@pytest.mark.parallel([1, 2, 3]) +def test_legacy_parallel_assert(): + from mpi4py import MPI + comm = MPI.COMM_WORLD + expression = comm.rank < comm.size // 2 + raised_exception = False + + try: + parallel_assert(lambda: expression, participating=expression) + except AssertionError: + raised_exception = True + + assert not raised_exception, f'Exception raised on rank {comm.rank}!' From cc3326cb44a21429aa4a29950bf751490a64ada3 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 10:54:25 +0100 Subject: [PATCH 02/11] Moved test environment file --- {etc => .github/etc}/test_environment_mpich.yml | 0 .github/workflows/ci_pipeline.yml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename {etc => .github/etc}/test_environment_mpich.yml (100%) diff --git a/etc/test_environment_mpich.yml b/.github/etc/test_environment_mpich.yml similarity index 100% rename from etc/test_environment_mpich.yml rename to .github/etc/test_environment_mpich.yml diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml index 64b8117..68efd89 100644 --- a/.github/workflows/ci_pipeline.yml +++ b/.github/workflows/ci_pipeline.yml @@ -25,7 +25,7 @@ jobs: - name: Install Conda environment with Micromamba uses: mamba-org/setup-micromamba@v1 with: - environment-file: "etc/test_environment_mpich.yml" + environment-file: ".github/etc/test_environment_mpich.yml" create-args: >- python=${{ matrix.python }} - name: Install mpi-pytest as a package in the current environment From f771aff69918b3a0d64c590966a23fee03b22a0b Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 10:55:45 +0100 Subject: [PATCH 03/11] Added missing mpi4py dependency to test environment --- .github/etc/test_environment_mpich.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/etc/test_environment_mpich.yml b/.github/etc/test_environment_mpich.yml index 60eb0b5..4d4cf10 100644 --- a/.github/etc/test_environment_mpich.yml +++ b/.github/etc/test_environment_mpich.yml @@ -6,3 +6,4 @@ channels: dependencies: - pytest - mpich + - mpi4py From c4edcf81f024039ec75a479f40e213ad8b408df3 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:07:12 +0100 Subject: [PATCH 04/11] Forgot to adapt README to new syntax --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7093049..7995d86 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ from pytest_mpi import parallel_assert @pytest.mark.parallel(2) def test_something(): # this will fail on *all* ranks - parallel_assert(lambda: COMM_WORLD.rank == 0) + parallel_assert(COMM_WORLD.rank == 0) ... ``` From 8c7e0fb36a913b5719556d080cc51d093eedf5e2 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:39:17 +0100 Subject: [PATCH 05/11] Update tests/test_parallel_assert.py Co-authored-by: Connor Ward --- tests/test_parallel_assert.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_parallel_assert.py b/tests/test_parallel_assert.py index 02a6b96..a736eb7 100644 --- a/tests/test_parallel_assert.py +++ b/tests/test_parallel_assert.py @@ -4,19 +4,18 @@ @pytest.mark.parametrize('expression', [True, False]) def test_parallel_assert_equivalent_to_assert_in_serial(expression): - raised_exception = True - try: parallel_assert(expression) - raised_exception = False + parallel_raised_exception = False except AssertionError: - try: - assert expression - except AssertionError: - pass - - if not raised_exception: + parallel_raised_exception = True + try: assert expression + serial_raised_exception = False + except AssertionError: + serial_raised_exception = True + + assert serial_raised_exception == parallel_raised_exception @pytest.mark.parallel([1, 2, 3]) From 4cc1914d1a127dbd8b27b8eac24ac46ad57d1cfa Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:39:44 +0100 Subject: [PATCH 06/11] Update tests/test_parallel_assert.py Co-authored-by: Connor Ward --- tests/test_parallel_assert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_parallel_assert.py b/tests/test_parallel_assert.py index a736eb7..eb61aaa 100644 --- a/tests/test_parallel_assert.py +++ b/tests/test_parallel_assert.py @@ -23,10 +23,10 @@ def test_parallel_assert_all_tasks(): from mpi4py import MPI comm = MPI.COMM_WORLD expression = comm.rank < comm.size // 2 - raised_exception = False try: parallel_assert(expression) + raised_exception = False except AssertionError: raised_exception = True From 3612e7171ba72aaccb0f3c37b4e16562e42f0cf7 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:40:40 +0100 Subject: [PATCH 07/11] Update pytest_mpi/parallel_assert.py Co-authored-by: Connor Ward --- pytest_mpi/parallel_assert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytest_mpi/parallel_assert.py b/pytest_mpi/parallel_assert.py index 9952496..af0495d 100644 --- a/pytest_mpi/parallel_assert.py +++ b/pytest_mpi/parallel_assert.py @@ -9,8 +9,8 @@ def parallel_assert(assertion: bool, participating: bool = True, msg: str = "") Parameters ---------- assertion : - If this is `False` on any participating task, an `AssertionError` will - be raised. + The assertion to check. If this is `False` on any participating task, an + `AssertionError` will be raised. participating : Whether the given rank should evaluate the assertion. msg : From dac946d47ee55e1ac0f703d87f70ea1da8f577a2 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:45:09 +0100 Subject: [PATCH 08/11] Update pytest_mpi/parallel_assert.py Co-authored-by: Connor Ward --- pytest_mpi/parallel_assert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytest_mpi/parallel_assert.py b/pytest_mpi/parallel_assert.py index af0495d..4c2605e 100644 --- a/pytest_mpi/parallel_assert.py +++ b/pytest_mpi/parallel_assert.py @@ -42,8 +42,8 @@ def parallel_assert(assertion: bool, participating: bool = True, msg: str = "") """ if participating: if callable(assertion): - warnings.warn('Passing callables to parallel_assert is no longer recommended.' - 'Please pass booleans instead.') + warnings.warn("Passing callables to parallel_assert is no longer recommended." + "Please pass booleans instead.", FutureWarning) result = assertion() else: result = assertion From a83d0d7b72f9134efbe66e396836722f962a3869 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:49:36 +0100 Subject: [PATCH 09/11] Implemented @connorjward's comments --- pytest_mpi/parallel_assert.py | 7 ++----- tests/test_parallel_assert.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pytest_mpi/parallel_assert.py b/pytest_mpi/parallel_assert.py index af0495d..408556d 100644 --- a/pytest_mpi/parallel_assert.py +++ b/pytest_mpi/parallel_assert.py @@ -3,14 +3,14 @@ from mpi4py import MPI -def parallel_assert(assertion: bool, participating: bool = True, msg: str = "") -> None: +def parallel_assert(assertion: bool, msg: str = "", *, participating: bool = True) -> None: """Make an assertion across ``COMM_WORLD``. Parameters ---------- assertion : The assertion to check. If this is `False` on any participating task, an - `AssertionError` will be raised. + `AssertionError` will be raised. This argument can also be callable. participating : Whether the given rank should evaluate the assertion. msg : @@ -20,9 +20,6 @@ def parallel_assert(assertion: bool, participating: bool = True, msg: str = "") ----- It is very important that ``parallel_assert`` is called collectively on all ranks simultaneously. - This function allows passing a callable instead of a boolean for the - `assertion` argument. This is useful in rare circumstances, such as when - the assertion is not defined on all tasks, but is not recommended. Example diff --git a/tests/test_parallel_assert.py b/tests/test_parallel_assert.py index eb61aaa..a194504 100644 --- a/tests/test_parallel_assert.py +++ b/tests/test_parallel_assert.py @@ -1,4 +1,5 @@ import pytest +from mpi4py import MPI from pytest_mpi.parallel_assert import parallel_assert @@ -20,12 +21,11 @@ def test_parallel_assert_equivalent_to_assert_in_serial(expression): @pytest.mark.parallel([1, 2, 3]) def test_parallel_assert_all_tasks(): - from mpi4py import MPI comm = MPI.COMM_WORLD - expression = comm.rank < comm.size // 2 + expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others try: - parallel_assert(expression) + parallel_assert(expression, 'Failed') raised_exception = False except AssertionError: raised_exception = True @@ -35,13 +35,12 @@ def test_parallel_assert_all_tasks(): @pytest.mark.parallel([1, 2, 3]) def test_parallel_assert_participating_tasks_only(): - from mpi4py import MPI comm = MPI.COMM_WORLD - expression = comm.rank < comm.size // 2 - raised_exception = False + expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others try: parallel_assert(expression, participating=expression) + raised_exception = False except AssertionError: raised_exception = True @@ -50,13 +49,12 @@ def test_parallel_assert_participating_tasks_only(): @pytest.mark.parallel([1, 2, 3]) def test_legacy_parallel_assert(): - from mpi4py import MPI comm = MPI.COMM_WORLD - expression = comm.rank < comm.size // 2 - raised_exception = False + expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others try: parallel_assert(lambda: expression, participating=expression) + raised_exception = False except AssertionError: raised_exception = True From 13553f7cb157b905ebb4f4eedb566d4bf88c93b7 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 12:09:43 +0100 Subject: [PATCH 10/11] Fixed documentation --- pytest_mpi/parallel_assert.py | 12 +++++++----- tests/test_parallel_assert.py | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pytest_mpi/parallel_assert.py b/pytest_mpi/parallel_assert.py index 43ff46b..f566854 100644 --- a/pytest_mpi/parallel_assert.py +++ b/pytest_mpi/parallel_assert.py @@ -10,11 +10,12 @@ def parallel_assert(assertion: bool, msg: str = "", *, participating: bool = Tru ---------- assertion : The assertion to check. If this is `False` on any participating task, an - `AssertionError` will be raised. This argument can also be callable. - participating : - Whether the given rank should evaluate the assertion. + `AssertionError` will be raised. This argument can also be a callable + that returns a `bool` (deprecated). msg : Optional error message to print out on failure. + participating : + Whether the given rank should evaluate the assertion. Notes ----- @@ -39,8 +40,9 @@ def parallel_assert(assertion: bool, msg: str = "", *, participating: bool = Tru """ if participating: if callable(assertion): - warnings.warn("Passing callables to parallel_assert is no longer recommended." - "Please pass booleans instead.", FutureWarning) + warnings.warn("Passing callables to parallel_assert is no longer" + "recommended. Please pass booleans instead.", + FutureWarning) result = assertion() else: result = assertion diff --git a/tests/test_parallel_assert.py b/tests/test_parallel_assert.py index a194504..26dad17 100644 --- a/tests/test_parallel_assert.py +++ b/tests/test_parallel_assert.py @@ -51,9 +51,11 @@ def test_parallel_assert_participating_tasks_only(): def test_legacy_parallel_assert(): comm = MPI.COMM_WORLD expression = comm.rank < comm.size // 2 # will be True on some tasks but False on others + if expression: + local_expression = expression # This variable is undefined on non-participating tasks try: - parallel_assert(lambda: expression, participating=expression) + parallel_assert(lambda: local_expression, participating=expression) raised_exception = False except AssertionError: raised_exception = True From cb0ca7fb9acc312de7b78336aba024b64aba9bcf Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 21 Feb 2025 12:15:50 +0100 Subject: [PATCH 11/11] Added CI badge to README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 7995d86..82c7c8a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![badge-ga](https://github.com/firedrakeproject/mpi-pytest/actions/workflows/ci_pipeline.yml/badge.svg?branch=master)](https://github.com/firedrakeproject/mpi-pytest/actions/workflows/ci_pipeline.yml) + # mpi-pytest Pytest plugin that lets you run tests in parallel with MPI.