From b83f8828d3ecaf16f45a93f38a53c32a797f496c Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 14 Feb 2024 12:43:02 -0500 Subject: [PATCH 01/12] Adding and updating __init__ files for modules --- ndsl/__init__.py | 6 ++++++ ndsl/buffer.py | 3 ++- ndsl/comm/__init__.py | 15 +++++++++++++++ ndsl/dsl/__init__.py | 5 +++-- ndsl/dsl/caches/__init__.py | 1 + ndsl/dsl/dace/__init__.py | 4 ++-- ndsl/halo/__init__.py | 6 ++++++ ndsl/performance/__init__.py | 7 +++++++ ndsl/stencils/__init__.py | 4 ++++ ndsl/stencils/testing/__init__.py | 1 + tests/checkpointer/__init__.py | 0 tests/dsl/__init__.py | 2 ++ tests/mpi/__init__.py | 0 tests/mpi/test_mpi_halo_update.py | 3 ++- tests/mpi/test_mpi_mock.py | 3 ++- tests/quantity/__init__.py | 0 16 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 ndsl/dsl/caches/__init__.py create mode 100644 tests/checkpointer/__init__.py create mode 100644 tests/dsl/__init__.py create mode 100644 tests/mpi/__init__.py create mode 100644 tests/quantity/__init__.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index e023b273..89d35b84 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1 +1,7 @@ +from .constants import ConstantVersions +from .exceptions import OutOfBoundsError from .logging import ndsl_log +from .optional_imports import RaiseWhenAccessed +from .types import Allocator, AsyncRequest, NumpyModule +from .units import UnitsError +from .utils import MetaEnumStr diff --git a/ndsl/buffer.py b/ndsl/buffer.py index 05cd6434..bc829e79 100644 --- a/ndsl/buffer.py +++ b/ndsl/buffer.py @@ -4,7 +4,6 @@ import numpy as np from numpy.lib.index_tricks import IndexExpression -from ndsl.performance.timer import NullTimer, Timer from ndsl.types import Allocator from ndsl.utils import ( device_synchronize, @@ -13,6 +12,8 @@ safe_mpi_allocate, ) +from .performance.timer import NullTimer, Timer + BufferKey = Tuple[Callable, Iterable[int], type] BUFFER_CACHE: Dict[BufferKey, List["Buffer"]] = {} diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index e69de29b..92675c01 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -0,0 +1,15 @@ +from .boundary import SimpleBoundary +from .caching_comm import ( + CachingCommData, + CachingCommReader, + CachingCommWriter, + CachingRequestReader, + CachingRequestWriter, + NullRequest, +) +from .comm_abc import Comm, Request +from .communicator import CubedSphereCommunicator, TileCommunicator +from .local_comm import AsyncResult, ConcurrencyError, LocalComm +from .mpi import MPIComm +from .null_comm import NullAsyncResult, NullComm +from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 5dd32b25..c013460e 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -6,12 +6,13 @@ from .dace.dace_config import DaceConfig, DaCeOrchestration from .dace.orchestration import orchestrate, orchestrate_function from .stencil import ( - CompilationConfig, + CompareToNumpyStencil, FrozenStencil, GridIndexing, - StencilConfig, StencilFactory, + TimingCollector, ) +from .stencil_config import CompilationConfig, RunMode, StencilConfig if MPI is not None: diff --git a/ndsl/dsl/caches/__init__.py b/ndsl/dsl/caches/__init__.py new file mode 100644 index 00000000..3417ff03 --- /dev/null +++ b/ndsl/dsl/caches/__init__.py @@ -0,0 +1 @@ +from .codepath import FV3CodePath \ No newline at end of file diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index bcae0c46..d2c64b06 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,2 +1,2 @@ -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.dace.orchestration import orchestrate +from .dace_config import DaceConfig +from .orchestration import orchestrate diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index e69de29b..b32a6937 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -0,0 +1,6 @@ +from .data_transformer import ( + HaloDataTransformerCPU, + HaloDataTransformerGPU, + HaloExchangeSpec, +) +from .updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index 28e03bc6..fd79608f 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1,2 +1,9 @@ +from .collector import ( + AbstractPerformanceCollector, + NullPerformanceCollector, + PerformanceCollector, +) from .config import PerformanceConfig +from .profiler import NullProfiler, Profiler +from .report import Experiment, Report, TimeReport from .timer import NullTimer, Timer diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index d3ec452c..641e032a 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1 +1,5 @@ +from .c2l_ord import CubedToLatLon +from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid + + __version__ = "0.2.0" diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index d676e871..d66176c6 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -4,6 +4,7 @@ ParallelTranslate2Py, ParallelTranslate2PyState, ParallelTranslateBaseSlicing, + ParallelTranslateGrid, ) from .savepoint import SavepointCase, Translate, dataset_to_dict from .temporaries import assert_same_temporaries, copy_temporaries diff --git a/tests/checkpointer/__init__.py b/tests/checkpointer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py new file mode 100644 index 00000000..10487c18 --- /dev/null +++ b/tests/dsl/__init__.py @@ -0,0 +1,2 @@ +from .test_stencil_wrapper import MockFieldInfo +from .test_caches import OrchestratedProgam \ No newline at end of file diff --git a/tests/mpi/__init__.py b/tests/mpi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 4d5133d4..7343bf8d 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -1,7 +1,6 @@ import copy import pytest -from mpi_comm import MPI from ndsl.comm._boundary_utils import get_boundary_slice from ndsl.comm.communicator import CubedSphereCommunicator @@ -25,6 +24,8 @@ ) from ndsl.quantity import Quantity +from .mpi_comm import MPI + @pytest.fixture def dtype(numpy): diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index c9d3d610..4d4d24a0 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,10 +1,11 @@ import numpy as np import pytest -from mpi_comm import MPI from ndsl.comm.communicator import recv_buffer from ndsl.testing import ConcurrencyError, DummyComm +from .mpi_comm import MPI + worker_function_list = [] diff --git a/tests/quantity/__init__.py b/tests/quantity/__init__.py new file mode 100644 index 00000000..e69de29b From 2b1cb3ae3d0a21b972a5adaa8472e5c47d149318 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Mon, 26 Feb 2024 15:45:32 -0500 Subject: [PATCH 02/12] Changes suggested from PR and updates to what is exposed --- examples/mpi/global_timings.py | 2 +- examples/mpi/zarr_monitor.py | 11 ++++--- ndsl/__init__.py | 34 ++++++++++++++++++--- ndsl/buffer.py | 3 +- ndsl/checkpointer/__init__.py | 9 ------ ndsl/checkpointer/null.py | 2 +- ndsl/checkpointer/snapshots.py | 3 +- ndsl/checkpointer/thresholds.py | 4 +-- ndsl/checkpointer/validation.py | 10 ++++-- ndsl/comm/__init__.py | 16 ++-------- ndsl/comm/boundary.py | 4 +-- ndsl/dsl/__init__.py | 11 ++----- ndsl/dsl/caches/__init__.py | 2 +- ndsl/dsl/dace/__init__.py | 11 +++++-- ndsl/dsl/stencil.py | 2 +- ndsl/grid/__init__.py | 10 ------ ndsl/grid/generation.py | 25 ++++++++------- ndsl/grid/geometry.py | 5 ++- ndsl/grid/global_setup.py | 7 ++--- ndsl/grid/helper.py | 3 +- ndsl/halo/__init__.py | 8 ++--- ndsl/initialization/__init__.py | 2 +- ndsl/initialization/allocator.py | 8 ++--- ndsl/monitor/__init__.py | 1 - ndsl/monitor/convert.py | 2 +- ndsl/monitor/netcdf_monitor.py | 9 +++--- ndsl/performance/__init__.py | 8 ----- ndsl/performance/collector.py | 3 +- ndsl/performance/config.py | 5 ++- ndsl/stencils/__init__.py | 4 --- ndsl/stencils/c2l_ord.py | 2 +- ndsl/stencils/testing/__init__.py | 16 ++-------- ndsl/stencils/testing/conftest.py | 3 +- ndsl/stencils/testing/parallel_translate.py | 6 ++-- ndsl/stencils/testing/savepoint.py | 2 +- tests/checkpointer/test_snapshot.py | 2 +- tests/checkpointer/test_thresholds.py | 2 +- tests/checkpointer/test_validation.py | 7 +++-- tests/dsl/__init__.py | 4 +-- tests/dsl/test_caches.py | 9 +++--- tests/dsl/test_compilation_config.py | 10 ++++-- tests/dsl/test_dace_config.py | 10 ++---- tests/dsl/test_skip_passes.py | 6 ++-- tests/dsl/test_stencil.py | 7 +---- tests/dsl/test_stencil_config.py | 3 +- tests/dsl/test_stencil_factory.py | 13 ++++---- tests/dsl/test_stencil_wrapper.py | 17 +++++++---- tests/mpi/test_mpi_halo_update.py | 12 +++++--- tests/mpi/test_mpi_mock.py | 3 +- tests/quantity/test_boundary.py | 2 +- tests/quantity/test_deepcopy.py | 2 +- tests/quantity/test_quantity.py | 2 +- tests/quantity/test_storage.py | 2 +- tests/quantity/test_transpose.py | 2 +- tests/quantity/test_view.py | 2 +- tests/test_caching_comm.py | 16 ++++++---- tests/test_cube_scatter_gather.py | 13 +++++--- tests/test_decomposition.py | 5 ++- tests/test_dimension_sizer.py | 3 +- tests/test_g2g_communication.py | 13 +++++--- tests/test_halo_data_transformer.py | 3 +- tests/test_halo_update.py | 19 +++++++----- tests/test_halo_update_ranks.py | 13 +++++--- tests/test_legacy_restart.py | 11 ++++--- tests/test_local_comm.py | 2 +- tests/test_netcdf_monitor.py | 13 +++++--- tests/test_null_comm.py | 9 ++++-- tests/test_partitioner.py | 4 +-- tests/test_partitioner_boundaries.py | 7 ++--- tests/test_sync_shared_boundary.py | 13 +++++--- tests/test_tile_scatter.py | 5 +-- tests/test_tile_scatter_gather.py | 5 +-- tests/test_timer.py | 2 +- tests/test_zarr_monitor.py | 11 ++++--- 74 files changed, 264 insertions(+), 273 deletions(-) diff --git a/examples/mpi/global_timings.py b/examples/mpi/global_timings.py index 0921acd1..9e3ecdb1 100644 --- a/examples/mpi/global_timings.py +++ b/examples/mpi/global_timings.py @@ -3,7 +3,7 @@ import numpy as np from mpi4py import MPI -from ndsl.performance.timer import Timer +from ndsl import Timer @contextlib.contextmanager diff --git a/examples/mpi/zarr_monitor.py b/examples/mpi/zarr_monitor.py index 20d418f1..0c089af4 100644 --- a/examples/mpi/zarr_monitor.py +++ b/examples/mpi/zarr_monitor.py @@ -5,11 +5,14 @@ import zarr from mpi4py import MPI -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSpherePartitioner, + QuantityFactory, + SubtileGridSizer, + TilePartitioner, + ZarrMonitor, +) from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer -from ndsl.monitor import ZarrMonitor OUTPUT_PATH = "output/zarr_monitor.zarr" diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 89d35b84..06a77ac2 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,7 +1,31 @@ -from .constants import ConstantVersions +from .checkpointer import SnapshotCheckpointer +from .comm import ( + CachingCommReader, + CachingCommWriter, + ConcurrencyError, + CubedSphereCommunicator, + CubedSpherePartitioner, + LocalComm, + NullComm, + TileCommunicator, + TilePartitioner, +) +from .dsl import ( + CompareToNumpyStencil, + CompilationConfig, + DaceConfig, + DaCeOrchestration, + FrozenStencil, + GridIndexing, + RunMode, + StencilConfig, + StencilFactory, +) from .exceptions import OutOfBoundsError +from .halo import HaloDataTransformer, HaloExchangeSpec, HaloUpdater +from .initialization import QuantityFactory, SubtileGridSizer from .logging import ndsl_log -from .optional_imports import RaiseWhenAccessed -from .types import Allocator, AsyncRequest, NumpyModule -from .units import UnitsError -from .utils import MetaEnumStr +from .monitor import NetCDFMonitor, ZarrMonitor +from .performance import NullTimer, Timer +from .quantity import Quantity, QuantityHaloSpec +from .testing import DummyComm diff --git a/ndsl/buffer.py b/ndsl/buffer.py index bc829e79..05cd6434 100644 --- a/ndsl/buffer.py +++ b/ndsl/buffer.py @@ -4,6 +4,7 @@ import numpy as np from numpy.lib.index_tricks import IndexExpression +from ndsl.performance.timer import NullTimer, Timer from ndsl.types import Allocator from ndsl.utils import ( device_synchronize, @@ -12,8 +13,6 @@ safe_mpi_allocate, ) -from .performance.timer import NullTimer, Timer - BufferKey = Tuple[Callable, Iterable[int], type] BUFFER_CACHE: Dict[BufferKey, List["Buffer"]] = {} diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index a51a4d9e..46d32a6c 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1,10 +1 @@ -from .base import Checkpointer -from .null import NullCheckpointer from .snapshots import SnapshotCheckpointer -from .thresholds import ( - InsufficientTrialsError, - SavepointThresholds, - Threshold, - ThresholdCalibrationCheckpointer, -) -from .validation import ValidationCheckpointer diff --git a/ndsl/checkpointer/null.py b/ndsl/checkpointer/null.py index e707d589..fbc78755 100644 --- a/ndsl/checkpointer/null.py +++ b/ndsl/checkpointer/null.py @@ -1,4 +1,4 @@ -from .base import Checkpointer +from ndsl.checkpointer.base import Checkpointer class NullCheckpointer(Checkpointer): diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index 11b7b89d..aa806b21 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -2,11 +2,10 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.optional_imports import cupy as cp from ndsl.optional_imports import xarray as xr -from .base import Checkpointer - def make_dims(savepoint_dim, label, data_list): """ diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index 86133a81..ded73b39 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -5,8 +5,8 @@ import numpy as np -from ..quantity import Quantity -from .base import Checkpointer +from ndsl.checkpointer.base import Checkpointer +from ndsl.quantity import Quantity try: diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 360c6a1b..8af11317 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -5,11 +5,15 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer +from ndsl.checkpointer.thresholds import ( + ArrayLike, + SavepointName, + SavepointThresholds, + cast_to_ndarray, +) from ndsl.optional_imports import xarray as xr -from .base import Checkpointer -from .thresholds import ArrayLike, SavepointName, SavepointThresholds, cast_to_ndarray - def _clip_pace_array_to_target( array: np.ndarray, target_shape: Tuple[int, ...] diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 92675c01..0e86fe02 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -1,15 +1,5 @@ -from .boundary import SimpleBoundary -from .caching_comm import ( - CachingCommData, - CachingCommReader, - CachingCommWriter, - CachingRequestReader, - CachingRequestWriter, - NullRequest, -) -from .comm_abc import Comm, Request +from .caching_comm import CachingCommReader, CachingCommWriter from .communicator import CubedSphereCommunicator, TileCommunicator -from .local_comm import AsyncResult, ConcurrencyError, LocalComm -from .mpi import MPIComm -from .null_comm import NullAsyncResult, NullComm +from .local_comm import ConcurrencyError, LocalComm +from .null_comm import NullComm from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/comm/boundary.py b/ndsl/comm/boundary.py index 540f0256..020798c6 100644 --- a/ndsl/comm/boundary.py +++ b/ndsl/comm/boundary.py @@ -1,8 +1,8 @@ import dataclasses from typing import Tuple -from ..quantity import Quantity, QuantityHaloSpec -from ._boundary_utils import get_boundary_slice +from ndsl.comm._boundary_utils import get_boundary_slice +from ndsl.quantity import Quantity, QuantityHaloSpec @dataclasses.dataclass diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index c013460e..1331294e 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -3,15 +3,8 @@ from ndsl.comm.mpi import MPI from . import dace -from .dace.dace_config import DaceConfig, DaCeOrchestration -from .dace.orchestration import orchestrate, orchestrate_function -from .stencil import ( - CompareToNumpyStencil, - FrozenStencil, - GridIndexing, - StencilFactory, - TimingCollector, -) +from .dace import DaceConfig, DaCeOrchestration, orchestrate, orchestrate_function +from .stencil import CompareToNumpyStencil, FrozenStencil, GridIndexing, StencilFactory from .stencil_config import CompilationConfig, RunMode, StencilConfig diff --git a/ndsl/dsl/caches/__init__.py b/ndsl/dsl/caches/__init__.py index 3417ff03..4fbb20e9 100644 --- a/ndsl/dsl/caches/__init__.py +++ b/ndsl/dsl/caches/__init__.py @@ -1 +1 @@ -from .codepath import FV3CodePath \ No newline at end of file +from .codepath import FV3CodePath diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index d2c64b06..0f1edcbf 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,2 +1,9 @@ -from .dace_config import DaceConfig -from .orchestration import orchestrate +from .dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG +from .orchestration import ( + _LazyComputepathFunction, + _LazyComputepathMethod, + orchestrate, + orchestrate_function, +) +from .utils import ArrayReport, DaCeProgress, MaxBandwithBenchmarkProgram, StorageReport +from .wrapped_halo_exchange import WrappedHaloUpdater diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index a7a4f941..77efd672 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -31,7 +31,7 @@ from ndsl.dsl.dace.orchestration import SDFGConvertible from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from ndsl.dsl.typing import Float, Index3D, cast_to_index3d -from ndsl.initialization import GridSizer, SubtileGridSizer +from ndsl.initialization.sizer import GridSizer, SubtileGridSizer from ndsl.quantity import Quantity diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index 5e488743..a7692a8f 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -1,7 +1,6 @@ # flake8: noqa: F401 from .eta import set_hybrid_pressure_coefficients -from .generation import GridDefinitions, MetricTerms from .gnomonic import ( great_circle_distance_along_axis, great_circle_distance_lon_lat, @@ -11,13 +10,4 @@ xyz_midpoint, xyz_to_lon_lat, ) -from .helper import ( - AngleGridData, - ContravariantGridData, - DampingCoefficients, - DriverGridData, - GridData, - HorizontalGridData, - VerticalGridData, -) from .stretch_transformation import direct_transform diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index b38dbf2f..12275d7d 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -21,17 +21,7 @@ from ndsl.dsl.stencil import GridIndexing from ndsl.dsl.typing import Float from ndsl.grid import eta -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer -from ndsl.quantity import Quantity -from ndsl.stencils.corners import ( - fill_corners_2d, - fill_corners_agrid, - fill_corners_cgrid, - fill_corners_dgrid, -) - -from .geometry import ( +from ndsl.grid.geometry import ( calc_unit_vector_south, calc_unit_vector_west, calculate_divg_del6, @@ -47,7 +37,7 @@ supergrid_corner_fix, unit_vector_lonlat, ) -from .gnomonic import ( +from ndsl.grid.gnomonic import ( get_area, great_circle_distance_along_axis, local_gnomonic_ed, @@ -59,7 +49,16 @@ set_tile_border_dxc, set_tile_border_dyc, ) -from .mirror import mirror_grid +from ndsl.grid.mirror import mirror_grid +from ndsl.initialization.allocator import QuantityFactory +from ndsl.initialization.sizer import SubtileGridSizer +from ndsl.quantity import Quantity +from ndsl.stencils.corners import ( + fill_corners_2d, + fill_corners_agrid, + fill_corners_cgrid, + fill_corners_dgrid, +) # TODO: when every environment in python3.8, remove diff --git a/ndsl/grid/geometry.py b/ndsl/grid/geometry.py index 5b2ec028..804be0fe 100644 --- a/ndsl/grid/geometry.py +++ b/ndsl/grid/geometry.py @@ -1,7 +1,5 @@ from ndsl.comm.partitioner import TilePartitioner -from ndsl.quantity import Quantity - -from .gnomonic import ( +from ndsl.grid.gnomonic import ( get_lonlat_vect, get_unit_vector_direction, great_circle_distance_lon_lat, @@ -10,6 +8,7 @@ spherical_cos, xyz_midpoint, ) +from ndsl.quantity import Quantity def get_center_vector( diff --git a/ndsl/grid/global_setup.py b/ndsl/grid/global_setup.py index 46c0c902..a0237ec6 100644 --- a/ndsl/grid/global_setup.py +++ b/ndsl/grid/global_setup.py @@ -1,16 +1,15 @@ import math from ndsl.constants import PI, RADIUS - -from .generation import MetricTerms -from .gnomonic import ( +from ndsl.grid.generation import MetricTerms +from ndsl.grid.gnomonic import ( _cart_to_latlon, _check_shapes, _latlon2xyz, _mirror_latlon, symm_ed, ) -from .mirror import _rot_3d +from ndsl.grid.mirror import _rot_3d def gnomonic_grid(grid_type: int, lon, lat, np): diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index 89a8c0e7..ee97a6b0 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -13,11 +13,10 @@ import ndsl.constants as constants from ndsl.constants import Z_DIM, Z_INTERFACE_DIM from ndsl.filesystem import get_fs +from ndsl.grid.generation import MetricTerms from ndsl.initialization import QuantityFactory from ndsl.quantity import Quantity -from .generation import MetricTerms - @dataclasses.dataclass(frozen=True) class DampingCoefficients: diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index b32a6937..823bd226 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -1,6 +1,2 @@ -from .data_transformer import ( - HaloDataTransformerCPU, - HaloDataTransformerGPU, - HaloExchangeSpec, -) -from .updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from .data_transformer import HaloDataTransformer, HaloExchangeSpec +from .updater import HaloUpdater diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index fe15db8b..50fd2f84 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -1,2 +1,2 @@ from .allocator import QuantityFactory -from .sizer import GridSizer, SubtileGridSizer +from .sizer import SubtileGridSizer diff --git a/ndsl/initialization/allocator.py b/ndsl/initialization/allocator.py index cbbd78de..5320e4c6 100644 --- a/ndsl/initialization/allocator.py +++ b/ndsl/initialization/allocator.py @@ -2,10 +2,10 @@ import numpy as np -from ..constants import SPATIAL_DIMS -from ..optional_imports import gt4py -from ..quantity import Quantity, QuantityHaloSpec -from .sizer import GridSizer +from ndsl.constants import SPATIAL_DIMS +from ndsl.initialization.sizer import GridSizer +from ndsl.optional_imports import gt4py +from ndsl.quantity import Quantity, QuantityHaloSpec class StorageNumpy: diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index a0c7e036..26b38cc6 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -1,3 +1,2 @@ from .netcdf_monitor import NetCDFMonitor -from .protocol import Monitor from .zarr_monitor import ZarrMonitor diff --git a/ndsl/monitor/convert.py b/ndsl/monitor/convert.py index ad05b27d..a62af01a 100644 --- a/ndsl/monitor/convert.py +++ b/ndsl/monitor/convert.py @@ -1,6 +1,6 @@ import numpy as np -from ..optional_imports import cupy +from ndsl.optional_imports import cupy def to_numpy(array, dtype=None) -> np.ndarray: diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 3d950ae0..8a0b96fd 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -6,12 +6,11 @@ import numpy as np from ndsl.comm.communicator import Communicator +from ndsl.filesystem import get_fs +from ndsl.logging import ndsl_log +from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import xarray as xr - -from ..filesystem import get_fs -from ..logging import ndsl_log -from ..quantity import Quantity -from .convert import to_numpy +from ndsl.quantity import Quantity class _TimeChunkedVariable: diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index fd79608f..b0a65f32 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1,9 +1 @@ -from .collector import ( - AbstractPerformanceCollector, - NullPerformanceCollector, - PerformanceCollector, -) -from .config import PerformanceConfig -from .profiler import NullProfiler, Profiler -from .report import Experiment, Report, TimeReport from .timer import NullTimer, Timer diff --git a/ndsl/performance/collector.py b/ndsl/performance/collector.py index 4df04407..8ec7a817 100644 --- a/ndsl/performance/collector.py +++ b/ndsl/performance/collector.py @@ -11,6 +11,7 @@ from ndsl.performance.report import ( Report, TimeReport, + collect_data_and_write_to_file, collect_keys_from_data, gather_hit_counts, get_experiment_info, @@ -19,8 +20,6 @@ from ndsl.performance.timer import NullTimer, Timer from ndsl.utils import GPU_AVAILABLE -from .report import collect_data_and_write_to_file - class AbstractPerformanceCollector(Protocol): total_timer: Timer diff --git a/ndsl/performance/config.py b/ndsl/performance/config.py index fa1ce8ed..99e6109f 100644 --- a/ndsl/performance/config.py +++ b/ndsl/performance/config.py @@ -1,13 +1,12 @@ import dataclasses from ndsl.comm.comm_abc import Comm -from ndsl.performance.profiler import NullProfiler, Profiler - -from .collector import ( +from ndsl.performance.collector import ( AbstractPerformanceCollector, NullPerformanceCollector, PerformanceCollector, ) +from ndsl.performance.profiler import NullProfiler, Profiler @dataclasses.dataclass diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 641e032a..d3ec452c 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,5 +1 @@ -from .c2l_ord import CubedToLatLon -from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid - - __version__ = "0.2.0" diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index 4e18c1ff..67f2b5a1 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -13,7 +13,7 @@ from ndsl.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ -from ndsl.grid import GridData +from ndsl.grid.helper import GridData from ndsl.initialization.allocator import QuantityFactory diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index d66176c6..61a64831 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -1,16 +1,4 @@ from . import parallel_translate, translate -from .parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .savepoint import SavepointCase, Translate, dataset_to_dict +from .savepoint import dataset_to_dict from .temporaries import assert_same_temporaries, copy_temporaries -from .translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) +from .translate import pad_field_in_j, read_serialized_data diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 65ac8c03..d000e1fa 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -17,8 +17,9 @@ from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.namelist import Namelist -from ndsl.stencils.testing import ParallelTranslate, TranslateGrid +from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict +from ndsl.stencils.testing.translate import TranslateGrid @pytest.fixture() diff --git a/ndsl/stencils/testing/parallel_translate.py b/ndsl/stencils/testing/parallel_translate.py index 7481a41e..e0669994 100644 --- a/ndsl/stencils/testing/parallel_translate.py +++ b/ndsl/stencils/testing/parallel_translate.py @@ -8,8 +8,10 @@ from ndsl.constants import HORIZONTAL_DIMS, N_HALO_DEFAULT, X_DIMS, Y_DIMS from ndsl.dsl import gt4py_utils as utils from ndsl.quantity import Quantity - -from .translate import TranslateFortranData2Py, read_serialized_data +from ndsl.stencils.testing.translate import ( + TranslateFortranData2Py, + read_serialized_data, +) class ParallelTranslate: diff --git a/ndsl/stencils/testing/savepoint.py b/ndsl/stencils/testing/savepoint.py index 04d01e21..77d71917 100644 --- a/ndsl/stencils/testing/savepoint.py +++ b/ndsl/stencils/testing/savepoint.py @@ -5,7 +5,7 @@ import numpy as np import xarray as xr -from .grid import Grid # type: ignore +from ndsl.stencils.testing.grid import Grid # type: ignore def dataset_to_dict(ds: xr.Dataset) -> Dict[str, Union[np.ndarray, float, int]]: diff --git a/tests/checkpointer/test_snapshot.py b/tests/checkpointer/test_snapshot.py index 89d368ec..797b4701 100644 --- a/tests/checkpointer/test_snapshot.py +++ b/tests/checkpointer/test_snapshot.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer import SnapshotCheckpointer +from ndsl.checkpointer.snapshots import SnapshotCheckpointer from ndsl.optional_imports import xarray as xr diff --git a/tests/checkpointer/test_thresholds.py b/tests/checkpointer/test_thresholds.py index 8bf70b00..851a9665 100644 --- a/tests/checkpointer/test_thresholds.py +++ b/tests/checkpointer/test_thresholds.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer import ( +from ndsl.checkpointer.thresholds import ( InsufficientTrialsError, Threshold, ThresholdCalibrationCheckpointer, diff --git a/tests/checkpointer/test_validation.py b/tests/checkpointer/test_validation.py index 0c08d52b..b696aca3 100644 --- a/tests/checkpointer/test_validation.py +++ b/tests/checkpointer/test_validation.py @@ -4,8 +4,11 @@ import numpy as np import pytest -from ndsl.checkpointer import SavepointThresholds, Threshold, ValidationCheckpointer -from ndsl.checkpointer.validation import _clip_pace_array_to_target +from ndsl.checkpointer.thresholds import SavepointThresholds, Threshold +from ndsl.checkpointer.validation import ( + ValidationCheckpointer, + _clip_pace_array_to_target, +) from ndsl.optional_imports import xarray as xr diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py index 10487c18..d5af6cf1 100644 --- a/tests/dsl/__init__.py +++ b/tests/dsl/__init__.py @@ -1,2 +1,2 @@ -from .test_stencil_wrapper import MockFieldInfo -from .test_caches import OrchestratedProgam \ No newline at end of file +# from .test_caches import OrchestratedProgam +# from .test_stencil_wrapper import MockFieldInfo diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index a7218b05..f90a1828 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -2,15 +2,16 @@ from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval from gt4py.storage import empty, ones -from ndsl.comm.mpi import MPI -from ndsl.dsl.dace import orchestrate -from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration -from ndsl.dsl.stencil import ( +from ndsl import ( CompilationConfig, + DaceConfig, + DaCeOrchestration, GridIndexing, StencilConfig, StencilFactory, ) +from ndsl.comm.mpi import MPI +from ndsl.dsl.dace import orchestrate def _make_storage( diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 14b240a0..62049d91 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -3,9 +3,13 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner -from ndsl.dsl.stencil import CompilationConfig, RunMode +from ndsl import ( + CompilationConfig, + CubedSphereCommunicator, + CubedSpherePartitioner, + RunMode, + TilePartitioner, +) def test_safety_checks(): diff --git a/tests/dsl/test_dace_config.py b/tests/dsl/test_dace_config.py index 0aca1764..c044cb16 100644 --- a/tests/dsl/test_dace_config.py +++ b/tests/dsl/test_dace_config.py @@ -1,12 +1,8 @@ import unittest.mock -from ndsl.comm.communicator import CubedSpherePartitioner, TilePartitioner -from ndsl.dsl.dace.dace_config import DaceConfig, _determine_compiling_ranks -from ndsl.dsl.dace.orchestration import ( - DaCeOrchestration, - orchestrate, - orchestrate_function, -) +from ndsl import CubedSpherePartitioner, DaceConfig, DaCeOrchestration, TilePartitioner +from ndsl.dsl.dace.dace_config import _determine_compiling_ranks +from ndsl.dsl.dace.orchestration import orchestrate, orchestrate_function """ diff --git a/tests/dsl/test_skip_passes.py b/tests/dsl/test_skip_passes.py index c1f3a712..e0173b7b 100644 --- a/tests/dsl/test_skip_passes.py +++ b/tests/dsl/test_skip_passes.py @@ -7,14 +7,14 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline from gt4py.cartesian.gtscript import PARALLEL, computation, interval -from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.stencil import ( +from ndsl import ( CompilationConfig, + DaceConfig, GridIndexing, StencilConfig, StencilFactory, ) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import FloatField diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 18cb99af..180b7ba2 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -1,12 +1,7 @@ from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval from gt4py.storage import empty, ones -from ndsl.dsl.stencil import ( - CompilationConfig, - GridIndexing, - StencilConfig, - StencilFactory, -) +from ndsl import CompilationConfig, GridIndexing, StencilConfig, StencilFactory def _make_storage( diff --git a/tests/dsl/test_stencil_config.py b/tests/dsl/test_stencil_config.py index 45891df4..7e6b4da3 100644 --- a/tests/dsl/test_stencil_config.py +++ b/tests/dsl/test_stencil_config.py @@ -1,7 +1,6 @@ import pytest -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.stencil import CompilationConfig, StencilConfig +from ndsl import CompilationConfig, DaceConfig, StencilConfig @pytest.mark.parametrize("validate_args", [True, False]) diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index 364e5a3b..756de952 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -2,17 +2,18 @@ import pytest from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region -from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.dsl.dace.dace_config import DaceConfig -from ndsl.dsl.gt4py_utils import make_storage_from_shape -from ndsl.dsl.stencil import ( +from ndsl import ( CompareToNumpyStencil, + CompilationConfig, + DaceConfig, FrozenStencil, GridIndexing, + StencilConfig, StencilFactory, - get_stencils_with_varied_bounds, ) -from ndsl.dsl.stencil_config import CompilationConfig, StencilConfig +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.gt4py_utils import make_storage_from_shape +from ndsl.dsl.stencil import get_stencils_with_varied_bounds from ndsl.dsl.typing import FloatField diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index ba3da536..cfe56ded 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -6,12 +6,17 @@ import pytest from gt4py.cartesian.gtscript import PARALLEL, computation, interval -from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration +from ndsl import ( + CompilationConfig, + DaceConfig, + DaCeOrchestration, + FrozenStencil, + Quantity, + StencilConfig, +) from ndsl.dsl.gt4py_utils import make_storage_from_shape -from ndsl.dsl.stencil import FrozenStencil, _convert_quantities_to_storage -from ndsl.dsl.stencil_config import CompilationConfig, StencilConfig +from ndsl.dsl.stencil import _convert_quantities_to_storage from ndsl.dsl.typing import Float, FloatField -from ndsl.quantity import Quantity def get_stencil_config( @@ -280,14 +285,14 @@ def test_backend_options( "backend": "numpy", "rebuild": True, "format_source": False, - "name": "test_stencil_wrapper.copy_stencil", + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", }, "cuda": { "backend": "cuda", "rebuild": True, "device_sync": False, "format_source": False, - "name": "test_stencil_wrapper.copy_stencil", + "name": "tests.dsl.test_stencil_wrapper.copy_stencil", }, } diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 7343bf8d..ab11b16e 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -2,9 +2,13 @@ import pytest +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + Quantity, + TilePartitioner, +) from ndsl.comm._boundary_utils import get_boundary_slice -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -22,9 +26,7 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.quantity import Quantity - -from .mpi_comm import MPI +from tests.mpi.mpi_comm import MPI @pytest.fixture diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 4d4d24a0..d099d76a 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -3,8 +3,7 @@ from ndsl.comm.communicator import recv_buffer from ndsl.testing import ConcurrencyError, DummyComm - -from .mpi_comm import MPI +from tests.mpi.mpi_comm import MPI worker_function_list = [] diff --git a/tests/quantity/test_boundary.py b/tests/quantity/test_boundary.py index 42db16ab..a4f8e812 100644 --- a/tests/quantity/test_boundary.py +++ b/tests/quantity/test_boundary.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from ndsl import Quantity from ndsl.comm._boundary_utils import _shift_boundary_slice, get_boundary_slice from ndsl.constants import ( EAST, @@ -12,7 +13,6 @@ Y_DIM, Z_DIM, ) -from ndsl.quantity import Quantity def boundary_data(quantity, boundary_type, n_points, interior=True): diff --git a/tests/quantity/test_deepcopy.py b/tests/quantity/test_deepcopy.py index c44ea394..a7b1564c 100644 --- a/tests/quantity/test_deepcopy.py +++ b/tests/quantity/test_deepcopy.py @@ -3,7 +3,7 @@ import numpy as np -from ndsl.quantity import Quantity +from ndsl import Quantity def test_deepcopy_copy_is_editable_by_view(): diff --git a/tests/quantity/test_quantity.py b/tests/quantity/test_quantity.py index 7d0d75f0..a6de628b 100644 --- a/tests/quantity/test_quantity.py +++ b/tests/quantity/test_quantity.py @@ -2,7 +2,7 @@ import pytest import ndsl.quantity as qty -from ndsl.quantity import Quantity +from ndsl import Quantity try: diff --git a/tests/quantity/test_storage.py b/tests/quantity/test_storage.py index 172d78d6..2cdb8d49 100644 --- a/tests/quantity/test_storage.py +++ b/tests/quantity/test_storage.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.quantity import Quantity +from ndsl import Quantity try: diff --git a/tests/quantity/test_transpose.py b/tests/quantity/test_transpose.py index be1569a4..5e527279 100644 --- a/tests/quantity/test_transpose.py +++ b/tests/quantity/test_transpose.py @@ -1,5 +1,6 @@ import pytest +from ndsl import Quantity from ndsl.constants import ( X_DIM, X_DIMS, @@ -10,7 +11,6 @@ Z_DIM, Z_DIMS, ) -from ndsl.quantity import Quantity @pytest.fixture diff --git a/tests/quantity/test_view.py b/tests/quantity/test_view.py index a1ba5e57..73245093 100644 --- a/tests/quantity/test_view.py +++ b/tests/quantity/test_view.py @@ -1,8 +1,8 @@ import numpy as np import pytest +from ndsl import Quantity from ndsl.constants import X_DIM, Y_DIM -from ndsl.quantity import Quantity @pytest.fixture diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index 6481315a..b28eba16 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -4,13 +4,17 @@ import numpy as np -from ndsl.comm.caching_comm import CachingCommReader, CachingCommWriter -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.local_comm import LocalComm -from ndsl.comm.null_comm import NullComm -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CachingCommReader, + CachingCommWriter, + CubedSphereCommunicator, + CubedSpherePartitioner, + LocalComm, + NullComm, + Quantity, + TilePartitioner, +) from ndsl.constants import X_DIM, Y_DIM -from ndsl.quantity import Quantity def test_halo_update_integration(): diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 855422bd..7966142d 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -3,8 +3,14 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import ( HORIZONTAL_DIMS, TILE_DIM, @@ -15,9 +21,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm try: diff --git a/tests/test_decomposition.py b/tests/test_decomposition.py index de4d40c2..bf7363e3 100644 --- a/tests/test_decomposition.py +++ b/tests/test_decomposition.py @@ -4,6 +4,7 @@ import pytest +from ndsl import CubedSpherePartitioner, TilePartitioner from ndsl.comm.decomposition import ( block_waiting_for_compilation, build_cache_path, @@ -11,9 +12,7 @@ determine_rank_is_compiling, unblock_waiting_tiles, ) -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner - -from .mpi.mpi_comm import MPI +from tests.mpi.mpi_comm import MPI @pytest.mark.parametrize( diff --git a/tests/test_dimension_sizer.py b/tests/test_dimension_sizer.py index 3f2cdded..a401e698 100644 --- a/tests/test_dimension_sizer.py +++ b/tests/test_dimension_sizer.py @@ -2,6 +2,7 @@ import pytest +from ndsl import QuantityFactory, SubtileGridSizer from ndsl.constants import ( N_HALO_DEFAULT, X_DIM, @@ -11,8 +12,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.initialization.allocator import QuantityFactory -from ndsl.initialization.sizer import SubtileGridSizer @pytest.fixture(params=[48, 96]) diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index a05b0fb3..28f1af7c 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -8,12 +8,15 @@ import numpy as np import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import X_DIM, Y_DIM, Z_DIM -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm try: diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index 10d7f996..7e1b9f09 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from ndsl import HaloDataTransformer, HaloExchangeSpec, Quantity, QuantityHaloSpec from ndsl.buffer import Buffer from ndsl.comm import _boundary_utils from ndsl.constants import ( @@ -22,9 +23,7 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.halo.data_transformer import HaloDataTransformer, HaloExchangeSpec from ndsl.halo.rotate import rotate_scalar_data, rotate_vector_data -from ndsl.quantity import Quantity, QuantityHaloSpec @pytest.fixture diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index c17903cd..d0536b24 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -3,10 +3,20 @@ import pytest +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + HaloUpdater, + OutOfBoundsError, + Quantity, + QuantityHaloSpec, + TileCommunicator, + TilePartitioner, + Timer, +) from ndsl.buffer import BUFFER_CACHE from ndsl.comm._boundary_utils import get_boundary_slice -from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -24,11 +34,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.exceptions import OutOfBoundsError -from ndsl.halo.updater import HaloUpdater -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity, QuantityHaloSpec -from ndsl.testing import DummyComm @pytest.fixture diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index e33f0d66..8ec77cc1 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -1,7 +1,13 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import ( X_DIM, X_INTERFACE_DIM, @@ -10,9 +16,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm @pytest.fixture diff --git a/tests/test_legacy_restart.py b/tests/test_legacy_restart.py index 3728b94c..2034c04c 100644 --- a/tests/test_legacy_restart.py +++ b/tests/test_legacy_restart.py @@ -12,17 +12,20 @@ import pytest import ndsl.io as io -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, +) from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM -from ndsl.quantity import Quantity from ndsl.restart._legacy_restart import ( _apply_dims, get_rank_suffix, map_keys, open_restart, ) -from ndsl.testing import DummyComm requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") diff --git a/tests/test_local_comm.py b/tests/test_local_comm.py index 0b8072af..c549ee2a 100644 --- a/tests/test_local_comm.py +++ b/tests/test_local_comm.py @@ -1,7 +1,7 @@ import numpy import pytest -from ndsl.comm.local_comm import LocalComm +from ndsl import LocalComm @pytest.fixture diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 6e20537f..7a21dd78 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -6,12 +6,15 @@ import numpy as np import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner -from ndsl.monitor import NetCDFMonitor +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + NetCDFMonitor, + Quantity, + TilePartitioner, +) from ndsl.optional_imports import xarray as xr -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm requires_xarray = pytest.mark.skipif(xr is None, reason="xarray is not installed") diff --git a/tests/test_null_comm.py b/tests/test_null_comm.py index 0a384767..74065f67 100644 --- a/tests/test_null_comm.py +++ b/tests/test_null_comm.py @@ -1,6 +1,9 @@ -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.null_comm import NullComm -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + NullComm, + TilePartitioner, +) def test_can_create_cube_communicator(): diff --git a/tests/test_partitioner.py b/tests/test_partitioner.py index 99f1fb6b..6bd15eda 100644 --- a/tests/test_partitioner.py +++ b/tests/test_partitioner.py @@ -1,9 +1,8 @@ import numpy as np import pytest +from ndsl import CubedSpherePartitioner, Quantity, TilePartitioner from ndsl.comm.partitioner import ( - CubedSpherePartitioner, - TilePartitioner, _subtile_extents_from_tile_metadata, get_tile_index, get_tile_number, @@ -18,7 +17,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.quantity import Quantity rank_list = [] diff --git a/tests/test_partitioner_boundaries.py b/tests/test_partitioner_boundaries.py index 71574ee6..eff528fd 100644 --- a/tests/test_partitioner_boundaries.py +++ b/tests/test_partitioner_boundaries.py @@ -1,10 +1,7 @@ import pytest -from ndsl.comm.partitioner import ( - CubedSpherePartitioner, - TilePartitioner, - rotate_subtile_rank, -) +from ndsl import CubedSpherePartitioner, TilePartitioner +from ndsl.comm.partitioner import rotate_subtile_rank from ndsl.constants import ( BOUNDARY_TYPES, CORNER_BOUNDARY_TYPES, diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 37711009..3e0930a0 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -1,11 +1,14 @@ import pytest -from ndsl.comm.communicator import CubedSphereCommunicator -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + Timer, +) from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm @pytest.fixture diff --git a/tests/test_tile_scatter.py b/tests/test_tile_scatter.py index 26aa3d0d..d768bb15 100644 --- a/tests/test_tile_scatter.py +++ b/tests/test_tile_scatter.py @@ -1,10 +1,7 @@ import pytest -from ndsl.comm.communicator import TileCommunicator -from ndsl.comm.partitioner import TilePartitioner +from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm def rank_scatter_results(communicator_list, quantity): diff --git a/tests/test_tile_scatter_gather.py b/tests/test_tile_scatter_gather.py index 2669f5fc..6d56dd6f 100644 --- a/tests/test_tile_scatter_gather.py +++ b/tests/test_tile_scatter_gather.py @@ -3,8 +3,7 @@ import pytest -from ndsl.comm.communicator import TileCommunicator -from ndsl.comm.partitioner import TilePartitioner +from ndsl import DummyComm, Quantity, TileCommunicator, TilePartitioner from ndsl.constants import ( HORIZONTAL_DIMS, X_DIM, @@ -14,8 +13,6 @@ Z_DIM, Z_INTERFACE_DIM, ) -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm try: diff --git a/tests/test_timer.py b/tests/test_timer.py index 0b0cd4bf..213a487a 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -2,7 +2,7 @@ import pytest -from ndsl.performance.timer import NullTimer, Timer +from ndsl import NullTimer, Timer @pytest.fixture diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index fbe90408..b608ec08 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -12,7 +12,13 @@ import cftime import pytest -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl import ( + CubedSpherePartitioner, + DummyComm, + Quantity, + TilePartitioner, + ZarrMonitor, +) from ndsl.constants import ( X_DIM, X_DIMS, @@ -22,11 +28,8 @@ Y_INTERFACE_DIM, Z_DIM, ) -from ndsl.monitor import ZarrMonitor from ndsl.monitor.zarr_monitor import array_chunks, get_calendar from ndsl.optional_imports import xarray as xr -from ndsl.quantity import Quantity -from ndsl.testing import DummyComm requires_zarr = pytest.mark.skipif(zarr is None, reason="zarr is not installed") From 7d333b55d0dba54f2b5295f58c8b1608646c3d0d Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 27 Feb 2024 12:06:01 -0500 Subject: [PATCH 03/12] Changes to missed exposed modules and clean-up of comments --- ndsl/dsl/caches/__init__.py | 1 - ndsl/dsl/dace/__init__.py | 11 ++--------- tests/dsl/__init__.py | 2 -- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/ndsl/dsl/caches/__init__.py b/ndsl/dsl/caches/__init__.py index 4fbb20e9..e69de29b 100644 --- a/ndsl/dsl/caches/__init__.py +++ b/ndsl/dsl/caches/__init__.py @@ -1 +0,0 @@ -from .codepath import FV3CodePath diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index 0f1edcbf..aa19a3d9 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,9 +1,2 @@ -from .dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG -from .orchestration import ( - _LazyComputepathFunction, - _LazyComputepathMethod, - orchestrate, - orchestrate_function, -) -from .utils import ArrayReport, DaCeProgress, MaxBandwithBenchmarkProgram, StorageReport -from .wrapped_halo_exchange import WrappedHaloUpdater +from .dace_config import DaceConfig, DaCeOrchestration +from .orchestration import orchestrate, orchestrate_function diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py index d5af6cf1..e69de29b 100644 --- a/tests/dsl/__init__.py +++ b/tests/dsl/__init__.py @@ -1,2 +0,0 @@ -# from .test_caches import OrchestratedProgam -# from .test_stencil_wrapper import MockFieldInfo From fba0d95674fc9b1849bd91af5ea0e249bd3f8847 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 27 Feb 2024 15:47:23 -0500 Subject: [PATCH 04/12] Adding exposure for modules needed by external modules pyFV3 and pySHiELD --- ndsl/__init__.py | 22 +++++++++++++++++++--- ndsl/checkpointer/__init__.py | 2 ++ ndsl/comm/__init__.py | 4 +++- ndsl/dsl/__init__.py | 8 +++++++- ndsl/dsl/dace/__init__.py | 1 + ndsl/initialization/__init__.py | 2 +- ndsl/performance/__init__.py | 1 + ndsl/stencils/__init__.py | 13 +++++++++++++ ndsl/stencils/testing/__init__.py | 15 ++++++++++++++- ndsl/stencils/testing/grid.py | 4 ++-- tests/checkpointer/test_snapshot.py | 2 +- 11 files changed, 64 insertions(+), 10 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 06a77ac2..221cf1bb 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,11 +1,14 @@ -from .checkpointer import SnapshotCheckpointer +from .checkpointer import Checkpointer, NullCheckpointer, SnapshotCheckpointer from .comm import ( CachingCommReader, CachingCommWriter, + Comm, + Communicator, ConcurrencyError, CubedSphereCommunicator, CubedSpherePartitioner, LocalComm, + MPIComm, NullComm, TileCommunicator, TilePartitioner, @@ -20,12 +23,25 @@ RunMode, StencilConfig, StencilFactory, + WrappedHaloUpdater, ) from .exceptions import OutOfBoundsError from .halo import HaloDataTransformer, HaloExchangeSpec, HaloUpdater -from .initialization import QuantityFactory, SubtileGridSizer +from .initialization import GridSizer, QuantityFactory, SubtileGridSizer from .logging import ndsl_log from .monitor import NetCDFMonitor, ZarrMonitor -from .performance import NullTimer, Timer +from .performance import NullTimer, PerformanceCollector, Timer from .quantity import Quantity, QuantityHaloSpec +from .stencils import ( + CubedToLatLon, + Grid, + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, + TranslateFortranData2Py, + TranslateGrid, +) from .testing import DummyComm +from .utils import MetaEnumStr diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index 46d32a6c..d24936cb 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1 +1,3 @@ +from .base import Checkpointer +from .null import NullCheckpointer from .snapshots import SnapshotCheckpointer diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 0e86fe02..31319c78 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -1,5 +1,7 @@ from .caching_comm import CachingCommReader, CachingCommWriter -from .communicator import CubedSphereCommunicator, TileCommunicator +from .comm_abc import Comm +from .communicator import Communicator, CubedSphereCommunicator, TileCommunicator from .local_comm import ConcurrencyError, LocalComm +from .mpi import MPIComm from .null_comm import NullComm from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 1331294e..269ae957 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -3,7 +3,13 @@ from ndsl.comm.mpi import MPI from . import dace -from .dace import DaceConfig, DaCeOrchestration, orchestrate, orchestrate_function +from .dace import ( + DaceConfig, + DaCeOrchestration, + WrappedHaloUpdater, + orchestrate, + orchestrate_function, +) from .stencil import CompareToNumpyStencil, FrozenStencil, GridIndexing, StencilFactory from .stencil_config import CompilationConfig, RunMode, StencilConfig diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index aa19a3d9..c1386ad9 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,2 +1,3 @@ from .dace_config import DaceConfig, DaCeOrchestration from .orchestration import orchestrate, orchestrate_function +from .wrapped_halo_exchange import WrappedHaloUpdater diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index 50fd2f84..fe15db8b 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -1,2 +1,2 @@ from .allocator import QuantityFactory -from .sizer import SubtileGridSizer +from .sizer import GridSizer, SubtileGridSizer diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index b0a65f32..51e9bf84 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1 +1,2 @@ +from .collector import PerformanceCollector from .timer import NullTimer, Timer diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index d3ec452c..6083272b 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1 +1,14 @@ +from .c2l_ord import CubedToLatLon +from .testing import ( + Grid, + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, + TranslateFortranData2Py, + TranslateGrid, +) + + __version__ = "0.2.0" diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index 61a64831..3ad9ef9c 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -1,4 +1,17 @@ from . import parallel_translate, translate +from .grid import Grid # type: ignore +from .parallel_translate import ( + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, +) from .savepoint import dataset_to_dict from .temporaries import assert_same_temporaries, copy_temporaries -from .translate import pad_field_in_j, read_serialized_data +from .translate import ( + TranslateFortranData2Py, + TranslateGrid, + pad_field_in_j, + read_serialized_data, +) diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index b6a55133..273a0f3d 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -8,13 +8,13 @@ from ndsl.dsl import gt4py_utils as utils from ndsl.dsl.stencil import GridIndexing from ndsl.dsl.typing import Float -from ndsl.grid import ( +from ndsl.grid.generation import GridDefinitions +from ndsl.grid.helper import ( AngleGridData, ContravariantGridData, DampingCoefficients, DriverGridData, GridData, - GridDefinitions, HorizontalGridData, MetricTerms, VerticalGridData, diff --git a/tests/checkpointer/test_snapshot.py b/tests/checkpointer/test_snapshot.py index 797b4701..a8dd5387 100644 --- a/tests/checkpointer/test_snapshot.py +++ b/tests/checkpointer/test_snapshot.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer.snapshots import SnapshotCheckpointer +from ndsl import SnapshotCheckpointer from ndsl.optional_imports import xarray as xr From 8400c83b55fe044fd7285f3b80407c584c127bd7 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 28 Feb 2024 16:00:39 -0500 Subject: [PATCH 05/12] Updated import method to mimic other package styles --- ndsl/__init__.py | 121 ++++++++++++++++++------ ndsl/checkpointer/__init__.py | 3 - ndsl/comm/__init__.py | 7 -- ndsl/dsl/__init__.py | 11 --- ndsl/dsl/dace/__init__.py | 3 - ndsl/dsl/stencil.py | 6 +- ndsl/grid/__init__.py | 13 --- ndsl/grid/helper.py | 2 +- ndsl/halo/__init__.py | 2 - ndsl/initialization/__init__.py | 2 - ndsl/monitor/__init__.py | 2 - ndsl/performance/__init__.py | 2 - ndsl/stencils/__init__.py | 13 --- ndsl/stencils/testing/__init__.py | 17 ---- ndsl/stencils/testing/test_translate.py | 5 +- ndsl/testing/__init__.py | 3 - ndsl/testing/dummy_comm.py | 1 - setup.py | 1 + tests/checkpointer/test_thresholds.py | 6 +- tests/checkpointer/test_validation.py | 7 +- tests/dsl/test_caches.py | 2 +- tests/mpi/test_mpi_mock.py | 2 +- tests/test_halo_data_transformer.py | 9 +- 23 files changed, 113 insertions(+), 127 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 221cf1bb..6967d8dc 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,47 +1,110 @@ -from .checkpointer import Checkpointer, NullCheckpointer, SnapshotCheckpointer -from .comm import ( +from .buffer import Buffer +from .checkpointer.base import Checkpointer +from .checkpointer.null import NullCheckpointer +from .checkpointer.snapshots import SnapshotCheckpointer, _Snapshots +from .checkpointer.thresholds import ( + InsufficientTrialsError, + SavepointThresholds, + Threshold, + ThresholdCalibrationCheckpointer, +) +from .checkpointer.validation import ValidationCheckpointer +from .comm.boundary import Boundary, SimpleBoundary +from .comm.caching_comm import ( + CachingCommData, CachingCommReader, CachingCommWriter, - Comm, - Communicator, - ConcurrencyError, - CubedSphereCommunicator, - CubedSpherePartitioner, - LocalComm, - MPIComm, - NullComm, - TileCommunicator, - TilePartitioner, -) -from .dsl import ( + CachingRequestReader, + CachingRequestWriter, + NullRequest, +) +from .comm.comm_abc import Comm, Request +from .comm.communicator import Communicator, CubedSphereCommunicator, TileCommunicator +from .comm.local_comm import AsyncResult, ConcurrencyError, LocalComm +from .comm.mpi import MPIComm +from .comm.null_comm import NullAsyncResult, NullComm +from .comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner +from .constants import ConstantVersions +from .dsl.caches.codepath import FV3CodePath +from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG +from .dsl.dace.orchestration import orchestrate, orchestrate_function +from .dsl.dace.utils import ( + ArrayReport, + DaCeProgress, + MaxBandwithBenchmarkProgram, + StorageReport, +) +from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater +from .dsl.stencil import ( CompareToNumpyStencil, - CompilationConfig, - DaceConfig, - DaCeOrchestration, FrozenStencil, GridIndexing, - RunMode, - StencilConfig, StencilFactory, - WrappedHaloUpdater, + TimingCollector, ) +from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError -from .halo import HaloDataTransformer, HaloExchangeSpec, HaloUpdater -from .initialization import GridSizer, QuantityFactory, SubtileGridSizer +from .grid.eta import HybridPressureCoefficients +from .grid.generation import GridDefinition, GridDefinitions, MetricTerms +from .grid.helper import ( + AngleGridData, + ContravariantGridData, + DampingCoefficients, + DriverGridData, + GridData, + HorizontalGridData, + VerticalGridData, +) +from .halo.data_transformer import ( + HaloDataTransformer, + HaloDataTransformerCPU, + HaloDataTransformerGPU, + HaloExchangeSpec, +) +from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from .initialization.allocator import QuantityFactory, StorageNumpy +from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log -from .monitor import NetCDFMonitor, ZarrMonitor -from .performance import NullTimer, PerformanceCollector, Timer -from .quantity import Quantity, QuantityHaloSpec -from .stencils import ( - CubedToLatLon, - Grid, +from .monitor.netcdf_monitor import NetCDFMonitor +from .monitor.protocol import Protocol +from .monitor.zarr_monitor import ZarrMonitor +from .namelist import Namelist +from .optional_imports import RaiseWhenAccessed +from .performance.collector import ( + AbstractPerformanceCollector, + NullPerformanceCollector, + PerformanceCollector, +) +from .performance.config import PerformanceConfig +from .performance.profiler import NullProfiler, Profiler +from .performance.report import Experiment, Report, TimeReport +from .performance.timer import NullTimer, Timer +from .quantity import ( + BoundaryArrayView, + BoundedArrayView, + Quantity, + QuantityHaloSpec, + QuantityMetadata, +) +from .stencils.c2l_ord import CubedToLatLon +from .stencils.corners import CopyCorners, CopyCornersXY, FillCornersBGrid +from .stencils.testing.grid import Grid # type: ignore +from .stencils.testing.parallel_translate import ( ParallelTranslate, ParallelTranslate2Py, ParallelTranslate2PyState, ParallelTranslateBaseSlicing, ParallelTranslateGrid, +) +from .stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict +from .stencils.testing.temporaries import assert_same_temporaries, copy_temporaries +from .stencils.testing.translate import ( TranslateFortranData2Py, TranslateGrid, + pad_field_in_j, + read_serialized_data, ) -from .testing import DummyComm +from .testing.dummy_comm import DummyComm +from .types import Allocator, AsyncRequest, NumpyModule +from .units import UnitsError from .utils import MetaEnumStr diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index d24936cb..e69de29b 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1,3 +0,0 @@ -from .base import Checkpointer -from .null import NullCheckpointer -from .snapshots import SnapshotCheckpointer diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 31319c78..e69de29b 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -1,7 +0,0 @@ -from .caching_comm import CachingCommReader, CachingCommWriter -from .comm_abc import Comm -from .communicator import Communicator, CubedSphereCommunicator, TileCommunicator -from .local_comm import ConcurrencyError, LocalComm -from .mpi import MPIComm -from .null_comm import NullComm -from .partitioner import CubedSpherePartitioner, TilePartitioner diff --git a/ndsl/dsl/__init__.py b/ndsl/dsl/__init__.py index 269ae957..ed44420a 100644 --- a/ndsl/dsl/__init__.py +++ b/ndsl/dsl/__init__.py @@ -2,17 +2,6 @@ from ndsl.comm.mpi import MPI -from . import dace -from .dace import ( - DaceConfig, - DaCeOrchestration, - WrappedHaloUpdater, - orchestrate, - orchestrate_function, -) -from .stencil import CompareToNumpyStencil, FrozenStencil, GridIndexing, StencilFactory -from .stencil_config import CompilationConfig, RunMode, StencilConfig - if MPI is not None: import os diff --git a/ndsl/dsl/dace/__init__.py b/ndsl/dsl/dace/__init__.py index c1386ad9..e69de29b 100644 --- a/ndsl/dsl/dace/__init__.py +++ b/ndsl/dsl/dace/__init__.py @@ -1,3 +0,0 @@ -from .dace_config import DaceConfig, DaCeOrchestration -from .orchestration import orchestrate, orchestrate_function -from .wrapped_halo_exchange import WrappedHaloUpdater diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 77efd672..b8316727 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -22,7 +22,6 @@ from gt4py.cartesian import gtscript from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline -from ndsl import testing from ndsl.comm.comm_abc import Comm from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles @@ -34,6 +33,9 @@ from ndsl.initialization.sizer import GridSizer, SubtileGridSizer from ndsl.quantity import Quantity +# from ndsl import testing +from ndsl.testing import comparison + try: import cupy as cp @@ -68,7 +70,7 @@ def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id def report_diff(arg: np.ndarray, numpy_arg: np.ndarray, label) -> str: - metric_err = testing.compare_arr(arg, numpy_arg) + metric_err = comparison.compare_arr(arg, numpy_arg) nans_match = np.logical_and(np.isnan(arg), np.isnan(numpy_arg)) n_points = np.product(arg.shape) failures_14 = n_points - np.sum( diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index a7692a8f..e69de29b 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -1,13 +0,0 @@ -# flake8: noqa: F401 - -from .eta import set_hybrid_pressure_coefficients -from .gnomonic import ( - great_circle_distance_along_axis, - great_circle_distance_lon_lat, - lon_lat_corner_to_cell_center, - lon_lat_midpoint, - lon_lat_to_xyz, - xyz_midpoint, - xyz_to_lon_lat, -) -from .stretch_transformation import direct_transform diff --git a/ndsl/grid/helper.py b/ndsl/grid/helper.py index ee97a6b0..fd62d771 100644 --- a/ndsl/grid/helper.py +++ b/ndsl/grid/helper.py @@ -14,7 +14,7 @@ from ndsl.constants import Z_DIM, Z_INTERFACE_DIM from ndsl.filesystem import get_fs from ndsl.grid.generation import MetricTerms -from ndsl.initialization import QuantityFactory +from ndsl.initialization.allocator import QuantityFactory from ndsl.quantity import Quantity diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index 823bd226..e69de29b 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -1,2 +0,0 @@ -from .data_transformer import HaloDataTransformer, HaloExchangeSpec -from .updater import HaloUpdater diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index fe15db8b..e69de29b 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -1,2 +0,0 @@ -from .allocator import QuantityFactory -from .sizer import GridSizer, SubtileGridSizer diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index 26b38cc6..e69de29b 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -1,2 +0,0 @@ -from .netcdf_monitor import NetCDFMonitor -from .zarr_monitor import ZarrMonitor diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index 51e9bf84..e69de29b 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -1,2 +0,0 @@ -from .collector import PerformanceCollector -from .timer import NullTimer, Timer diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 6083272b..d3ec452c 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,14 +1 @@ -from .c2l_ord import CubedToLatLon -from .testing import ( - Grid, - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, - TranslateFortranData2Py, - TranslateGrid, -) - - __version__ = "0.2.0" diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index 3ad9ef9c..e69de29b 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -1,17 +0,0 @@ -from . import parallel_translate, translate -from .grid import Grid # type: ignore -from .parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .savepoint import dataset_to_dict -from .temporaries import assert_same_temporaries, copy_temporaries -from .translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index 2e42e273..29c4ed65 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -14,8 +14,9 @@ from ndsl.dsl.stencil import CompilationConfig, StencilConfig from ndsl.quantity import Quantity from ndsl.restart._legacy_restart import RESTART_PROPERTIES -from ndsl.stencils.testing import SavepointCase, dataset_to_dict -from ndsl.testing import compare_scalar, perturb, success, success_array +from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict +from ndsl.testing.comparison import compare_scalar, success, success_array +from ndsl.testing.perturbation import perturb # this only matters for manually-added print statements diff --git a/ndsl/testing/__init__.py b/ndsl/testing/__init__.py index a1c927e9..e69de29b 100644 --- a/ndsl/testing/__init__.py +++ b/ndsl/testing/__init__.py @@ -1,3 +0,0 @@ -from .comparison import compare_arr, compare_scalar, success, success_array -from .dummy_comm import ConcurrencyError, DummyComm -from .perturbation import perturb diff --git a/ndsl/testing/dummy_comm.py b/ndsl/testing/dummy_comm.py index b4df2347..f3e93817 100644 --- a/ndsl/testing/dummy_comm.py +++ b/ndsl/testing/dummy_comm.py @@ -1,2 +1 @@ -from ndsl.comm.local_comm import ConcurrencyError # noqa from ndsl.comm.local_comm import LocalComm as DummyComm # noqa diff --git a/setup.py b/setup.py index 73ec210a..c0d71813 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ def local_pkg(name: str, relative_path: str) -> str: "mpi4py", "cftime", "xarray", + "f90nml>=1.1.0", "fsspec", "netcdf4", "scipy", # restart capacities only diff --git a/tests/checkpointer/test_thresholds.py b/tests/checkpointer/test_thresholds.py index 851a9665..90d1f8fc 100644 --- a/tests/checkpointer/test_thresholds.py +++ b/tests/checkpointer/test_thresholds.py @@ -1,11 +1,7 @@ import numpy as np import pytest -from ndsl.checkpointer.thresholds import ( - InsufficientTrialsError, - Threshold, - ThresholdCalibrationCheckpointer, -) +from ndsl import InsufficientTrialsError, Threshold, ThresholdCalibrationCheckpointer def test_thresholds_no_trials(): diff --git a/tests/checkpointer/test_validation.py b/tests/checkpointer/test_validation.py index b696aca3..091bb7c6 100644 --- a/tests/checkpointer/test_validation.py +++ b/tests/checkpointer/test_validation.py @@ -4,11 +4,8 @@ import numpy as np import pytest -from ndsl.checkpointer.thresholds import SavepointThresholds, Threshold -from ndsl.checkpointer.validation import ( - ValidationCheckpointer, - _clip_pace_array_to_target, -) +from ndsl import SavepointThresholds, Threshold, ValidationCheckpointer +from ndsl.checkpointer.validation import _clip_pace_array_to_target from ndsl.optional_imports import xarray as xr diff --git a/tests/dsl/test_caches.py b/tests/dsl/test_caches.py index f90a1828..893fb89d 100644 --- a/tests/dsl/test_caches.py +++ b/tests/dsl/test_caches.py @@ -11,7 +11,7 @@ StencilFactory, ) from ndsl.comm.mpi import MPI -from ndsl.dsl.dace import orchestrate +from ndsl.dsl.dace.orchestration import orchestrate def _make_storage( diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index d099d76a..def0d342 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,8 +1,8 @@ import numpy as np import pytest +from ndsl import ConcurrencyError, DummyComm from ndsl.comm.communicator import recv_buffer -from ndsl.testing import ConcurrencyError, DummyComm from tests.mpi.mpi_comm import MPI diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index 7e1b9f09..ec986f8c 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -4,8 +4,13 @@ import numpy as np import pytest -from ndsl import HaloDataTransformer, HaloExchangeSpec, Quantity, QuantityHaloSpec -from ndsl.buffer import Buffer +from ndsl import ( + Buffer, + HaloDataTransformer, + HaloExchangeSpec, + Quantity, + QuantityHaloSpec, +) from ndsl.comm import _boundary_utils from ndsl.constants import ( EAST, From 1c8b4b9c09fcfa98a83311358ddd0ea7e46f5552 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 5 Mar 2024 16:28:43 -0500 Subject: [PATCH 06/12] Exposure changes and creation of ndsl.typing module --- ndsl/__init__.py | 52 +- ndsl/checkpointer/__init__.py | 9 + ndsl/checkpointer/base.py | 7 - ndsl/checkpointer/null.py | 2 +- ndsl/checkpointer/snapshots.py | 2 +- ndsl/checkpointer/thresholds.py | 2 +- ndsl/checkpointer/validation.py | 2 +- ndsl/comm/__init__.py | 9 + ndsl/comm/communicator.py | 568 +--------------------- ndsl/comm/partitioner.py | 79 +-- ndsl/dsl/caches/cache_location.py | 2 +- ndsl/dsl/dace/dace_config.py | 2 +- ndsl/dsl/dace/wrapped_halo_exchange.py | 2 +- ndsl/dsl/stencil.py | 2 +- ndsl/dsl/stencil_config.py | 3 +- ndsl/grid/__init__.py | 11 + ndsl/grid/generation.py | 2 +- ndsl/halo/updater.py | 2 +- ndsl/monitor/netcdf_monitor.py | 2 +- ndsl/monitor/zarr_monitor.py | 3 +- ndsl/restart/_legacy_restart.py | 2 +- ndsl/stencils/__init__.py | 20 + ndsl/stencils/c2l_ord.py | 2 +- ndsl/stencils/testing/conftest.py | 7 +- ndsl/typing.py | 648 +++++++++++++++++++++++++ tests/checkpointer/test_snapshot.py | 2 +- tests/checkpointer/test_thresholds.py | 6 +- tests/checkpointer/test_validation.py | 2 +- tests/mpi/test_mpi_mock.py | 2 +- tests/test_caching_comm.py | 3 +- 30 files changed, 733 insertions(+), 724 deletions(-) delete mode 100644 ndsl/checkpointer/base.py create mode 100644 ndsl/typing.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 6967d8dc..a5073021 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,29 +1,10 @@ from .buffer import Buffer -from .checkpointer.base import Checkpointer -from .checkpointer.null import NullCheckpointer -from .checkpointer.snapshots import SnapshotCheckpointer, _Snapshots -from .checkpointer.thresholds import ( - InsufficientTrialsError, - SavepointThresholds, - Threshold, - ThresholdCalibrationCheckpointer, -) -from .checkpointer.validation import ValidationCheckpointer from .comm.boundary import Boundary, SimpleBoundary -from .comm.caching_comm import ( - CachingCommData, - CachingCommReader, - CachingCommWriter, - CachingRequestReader, - CachingRequestWriter, - NullRequest, -) -from .comm.comm_abc import Comm, Request -from .comm.communicator import Communicator, CubedSphereCommunicator, TileCommunicator +from .comm.communicator import CubedSphereCommunicator, TileCommunicator from .comm.local_comm import AsyncResult, ConcurrencyError, LocalComm from .comm.mpi import MPIComm from .comm.null_comm import NullAsyncResult, NullComm -from .comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner +from .comm.partitioner import CubedSpherePartitioner, TilePartitioner from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath from .dsl.dace.dace_config import DaceConfig, DaCeOrchestration, FrozenCompiledSDFG @@ -44,17 +25,6 @@ ) from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError -from .grid.eta import HybridPressureCoefficients -from .grid.generation import GridDefinition, GridDefinitions, MetricTerms -from .grid.helper import ( - AngleGridData, - ContravariantGridData, - DampingCoefficients, - DriverGridData, - GridData, - HorizontalGridData, - VerticalGridData, -) from .halo.data_transformer import ( HaloDataTransformer, HaloDataTransformerCPU, @@ -86,24 +56,6 @@ QuantityHaloSpec, QuantityMetadata, ) -from .stencils.c2l_ord import CubedToLatLon -from .stencils.corners import CopyCorners, CopyCornersXY, FillCornersBGrid -from .stencils.testing.grid import Grid # type: ignore -from .stencils.testing.parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .stencils.testing.savepoint import SavepointCase, Translate, dataset_to_dict -from .stencils.testing.temporaries import assert_same_temporaries, copy_temporaries -from .stencils.testing.translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) from .testing.dummy_comm import DummyComm from .types import Allocator, AsyncRequest, NumpyModule from .units import UnitsError diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index e69de29b..6486d96c 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -0,0 +1,9 @@ +from .null import NullCheckpointer +from .snapshots import SnapshotCheckpointer, _Snapshots +from .thresholds import ( + InsufficientTrialsError, + SavepointThresholds, + Threshold, + ThresholdCalibrationCheckpointer, +) +from .validation import ValidationCheckpointer diff --git a/ndsl/checkpointer/base.py b/ndsl/checkpointer/base.py deleted file mode 100644 index 8218bbfe..00000000 --- a/ndsl/checkpointer/base.py +++ /dev/null @@ -1,7 +0,0 @@ -import abc - - -class Checkpointer(abc.ABC): - @abc.abstractmethod - def __call__(self, savepoint_name, **kwargs): - ... diff --git a/ndsl/checkpointer/null.py b/ndsl/checkpointer/null.py index fbc78755..448b3a6e 100644 --- a/ndsl/checkpointer/null.py +++ b/ndsl/checkpointer/null.py @@ -1,4 +1,4 @@ -from ndsl.checkpointer.base import Checkpointer +from ndsl.typing import Checkpointer class NullCheckpointer(Checkpointer): diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index aa806b21..573701ae 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -2,9 +2,9 @@ import numpy as np -from ndsl.checkpointer.base import Checkpointer from ndsl.optional_imports import cupy as cp from ndsl.optional_imports import xarray as xr +from ndsl.typing import Checkpointer def make_dims(savepoint_dim, label, data_list): diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index ded73b39..2f1af55c 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -5,8 +5,8 @@ import numpy as np -from ndsl.checkpointer.base import Checkpointer from ndsl.quantity import Quantity +from ndsl.typing import Checkpointer try: diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 8af11317..12146a53 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -5,7 +5,6 @@ import numpy as np -from ndsl.checkpointer.base import Checkpointer from ndsl.checkpointer.thresholds import ( ArrayLike, SavepointName, @@ -13,6 +12,7 @@ cast_to_ndarray, ) from ndsl.optional_imports import xarray as xr +from ndsl.typing import Checkpointer def _clip_pace_array_to_target( diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index e69de29b..289e6413 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -0,0 +1,9 @@ +from .caching_comm import ( + CachingCommData, + CachingCommReader, + CachingCommWriter, + CachingRequestReader, + CachingRequestWriter, + NullRequest, +) +from .comm_abc import Comm, Request diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 9149c1ed..3f21ee21 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -1,17 +1,11 @@ -import abc -from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast - -import numpy as np +from typing import List, Optional, Sequence, Tuple, Union, cast import ndsl.constants as constants -from ndsl.buffer import array_buffer, recv_buffer, send_buffer -from ndsl.comm.boundary import Boundary -from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner -from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from ndsl.performance.timer import NullTimer, Timer -from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata -from ndsl.types import NumpyModule -from ndsl.utils import device_synchronize +from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner +from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest +from ndsl.performance.timer import Timer +from ndsl.quantity import Quantity, QuantityMetadata +from ndsl.typing import Communicator try: @@ -20,29 +14,6 @@ cupy = None -def to_numpy(array, dtype=None) -> np.ndarray: - """ - Input array can be a numpy array or a cupy array. Returns numpy array. - """ - try: - output = np.asarray(array) - except ValueError as err: - if err.args[0] == "object __array__ method not producing an array": - output = cupy.asnumpy(array) - else: - raise err - except TypeError as err: - if err.args[0].startswith( - "Implicit conversion to a NumPy array is not allowed." - ): - output = cupy.asnumpy(array) - else: - raise err - if dtype: - output = output.astype(dtype=dtype) - return output - - def bcast_metadata_list(comm, quantity_list): is_root = comm.Get_rank() == constants.ROOT_RANK if is_root: @@ -58,533 +29,6 @@ def bcast_metadata(comm, array): return bcast_metadata_list(comm, [array])[0] -class Communicator(abc.ABC): - def __init__( - self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None - ): - self.comm = comm - self.partitioner: Partitioner = partitioner - self._force_cpu = force_cpu - self._boundaries: Optional[Mapping[int, Boundary]] = None - self._last_halo_tag = 0 - self.timer: Timer = timer if timer is not None else NullTimer() - - @abc.abstractproperty - def tile(self) -> "TileCommunicator": - pass - - @classmethod - @abc.abstractmethod - def from_layout( - cls, - comm, - layout: Tuple[int, int], - force_cpu: bool = False, - timer: Optional[Timer] = None, - ): - pass - - @property - def rank(self) -> int: - """rank of the current process within this communicator""" - return self.comm.Get_rank() - - @property - def size(self) -> int: - """Total number of ranks in this communicator""" - return self.comm.Get_size() - - def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: - """ - Get a numpy-like module depending on configuration and - Quantity original allocator. - """ - if self._force_cpu: - return np - return module - - @staticmethod - def _device_synchronize(): - """Wait for all work that could be in-flight to finish.""" - # this is a method so we can profile it separately from other device syncs - device_synchronize() - - def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Scatter(send, recv, **kwargs) - - def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Gather(send, recv, **kwargs) - - def scatter( - self, - send_quantity: Optional[Quantity] = None, - recv_quantity: Optional[Quantity] = None, - ) -> Quantity: - """Transfer subtile regions of a full-tile quantity - from the tile root rank to all subtiles. - - Args: - send_quantity: quantity to send, only required/used on the tile root rank - recv_quantity: if provided, assign received data into this Quantity. - Returns: - recv_quantity - """ - if self.rank == constants.ROOT_RANK and send_quantity is None: - raise TypeError("send_quantity is a required argument on the root rank") - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) - else: - metadata = self.comm.bcast(None, root=constants.ROOT_RANK) - shape = self.partitioner.subtile_extent(metadata, self.rank) - if recv_quantity is None: - recv_quantity = self._get_scatter_recv_quantity(shape, metadata) - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - with array_buffer( - self._maybe_force_cpu(metadata.np).zeros, - (self.partitioner.total_ranks,) + shape, - dtype=metadata.dtype, - ) as sendbuf: - for rank in range(0, self.partitioner.total_ranks): - subtile_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=metadata.dims, - global_extent=metadata.extent, - overlap=True, - ) - sendbuf.assign_from( - send_quantity.view[subtile_slice], - buffer_slice=np.index_exp[rank, :], - ) - self._Scatter( - metadata.np, - sendbuf.array, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - else: - self._Scatter( - metadata.np, - None, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - return recv_quantity - - def _get_gather_recv_quantity( - self, global_extent: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving global data during gather""" - recv_quantity = Quantity( - send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - origin=tuple([0 for dim in send_metadata.dims]), - extent=global_extent, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def _get_scatter_recv_quantity( - self, shape: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving subtile data during scatter""" - recv_quantity = Quantity( - send_metadata.np.zeros(shape, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def gather( - self, send_quantity: Quantity, recv_quantity: Quantity = None - ) -> Optional[Quantity]: - """Transfer subtile regions of a full-tile quantity - from each rank to the tile root rank. - - Args: - send_quantity: quantity to send - recv_quantity: if provided, assign received data into this Quantity (only - used on the tile root rank) - Returns: - recv_quantity: quantity if on root rank, otherwise None - """ - result: Optional[Quantity] - if self.rank == constants.ROOT_RANK: - with array_buffer( - send_quantity.np.zeros, - (self.partitioner.total_ranks,) + tuple(send_quantity.extent), - dtype=send_quantity.data.dtype, - ) as recvbuf: - self._Gather( - send_quantity.np, - send_quantity.view[:], - recvbuf.array, - root=constants.ROOT_RANK, - ) - if recv_quantity is None: - global_extent = self.partitioner.global_extent( - send_quantity.metadata - ) - recv_quantity = self._get_gather_recv_quantity( - global_extent, send_quantity.metadata - ) - for rank in range(self.partitioner.total_ranks): - to_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=recv_quantity.dims, - global_extent=recv_quantity.extent, - overlap=True, - ) - recvbuf.assign_to( - recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] - ) - result = recv_quantity - else: - self._Gather( - send_quantity.np, - send_quantity.view[:], - None, - root=constants.ROOT_RANK, - ) - result = None - return result - - def gather_state(self, send_state=None, recv_state=None, transfer_type=None): - """Transfer a state dictionary from subtile ranks to the tile root rank. - - 'time' is assumed to be the same on all ranks, and its value will be set - to the value from the root rank. - - Args: - send_state: the model state to be sent containing the subtile data - recv_state: the pre-allocated state in which to recieve the full tile - state. Only variables which are scattered will be written to. - Returns: - recv_state: on the root rank, the state containing the entire tile - """ - if self.rank == constants.ROOT_RANK and recv_state is None: - recv_state = {} - for name, quantity in send_state.items(): - if name == "time": - if self.rank == constants.ROOT_RANK: - recv_state["time"] = send_state["time"] - else: - gather_value = to_numpy(quantity.view[:], dtype=transfer_type) - gather_quantity = Quantity( - data=gather_value, - dims=quantity.dims, - units=quantity.units, - allow_mismatch_float_precision=True, - ) - if recv_state is not None and name in recv_state: - tile_quantity = self.gather( - gather_quantity, recv_quantity=recv_state[name] - ) - else: - tile_quantity = self.gather(gather_quantity) - if self.rank == constants.ROOT_RANK: - recv_state[name] = tile_quantity - del gather_quantity - return recv_state - - def scatter_state(self, send_state=None, recv_state=None): - """Transfer a state dictionary from the tile root rank to all subtiles. - - Args: - send_state: the model state to be sent containing the entire tile, - required only from the root rank - recv_state: the pre-allocated state in which to recieve the scattered - state. Only variables which are scattered will be written to. - Returns: - rank_state: the state corresponding to this rank's subdomain - """ - - def scatter_root(): - if send_state is None: - raise TypeError("send_state is a required argument on the root rank") - name_list = list(send_state.keys()) - while "time" in name_list: - name_list.remove("time") - name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) - array_list = [send_state[name] for name in name_list] - for name, array in zip(name_list, array_list): - if name in recv_state: - self.scatter(send_quantity=array, recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter(send_quantity=array) - recv_state["time"] = self.comm.bcast( - send_state.get("time", None), root=constants.ROOT_RANK - ) - - def scatter_client(): - name_list = self.comm.bcast(None, root=constants.ROOT_RANK) - for name in name_list: - if name in recv_state: - self.scatter(recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter() - recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) - - if recv_state is None: - recv_state = {} - if self.rank == constants.ROOT_RANK: - scatter_root() - else: - scatter_client() - if recv_state["time"] is None: - recv_state.pop("time") - return recv_state - - def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): - """Perform a halo update on a quantity or quantities - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - halo_updater = self.start_halo_update(quantities, n_points) - halo_updater.wait() - - def start_halo_update( - self, quantity: Union[Quantity, List[Quantity]], n_points: int - ) -> HaloUpdater: - """Start an asynchronous halo update on a quantity. - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - specifications = [] - for quantity in quantities: - specification = QuantityHaloSpec( - n_points=n_points, - shape=quantity.data.shape, - strides=quantity.data.strides, - itemsize=quantity.data.itemsize, - origin=quantity.origin, - extent=quantity.extent, - dims=quantity.dims, - numpy_module=self._maybe_force_cpu(quantity.np), - dtype=quantity.metadata.dtype, - ) - specifications.append(specification) - - halo_updater = self.get_scalar_halo_updater(specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(quantities) - return halo_updater - - def vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ): - """Perform a halo update of a horizontal vector quantity or quantities. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - halo_updater = self.start_vector_halo_update( - x_quantities, y_quantities, n_points - ) - halo_updater.wait() - - def start_vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ) -> HaloUpdater: - """Start an asynchronous halo update of a horizontal vector quantity. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - x_specifications = [] - y_specifications = [] - for x_quantity, y_quantity in zip(x_quantities, y_quantities): - x_specification = QuantityHaloSpec( - n_points=n_points, - shape=x_quantity.data.shape, - strides=x_quantity.data.strides, - itemsize=x_quantity.data.itemsize, - origin=x_quantity.metadata.origin, - extent=x_quantity.metadata.extent, - dims=x_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(x_quantity.np), - dtype=x_quantity.metadata.dtype, - ) - x_specifications.append(x_specification) - y_specification = QuantityHaloSpec( - n_points=n_points, - shape=y_quantity.data.shape, - strides=y_quantity.data.strides, - itemsize=y_quantity.data.itemsize, - origin=y_quantity.metadata.origin, - extent=y_quantity.metadata.extent, - dims=y_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(y_quantity.np), - dtype=y_quantity.metadata.dtype, - ) - y_specifications.append(y_specification) - - halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(x_quantities, y_quantities) - return halo_updater - - def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - """ - req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) - req.wait() - - def start_synchronize_vector_interfaces( - self, x_quantity: Quantity, y_quantity: Quantity - ) -> HaloUpdateRequest: - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - - Returns: - request: an asynchronous request object with a .wait() method - """ - halo_updater = VectorInterfaceHaloUpdater( - comm=self.comm, - boundaries=self.boundaries, - force_cpu=self._force_cpu, - timer=self.timer, - ) - req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) - return req - - def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): - if len(specifications) == 0: - raise RuntimeError("Cannot create updater with specifications list") - if specifications[0].n_points == 0: - raise ValueError("cannot perform a halo update on zero halo points") - return HaloUpdater.from_scalar_specifications( - self, - self._maybe_force_cpu(specifications[0].numpy_module), - specifications, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def get_vector_halo_updater( - self, - specifications_x: List[QuantityHaloSpec], - specifications_y: List[QuantityHaloSpec], - ): - if len(specifications_x) == 0 and len(specifications_y) == 0: - raise RuntimeError("Cannot create updater with empty specifications list") - if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: - raise ValueError("Cannot perform a halo update on zero halo points") - return HaloUpdater.from_vector_specifications( - self, - self._maybe_force_cpu(specifications_x[0].numpy_module), - specifications_x, - specifications_y, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def _get_halo_tag(self) -> int: - self._last_halo_tag += 1 - return self._last_halo_tag - - @property - def boundaries(self) -> Mapping[int, Boundary]: - """boundaries of this tile with neighboring tiles""" - if self._boundaries is None: - self._boundaries = {} - for boundary_type in constants.BOUNDARY_TYPES: - boundary = self.partitioner.boundary(boundary_type, self.rank) - if boundary is not None: - self._boundaries[boundary_type] = boundary - return self._boundaries - - class TileCommunicator(Communicator): """Performs communications within a single tile or region of a tile""" diff --git a/ndsl/comm/partitioner.py b/ndsl/comm/partitioner.py index 6b8750a1..e3b2e02b 100644 --- a/ndsl/comm/partitioner.py +++ b/ndsl/comm/partitioner.py @@ -1,4 +1,3 @@ -import abc import copy import functools from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union, cast @@ -18,6 +17,7 @@ WEST, ) from ndsl.quantity import Quantity, QuantityMetadata +from ndsl.typing import Partitioner from ndsl.utils import list_by_dims @@ -54,83 +54,6 @@ def get_tile_number(tile_rank: int, total_ranks: int) -> int: return tile_rank // ranks_per_tile + 1 -class Partitioner(abc.ABC): - @abc.abstractmethod - def __init__(self): - self.tile = None - self.layout = None - - @abc.abstractmethod - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: - ... - - @abc.abstractmethod - def tile_index(self, rank: int): - pass - - @abc.abstractmethod - def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: - """Return the shape of a full tile representation for the given dimensions. - - Args: - metadata: quantity metadata - - Returns: - extent: shape of full tile representation - """ - pass - - @abc.abstractmethod - def subtile_slice( - self, - rank: int, - global_dims: Sequence[str], - global_extent: Sequence[int], - overlap: bool = False, - ) -> Tuple[Union[int, slice], ...]: - """Return the subtile slice of a given rank on an array. - - Global refers to the domain being partitioned. For example, for a partitioning - of a tile, the tile would be the "global" domain. - - Args: - rank: the rank of the process - global_dims: dimensions of the global quantity being partitioned - global_extent: extent of the global quantity being partitioned - overlap (optional): if True, for interface variables include the part - of the array shared by adjacent ranks in both ranks. If False, ensure - only one of those ranks (the greater rank) is assigned the overlapping - section. Default is False. - - Returns: - subtile_slice: the slice of the global compute domain corresponding - to the subtile compute domain - """ - pass - - @abc.abstractmethod - def subtile_extent( - self, - global_metadata: QuantityMetadata, - rank: int, - ) -> Tuple[int, ...]: - """Return the shape of a single rank representation for the given dimensions. - - Args: - global_metadata: quantity metadata. - rank: rank of the process. - - Returns: - extent: shape of a single rank representation for the given dimensions. - """ - pass - - @property - @abc.abstractmethod - def total_ranks(self) -> int: - pass - - class TilePartitioner(Partitioner): def __init__( self, diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index edf563b7..2d973f7a 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -1,5 +1,5 @@ -from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.codepath import FV3CodePath +from ndsl.typing import Partitioner def identify_code_path( diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 7671b464..f93d2baf 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -6,12 +6,12 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.frontend.python.parser import DaceProgram -from ndsl.comm.communicator import Communicator, Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import floating_point_precision from ndsl.optional_imports import cupy as cp +from ndsl.typing import Communicator, Partitioner # This can be turned on to revert compilation for orchestration diff --git a/ndsl/dsl/dace/wrapped_halo_exchange.py b/ndsl/dsl/dace/wrapped_halo_exchange.py index 78a68fa4..ca36f3a0 100644 --- a/ndsl/dsl/dace/wrapped_halo_exchange.py +++ b/ndsl/dsl/dace/wrapped_halo_exchange.py @@ -1,9 +1,9 @@ import dataclasses from typing import Any, List, Optional -from ndsl.comm.communicator import Communicator from ndsl.dsl.dace.orchestration import dace_inhibitor from ndsl.halo.updater import HaloUpdater +from ndsl.typing import Communicator class WrappedHaloUpdater: diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index b8316727..f57c139a 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -23,7 +23,6 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.comm_abc import Comm -from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI from ndsl.constants import X_DIM, X_DIMS, Y_DIM, Y_DIMS, Z_DIM, Z_DIMS @@ -35,6 +34,7 @@ # from ndsl import testing from ndsl.testing import comparison +from ndsl.typing import Communicator try: diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index 6b8f75eb..e1e233b7 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -5,11 +5,10 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline -from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import determine_rank_is_compiling, set_distributed_caches -from ndsl.comm.partitioner import Partitioner from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration from ndsl.dsl.gt4py_utils import is_gpu_backend +from ndsl.typing import Communicator, Partitioner class RunMode(enum.Enum): diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index e69de29b..49eccf05 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -0,0 +1,11 @@ +from .eta import HybridPressureCoefficients +from .generation import GridDefinition, GridDefinitions, MetricTerms +from .helper import ( + AngleGridData, + ContravariantGridData, + DampingCoefficients, + DriverGridData, + GridData, + HorizontalGridData, + VerticalGridData, +) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 12275d7d..2d6450a9 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -5,7 +5,6 @@ import numpy as np -from ndsl.comm.communicator import Communicator from ndsl.constants import ( N_HALO_DEFAULT, PI, @@ -59,6 +58,7 @@ fill_corners_cgrid, fill_corners_dgrid, ) +from ndsl.typing import Communicator # TODO: when every environment in python3.8, remove diff --git a/ndsl/halo/updater.py b/ndsl/halo/updater.py index 665d0b95..7684c564 100644 --- a/ndsl/halo/updater.py +++ b/ndsl/halo/updater.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: - from ndsl.comm.communicator import Communicator + from ndsl.typing import Communicator _HaloSendTuple = Tuple[AsyncRequest, Buffer] _HaloRecvTuple = Tuple[AsyncRequest, Buffer, np.ndarray] diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 8a0b96fd..30731095 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -5,12 +5,12 @@ import fsspec import numpy as np -from ndsl.comm.communicator import Communicator from ndsl.filesystem import get_fs from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity +from ndsl.typing import Communicator class _TimeChunkedVariable: diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 214171be..85e37222 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -4,12 +4,13 @@ import cftime import ndsl.constants as constants -from ndsl.comm.partitioner import Partitioner, subtile_slice +from ndsl.comm.partitioner import subtile_slice from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import cupy from ndsl.optional_imports import xarray as xr from ndsl.optional_imports import zarr +from ndsl.typing import Partitioner from ndsl.utils import list_by_dims diff --git a/ndsl/restart/_legacy_restart.py b/ndsl/restart/_legacy_restart.py index 01f9bdb8..afa4d523 100644 --- a/ndsl/restart/_legacy_restart.py +++ b/ndsl/restart/_legacy_restart.py @@ -5,11 +5,11 @@ import ndsl.constants as constants import ndsl.filesystem as filesystem import ndsl.io as io -from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import get_tile_index from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity from ndsl.restart._properties import RESTART_PROPERTIES, RestartProperties +from ndsl.typing import Communicator __all__ = ["open_restart"] diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index d3ec452c..0fe16725 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1 +1,21 @@ +from .c2l_ord import CubedToLatLon +from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid +from .testing.grid import Grid # type: ignore +from .testing.parallel_translate import ( + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, +) +from .testing.savepoint import SavepointCase, Translate, dataset_to_dict +from .testing.temporaries import assert_same_temporaries, copy_temporaries +from .testing.translate import ( + TranslateFortranData2Py, + TranslateGrid, + pad_field_in_j, + read_serialized_data, +) + + __version__ = "0.2.0" diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index 67f2b5a1..87d59f27 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -8,13 +8,13 @@ ) import ndsl.dsl.gt4py_utils as utils -from ndsl.comm.communicator import Communicator from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM from ndsl.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ from ndsl.grid.helper import GridData from ndsl.initialization.allocator import QuantityFactory +from ndsl.typing import Communicator A1 = 0.5625 diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index d000e1fa..b3a3a7e8 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -8,11 +8,7 @@ import yaml import ndsl.dsl -from ndsl.comm.communicator import ( - Communicator, - CubedSphereCommunicator, - TileCommunicator, -) +from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator from ndsl.comm.mpi import MPI from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig @@ -20,6 +16,7 @@ from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict from ndsl.stencils.testing.translate import TranslateGrid +from ndsl.typing import Communicator @pytest.fixture() diff --git a/ndsl/typing.py b/ndsl/typing.py new file mode 100644 index 00000000..2f815d0d --- /dev/null +++ b/ndsl/typing.py @@ -0,0 +1,648 @@ +import abc +from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast + +import numpy as np + +import ndsl.constants as constants +from ndsl.buffer import array_buffer, recv_buffer, send_buffer +from ndsl.comm.boundary import Boundary +from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from ndsl.performance.timer import NullTimer, Timer +from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata +from ndsl.types import NumpyModule +from ndsl.utils import device_synchronize +from ndsl.comm import boundary as bd + +try: + import cupy +except ImportError: + cupy = None + +def to_numpy(array, dtype=None) -> np.ndarray: + """ + Input array can be a numpy array or a cupy array. Returns numpy array. + """ + try: + output = np.asarray(array) + except ValueError as err: + if err.args[0] == "object __array__ method not producing an array": + output = cupy.asnumpy(array) + else: + raise err + except TypeError as err: + if err.args[0].startswith( + "Implicit conversion to a NumPy array is not allowed." + ): + output = cupy.asnumpy(array) + else: + raise err + if dtype: + output = output.astype(dtype=dtype) + return output + +class Checkpointer(abc.ABC): + @abc.abstractmethod + def __call__(self, savepoint_name, **kwargs): + ... + +class Communicator(abc.ABC): + def __init__( + self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None + ): + self.comm = comm + self.partitioner: Partitioner = partitioner + self._force_cpu = force_cpu + self._boundaries: Optional[Mapping[int, Boundary]] = None + self._last_halo_tag = 0 + self.timer: Timer = timer if timer is not None else NullTimer() + + @abc.abstractproperty + def tile(self): + pass + + @classmethod + @abc.abstractmethod + def from_layout( + cls, + comm, + layout: Tuple[int, int], + force_cpu: bool = False, + timer: Optional[Timer] = None, + ): + pass + + @property + def rank(self) -> int: + """rank of the current process within this communicator""" + return self.comm.Get_rank() + + @property + def size(self) -> int: + """Total number of ranks in this communicator""" + return self.comm.Get_size() + + def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: + """ + Get a numpy-like module depending on configuration and + Quantity original allocator. + """ + if self._force_cpu: + return np + return module + + @staticmethod + def _device_synchronize(): + """Wait for all work that could be in-flight to finish.""" + # this is a method so we can profile it separately from other device syncs + device_synchronize() + + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Scatter(send, recv, **kwargs) + + def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Gather(send, recv, **kwargs) + + def scatter( + self, + send_quantity: Optional[Quantity] = None, + recv_quantity: Optional[Quantity] = None, + ) -> Quantity: + """Transfer subtile regions of a full-tile quantity + from the tile root rank to all subtiles. + + Args: + send_quantity: quantity to send, only required/used on the tile root rank + recv_quantity: if provided, assign received data into this Quantity. + Returns: + recv_quantity + """ + if self.rank == constants.ROOT_RANK and send_quantity is None: + raise TypeError("send_quantity is a required argument on the root rank") + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) + else: + metadata = self.comm.bcast(None, root=constants.ROOT_RANK) + shape = self.partitioner.subtile_extent(metadata, self.rank) + if recv_quantity is None: + recv_quantity = self._get_scatter_recv_quantity(shape, metadata) + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + with array_buffer( + self._maybe_force_cpu(metadata.np).zeros, + (self.partitioner.total_ranks,) + shape, + dtype=metadata.dtype, + ) as sendbuf: + for rank in range(0, self.partitioner.total_ranks): + subtile_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=metadata.dims, + global_extent=metadata.extent, + overlap=True, + ) + sendbuf.assign_from( + send_quantity.view[subtile_slice], + buffer_slice=np.index_exp[rank, :], + ) + self._Scatter( + metadata.np, + sendbuf.array, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + else: + self._Scatter( + metadata.np, + None, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + return recv_quantity + + def _get_gather_recv_quantity( + self, global_extent: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving global data during gather""" + recv_quantity = Quantity( + send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + origin=tuple([0 for dim in send_metadata.dims]), + extent=global_extent, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def _get_scatter_recv_quantity( + self, shape: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving subtile data during scatter""" + recv_quantity = Quantity( + send_metadata.np.zeros(shape, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def gather( + self, send_quantity: Quantity, recv_quantity: Quantity = None + ) -> Optional[Quantity]: + """Transfer subtile regions of a full-tile quantity + from each rank to the tile root rank. + + Args: + send_quantity: quantity to send + recv_quantity: if provided, assign received data into this Quantity (only + used on the tile root rank) + Returns: + recv_quantity: quantity if on root rank, otherwise None + """ + result: Optional[Quantity] + if self.rank == constants.ROOT_RANK: + with array_buffer( + send_quantity.np.zeros, + (self.partitioner.total_ranks,) + tuple(send_quantity.extent), + dtype=send_quantity.data.dtype, + ) as recvbuf: + self._Gather( + send_quantity.np, + send_quantity.view[:], + recvbuf.array, + root=constants.ROOT_RANK, + ) + if recv_quantity is None: + global_extent = self.partitioner.global_extent( + send_quantity.metadata + ) + recv_quantity = self._get_gather_recv_quantity( + global_extent, send_quantity.metadata + ) + for rank in range(self.partitioner.total_ranks): + to_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=recv_quantity.dims, + global_extent=recv_quantity.extent, + overlap=True, + ) + recvbuf.assign_to( + recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] + ) + result = recv_quantity + else: + self._Gather( + send_quantity.np, + send_quantity.view[:], + None, + root=constants.ROOT_RANK, + ) + result = None + return result + + def gather_state(self, send_state=None, recv_state=None, transfer_type=None): + """Transfer a state dictionary from subtile ranks to the tile root rank. + + 'time' is assumed to be the same on all ranks, and its value will be set + to the value from the root rank. + + Args: + send_state: the model state to be sent containing the subtile data + recv_state: the pre-allocated state in which to recieve the full tile + state. Only variables which are scattered will be written to. + Returns: + recv_state: on the root rank, the state containing the entire tile + """ + if self.rank == constants.ROOT_RANK and recv_state is None: + recv_state = {} + for name, quantity in send_state.items(): + if name == "time": + if self.rank == constants.ROOT_RANK: + recv_state["time"] = send_state["time"] + else: + gather_value = to_numpy(quantity.view[:], dtype=transfer_type) + gather_quantity = Quantity( + data=gather_value, + dims=quantity.dims, + units=quantity.units, + allow_mismatch_float_precision=True, + ) + if recv_state is not None and name in recv_state: + tile_quantity = self.gather( + gather_quantity, recv_quantity=recv_state[name] + ) + else: + tile_quantity = self.gather(gather_quantity) + if self.rank == constants.ROOT_RANK: + recv_state[name] = tile_quantity + del gather_quantity + return recv_state + + def scatter_state(self, send_state=None, recv_state=None): + """Transfer a state dictionary from the tile root rank to all subtiles. + + Args: + send_state: the model state to be sent containing the entire tile, + required only from the root rank + recv_state: the pre-allocated state in which to recieve the scattered + state. Only variables which are scattered will be written to. + Returns: + rank_state: the state corresponding to this rank's subdomain + """ + + def scatter_root(): + if send_state is None: + raise TypeError("send_state is a required argument on the root rank") + name_list = list(send_state.keys()) + while "time" in name_list: + name_list.remove("time") + name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) + array_list = [send_state[name] for name in name_list] + for name, array in zip(name_list, array_list): + if name in recv_state: + self.scatter(send_quantity=array, recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter(send_quantity=array) + recv_state["time"] = self.comm.bcast( + send_state.get("time", None), root=constants.ROOT_RANK + ) + + def scatter_client(): + name_list = self.comm.bcast(None, root=constants.ROOT_RANK) + for name in name_list: + if name in recv_state: + self.scatter(recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter() + recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) + + if recv_state is None: + recv_state = {} + if self.rank == constants.ROOT_RANK: + scatter_root() + else: + scatter_client() + if recv_state["time"] is None: + recv_state.pop("time") + return recv_state + + def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): + """Perform a halo update on a quantity or quantities + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + halo_updater = self.start_halo_update(quantities, n_points) + halo_updater.wait() + + def start_halo_update( + self, quantity: Union[Quantity, List[Quantity]], n_points: int + ) -> HaloUpdater: + """Start an asynchronous halo update on a quantity. + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + specifications = [] + for quantity in quantities: + specification = QuantityHaloSpec( + n_points=n_points, + shape=quantity.data.shape, + strides=quantity.data.strides, + itemsize=quantity.data.itemsize, + origin=quantity.origin, + extent=quantity.extent, + dims=quantity.dims, + numpy_module=self._maybe_force_cpu(quantity.np), + dtype=quantity.metadata.dtype, + ) + specifications.append(specification) + + halo_updater = self.get_scalar_halo_updater(specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(quantities) + return halo_updater + + def vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ): + """Perform a halo update of a horizontal vector quantity or quantities. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + halo_updater = self.start_vector_halo_update( + x_quantities, y_quantities, n_points + ) + halo_updater.wait() + + def start_vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ) -> HaloUpdater: + """Start an asynchronous halo update of a horizontal vector quantity. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + x_specifications = [] + y_specifications = [] + for x_quantity, y_quantity in zip(x_quantities, y_quantities): + x_specification = QuantityHaloSpec( + n_points=n_points, + shape=x_quantity.data.shape, + strides=x_quantity.data.strides, + itemsize=x_quantity.data.itemsize, + origin=x_quantity.metadata.origin, + extent=x_quantity.metadata.extent, + dims=x_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(x_quantity.np), + dtype=x_quantity.metadata.dtype, + ) + x_specifications.append(x_specification) + y_specification = QuantityHaloSpec( + n_points=n_points, + shape=y_quantity.data.shape, + strides=y_quantity.data.strides, + itemsize=y_quantity.data.itemsize, + origin=y_quantity.metadata.origin, + extent=y_quantity.metadata.extent, + dims=y_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(y_quantity.np), + dtype=y_quantity.metadata.dtype, + ) + y_specifications.append(y_specification) + + halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(x_quantities, y_quantities) + return halo_updater + + def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + """ + req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) + req.wait() + + def start_synchronize_vector_interfaces( + self, x_quantity: Quantity, y_quantity: Quantity + ) -> HaloUpdateRequest: + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + + Returns: + request: an asynchronous request object with a .wait() method + """ + halo_updater = VectorInterfaceHaloUpdater( + comm=self.comm, + boundaries=self.boundaries, + force_cpu=self._force_cpu, + timer=self.timer, + ) + req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) + return req + + def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): + if len(specifications) == 0: + raise RuntimeError("Cannot create updater with specifications list") + if specifications[0].n_points == 0: + raise ValueError("cannot perform a halo update on zero halo points") + return HaloUpdater.from_scalar_specifications( + self, + self._maybe_force_cpu(specifications[0].numpy_module), + specifications, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def get_vector_halo_updater( + self, + specifications_x: List[QuantityHaloSpec], + specifications_y: List[QuantityHaloSpec], + ): + if len(specifications_x) == 0 and len(specifications_y) == 0: + raise RuntimeError("Cannot create updater with empty specifications list") + if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: + raise ValueError("Cannot perform a halo update on zero halo points") + return HaloUpdater.from_vector_specifications( + self, + self._maybe_force_cpu(specifications_x[0].numpy_module), + specifications_x, + specifications_y, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def _get_halo_tag(self) -> int: + self._last_halo_tag += 1 + return self._last_halo_tag + + @property + def boundaries(self) -> Mapping[int, Boundary]: + """boundaries of this tile with neighboring tiles""" + if self._boundaries is None: + self._boundaries = {} + for boundary_type in constants.BOUNDARY_TYPES: + boundary = self.partitioner.boundary(boundary_type, self.rank) + if boundary is not None: + self._boundaries[boundary_type] = boundary + return self._boundaries + +class Partitioner(abc.ABC): + @abc.abstractmethod + def __init__(self): + self.tile = None + self.layout = None + + @abc.abstractmethod + def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: + ... + + @abc.abstractmethod + def tile_index(self, rank: int): + pass + + @abc.abstractmethod + def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: + """Return the shape of a full tile representation for the given dimensions. + + Args: + metadata: quantity metadata + + Returns: + extent: shape of full tile representation + """ + pass + + @abc.abstractmethod + def subtile_slice( + self, + rank: int, + global_dims: Sequence[str], + global_extent: Sequence[int], + overlap: bool = False, + ) -> Tuple[Union[int, slice], ...]: + """Return the subtile slice of a given rank on an array. + + Global refers to the domain being partitioned. For example, for a partitioning + of a tile, the tile would be the "global" domain. + + Args: + rank: the rank of the process + global_dims: dimensions of the global quantity being partitioned + global_extent: extent of the global quantity being partitioned + overlap (optional): if True, for interface variables include the part + of the array shared by adjacent ranks in both ranks. If False, ensure + only one of those ranks (the greater rank) is assigned the overlapping + section. Default is False. + + Returns: + subtile_slice: the slice of the global compute domain corresponding + to the subtile compute domain + """ + pass + + @abc.abstractmethod + def subtile_extent( + self, + global_metadata: QuantityMetadata, + rank: int, + ) -> Tuple[int, ...]: + """Return the shape of a single rank representation for the given dimensions. + + Args: + global_metadata: quantity metadata. + rank: rank of the process. + + Returns: + extent: shape of a single rank representation for the given dimensions. + """ + pass + + @property + @abc.abstractmethod + def total_ranks(self) -> int: + pass \ No newline at end of file diff --git a/tests/checkpointer/test_snapshot.py b/tests/checkpointer/test_snapshot.py index a8dd5387..89d368ec 100644 --- a/tests/checkpointer/test_snapshot.py +++ b/tests/checkpointer/test_snapshot.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from ndsl import SnapshotCheckpointer +from ndsl.checkpointer import SnapshotCheckpointer from ndsl.optional_imports import xarray as xr diff --git a/tests/checkpointer/test_thresholds.py b/tests/checkpointer/test_thresholds.py index 90d1f8fc..8bf70b00 100644 --- a/tests/checkpointer/test_thresholds.py +++ b/tests/checkpointer/test_thresholds.py @@ -1,7 +1,11 @@ import numpy as np import pytest -from ndsl import InsufficientTrialsError, Threshold, ThresholdCalibrationCheckpointer +from ndsl.checkpointer import ( + InsufficientTrialsError, + Threshold, + ThresholdCalibrationCheckpointer, +) def test_thresholds_no_trials(): diff --git a/tests/checkpointer/test_validation.py b/tests/checkpointer/test_validation.py index 091bb7c6..0c08d52b 100644 --- a/tests/checkpointer/test_validation.py +++ b/tests/checkpointer/test_validation.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from ndsl import SavepointThresholds, Threshold, ValidationCheckpointer +from ndsl.checkpointer import SavepointThresholds, Threshold, ValidationCheckpointer from ndsl.checkpointer.validation import _clip_pace_array_to_target from ndsl.optional_imports import xarray as xr diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index def0d342..79e71cf7 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -2,7 +2,7 @@ import pytest from ndsl import ConcurrencyError, DummyComm -from ndsl.comm.communicator import recv_buffer +from ndsl.buffer import recv_buffer from tests.mpi.mpi_comm import MPI diff --git a/tests/test_caching_comm.py b/tests/test_caching_comm.py index b28eba16..5674bfc9 100644 --- a/tests/test_caching_comm.py +++ b/tests/test_caching_comm.py @@ -5,8 +5,6 @@ import numpy as np from ndsl import ( - CachingCommReader, - CachingCommWriter, CubedSphereCommunicator, CubedSpherePartitioner, LocalComm, @@ -14,6 +12,7 @@ Quantity, TilePartitioner, ) +from ndsl.comm import CachingCommReader, CachingCommWriter from ndsl.constants import X_DIM, Y_DIM From a59736e18a9fa90f65eaae40e5379bab430ba4e5 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 5 Mar 2024 16:33:07 -0500 Subject: [PATCH 07/12] Linting --- ndsl/typing.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ndsl/typing.py b/ndsl/typing.py index 2f815d0d..aa80c386 100644 --- a/ndsl/typing.py +++ b/ndsl/typing.py @@ -5,19 +5,21 @@ import ndsl.constants as constants from ndsl.buffer import array_buffer, recv_buffer, send_buffer +from ndsl.comm import boundary as bd from ndsl.comm.boundary import Boundary from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from ndsl.performance.timer import NullTimer, Timer from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata from ndsl.types import NumpyModule from ndsl.utils import device_synchronize -from ndsl.comm import boundary as bd + try: import cupy except ImportError: cupy = None + def to_numpy(array, dtype=None) -> np.ndarray: """ Input array can be a numpy array or a cupy array. Returns numpy array. @@ -40,11 +42,13 @@ def to_numpy(array, dtype=None) -> np.ndarray: output = output.astype(dtype=dtype) return output + class Checkpointer(abc.ABC): @abc.abstractmethod def __call__(self, savepoint_name, **kwargs): ... + class Communicator(abc.ABC): def __init__( self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None @@ -570,7 +574,8 @@ def boundaries(self) -> Mapping[int, Boundary]: if boundary is not None: self._boundaries[boundary_type] = boundary return self._boundaries - + + class Partitioner(abc.ABC): @abc.abstractmethod def __init__(self): @@ -645,4 +650,4 @@ def subtile_extent( @property @abc.abstractmethod def total_ranks(self) -> int: - pass \ No newline at end of file + pass From 9f0477a6a8c5f752e076c833685d37a6f0c944dc Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 6 Mar 2024 15:20:38 -0500 Subject: [PATCH 08/12] Changes as of comments from 6 Mar 2024, from PR 14 --- ndsl/__init__.py | 38 +- ndsl/checkpointer/base.py | 7 + ndsl/checkpointer/null.py | 2 +- ndsl/checkpointer/snapshots.py | 2 +- ndsl/checkpointer/thresholds.py | 2 +- ndsl/checkpointer/validation.py | 2 +- ndsl/comm/communicator.py | 567 ++++++++++++++++++++- ndsl/comm/partitioner.py | 79 ++- ndsl/dsl/caches/cache_location.py | 2 +- ndsl/dsl/dace/dace_config.py | 3 +- ndsl/dsl/dace/wrapped_halo_exchange.py | 2 +- ndsl/dsl/stencil.py | 2 +- ndsl/dsl/stencil_config.py | 3 +- ndsl/grid/__init__.py | 2 +- ndsl/grid/generation.py | 2 +- ndsl/halo/__init__.py | 5 + ndsl/halo/updater.py | 2 +- ndsl/monitor/netcdf_monitor.py | 2 +- ndsl/monitor/zarr_monitor.py | 3 +- ndsl/restart/_legacy_restart.py | 2 +- ndsl/stencils/__init__.py | 16 - ndsl/stencils/c2l_ord.py | 2 +- ndsl/stencils/testing/__init__.py | 16 + ndsl/stencils/testing/conftest.py | 7 +- ndsl/typing.py | 662 +------------------------ tests/dsl/test_stencil_factory.py | 3 +- tests/mpi/test_mpi_mock.py | 3 +- tests/test_halo_data_transformer.py | 11 +- tests/test_halo_update.py | 2 +- 29 files changed, 714 insertions(+), 737 deletions(-) create mode 100644 ndsl/checkpointer/base.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index a5073021..f8c64614 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -1,9 +1,7 @@ -from .buffer import Buffer -from .comm.boundary import Boundary, SimpleBoundary from .comm.communicator import CubedSphereCommunicator, TileCommunicator -from .comm.local_comm import AsyncResult, ConcurrencyError, LocalComm +from .comm.local_comm import LocalComm from .comm.mpi import MPIComm -from .comm.null_comm import NullAsyncResult, NullComm +from .comm.null_comm import NullComm from .comm.partitioner import CubedSpherePartitioner, TilePartitioner from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath @@ -16,46 +14,24 @@ StorageReport, ) from .dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater -from .dsl.stencil import ( - CompareToNumpyStencil, - FrozenStencil, - GridIndexing, - StencilFactory, - TimingCollector, -) +from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig from .exceptions import OutOfBoundsError -from .halo.data_transformer import ( - HaloDataTransformer, - HaloDataTransformerCPU, - HaloDataTransformerGPU, - HaloExchangeSpec, -) +from .halo.data_transformer import HaloExchangeSpec from .halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from .initialization.allocator import QuantityFactory, StorageNumpy +from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log from .monitor.netcdf_monitor import NetCDFMonitor from .monitor.protocol import Protocol from .monitor.zarr_monitor import ZarrMonitor from .namelist import Namelist -from .optional_imports import RaiseWhenAccessed -from .performance.collector import ( - AbstractPerformanceCollector, - NullPerformanceCollector, - PerformanceCollector, -) +from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.config import PerformanceConfig from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport from .performance.timer import NullTimer, Timer -from .quantity import ( - BoundaryArrayView, - BoundedArrayView, - Quantity, - QuantityHaloSpec, - QuantityMetadata, -) +from .quantity import Quantity from .testing.dummy_comm import DummyComm from .types import Allocator, AsyncRequest, NumpyModule from .units import UnitsError diff --git a/ndsl/checkpointer/base.py b/ndsl/checkpointer/base.py new file mode 100644 index 00000000..8218bbfe --- /dev/null +++ b/ndsl/checkpointer/base.py @@ -0,0 +1,7 @@ +import abc + + +class Checkpointer(abc.ABC): + @abc.abstractmethod + def __call__(self, savepoint_name, **kwargs): + ... diff --git a/ndsl/checkpointer/null.py b/ndsl/checkpointer/null.py index 448b3a6e..fbc78755 100644 --- a/ndsl/checkpointer/null.py +++ b/ndsl/checkpointer/null.py @@ -1,4 +1,4 @@ -from ndsl.typing import Checkpointer +from ndsl.checkpointer.base import Checkpointer class NullCheckpointer(Checkpointer): diff --git a/ndsl/checkpointer/snapshots.py b/ndsl/checkpointer/snapshots.py index 573701ae..aa806b21 100644 --- a/ndsl/checkpointer/snapshots.py +++ b/ndsl/checkpointer/snapshots.py @@ -2,9 +2,9 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.optional_imports import cupy as cp from ndsl.optional_imports import xarray as xr -from ndsl.typing import Checkpointer def make_dims(savepoint_dim, label, data_list): diff --git a/ndsl/checkpointer/thresholds.py b/ndsl/checkpointer/thresholds.py index 2f1af55c..ded73b39 100644 --- a/ndsl/checkpointer/thresholds.py +++ b/ndsl/checkpointer/thresholds.py @@ -5,8 +5,8 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.quantity import Quantity -from ndsl.typing import Checkpointer try: diff --git a/ndsl/checkpointer/validation.py b/ndsl/checkpointer/validation.py index 12146a53..8af11317 100644 --- a/ndsl/checkpointer/validation.py +++ b/ndsl/checkpointer/validation.py @@ -5,6 +5,7 @@ import numpy as np +from ndsl.checkpointer.base import Checkpointer from ndsl.checkpointer.thresholds import ( ArrayLike, SavepointName, @@ -12,7 +13,6 @@ cast_to_ndarray, ) from ndsl.optional_imports import xarray as xr -from ndsl.typing import Checkpointer def _clip_pace_array_to_target( diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 3f21ee21..f1b97c87 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -1,11 +1,16 @@ -from typing import List, Optional, Sequence, Tuple, Union, cast +import abc +from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast + +import numpy as np import ndsl.constants as constants -from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner -from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest -from ndsl.performance.timer import Timer -from ndsl.quantity import Quantity, QuantityMetadata -from ndsl.typing import Communicator +from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer +from ndsl.comm.boundary import Boundary +from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner +from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater +from ndsl.performance.timer import NullTimer, Timer +from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata +from ndsl.types import NumpyModule try: @@ -14,6 +19,556 @@ cupy = None +def to_numpy(array, dtype=None) -> np.ndarray: + """ + Input array can be a numpy array or a cupy array. Returns numpy array. + """ + try: + output = np.asarray(array) + except ValueError as err: + if err.args[0] == "object __array__ method not producing an array": + output = cupy.asnumpy(array) + else: + raise err + except TypeError as err: + if err.args[0].startswith( + "Implicit conversion to a NumPy array is not allowed." + ): + output = cupy.asnumpy(array) + else: + raise err + if dtype: + output = output.astype(dtype=dtype) + return output + + +class Communicator(abc.ABC): + def __init__( + self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None + ): + self.comm = comm + self.partitioner: Partitioner = partitioner + self._force_cpu = force_cpu + self._boundaries: Optional[Mapping[int, Boundary]] = None + self._last_halo_tag = 0 + self.timer: Timer = timer if timer is not None else NullTimer() + + @abc.abstractproperty + def tile(self): + pass + + @classmethod + @abc.abstractmethod + def from_layout( + cls, + comm, + layout: Tuple[int, int], + force_cpu: bool = False, + timer: Optional[Timer] = None, + ): + pass + + @property + def rank(self) -> int: + """rank of the current process within this communicator""" + return self.comm.Get_rank() + + @property + def size(self) -> int: + """Total number of ranks in this communicator""" + return self.comm.Get_size() + + def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: + """ + Get a numpy-like module depending on configuration and + Quantity original allocator. + """ + if self._force_cpu: + return np + return module + + @staticmethod + def _device_synchronize(): + """Wait for all work that could be in-flight to finish.""" + # this is a method so we can profile it separately from other device syncs + device_synchronize() + + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Scatter(send, recv, **kwargs) + + def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): + with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( + numpy_module.zeros, recvbuf + ) as recv: + self.comm.Gather(send, recv, **kwargs) + + def scatter( + self, + send_quantity: Optional[Quantity] = None, + recv_quantity: Optional[Quantity] = None, + ) -> Quantity: + """Transfer subtile regions of a full-tile quantity + from the tile root rank to all subtiles. + + Args: + send_quantity: quantity to send, only required/used on the tile root rank + recv_quantity: if provided, assign received data into this Quantity. + Returns: + recv_quantity + """ + if self.rank == constants.ROOT_RANK and send_quantity is None: + raise TypeError("send_quantity is a required argument on the root rank") + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) + else: + metadata = self.comm.bcast(None, root=constants.ROOT_RANK) + shape = self.partitioner.subtile_extent(metadata, self.rank) + if recv_quantity is None: + recv_quantity = self._get_scatter_recv_quantity(shape, metadata) + if self.rank == constants.ROOT_RANK: + send_quantity = cast(Quantity, send_quantity) + with array_buffer( + self._maybe_force_cpu(metadata.np).zeros, + (self.partitioner.total_ranks,) + shape, + dtype=metadata.dtype, + ) as sendbuf: + for rank in range(0, self.partitioner.total_ranks): + subtile_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=metadata.dims, + global_extent=metadata.extent, + overlap=True, + ) + sendbuf.assign_from( + send_quantity.view[subtile_slice], + buffer_slice=np.index_exp[rank, :], + ) + self._Scatter( + metadata.np, + sendbuf.array, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + else: + self._Scatter( + metadata.np, + None, + recv_quantity.view[:], + root=constants.ROOT_RANK, + ) + return recv_quantity + + def _get_gather_recv_quantity( + self, global_extent: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving global data during gather""" + recv_quantity = Quantity( + send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + origin=tuple([0 for dim in send_metadata.dims]), + extent=global_extent, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def _get_scatter_recv_quantity( + self, shape: Sequence[int], send_metadata: QuantityMetadata + ) -> Quantity: + """Initialize a Quantity for use when receiving subtile data during scatter""" + recv_quantity = Quantity( + send_metadata.np.zeros(shape, dtype=send_metadata.dtype), + dims=send_metadata.dims, + units=send_metadata.units, + gt4py_backend=send_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return recv_quantity + + def gather( + self, send_quantity: Quantity, recv_quantity: Quantity = None + ) -> Optional[Quantity]: + """Transfer subtile regions of a full-tile quantity + from each rank to the tile root rank. + + Args: + send_quantity: quantity to send + recv_quantity: if provided, assign received data into this Quantity (only + used on the tile root rank) + Returns: + recv_quantity: quantity if on root rank, otherwise None + """ + result: Optional[Quantity] + if self.rank == constants.ROOT_RANK: + with array_buffer( + send_quantity.np.zeros, + (self.partitioner.total_ranks,) + tuple(send_quantity.extent), + dtype=send_quantity.data.dtype, + ) as recvbuf: + self._Gather( + send_quantity.np, + send_quantity.view[:], + recvbuf.array, + root=constants.ROOT_RANK, + ) + if recv_quantity is None: + global_extent = self.partitioner.global_extent( + send_quantity.metadata + ) + recv_quantity = self._get_gather_recv_quantity( + global_extent, send_quantity.metadata + ) + for rank in range(self.partitioner.total_ranks): + to_slice = self.partitioner.subtile_slice( + rank=rank, + global_dims=recv_quantity.dims, + global_extent=recv_quantity.extent, + overlap=True, + ) + recvbuf.assign_to( + recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] + ) + result = recv_quantity + else: + self._Gather( + send_quantity.np, + send_quantity.view[:], + None, + root=constants.ROOT_RANK, + ) + result = None + return result + + def gather_state(self, send_state=None, recv_state=None, transfer_type=None): + """Transfer a state dictionary from subtile ranks to the tile root rank. + + 'time' is assumed to be the same on all ranks, and its value will be set + to the value from the root rank. + + Args: + send_state: the model state to be sent containing the subtile data + recv_state: the pre-allocated state in which to recieve the full tile + state. Only variables which are scattered will be written to. + Returns: + recv_state: on the root rank, the state containing the entire tile + """ + if self.rank == constants.ROOT_RANK and recv_state is None: + recv_state = {} + for name, quantity in send_state.items(): + if name == "time": + if self.rank == constants.ROOT_RANK: + recv_state["time"] = send_state["time"] + else: + gather_value = to_numpy(quantity.view[:], dtype=transfer_type) + gather_quantity = Quantity( + data=gather_value, + dims=quantity.dims, + units=quantity.units, + allow_mismatch_float_precision=True, + ) + if recv_state is not None and name in recv_state: + tile_quantity = self.gather( + gather_quantity, recv_quantity=recv_state[name] + ) + else: + tile_quantity = self.gather(gather_quantity) + if self.rank == constants.ROOT_RANK: + recv_state[name] = tile_quantity + del gather_quantity + return recv_state + + def scatter_state(self, send_state=None, recv_state=None): + """Transfer a state dictionary from the tile root rank to all subtiles. + + Args: + send_state: the model state to be sent containing the entire tile, + required only from the root rank + recv_state: the pre-allocated state in which to recieve the scattered + state. Only variables which are scattered will be written to. + Returns: + rank_state: the state corresponding to this rank's subdomain + """ + + def scatter_root(): + if send_state is None: + raise TypeError("send_state is a required argument on the root rank") + name_list = list(send_state.keys()) + while "time" in name_list: + name_list.remove("time") + name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) + array_list = [send_state[name] for name in name_list] + for name, array in zip(name_list, array_list): + if name in recv_state: + self.scatter(send_quantity=array, recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter(send_quantity=array) + recv_state["time"] = self.comm.bcast( + send_state.get("time", None), root=constants.ROOT_RANK + ) + + def scatter_client(): + name_list = self.comm.bcast(None, root=constants.ROOT_RANK) + for name in name_list: + if name in recv_state: + self.scatter(recv_quantity=recv_state[name]) + else: + recv_state[name] = self.scatter() + recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) + + if recv_state is None: + recv_state = {} + if self.rank == constants.ROOT_RANK: + scatter_root() + else: + scatter_client() + if recv_state["time"] is None: + recv_state.pop("time") + return recv_state + + def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): + """Perform a halo update on a quantity or quantities + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + halo_updater = self.start_halo_update(quantities, n_points) + halo_updater.wait() + + def start_halo_update( + self, quantity: Union[Quantity, List[Quantity]], n_points: int + ) -> HaloUpdater: + """Start an asynchronous halo update on a quantity. + + Args: + quantity: the quantity to be updated + n_points: how many halo points to update, starting from the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(quantity, Quantity): + quantities = [quantity] + else: + quantities = quantity + + specifications = [] + for quantity in quantities: + specification = QuantityHaloSpec( + n_points=n_points, + shape=quantity.data.shape, + strides=quantity.data.strides, + itemsize=quantity.data.itemsize, + origin=quantity.origin, + extent=quantity.extent, + dims=quantity.dims, + numpy_module=self._maybe_force_cpu(quantity.np), + dtype=quantity.metadata.dtype, + ) + specifications.append(specification) + + halo_updater = self.get_scalar_halo_updater(specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(quantities) + return halo_updater + + def vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ): + """Perform a halo update of a horizontal vector quantity or quantities. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + halo_updater = self.start_vector_halo_update( + x_quantities, y_quantities, n_points + ) + halo_updater.wait() + + def start_vector_halo_update( + self, + x_quantity: Union[Quantity, List[Quantity]], + y_quantity: Union[Quantity, List[Quantity]], + n_points: int, + ) -> HaloUpdater: + """Start an asynchronous halo update of a horizontal vector quantity. + + Assumes the x and y dimension indices are the same between the two quantities. + + Args: + x_quantity: the x-component quantity to be halo updated + y_quantity: the y-component quantity to be halo updated + n_points: how many halo points to update, starting at the interior + + Returns: + request: an asynchronous request object with a .wait() method + """ + if isinstance(x_quantity, Quantity): + x_quantities = [x_quantity] + else: + x_quantities = x_quantity + if isinstance(y_quantity, Quantity): + y_quantities = [y_quantity] + else: + y_quantities = y_quantity + + x_specifications = [] + y_specifications = [] + for x_quantity, y_quantity in zip(x_quantities, y_quantities): + x_specification = QuantityHaloSpec( + n_points=n_points, + shape=x_quantity.data.shape, + strides=x_quantity.data.strides, + itemsize=x_quantity.data.itemsize, + origin=x_quantity.metadata.origin, + extent=x_quantity.metadata.extent, + dims=x_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(x_quantity.np), + dtype=x_quantity.metadata.dtype, + ) + x_specifications.append(x_specification) + y_specification = QuantityHaloSpec( + n_points=n_points, + shape=y_quantity.data.shape, + strides=y_quantity.data.strides, + itemsize=y_quantity.data.itemsize, + origin=y_quantity.metadata.origin, + extent=y_quantity.metadata.extent, + dims=y_quantity.metadata.dims, + numpy_module=self._maybe_force_cpu(y_quantity.np), + dtype=y_quantity.metadata.dtype, + ) + y_specifications.append(y_specification) + + halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) + halo_updater.force_finalize_on_wait() + halo_updater.start(x_quantities, y_quantities) + return halo_updater + + def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + """ + req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) + req.wait() + + def start_synchronize_vector_interfaces( + self, x_quantity: Quantity, y_quantity: Quantity + ) -> HaloUpdateRequest: + """ + Synchronize shared points at the edges of a vector interface variable. + + Sends the values on the south and west edges to overwrite the values on adjacent + subtiles. Vector must be defined on the Arakawa C grid. + + For interface variables, the edges of the tile are computed on both ranks + bordering that edge. This routine copies values across those shared edges + so that both ranks have the same value for that edge. It also handles any + rotation of vector quantities needed to move data across the edge. + + Args: + x_quantity: the x-component quantity to be synchronized + y_quantity: the y-component quantity to be synchronized + + Returns: + request: an asynchronous request object with a .wait() method + """ + halo_updater = VectorInterfaceHaloUpdater( + comm=self.comm, + boundaries=self.boundaries, + force_cpu=self._force_cpu, + timer=self.timer, + ) + req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) + return req + + def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): + if len(specifications) == 0: + raise RuntimeError("Cannot create updater with specifications list") + if specifications[0].n_points == 0: + raise ValueError("cannot perform a halo update on zero halo points") + return HaloUpdater.from_scalar_specifications( + self, + self._maybe_force_cpu(specifications[0].numpy_module), + specifications, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def get_vector_halo_updater( + self, + specifications_x: List[QuantityHaloSpec], + specifications_y: List[QuantityHaloSpec], + ): + if len(specifications_x) == 0 and len(specifications_y) == 0: + raise RuntimeError("Cannot create updater with empty specifications list") + if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: + raise ValueError("Cannot perform a halo update on zero halo points") + return HaloUpdater.from_vector_specifications( + self, + self._maybe_force_cpu(specifications_x[0].numpy_module), + specifications_x, + specifications_y, + self.boundaries.values(), + self._get_halo_tag(), + self.timer, + ) + + def _get_halo_tag(self) -> int: + self._last_halo_tag += 1 + return self._last_halo_tag + + @property + def boundaries(self) -> Mapping[int, Boundary]: + """boundaries of this tile with neighboring tiles""" + if self._boundaries is None: + self._boundaries = {} + for boundary_type in constants.BOUNDARY_TYPES: + boundary = self.partitioner.boundary(boundary_type, self.rank) + if boundary is not None: + self._boundaries[boundary_type] = boundary + return self._boundaries + + def bcast_metadata_list(comm, quantity_list): is_root = comm.Get_rank() == constants.ROOT_RANK if is_root: diff --git a/ndsl/comm/partitioner.py b/ndsl/comm/partitioner.py index e3b2e02b..6b8750a1 100644 --- a/ndsl/comm/partitioner.py +++ b/ndsl/comm/partitioner.py @@ -1,3 +1,4 @@ +import abc import copy import functools from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union, cast @@ -17,7 +18,6 @@ WEST, ) from ndsl.quantity import Quantity, QuantityMetadata -from ndsl.typing import Partitioner from ndsl.utils import list_by_dims @@ -54,6 +54,83 @@ def get_tile_number(tile_rank: int, total_ranks: int) -> int: return tile_rank // ranks_per_tile + 1 +class Partitioner(abc.ABC): + @abc.abstractmethod + def __init__(self): + self.tile = None + self.layout = None + + @abc.abstractmethod + def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: + ... + + @abc.abstractmethod + def tile_index(self, rank: int): + pass + + @abc.abstractmethod + def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: + """Return the shape of a full tile representation for the given dimensions. + + Args: + metadata: quantity metadata + + Returns: + extent: shape of full tile representation + """ + pass + + @abc.abstractmethod + def subtile_slice( + self, + rank: int, + global_dims: Sequence[str], + global_extent: Sequence[int], + overlap: bool = False, + ) -> Tuple[Union[int, slice], ...]: + """Return the subtile slice of a given rank on an array. + + Global refers to the domain being partitioned. For example, for a partitioning + of a tile, the tile would be the "global" domain. + + Args: + rank: the rank of the process + global_dims: dimensions of the global quantity being partitioned + global_extent: extent of the global quantity being partitioned + overlap (optional): if True, for interface variables include the part + of the array shared by adjacent ranks in both ranks. If False, ensure + only one of those ranks (the greater rank) is assigned the overlapping + section. Default is False. + + Returns: + subtile_slice: the slice of the global compute domain corresponding + to the subtile compute domain + """ + pass + + @abc.abstractmethod + def subtile_extent( + self, + global_metadata: QuantityMetadata, + rank: int, + ) -> Tuple[int, ...]: + """Return the shape of a single rank representation for the given dimensions. + + Args: + global_metadata: quantity metadata. + rank: rank of the process. + + Returns: + extent: shape of a single rank representation for the given dimensions. + """ + pass + + @property + @abc.abstractmethod + def total_ranks(self) -> int: + pass + + class TilePartitioner(Partitioner): def __init__( self, diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index 2d973f7a..edf563b7 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -1,5 +1,5 @@ +from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.codepath import FV3CodePath -from ndsl.typing import Partitioner def identify_code_path( diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index f93d2baf..7f1c1477 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -6,12 +6,13 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.frontend.python.parser import DaceProgram +from ndsl.comm.communicator import Communicator +from ndsl.comm.partitioner import Partitioner from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.dsl.gt4py_utils import is_gpu_backend from ndsl.dsl.typing import floating_point_precision from ndsl.optional_imports import cupy as cp -from ndsl.typing import Communicator, Partitioner # This can be turned on to revert compilation for orchestration diff --git a/ndsl/dsl/dace/wrapped_halo_exchange.py b/ndsl/dsl/dace/wrapped_halo_exchange.py index ca36f3a0..78a68fa4 100644 --- a/ndsl/dsl/dace/wrapped_halo_exchange.py +++ b/ndsl/dsl/dace/wrapped_halo_exchange.py @@ -1,9 +1,9 @@ import dataclasses from typing import Any, List, Optional +from ndsl.comm.communicator import Communicator from ndsl.dsl.dace.orchestration import dace_inhibitor from ndsl.halo.updater import HaloUpdater -from ndsl.typing import Communicator class WrappedHaloUpdater: diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index f57c139a..b8316727 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -23,6 +23,7 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline from ndsl.comm.comm_abc import Comm +from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import block_waiting_for_compilation, unblock_waiting_tiles from ndsl.comm.mpi import MPI from ndsl.constants import X_DIM, X_DIMS, Y_DIM, Y_DIMS, Z_DIM, Z_DIMS @@ -34,7 +35,6 @@ # from ndsl import testing from ndsl.testing import comparison -from ndsl.typing import Communicator try: diff --git a/ndsl/dsl/stencil_config.py b/ndsl/dsl/stencil_config.py index e1e233b7..6b8f75eb 100644 --- a/ndsl/dsl/stencil_config.py +++ b/ndsl/dsl/stencil_config.py @@ -5,10 +5,11 @@ from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline +from ndsl.comm.communicator import Communicator from ndsl.comm.decomposition import determine_rank_is_compiling, set_distributed_caches +from ndsl.comm.partitioner import Partitioner from ndsl.dsl.dace.dace_config import DaceConfig, DaCeOrchestration from ndsl.dsl.gt4py_utils import is_gpu_backend -from ndsl.typing import Communicator, Partitioner class RunMode(enum.Enum): diff --git a/ndsl/grid/__init__.py b/ndsl/grid/__init__.py index 49eccf05..fabe72bf 100644 --- a/ndsl/grid/__init__.py +++ b/ndsl/grid/__init__.py @@ -1,5 +1,5 @@ from .eta import HybridPressureCoefficients -from .generation import GridDefinition, GridDefinitions, MetricTerms +from .generation import GridDefinitions, MetricTerms from .helper import ( AngleGridData, ContravariantGridData, diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 2d6450a9..12275d7d 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -5,6 +5,7 @@ import numpy as np +from ndsl.comm.communicator import Communicator from ndsl.constants import ( N_HALO_DEFAULT, PI, @@ -58,7 +59,6 @@ fill_corners_cgrid, fill_corners_dgrid, ) -from ndsl.typing import Communicator # TODO: when every environment in python3.8, remove diff --git a/ndsl/halo/__init__.py b/ndsl/halo/__init__.py index e69de29b..e16177d5 100644 --- a/ndsl/halo/__init__.py +++ b/ndsl/halo/__init__.py @@ -0,0 +1,5 @@ +from .data_transformer import ( + HaloDataTransformer, + HaloDataTransformerCPU, + HaloDataTransformerGPU, +) diff --git a/ndsl/halo/updater.py b/ndsl/halo/updater.py index 7684c564..665d0b95 100644 --- a/ndsl/halo/updater.py +++ b/ndsl/halo/updater.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: - from ndsl.typing import Communicator + from ndsl.comm.communicator import Communicator _HaloSendTuple = Tuple[AsyncRequest, Buffer] _HaloRecvTuple = Tuple[AsyncRequest, Buffer, np.ndarray] diff --git a/ndsl/monitor/netcdf_monitor.py b/ndsl/monitor/netcdf_monitor.py index 30731095..8a0b96fd 100644 --- a/ndsl/monitor/netcdf_monitor.py +++ b/ndsl/monitor/netcdf_monitor.py @@ -5,12 +5,12 @@ import fsspec import numpy as np +from ndsl.comm.communicator import Communicator from ndsl.filesystem import get_fs from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity -from ndsl.typing import Communicator class _TimeChunkedVariable: diff --git a/ndsl/monitor/zarr_monitor.py b/ndsl/monitor/zarr_monitor.py index 85e37222..214171be 100644 --- a/ndsl/monitor/zarr_monitor.py +++ b/ndsl/monitor/zarr_monitor.py @@ -4,13 +4,12 @@ import cftime import ndsl.constants as constants -from ndsl.comm.partitioner import subtile_slice +from ndsl.comm.partitioner import Partitioner, subtile_slice from ndsl.logging import ndsl_log from ndsl.monitor.convert import to_numpy from ndsl.optional_imports import cupy from ndsl.optional_imports import xarray as xr from ndsl.optional_imports import zarr -from ndsl.typing import Partitioner from ndsl.utils import list_by_dims diff --git a/ndsl/restart/_legacy_restart.py b/ndsl/restart/_legacy_restart.py index afa4d523..01f9bdb8 100644 --- a/ndsl/restart/_legacy_restart.py +++ b/ndsl/restart/_legacy_restart.py @@ -5,11 +5,11 @@ import ndsl.constants as constants import ndsl.filesystem as filesystem import ndsl.io as io +from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import get_tile_index from ndsl.optional_imports import xarray as xr from ndsl.quantity import Quantity from ndsl.restart._properties import RESTART_PROPERTIES, RestartProperties -from ndsl.typing import Communicator __all__ = ["open_restart"] diff --git a/ndsl/stencils/__init__.py b/ndsl/stencils/__init__.py index 0fe16725..641e032a 100644 --- a/ndsl/stencils/__init__.py +++ b/ndsl/stencils/__init__.py @@ -1,21 +1,5 @@ from .c2l_ord import CubedToLatLon from .corners import CopyCorners, CopyCornersXY, FillCornersBGrid -from .testing.grid import Grid # type: ignore -from .testing.parallel_translate import ( - ParallelTranslate, - ParallelTranslate2Py, - ParallelTranslate2PyState, - ParallelTranslateBaseSlicing, - ParallelTranslateGrid, -) -from .testing.savepoint import SavepointCase, Translate, dataset_to_dict -from .testing.temporaries import assert_same_temporaries, copy_temporaries -from .testing.translate import ( - TranslateFortranData2Py, - TranslateGrid, - pad_field_in_j, - read_serialized_data, -) __version__ = "0.2.0" diff --git a/ndsl/stencils/c2l_ord.py b/ndsl/stencils/c2l_ord.py index 87d59f27..67f2b5a1 100644 --- a/ndsl/stencils/c2l_ord.py +++ b/ndsl/stencils/c2l_ord.py @@ -8,13 +8,13 @@ ) import ndsl.dsl.gt4py_utils as utils +from ndsl.comm.communicator import Communicator from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM from ndsl.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ from ndsl.grid.helper import GridData from ndsl.initialization.allocator import QuantityFactory -from ndsl.typing import Communicator A1 = 0.5625 diff --git a/ndsl/stencils/testing/__init__.py b/ndsl/stencils/testing/__init__.py index e69de29b..4be2c60a 100644 --- a/ndsl/stencils/testing/__init__.py +++ b/ndsl/stencils/testing/__init__.py @@ -0,0 +1,16 @@ +from .grid import Grid # type: ignore +from .parallel_translate import ( + ParallelTranslate, + ParallelTranslate2Py, + ParallelTranslate2PyState, + ParallelTranslateBaseSlicing, + ParallelTranslateGrid, +) +from .savepoint import SavepointCase, Translate, dataset_to_dict +from .temporaries import assert_same_temporaries, copy_temporaries +from .translate import ( + TranslateFortranData2Py, + TranslateGrid, + pad_field_in_j, + read_serialized_data, +) diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index b3a3a7e8..d000e1fa 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -8,7 +8,11 @@ import yaml import ndsl.dsl -from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator +from ndsl.comm.communicator import ( + Communicator, + CubedSphereCommunicator, + TileCommunicator, +) from ndsl.comm.mpi import MPI from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig @@ -16,7 +20,6 @@ from ndsl.stencils.testing.parallel_translate import ParallelTranslate from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict from ndsl.stencils.testing.translate import TranslateGrid -from ndsl.typing import Communicator @pytest.fixture() diff --git a/ndsl/typing.py b/ndsl/typing.py index aa80c386..03f9624b 100644 --- a/ndsl/typing.py +++ b/ndsl/typing.py @@ -1,653 +1,9 @@ -import abc -from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast - -import numpy as np - -import ndsl.constants as constants -from ndsl.buffer import array_buffer, recv_buffer, send_buffer -from ndsl.comm import boundary as bd -from ndsl.comm.boundary import Boundary -from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater -from ndsl.performance.timer import NullTimer, Timer -from ndsl.quantity import Quantity, QuantityHaloSpec, QuantityMetadata -from ndsl.types import NumpyModule -from ndsl.utils import device_synchronize - - -try: - import cupy -except ImportError: - cupy = None - - -def to_numpy(array, dtype=None) -> np.ndarray: - """ - Input array can be a numpy array or a cupy array. Returns numpy array. - """ - try: - output = np.asarray(array) - except ValueError as err: - if err.args[0] == "object __array__ method not producing an array": - output = cupy.asnumpy(array) - else: - raise err - except TypeError as err: - if err.args[0].startswith( - "Implicit conversion to a NumPy array is not allowed." - ): - output = cupy.asnumpy(array) - else: - raise err - if dtype: - output = output.astype(dtype=dtype) - return output - - -class Checkpointer(abc.ABC): - @abc.abstractmethod - def __call__(self, savepoint_name, **kwargs): - ... - - -class Communicator(abc.ABC): - def __init__( - self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None - ): - self.comm = comm - self.partitioner: Partitioner = partitioner - self._force_cpu = force_cpu - self._boundaries: Optional[Mapping[int, Boundary]] = None - self._last_halo_tag = 0 - self.timer: Timer = timer if timer is not None else NullTimer() - - @abc.abstractproperty - def tile(self): - pass - - @classmethod - @abc.abstractmethod - def from_layout( - cls, - comm, - layout: Tuple[int, int], - force_cpu: bool = False, - timer: Optional[Timer] = None, - ): - pass - - @property - def rank(self) -> int: - """rank of the current process within this communicator""" - return self.comm.Get_rank() - - @property - def size(self) -> int: - """Total number of ranks in this communicator""" - return self.comm.Get_size() - - def _maybe_force_cpu(self, module: NumpyModule) -> NumpyModule: - """ - Get a numpy-like module depending on configuration and - Quantity original allocator. - """ - if self._force_cpu: - return np - return module - - @staticmethod - def _device_synchronize(): - """Wait for all work that could be in-flight to finish.""" - # this is a method so we can profile it separately from other device syncs - device_synchronize() - - def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Scatter(send, recv, **kwargs) - - def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: - self.comm.Gather(send, recv, **kwargs) - - def scatter( - self, - send_quantity: Optional[Quantity] = None, - recv_quantity: Optional[Quantity] = None, - ) -> Quantity: - """Transfer subtile regions of a full-tile quantity - from the tile root rank to all subtiles. - - Args: - send_quantity: quantity to send, only required/used on the tile root rank - recv_quantity: if provided, assign received data into this Quantity. - Returns: - recv_quantity - """ - if self.rank == constants.ROOT_RANK and send_quantity is None: - raise TypeError("send_quantity is a required argument on the root rank") - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - metadata = self.comm.bcast(send_quantity.metadata, root=constants.ROOT_RANK) - else: - metadata = self.comm.bcast(None, root=constants.ROOT_RANK) - shape = self.partitioner.subtile_extent(metadata, self.rank) - if recv_quantity is None: - recv_quantity = self._get_scatter_recv_quantity(shape, metadata) - if self.rank == constants.ROOT_RANK: - send_quantity = cast(Quantity, send_quantity) - with array_buffer( - self._maybe_force_cpu(metadata.np).zeros, - (self.partitioner.total_ranks,) + shape, - dtype=metadata.dtype, - ) as sendbuf: - for rank in range(0, self.partitioner.total_ranks): - subtile_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=metadata.dims, - global_extent=metadata.extent, - overlap=True, - ) - sendbuf.assign_from( - send_quantity.view[subtile_slice], - buffer_slice=np.index_exp[rank, :], - ) - self._Scatter( - metadata.np, - sendbuf.array, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - else: - self._Scatter( - metadata.np, - None, - recv_quantity.view[:], - root=constants.ROOT_RANK, - ) - return recv_quantity - - def _get_gather_recv_quantity( - self, global_extent: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving global data during gather""" - recv_quantity = Quantity( - send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - origin=tuple([0 for dim in send_metadata.dims]), - extent=global_extent, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def _get_scatter_recv_quantity( - self, shape: Sequence[int], send_metadata: QuantityMetadata - ) -> Quantity: - """Initialize a Quantity for use when receiving subtile data during scatter""" - recv_quantity = Quantity( - send_metadata.np.zeros(shape, dtype=send_metadata.dtype), - dims=send_metadata.dims, - units=send_metadata.units, - gt4py_backend=send_metadata.gt4py_backend, - allow_mismatch_float_precision=True, - ) - return recv_quantity - - def gather( - self, send_quantity: Quantity, recv_quantity: Quantity = None - ) -> Optional[Quantity]: - """Transfer subtile regions of a full-tile quantity - from each rank to the tile root rank. - - Args: - send_quantity: quantity to send - recv_quantity: if provided, assign received data into this Quantity (only - used on the tile root rank) - Returns: - recv_quantity: quantity if on root rank, otherwise None - """ - result: Optional[Quantity] - if self.rank == constants.ROOT_RANK: - with array_buffer( - send_quantity.np.zeros, - (self.partitioner.total_ranks,) + tuple(send_quantity.extent), - dtype=send_quantity.data.dtype, - ) as recvbuf: - self._Gather( - send_quantity.np, - send_quantity.view[:], - recvbuf.array, - root=constants.ROOT_RANK, - ) - if recv_quantity is None: - global_extent = self.partitioner.global_extent( - send_quantity.metadata - ) - recv_quantity = self._get_gather_recv_quantity( - global_extent, send_quantity.metadata - ) - for rank in range(self.partitioner.total_ranks): - to_slice = self.partitioner.subtile_slice( - rank=rank, - global_dims=recv_quantity.dims, - global_extent=recv_quantity.extent, - overlap=True, - ) - recvbuf.assign_to( - recv_quantity.view[to_slice], buffer_slice=np.index_exp[rank, :] - ) - result = recv_quantity - else: - self._Gather( - send_quantity.np, - send_quantity.view[:], - None, - root=constants.ROOT_RANK, - ) - result = None - return result - - def gather_state(self, send_state=None, recv_state=None, transfer_type=None): - """Transfer a state dictionary from subtile ranks to the tile root rank. - - 'time' is assumed to be the same on all ranks, and its value will be set - to the value from the root rank. - - Args: - send_state: the model state to be sent containing the subtile data - recv_state: the pre-allocated state in which to recieve the full tile - state. Only variables which are scattered will be written to. - Returns: - recv_state: on the root rank, the state containing the entire tile - """ - if self.rank == constants.ROOT_RANK and recv_state is None: - recv_state = {} - for name, quantity in send_state.items(): - if name == "time": - if self.rank == constants.ROOT_RANK: - recv_state["time"] = send_state["time"] - else: - gather_value = to_numpy(quantity.view[:], dtype=transfer_type) - gather_quantity = Quantity( - data=gather_value, - dims=quantity.dims, - units=quantity.units, - allow_mismatch_float_precision=True, - ) - if recv_state is not None and name in recv_state: - tile_quantity = self.gather( - gather_quantity, recv_quantity=recv_state[name] - ) - else: - tile_quantity = self.gather(gather_quantity) - if self.rank == constants.ROOT_RANK: - recv_state[name] = tile_quantity - del gather_quantity - return recv_state - - def scatter_state(self, send_state=None, recv_state=None): - """Transfer a state dictionary from the tile root rank to all subtiles. - - Args: - send_state: the model state to be sent containing the entire tile, - required only from the root rank - recv_state: the pre-allocated state in which to recieve the scattered - state. Only variables which are scattered will be written to. - Returns: - rank_state: the state corresponding to this rank's subdomain - """ - - def scatter_root(): - if send_state is None: - raise TypeError("send_state is a required argument on the root rank") - name_list = list(send_state.keys()) - while "time" in name_list: - name_list.remove("time") - name_list = self.comm.bcast(name_list, root=constants.ROOT_RANK) - array_list = [send_state[name] for name in name_list] - for name, array in zip(name_list, array_list): - if name in recv_state: - self.scatter(send_quantity=array, recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter(send_quantity=array) - recv_state["time"] = self.comm.bcast( - send_state.get("time", None), root=constants.ROOT_RANK - ) - - def scatter_client(): - name_list = self.comm.bcast(None, root=constants.ROOT_RANK) - for name in name_list: - if name in recv_state: - self.scatter(recv_quantity=recv_state[name]) - else: - recv_state[name] = self.scatter() - recv_state["time"] = self.comm.bcast(None, root=constants.ROOT_RANK) - - if recv_state is None: - recv_state = {} - if self.rank == constants.ROOT_RANK: - scatter_root() - else: - scatter_client() - if recv_state["time"] is None: - recv_state.pop("time") - return recv_state - - def halo_update(self, quantity: Union[Quantity, List[Quantity]], n_points: int): - """Perform a halo update on a quantity or quantities - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - halo_updater = self.start_halo_update(quantities, n_points) - halo_updater.wait() - - def start_halo_update( - self, quantity: Union[Quantity, List[Quantity]], n_points: int - ) -> HaloUpdater: - """Start an asynchronous halo update on a quantity. - - Args: - quantity: the quantity to be updated - n_points: how many halo points to update, starting from the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(quantity, Quantity): - quantities = [quantity] - else: - quantities = quantity - - specifications = [] - for quantity in quantities: - specification = QuantityHaloSpec( - n_points=n_points, - shape=quantity.data.shape, - strides=quantity.data.strides, - itemsize=quantity.data.itemsize, - origin=quantity.origin, - extent=quantity.extent, - dims=quantity.dims, - numpy_module=self._maybe_force_cpu(quantity.np), - dtype=quantity.metadata.dtype, - ) - specifications.append(specification) - - halo_updater = self.get_scalar_halo_updater(specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(quantities) - return halo_updater - - def vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ): - """Perform a halo update of a horizontal vector quantity or quantities. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - halo_updater = self.start_vector_halo_update( - x_quantities, y_quantities, n_points - ) - halo_updater.wait() - - def start_vector_halo_update( - self, - x_quantity: Union[Quantity, List[Quantity]], - y_quantity: Union[Quantity, List[Quantity]], - n_points: int, - ) -> HaloUpdater: - """Start an asynchronous halo update of a horizontal vector quantity. - - Assumes the x and y dimension indices are the same between the two quantities. - - Args: - x_quantity: the x-component quantity to be halo updated - y_quantity: the y-component quantity to be halo updated - n_points: how many halo points to update, starting at the interior - - Returns: - request: an asynchronous request object with a .wait() method - """ - if isinstance(x_quantity, Quantity): - x_quantities = [x_quantity] - else: - x_quantities = x_quantity - if isinstance(y_quantity, Quantity): - y_quantities = [y_quantity] - else: - y_quantities = y_quantity - - x_specifications = [] - y_specifications = [] - for x_quantity, y_quantity in zip(x_quantities, y_quantities): - x_specification = QuantityHaloSpec( - n_points=n_points, - shape=x_quantity.data.shape, - strides=x_quantity.data.strides, - itemsize=x_quantity.data.itemsize, - origin=x_quantity.metadata.origin, - extent=x_quantity.metadata.extent, - dims=x_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(x_quantity.np), - dtype=x_quantity.metadata.dtype, - ) - x_specifications.append(x_specification) - y_specification = QuantityHaloSpec( - n_points=n_points, - shape=y_quantity.data.shape, - strides=y_quantity.data.strides, - itemsize=y_quantity.data.itemsize, - origin=y_quantity.metadata.origin, - extent=y_quantity.metadata.extent, - dims=y_quantity.metadata.dims, - numpy_module=self._maybe_force_cpu(y_quantity.np), - dtype=y_quantity.metadata.dtype, - ) - y_specifications.append(y_specification) - - halo_updater = self.get_vector_halo_updater(x_specifications, y_specifications) - halo_updater.force_finalize_on_wait() - halo_updater.start(x_quantities, y_quantities) - return halo_updater - - def synchronize_vector_interfaces(self, x_quantity: Quantity, y_quantity: Quantity): - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - """ - req = self.start_synchronize_vector_interfaces(x_quantity, y_quantity) - req.wait() - - def start_synchronize_vector_interfaces( - self, x_quantity: Quantity, y_quantity: Quantity - ) -> HaloUpdateRequest: - """ - Synchronize shared points at the edges of a vector interface variable. - - Sends the values on the south and west edges to overwrite the values on adjacent - subtiles. Vector must be defined on the Arakawa C grid. - - For interface variables, the edges of the tile are computed on both ranks - bordering that edge. This routine copies values across those shared edges - so that both ranks have the same value for that edge. It also handles any - rotation of vector quantities needed to move data across the edge. - - Args: - x_quantity: the x-component quantity to be synchronized - y_quantity: the y-component quantity to be synchronized - - Returns: - request: an asynchronous request object with a .wait() method - """ - halo_updater = VectorInterfaceHaloUpdater( - comm=self.comm, - boundaries=self.boundaries, - force_cpu=self._force_cpu, - timer=self.timer, - ) - req = halo_updater.start_synchronize_vector_interfaces(x_quantity, y_quantity) - return req - - def get_scalar_halo_updater(self, specifications: List[QuantityHaloSpec]): - if len(specifications) == 0: - raise RuntimeError("Cannot create updater with specifications list") - if specifications[0].n_points == 0: - raise ValueError("cannot perform a halo update on zero halo points") - return HaloUpdater.from_scalar_specifications( - self, - self._maybe_force_cpu(specifications[0].numpy_module), - specifications, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def get_vector_halo_updater( - self, - specifications_x: List[QuantityHaloSpec], - specifications_y: List[QuantityHaloSpec], - ): - if len(specifications_x) == 0 and len(specifications_y) == 0: - raise RuntimeError("Cannot create updater with empty specifications list") - if specifications_x[0].n_points == 0 and specifications_y[0].n_points == 0: - raise ValueError("Cannot perform a halo update on zero halo points") - return HaloUpdater.from_vector_specifications( - self, - self._maybe_force_cpu(specifications_x[0].numpy_module), - specifications_x, - specifications_y, - self.boundaries.values(), - self._get_halo_tag(), - self.timer, - ) - - def _get_halo_tag(self) -> int: - self._last_halo_tag += 1 - return self._last_halo_tag - - @property - def boundaries(self) -> Mapping[int, Boundary]: - """boundaries of this tile with neighboring tiles""" - if self._boundaries is None: - self._boundaries = {} - for boundary_type in constants.BOUNDARY_TYPES: - boundary = self.partitioner.boundary(boundary_type, self.rank) - if boundary is not None: - self._boundaries[boundary_type] = boundary - return self._boundaries - - -class Partitioner(abc.ABC): - @abc.abstractmethod - def __init__(self): - self.tile = None - self.layout = None - - @abc.abstractmethod - def boundary(self, boundary_type: int, rank: int) -> Optional[bd.SimpleBoundary]: - ... - - @abc.abstractmethod - def tile_index(self, rank: int): - pass - - @abc.abstractmethod - def global_extent(self, rank_metadata: QuantityMetadata) -> Tuple[int, ...]: - """Return the shape of a full tile representation for the given dimensions. - - Args: - metadata: quantity metadata - - Returns: - extent: shape of full tile representation - """ - pass - - @abc.abstractmethod - def subtile_slice( - self, - rank: int, - global_dims: Sequence[str], - global_extent: Sequence[int], - overlap: bool = False, - ) -> Tuple[Union[int, slice], ...]: - """Return the subtile slice of a given rank on an array. - - Global refers to the domain being partitioned. For example, for a partitioning - of a tile, the tile would be the "global" domain. - - Args: - rank: the rank of the process - global_dims: dimensions of the global quantity being partitioned - global_extent: extent of the global quantity being partitioned - overlap (optional): if True, for interface variables include the part - of the array shared by adjacent ranks in both ranks. If False, ensure - only one of those ranks (the greater rank) is assigned the overlapping - section. Default is False. - - Returns: - subtile_slice: the slice of the global compute domain corresponding - to the subtile compute domain - """ - pass - - @abc.abstractmethod - def subtile_extent( - self, - global_metadata: QuantityMetadata, - rank: int, - ) -> Tuple[int, ...]: - """Return the shape of a single rank representation for the given dimensions. - - Args: - global_metadata: quantity metadata. - rank: rank of the process. - - Returns: - extent: shape of a single rank representation for the given dimensions. - """ - pass - - @property - @abc.abstractmethod - def total_ranks(self) -> int: - pass +# flake8: noqa +from ndsl.checkpointer.base import Checkpointer +from ndsl.comm.communicator import Communicator +from ndsl.comm.local_comm import AsyncResult, ConcurrencyError +from ndsl.comm.null_comm import NullAsyncResult +from ndsl.comm.partitioner import Partitioner +from ndsl.performance.collector import AbstractPerformanceCollector +from ndsl.types import AsyncRequest, NumpyModule +from ndsl.units import UnitsError diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index 756de952..ac189ad8 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -3,7 +3,6 @@ from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region from ndsl import ( - CompareToNumpyStencil, CompilationConfig, DaceConfig, FrozenStencil, @@ -13,7 +12,7 @@ ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py_utils import make_storage_from_shape -from ndsl.dsl.stencil import get_stencils_with_varied_bounds +from ndsl.dsl.stencil import CompareToNumpyStencil, get_stencils_with_varied_bounds from ndsl.dsl.typing import FloatField diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 79e71cf7..42fdcbec 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from ndsl import ConcurrencyError, DummyComm +from ndsl import DummyComm from ndsl.buffer import recv_buffer +from ndsl.typing import ConcurrencyError from tests.mpi.mpi_comm import MPI diff --git a/tests/test_halo_data_transformer.py b/tests/test_halo_data_transformer.py index ec986f8c..e3f6d851 100644 --- a/tests/test_halo_data_transformer.py +++ b/tests/test_halo_data_transformer.py @@ -4,13 +4,8 @@ import numpy as np import pytest -from ndsl import ( - Buffer, - HaloDataTransformer, - HaloExchangeSpec, - Quantity, - QuantityHaloSpec, -) +from ndsl import HaloExchangeSpec, Quantity +from ndsl.buffer import Buffer from ndsl.comm import _boundary_utils from ndsl.constants import ( EAST, @@ -28,7 +23,9 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.halo import HaloDataTransformer from ndsl.halo.rotate import rotate_scalar_data, rotate_vector_data +from ndsl.quantity import QuantityHaloSpec @pytest.fixture diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index d0536b24..afff2bd1 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -10,7 +10,6 @@ HaloUpdater, OutOfBoundsError, Quantity, - QuantityHaloSpec, TileCommunicator, TilePartitioner, Timer, @@ -34,6 +33,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.quantity import QuantityHaloSpec @pytest.fixture From a40a026d1743cca1005f02f46d2da58a5997ccda Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 7 Mar 2024 10:39:48 -0500 Subject: [PATCH 09/12] Changes requested 1000 7 Mar 2024, PR 14 from Florian --- ndsl/checkpointer/__init__.py | 2 +- ndsl/comm/__init__.py | 1 - ndsl/comm/communicator.py | 2 +- ndsl/dsl/stencil.py | 2 -- tests/checkpointer/__init__.py | 0 tests/dsl/__init__.py | 0 tests/dsl/test_stencil_wrapper.py | 4 ++-- tests/quantity/__init__.py | 0 8 files changed, 4 insertions(+), 7 deletions(-) delete mode 100644 tests/checkpointer/__init__.py delete mode 100644 tests/dsl/__init__.py delete mode 100644 tests/quantity/__init__.py diff --git a/ndsl/checkpointer/__init__.py b/ndsl/checkpointer/__init__.py index 6486d96c..8fee4dc1 100644 --- a/ndsl/checkpointer/__init__.py +++ b/ndsl/checkpointer/__init__.py @@ -1,5 +1,5 @@ from .null import NullCheckpointer -from .snapshots import SnapshotCheckpointer, _Snapshots +from .snapshots import SnapshotCheckpointer from .thresholds import ( InsufficientTrialsError, SavepointThresholds, diff --git a/ndsl/comm/__init__.py b/ndsl/comm/__init__.py index 289e6413..c4a58658 100644 --- a/ndsl/comm/__init__.py +++ b/ndsl/comm/__init__.py @@ -4,6 +4,5 @@ CachingCommWriter, CachingRequestReader, CachingRequestWriter, - NullRequest, ) from .comm_abc import Comm, Request diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index f1b97c87..ff270df5 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -54,7 +54,7 @@ def __init__( self.timer: Timer = timer if timer is not None else NullTimer() @abc.abstractproperty - def tile(self): + def tile(self) -> "TileCommunicator": pass @classmethod diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index b8316727..75ef28ea 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -32,8 +32,6 @@ from ndsl.dsl.typing import Float, Index3D, cast_to_index3d from ndsl.initialization.sizer import GridSizer, SubtileGridSizer from ndsl.quantity import Quantity - -# from ndsl import testing from ndsl.testing import comparison diff --git a/tests/checkpointer/__init__.py b/tests/checkpointer/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/dsl/__init__.py b/tests/dsl/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/dsl/test_stencil_wrapper.py b/tests/dsl/test_stencil_wrapper.py index cfe56ded..986883dc 100644 --- a/tests/dsl/test_stencil_wrapper.py +++ b/tests/dsl/test_stencil_wrapper.py @@ -285,14 +285,14 @@ def test_backend_options( "backend": "numpy", "rebuild": True, "format_source": False, - "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + "name": "test_stencil_wrapper.copy_stencil", }, "cuda": { "backend": "cuda", "rebuild": True, "device_sync": False, "format_source": False, - "name": "tests.dsl.test_stencil_wrapper.copy_stencil", + "name": "test_stencil_wrapper.copy_stencil", }, } diff --git a/tests/quantity/__init__.py b/tests/quantity/__init__.py deleted file mode 100644 index e69de29b..00000000 From 13f4f1ab86160765bc430ec7522348cafe404815 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 7 Mar 2024 14:15:28 -0500 Subject: [PATCH 10/12] Changes from comments (pokes) --- ndsl/__init__.py | 8 +------- ndsl/initialization/__init__.py | 1 + ndsl/monitor/__init__.py | 3 +++ ndsl/performance/__init__.py | 2 ++ tests/test_cube_scatter_gather.py | 2 +- tests/test_g2g_communication.py | 2 +- tests/test_halo_update.py | 2 +- tests/test_halo_update_ranks.py | 2 +- tests/test_netcdf_monitor.py | 2 +- tests/test_sync_shared_boundary.py | 2 +- tests/test_timer.py | 2 +- tests/test_zarr_monitor.py | 10 ++-------- 12 files changed, 16 insertions(+), 22 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index f8c64614..5ceae724 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -22,17 +22,11 @@ from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log -from .monitor.netcdf_monitor import NetCDFMonitor -from .monitor.protocol import Protocol -from .monitor.zarr_monitor import ZarrMonitor from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector -from .performance.config import PerformanceConfig from .performance.profiler import NullProfiler, Profiler from .performance.report import Experiment, Report, TimeReport -from .performance.timer import NullTimer, Timer from .quantity import Quantity from .testing.dummy_comm import DummyComm -from .types import Allocator, AsyncRequest, NumpyModule -from .units import UnitsError +from .types import Allocator from .utils import MetaEnumStr diff --git a/ndsl/initialization/__init__.py b/ndsl/initialization/__init__.py index e69de29b..8f40c7af 100644 --- a/ndsl/initialization/__init__.py +++ b/ndsl/initialization/__init__.py @@ -0,0 +1 @@ +from .sizer import GridSizer diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index e69de29b..a0c7e036 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -0,0 +1,3 @@ +from .netcdf_monitor import NetCDFMonitor +from .protocol import Monitor +from .zarr_monitor import ZarrMonitor diff --git a/ndsl/performance/__init__.py b/ndsl/performance/__init__.py index e69de29b..28e03bc6 100644 --- a/ndsl/performance/__init__.py +++ b/ndsl/performance/__init__.py @@ -0,0 +1,2 @@ +from .config import PerformanceConfig +from .timer import NullTimer, Timer diff --git a/tests/test_cube_scatter_gather.py b/tests/test_cube_scatter_gather.py index 7966142d..aee22533 100644 --- a/tests/test_cube_scatter_gather.py +++ b/tests/test_cube_scatter_gather.py @@ -9,7 +9,6 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import ( HORIZONTAL_DIMS, @@ -21,6 +20,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.performance import Timer try: diff --git a/tests/test_g2g_communication.py b/tests/test_g2g_communication.py index 28f1af7c..17a58785 100644 --- a/tests/test_g2g_communication.py +++ b/tests/test_g2g_communication.py @@ -14,9 +14,9 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.performance import Timer try: diff --git a/tests/test_halo_update.py b/tests/test_halo_update.py index afff2bd1..3d3bf501 100644 --- a/tests/test_halo_update.py +++ b/tests/test_halo_update.py @@ -12,7 +12,6 @@ Quantity, TileCommunicator, TilePartitioner, - Timer, ) from ndsl.buffer import BUFFER_CACHE from ndsl.comm._boundary_utils import get_boundary_slice @@ -33,6 +32,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.performance import Timer from ndsl.quantity import QuantityHaloSpec diff --git a/tests/test_halo_update_ranks.py b/tests/test_halo_update_ranks.py index 8ec77cc1..6ceb4886 100644 --- a/tests/test_halo_update_ranks.py +++ b/tests/test_halo_update_ranks.py @@ -6,7 +6,6 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import ( X_DIM, @@ -16,6 +15,7 @@ Z_DIM, Z_INTERFACE_DIM, ) +from ndsl.performance import Timer @pytest.fixture diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 7a21dd78..326739b0 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -10,10 +10,10 @@ CubedSphereCommunicator, CubedSpherePartitioner, DummyComm, - NetCDFMonitor, Quantity, TilePartitioner, ) +from ndsl.monitor import NetCDFMonitor from ndsl.optional_imports import xarray as xr diff --git a/tests/test_sync_shared_boundary.py b/tests/test_sync_shared_boundary.py index 3e0930a0..7db5a621 100644 --- a/tests/test_sync_shared_boundary.py +++ b/tests/test_sync_shared_boundary.py @@ -6,9 +6,9 @@ DummyComm, Quantity, TilePartitioner, - Timer, ) from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM +from ndsl.performance import Timer @pytest.fixture diff --git a/tests/test_timer.py b/tests/test_timer.py index 213a487a..bb8ec3a6 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -2,7 +2,7 @@ import pytest -from ndsl import NullTimer, Timer +from ndsl.performance import NullTimer, Timer @pytest.fixture diff --git a/tests/test_zarr_monitor.py b/tests/test_zarr_monitor.py index b608ec08..e40d5210 100644 --- a/tests/test_zarr_monitor.py +++ b/tests/test_zarr_monitor.py @@ -12,13 +12,7 @@ import cftime import pytest -from ndsl import ( - CubedSpherePartitioner, - DummyComm, - Quantity, - TilePartitioner, - ZarrMonitor, -) +from ndsl import CubedSpherePartitioner, DummyComm, Quantity, TilePartitioner from ndsl.constants import ( X_DIM, X_DIMS, @@ -28,7 +22,7 @@ Y_INTERFACE_DIM, Z_DIM, ) -from ndsl.monitor.zarr_monitor import array_chunks, get_calendar +from ndsl.monitor.zarr_monitor import ZarrMonitor, array_chunks, get_calendar from ndsl.optional_imports import xarray as xr From 14dc28b2f7452e46aad02dccdc54428ecd96bd30 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Fri, 8 Mar 2024 11:26:40 -0500 Subject: [PATCH 11/12] Poke changes --- ndsl/__init__.py | 1 + ndsl/monitor/__init__.py | 1 - tests/test_netcdf_monitor.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ndsl/__init__.py b/ndsl/__init__.py index 5ceae724..a2f771cd 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -22,6 +22,7 @@ from .initialization.allocator import QuantityFactory from .initialization.sizer import GridSizer, SubtileGridSizer from .logging import ndsl_log +from .monitor.netcdf_monitor import NetCDFMonitor from .namelist import Namelist from .performance.collector import NullPerformanceCollector, PerformanceCollector from .performance.profiler import NullProfiler, Profiler diff --git a/ndsl/monitor/__init__.py b/ndsl/monitor/__init__.py index a0c7e036..5d732315 100644 --- a/ndsl/monitor/__init__.py +++ b/ndsl/monitor/__init__.py @@ -1,3 +1,2 @@ -from .netcdf_monitor import NetCDFMonitor from .protocol import Monitor from .zarr_monitor import ZarrMonitor diff --git a/tests/test_netcdf_monitor.py b/tests/test_netcdf_monitor.py index 326739b0..7a21dd78 100644 --- a/tests/test_netcdf_monitor.py +++ b/tests/test_netcdf_monitor.py @@ -10,10 +10,10 @@ CubedSphereCommunicator, CubedSpherePartitioner, DummyComm, + NetCDFMonitor, Quantity, TilePartitioner, ) -from ndsl.monitor import NetCDFMonitor from ndsl.optional_imports import xarray as xr From f813fb162335186b3d1e4c85750f3c3416a48403 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Mon, 11 Mar 2024 11:01:59 -0400 Subject: [PATCH 12/12] Imported UnitsError and ConcurrencyError to exceptions, moved AsyncResult and NullAsyncResult out of typing --- ndsl/exceptions.py | 5 +++++ ndsl/typing.py | 3 --- tests/mpi/test_mpi_mock.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ndsl/exceptions.py b/ndsl/exceptions.py index 329b1b54..fa5d118a 100644 --- a/ndsl/exceptions.py +++ b/ndsl/exceptions.py @@ -1,2 +1,7 @@ +# flake8: noqa +from ndsl.comm.local_comm import ConcurrencyError +from ndsl.units import UnitsError + + class OutOfBoundsError(ValueError): pass diff --git a/ndsl/typing.py b/ndsl/typing.py index 03f9624b..ddbf1681 100644 --- a/ndsl/typing.py +++ b/ndsl/typing.py @@ -1,9 +1,6 @@ # flake8: noqa from ndsl.checkpointer.base import Checkpointer from ndsl.comm.communicator import Communicator -from ndsl.comm.local_comm import AsyncResult, ConcurrencyError -from ndsl.comm.null_comm import NullAsyncResult from ndsl.comm.partitioner import Partitioner from ndsl.performance.collector import AbstractPerformanceCollector from ndsl.types import AsyncRequest, NumpyModule -from ndsl.units import UnitsError diff --git a/tests/mpi/test_mpi_mock.py b/tests/mpi/test_mpi_mock.py index 42fdcbec..b8202995 100644 --- a/tests/mpi/test_mpi_mock.py +++ b/tests/mpi/test_mpi_mock.py @@ -3,7 +3,7 @@ from ndsl import DummyComm from ndsl.buffer import recv_buffer -from ndsl.typing import ConcurrencyError +from ndsl.exceptions import ConcurrencyError from tests.mpi.mpi_comm import MPI