Skip to content

Commit 9cb6526

Browse files
connorjwardqbisi
andauthored
OpenMPI forking mode support (#21)
* pytest-plugin: support openmpi * Add parallel marker tests --------- Co-authored-by: qbisi <[email protected]>
1 parent 98d7609 commit 9cb6526

File tree

5 files changed

+115
-49
lines changed

5 files changed

+115
-49
lines changed

.github/workflows/ci.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
schedule:
9+
- cron: '1 5 * * 1'
10+
11+
jobs:
12+
tests:
13+
runs-on: ubuntu-latest
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
python: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
18+
mpi: ['mpich', 'openmpi']
19+
defaults:
20+
run:
21+
shell: bash -l {0}
22+
env:
23+
PYTEST_MPI_MAX_NPROCS: 3
24+
steps:
25+
- uses: actions/checkout@v4
26+
27+
- name: Install Conda environment with Micromamba
28+
uses: mamba-org/setup-micromamba@v1
29+
with:
30+
environment-file: ".github/etc/test_environment_${{ matrix.mpi }}.yml"
31+
create-args: >-
32+
python=${{ matrix.python }}
33+
34+
- name: Install mpi-pytest
35+
run: pip install --no-deps -e .
36+
37+
- name: Run tests (MPICH)
38+
if: matrix.mpi == 'mpich'
39+
run: |
40+
: # 'forking' mode
41+
pytest -v tests
42+
: # 'non-forking' mode
43+
mpiexec -n 1 pytest -v -m "not parallel or parallel[1]" tests
44+
mpiexec -n 2 pytest -v -m parallel[2] tests
45+
mpiexec -n 3 pytest -v -m parallel[3] tests
46+
47+
- name: Run tests (OpenMPI)
48+
if: matrix.mpi == 'openmpi'
49+
run: |
50+
: # 'forking' mode
51+
pytest -v tests
52+
: # 'non-forking' mode
53+
mpiexec --oversubscribe -n 1 pytest -v -m "not parallel or parallel[1]" tests
54+
mpiexec --oversubscribe -n 2 pytest -v -m parallel[2] tests
55+
mpiexec --oversubscribe -n 3 pytest -v -m parallel[3] tests

.github/workflows/ci_pipeline.yml

Lines changed: 0 additions & 44 deletions
This file was deleted.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
55
[project]
66
name = "mpi-pytest"
77
# <year.month.patch>
8-
version = "2025.4.0"
8+
version = "2025.5.0.dev0"
99
dependencies = ["mpi4py", "pytest"]
1010
authors = [
1111
{ name="Connor Ward", email="[email protected]" },
@@ -14,7 +14,7 @@ authors = [
1414
description = "A pytest plugin for executing tests in parallel with MPI"
1515
readme = "README.md"
1616
license = { file = "LICENSE" }
17-
requires-python = ">=3.7"
17+
requires-python = ">=3.8"
1818
classifiers = [
1919
"Programming Language :: Python :: 3",
2020
"Framework :: Pytest",

pytest_mpi/plugin.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import enum
23
import numbers
34
import os
45
import subprocess
@@ -230,9 +231,14 @@ def _set_parallel_callback(item):
230231
"--disable-warnings", "--show-capture=no"
231232
]
232233

233-
cmd = [
234-
"mpiexec", "-n", "1", "-genv", CHILD_PROCESS_FLAG, "1", *executable
235-
] + pytest_args + [
234+
impl = detect_mpi_implementation()
235+
if impl == MPIImplementation.OPENMPI:
236+
cmd = ["mpiexec", "-n", "1", "-x", f"{CHILD_PROCESS_FLAG}=1", *executable]
237+
else:
238+
assert impl == MPIImplementation.MPICH
239+
cmd = ["mpiexec", "-n", "1", "-genv", CHILD_PROCESS_FLAG, "1", *executable]
240+
241+
cmd += pytest_args + [
236242
":", "-n", f"{nprocs-1}", *executable
237243
] + quieter_pytest_args
238244

@@ -286,3 +292,35 @@ def _parse_marker_nprocs(marker):
286292

287293
def _as_tuple(arg):
288294
return tuple(arg) if isinstance(arg, collections.abc.Iterable) else (arg,)
295+
296+
297+
class MPIImplementation(enum.Enum):
298+
OPENMPI = enum.auto()
299+
MPICH = enum.auto()
300+
301+
302+
def detect_mpi_implementation() -> MPIImplementation:
303+
try:
304+
result = subprocess.run(
305+
["mpiexec", "--version"],
306+
stdout=subprocess.PIPE,
307+
stderr=subprocess.STDOUT,
308+
text=True,
309+
check=True
310+
)
311+
except FileNotFoundError:
312+
raise FileNotFoundError(
313+
"'mpiexec' not found on your PATH, please run in non-forking mode "
314+
"where you can specify a different MPI executable"
315+
)
316+
317+
output = result.stdout.lower()
318+
if "open mpi" in output or "open-rte" in output:
319+
return MPIImplementation.OPENMPI
320+
elif "mpich" in output:
321+
return MPIImplementation.MPICH
322+
else:
323+
raise RuntimeError(
324+
"MPI distribution is not recognised, please run in non-forking "
325+
"mode where you can specify your MPI executable"
326+
)

tests/test_parallel_marker.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
from mpi4py import MPI
3+
4+
5+
@pytest.mark.parallel
6+
def test_parallel_marker_no_args():
7+
assert MPI.COMM_WORLD.size == 3
8+
9+
10+
@pytest.mark.parallel(2)
11+
def test_parallel_marker_with_int():
12+
assert MPI.COMM_WORLD.size == 2
13+
14+
15+
@pytest.mark.parallel([2, 3])
16+
def test_parallel_marker_with_list():
17+
assert MPI.COMM_WORLD.size in {2, 3}

0 commit comments

Comments
 (0)