From f5c368d9614519dd094808a1f3ef571aa70aa254 Mon Sep 17 00:00:00 2001
From: chavlin <chris.havlin@gmail.com>
Date: Fri, 11 Oct 2024 15:38:42 -0500
Subject: [PATCH] further improve typing

---
 pyproject.toml                                |  2 +-
 .../tiled_grid/tests/test_tiled_grid.py       | 11 ++++
 yt_experiments/tiled_grid/tiled_grid.py       | 60 +++++++++++++------
 3 files changed, 55 insertions(+), 18 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 00665b9..1b1f2bf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -119,4 +119,4 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
 warn_unreachable = true
 disallow_untyped_defs = false
 disallow_incomplete_defs = false
-disable_error_code = ["import-untyped", "import-not-found", "no-untyped-call"]
+disable_error_code = ["import-untyped", "import-not-found"]
diff --git a/yt_experiments/tiled_grid/tests/test_tiled_grid.py b/yt_experiments/tiled_grid/tests/test_tiled_grid.py
index 08895a1..39a9226 100644
--- a/yt_experiments/tiled_grid/tests/test_tiled_grid.py
+++ b/yt_experiments/tiled_grid/tests/test_tiled_grid.py
@@ -1,4 +1,5 @@
 import numpy as np
+import pytest
 import unyt
 from numpy.testing import assert_equal
 from yt.testing import fake_amr_ds, requires_module
@@ -78,6 +79,16 @@ def test_arbitrary_grid_oct():
         assert level_arrays[ilev].shape == expected_levels[ilev]
 
 
+def test_missing_ds():
+    with pytest.raises(ValueError, match="Please provide a dataset"):
+        _ = YTTiledArbitraryGrid(
+            unyt.unyt_array([0, 0, 0], "m"),
+            unyt.unyt_array([1, 1, 1], "m"),
+            (20, 20, 20),
+            5,
+        )
+
+
 @requires_module("xarray")
 def test_arbitrary_grid_to_xarray():
     import xarray as xr
diff --git a/yt_experiments/tiled_grid/tiled_grid.py b/yt_experiments/tiled_grid/tiled_grid.py
index 7729a7e..77a6afd 100644
--- a/yt_experiments/tiled_grid/tiled_grid.py
+++ b/yt_experiments/tiled_grid/tiled_grid.py
@@ -8,6 +8,10 @@
 from yt.data_objects.construction_data_containers import YTArbitraryGrid
 from yt.data_objects.static_output import Dataset
 
+_GridInfo = tuple[
+    npt.NDArray, npt.NDArray, unyt.unyt_array, unyt.unyt_array, Any, npt.NDArray
+]
+
 
 def _validate_edge(edge: npt.ArrayLike, ds: Dataset):
     if not isinstance(edge, unyt.unyt_array):
@@ -62,6 +66,9 @@ def __init__(
 
         """
 
+        if ds is None:
+            raise ValueError("Please provide a dataset via the ds keyword argument")
+
         self.ds = ds
         self.left_edge = _validate_edge(left_edge, ds)
         self.right_edge = _validate_edge(right_edge, ds)
@@ -86,7 +93,7 @@ def __init__(
         self._left_cell_center = self.left_edge + self.dds / 2.0
         self._right_cell_center = self.right_edge - self.dds / 2.0
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         nm = self.__class__.__name__
         shape = tuple(self.dims)
         n_chunks = tuple(self.nchunks)
@@ -97,13 +104,13 @@ def __repr__(self):
         )
         return msg
 
-    def _get_grid_by_ijk(self, ijk_grid):
+    def _get_grid_by_ijk(self, ijk_grid: npt.NDArray[int]) -> _GridInfo:
         chunksizes = self.chunks
 
         le_index = []
         re_index = []
-        le_val = self.ds.domain_left_edge.copy()
-        re_val = self.ds.domain_right_edge.copy()
+        le_val: unyt.unyt_array = self.ds.domain_left_edge.copy()
+        re_val: unyt.unyt_array = self.ds.domain_right_edge.copy()
 
         for idim in range(self._ndim):
             chunk_i = ijk_grid[idim]
@@ -122,29 +129,29 @@ def _get_grid_by_ijk(self, ijk_grid):
             le_index[2] : re_index[2],
         ]
 
-        le_index = np.array(le_index, dtype=int)
-        re_index = np.array(re_index, dtype=int)
+        le_index_ = np.array(le_index, dtype=int)
+        re_index_ = np.array(re_index, dtype=int)
         shape = chunksizes
 
-        return le_index, re_index, le_val, re_val, slc, shape
+        return le_index_, re_index_, le_val, re_val, slc, shape
 
-    def _get_grid(self, igrid: int):
+    def _get_grid(self, igrid: int) -> _GridInfo:
         # get grid extent of a **single** grid
         ijk_grid = np.unravel_index(igrid, self.nchunks)
         return self._get_grid_by_ijk(ijk_grid)
 
-    def _coord_array(self, idim):
+    def _coord_array(self, idim: int) -> npt.NDArray:
         LE = self._left_cell_center[idim]
         RE = self._right_cell_center[idim]
         N = self.dims[idim]
         return np.mgrid[LE : RE : N * 1j]
 
-    def to_xarray(self, field, *, output_array=None):
+    def to_xarray(
+        self, field: tuple[str, str], *, output_array: npt.ArrayLike | None = None
+    ) -> Any:
 
         import xarray as xr
 
-        # ToDo: import from on_demand_imports
-
         vals = self.to_array(field, output_array=output_array)
 
         dims = self.ds.coordinates.axis_order
@@ -162,7 +169,13 @@ def to_xarray(self, field, *, output_array=None):
         )
         return xr_ds
 
-    def single_grid_values(self, igrid, field, *, ops=None):
+    def single_grid_values(
+        self,
+        igrid: int,
+        field: tuple[str, str],
+        *,
+        ops: list[Callable[[npt.NDArray], npt.NDArray]] | None = None,
+    ) -> tuple[npt.NDArray, Any]:
         """
         Get the values for a field for a single grid chunk as in-memory array.
 
@@ -308,7 +321,9 @@ def __init__(
 
         self.levels: list[YTTiledArbitraryGrid] = levels
 
-    def _validate_levels(self, levels):
+    def _validate_levels(
+        self, levels: Sequence[int | tuple[int, int, int] | npt.ArrayLike]
+    ):
 
         for ilev in range(1, self.n_levels):
             res = np.prod(levels[ilev])
@@ -321,7 +336,7 @@ def _validate_levels(self, levels):
                 )
                 raise ValueError(msg)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return (
             f"{self.__class__.__name__} with {self.n_levels} levels and base resolution "
             f"{self.base_resolution}"
@@ -330,7 +345,11 @@ def __repr__(self):
     def base_resolution(self) -> tuple[int, int, int]:
         return tuple(self[0].dims)
 
-    def to_arrays(self, field, output_arrays=None):
+    def to_arrays(
+        self,
+        field: tuple[str, str],
+        output_arrays: list[npt.ArrayLike | None] | None = None,
+    ) -> list[npt.ArrayLike]:
         if output_arrays is None:
             output_arrays = [None for _ in range(len(self.levels))]
 
@@ -390,7 +409,14 @@ def _validate_factor(
         return np.asarray(input_factor, dtype=int)
 
 
-def _get_filled_grid(le, re, shp, field, ds, field_parameters):
+def _get_filled_grid(
+    le: npt.NDArray,
+    re: npt.NDArray,
+    shp: npt.NDArray,
+    field: tuple[str, str],
+    ds: Dataset,
+    field_parameters: Any,
+) -> npt.NDArray:
     grid = YTArbitraryGrid(le, re, shp, ds=ds, field_parameters=field_parameters)
     vals = grid[field]
     return vals