Skip to content

Commit

Permalink
further improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Oct 11, 2024
1 parent 616c975 commit f5c368d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,4 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
disable_error_code = ["import-untyped", "import-not-found", "no-untyped-call"]
disable_error_code = ["import-untyped", "import-not-found"]
11 changes: 11 additions & 0 deletions yt_experiments/tiled_grid/tests/test_tiled_grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import unyt
from numpy.testing import assert_equal
from yt.testing import fake_amr_ds, requires_module
Expand Down Expand Up @@ -78,6 +79,16 @@ def test_arbitrary_grid_oct():
assert level_arrays[ilev].shape == expected_levels[ilev]


def test_missing_ds():
with pytest.raises(ValueError, match="Please provide a dataset"):
_ = YTTiledArbitraryGrid(
unyt.unyt_array([0, 0, 0], "m"),
unyt.unyt_array([1, 1, 1], "m"),
(20, 20, 20),
5,
)


@requires_module("xarray")
def test_arbitrary_grid_to_xarray():
import xarray as xr
Expand Down
60 changes: 43 additions & 17 deletions yt_experiments/tiled_grid/tiled_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from yt.data_objects.construction_data_containers import YTArbitraryGrid
from yt.data_objects.static_output import Dataset

_GridInfo = tuple[
npt.NDArray, npt.NDArray, unyt.unyt_array, unyt.unyt_array, Any, npt.NDArray
]


def _validate_edge(edge: npt.ArrayLike, ds: Dataset):
if not isinstance(edge, unyt.unyt_array):
Expand Down Expand Up @@ -62,6 +66,9 @@ def __init__(
"""

if ds is None:
raise ValueError("Please provide a dataset via the ds keyword argument")

self.ds = ds
self.left_edge = _validate_edge(left_edge, ds)
self.right_edge = _validate_edge(right_edge, ds)
Expand All @@ -86,7 +93,7 @@ def __init__(
self._left_cell_center = self.left_edge + self.dds / 2.0
self._right_cell_center = self.right_edge - self.dds / 2.0

def __repr__(self):
def __repr__(self) -> str:
nm = self.__class__.__name__
shape = tuple(self.dims)
n_chunks = tuple(self.nchunks)
Expand All @@ -97,13 +104,13 @@ def __repr__(self):
)
return msg

def _get_grid_by_ijk(self, ijk_grid):
def _get_grid_by_ijk(self, ijk_grid: npt.NDArray[int]) -> _GridInfo:
chunksizes = self.chunks

le_index = []
re_index = []
le_val = self.ds.domain_left_edge.copy()
re_val = self.ds.domain_right_edge.copy()
le_val: unyt.unyt_array = self.ds.domain_left_edge.copy()
re_val: unyt.unyt_array = self.ds.domain_right_edge.copy()

for idim in range(self._ndim):
chunk_i = ijk_grid[idim]
Expand All @@ -122,29 +129,29 @@ def _get_grid_by_ijk(self, ijk_grid):
le_index[2] : re_index[2],
]

le_index = np.array(le_index, dtype=int)
re_index = np.array(re_index, dtype=int)
le_index_ = np.array(le_index, dtype=int)
re_index_ = np.array(re_index, dtype=int)
shape = chunksizes

return le_index, re_index, le_val, re_val, slc, shape
return le_index_, re_index_, le_val, re_val, slc, shape

def _get_grid(self, igrid: int):
def _get_grid(self, igrid: int) -> _GridInfo:
# get grid extent of a **single** grid
ijk_grid = np.unravel_index(igrid, self.nchunks)
return self._get_grid_by_ijk(ijk_grid)

def _coord_array(self, idim):
def _coord_array(self, idim: int) -> npt.NDArray:
LE = self._left_cell_center[idim]
RE = self._right_cell_center[idim]
N = self.dims[idim]
return np.mgrid[LE : RE : N * 1j]

def to_xarray(self, field, *, output_array=None):
def to_xarray(
self, field: tuple[str, str], *, output_array: npt.ArrayLike | None = None
) -> Any:

import xarray as xr

# ToDo: import from on_demand_imports

vals = self.to_array(field, output_array=output_array)

dims = self.ds.coordinates.axis_order
Expand All @@ -162,7 +169,13 @@ def to_xarray(self, field, *, output_array=None):
)
return xr_ds

def single_grid_values(self, igrid, field, *, ops=None):
def single_grid_values(
self,
igrid: int,
field: tuple[str, str],
*,
ops: list[Callable[[npt.NDArray], npt.NDArray]] | None = None,
) -> tuple[npt.NDArray, Any]:
"""
Get the values for a field for a single grid chunk as in-memory array.
Expand Down Expand Up @@ -308,7 +321,9 @@ def __init__(

self.levels: list[YTTiledArbitraryGrid] = levels

def _validate_levels(self, levels):
def _validate_levels(
self, levels: Sequence[int | tuple[int, int, int] | npt.ArrayLike]
):

for ilev in range(1, self.n_levels):
res = np.prod(levels[ilev])
Expand All @@ -321,7 +336,7 @@ def _validate_levels(self, levels):
)
raise ValueError(msg)

def __repr__(self):
def __repr__(self) -> str:
return (
f"{self.__class__.__name__} with {self.n_levels} levels and base resolution "
f"{self.base_resolution}"
Expand All @@ -330,7 +345,11 @@ def __repr__(self):
def base_resolution(self) -> tuple[int, int, int]:
return tuple(self[0].dims)

def to_arrays(self, field, output_arrays=None):
def to_arrays(
self,
field: tuple[str, str],
output_arrays: list[npt.ArrayLike | None] | None = None,
) -> list[npt.ArrayLike]:
if output_arrays is None:
output_arrays = [None for _ in range(len(self.levels))]

Expand Down Expand Up @@ -390,7 +409,14 @@ def _validate_factor(
return np.asarray(input_factor, dtype=int)


def _get_filled_grid(le, re, shp, field, ds, field_parameters):
def _get_filled_grid(
le: npt.NDArray,
re: npt.NDArray,
shp: npt.NDArray,
field: tuple[str, str],
ds: Dataset,
field_parameters: Any,
) -> npt.NDArray:
grid = YTArbitraryGrid(le, re, shp, ds=ds, field_parameters=field_parameters)
vals = grid[field]
return vals

0 comments on commit f5c368d

Please sign in to comment.