Skip to content

Commit

Permalink
Parallel netCDF testing
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Dec 18, 2024
1 parent 5e885cd commit ba92b94
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
python -m pip install -e .
python -m pip install \
pytest-timeout pytest-xdist
- name: Gusto tests (new netcdf)
- name: Gusto tests
run: |
. /home/firedrake/firedrake/bin/activate
firedrake-clean
Expand All @@ -51,7 +51,7 @@ jobs:
-o faulthandler_timeout=3660 \
-v unit-tests integration-tests examples
timeout-minutes: 120
- name: Gusto tests (old netcdf)
- name: Test serial netCDF
run: |
. /home/firedrake/firedrake/bin/activate
python -m pip uninstall -y netCDF4
Expand All @@ -66,7 +66,7 @@ jobs:
--timeout=3600 \
--timeout-method=thread \
-o faulthandler_timeout=3660 \
-v unit-tests integration-tests examples
-v integration-tests/models/test_nc_outputting.py
timeout-minutes: 120
- name: Prepare logs
if: always()
Expand Down
35 changes: 29 additions & 6 deletions integration-tests/model/test_nc_outputting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@
ForwardEuler, OutputParameters, XComponent, YComponent,
ZComponent, MeridionalComponent, ZonalComponent,
RadialComponent, DGUpwind)
from mpi4py import MPI
from netCDF4 import Dataset, chartostring
import pytest
from pytest_mpi import parallel_assert


def make_dirname(test_name):
if MPI.COMM_WORLD.size > 1:
return f'pytest_{test_name}_parallel'
else:
return f'pytest_{test_name}'


@pytest.fixture
Expand Down Expand Up @@ -53,17 +62,19 @@ def domain_and_mesh_details(geometry):
return (domain, mesh_details)


# TODO: make parallel configurations of this test
@pytest.mark.parallel([1, 2])
@pytest.mark.parametrize("geometry", ["interval", "vertical_slice",
"plane", "extruded_plane",
"spherical_shell", "extruded_spherical_shell"])
def test_nc_outputting(tmpdir, geometry, domain_and_mesh_details):
def test_nc_outputting(geometry, domain_and_mesh_details):

# ------------------------------------------------------------------------ #
# Make model objects
# ------------------------------------------------------------------------ #

dirname = str(tmpdir)
# Make sure all ranks use the same file
dirname = make_dirname("nc_outputting")

domain, mesh_details = domain_and_mesh_details
V = domain.spaces('DG')
if geometry == "interval":
Expand Down Expand Up @@ -136,7 +147,15 @@ def test_nc_outputting(tmpdir, geometry, domain_and_mesh_details):
# ------------------------------------------------------------------------ #

# Check that metadata is correct
output_data = Dataset(f'{dirname}/field_output.nc', 'r')
try:
output_data = Dataset(f'results/{dirname}/field_output.nc', 'r', parallel=True)
except ValueError:
# serial netCDF4, do everything on rank 0
if MPI.COMM_WORLD.rank == 0:
output_data = Dataset(f'results/{dirname}/field_output.nc', 'r', parallel=False)
else:
output_data = None

for metadata_key, metadata_value in mesh_details.items():
# Convert None or booleans to strings
if metadata_value is None or isinstance(metadata_value, bool):
Expand All @@ -146,6 +165,10 @@ def test_nc_outputting(tmpdir, geometry, domain_and_mesh_details):

error_message = f'Metadata {metadata_key} for geometry {geometry} is incorrect'
if type(output_value) == float:
assert output_data[metadata_key][0] - output_value < 1e-14, error_message
def assertion():
return output_data[metadata_key][0] - output_value < 1e-14
else:
assert str(chartostring(output_data[metadata_key][0])) == output_value, error_message
def assertion():
return str(chartostring(output_data[metadata_key][0])) == output_value

parallel_assert(assertion, participating=output_data is not None, msg=error_message)

0 comments on commit ba92b94

Please sign in to comment.