diff --git a/pyproject.toml b/pyproject.toml index 00665b9..1b1f2bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,4 +119,4 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] warn_unreachable = true disallow_untyped_defs = false disallow_incomplete_defs = false -disable_error_code = ["import-untyped", "import-not-found", "no-untyped-call"] +disable_error_code = ["import-untyped", "import-not-found"] diff --git a/yt_experiments/tiled_grid/tests/test_tiled_grid.py b/yt_experiments/tiled_grid/tests/test_tiled_grid.py index 08895a1..39a9226 100644 --- a/yt_experiments/tiled_grid/tests/test_tiled_grid.py +++ b/yt_experiments/tiled_grid/tests/test_tiled_grid.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import unyt from numpy.testing import assert_equal from yt.testing import fake_amr_ds, requires_module @@ -78,6 +79,16 @@ def test_arbitrary_grid_oct(): assert level_arrays[ilev].shape == expected_levels[ilev] +def test_missing_ds(): + with pytest.raises(ValueError, match="Please provide a dataset"): + _ = YTTiledArbitraryGrid( + unyt.unyt_array([0, 0, 0], "m"), + unyt.unyt_array([1, 1, 1], "m"), + (20, 20, 20), + 5, + ) + + @requires_module("xarray") def test_arbitrary_grid_to_xarray(): import xarray as xr diff --git a/yt_experiments/tiled_grid/tiled_grid.py b/yt_experiments/tiled_grid/tiled_grid.py index 7729a7e..77a6afd 100644 --- a/yt_experiments/tiled_grid/tiled_grid.py +++ b/yt_experiments/tiled_grid/tiled_grid.py @@ -8,6 +8,10 @@ from yt.data_objects.construction_data_containers import YTArbitraryGrid from yt.data_objects.static_output import Dataset +_GridInfo = tuple[ + npt.NDArray, npt.NDArray, unyt.unyt_array, unyt.unyt_array, Any, npt.NDArray +] + def _validate_edge(edge: npt.ArrayLike, ds: Dataset): if not isinstance(edge, unyt.unyt_array): @@ -62,6 +66,9 @@ def __init__( """ + if ds is None: + raise ValueError("Please provide a dataset via the ds keyword argument") + self.ds = ds self.left_edge = _validate_edge(left_edge, ds) self.right_edge = _validate_edge(right_edge, ds) @@ -86,7 +93,7 @@ def __init__( self._left_cell_center = self.left_edge + self.dds / 2.0 self._right_cell_center = self.right_edge - self.dds / 2.0 - def __repr__(self): + def __repr__(self) -> str: nm = self.__class__.__name__ shape = tuple(self.dims) n_chunks = tuple(self.nchunks) @@ -97,13 +104,13 @@ def __repr__(self): ) return msg - def _get_grid_by_ijk(self, ijk_grid): + def _get_grid_by_ijk(self, ijk_grid: npt.NDArray[int]) -> _GridInfo: chunksizes = self.chunks le_index = [] re_index = [] - le_val = self.ds.domain_left_edge.copy() - re_val = self.ds.domain_right_edge.copy() + le_val: unyt.unyt_array = self.ds.domain_left_edge.copy() + re_val: unyt.unyt_array = self.ds.domain_right_edge.copy() for idim in range(self._ndim): chunk_i = ijk_grid[idim] @@ -122,29 +129,29 @@ def _get_grid_by_ijk(self, ijk_grid): le_index[2] : re_index[2], ] - le_index = np.array(le_index, dtype=int) - re_index = np.array(re_index, dtype=int) + le_index_ = np.array(le_index, dtype=int) + re_index_ = np.array(re_index, dtype=int) shape = chunksizes - return le_index, re_index, le_val, re_val, slc, shape + return le_index_, re_index_, le_val, re_val, slc, shape - def _get_grid(self, igrid: int): + def _get_grid(self, igrid: int) -> _GridInfo: # get grid extent of a **single** grid ijk_grid = np.unravel_index(igrid, self.nchunks) return self._get_grid_by_ijk(ijk_grid) - def _coord_array(self, idim): + def _coord_array(self, idim: int) -> npt.NDArray: LE = self._left_cell_center[idim] RE = self._right_cell_center[idim] N = self.dims[idim] return np.mgrid[LE : RE : N * 1j] - def to_xarray(self, field, *, output_array=None): + def to_xarray( + self, field: tuple[str, str], *, output_array: npt.ArrayLike | None = None + ) -> Any: import xarray as xr - # ToDo: import from on_demand_imports - vals = self.to_array(field, output_array=output_array) dims = self.ds.coordinates.axis_order @@ -162,7 +169,13 @@ def to_xarray(self, field, *, output_array=None): ) return xr_ds - def single_grid_values(self, igrid, field, *, ops=None): + def single_grid_values( + self, + igrid: int, + field: tuple[str, str], + *, + ops: list[Callable[[npt.NDArray], npt.NDArray]] | None = None, + ) -> tuple[npt.NDArray, Any]: """ Get the values for a field for a single grid chunk as in-memory array. @@ -308,7 +321,9 @@ def __init__( self.levels: list[YTTiledArbitraryGrid] = levels - def _validate_levels(self, levels): + def _validate_levels( + self, levels: Sequence[int | tuple[int, int, int] | npt.ArrayLike] + ): for ilev in range(1, self.n_levels): res = np.prod(levels[ilev]) @@ -321,7 +336,7 @@ def _validate_levels(self, levels): ) raise ValueError(msg) - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.__class__.__name__} with {self.n_levels} levels and base resolution " f"{self.base_resolution}" @@ -330,7 +345,11 @@ def __repr__(self): def base_resolution(self) -> tuple[int, int, int]: return tuple(self[0].dims) - def to_arrays(self, field, output_arrays=None): + def to_arrays( + self, + field: tuple[str, str], + output_arrays: list[npt.ArrayLike | None] | None = None, + ) -> list[npt.ArrayLike]: if output_arrays is None: output_arrays = [None for _ in range(len(self.levels))] @@ -390,7 +409,14 @@ def _validate_factor( return np.asarray(input_factor, dtype=int) -def _get_filled_grid(le, re, shp, field, ds, field_parameters): +def _get_filled_grid( + le: npt.NDArray, + re: npt.NDArray, + shp: npt.NDArray, + field: tuple[str, str], + ds: Dataset, + field_parameters: Any, +) -> npt.NDArray: grid = YTArbitraryGrid(le, re, shp, ds=ds, field_parameters=field_parameters) vals = grid[field] return vals