From eee4434f2d0f30bd0f91e74f4e665e303bfb44b2 Mon Sep 17 00:00:00 2001 From: chavlin Date: Thu, 19 Oct 2023 12:00:21 -0500 Subject: [PATCH 01/11] in progress... --- yt/utilities/linear_interpolators.py | 213 +++++++++++++-------------- 1 file changed, 106 insertions(+), 107 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index d37e00c9815..fe2ec16c55b 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -1,11 +1,72 @@ +import abc + import numpy as np import yt.utilities.lib.interpolators as lib from yt.funcs import mylog -class UnilinearFieldInterpolator: - def __init__(self, table, boundaries, field_names, truncate=False): +class _LinearInterpolator(abc.ABC): + _ndim: int + + def __init__(self, table, truncate=False, *, store_table=True): + if store_table: + self.table = table.astype("float64") + else: + self.table = None + self.table_shape = table.shape + self.truncate = truncate + + def _raise_truncation_error(self, data_object): + mylog.error( + "Sorry, but your values are outside " + "the table! Dunno what to do, so dying." + ) + mylog.error("Error was in: %s", data_object) + raise ValueError + + def _validate_table(self, table_override): + if table_override is None: + if self.table is None: + msg = ( + f"You must either store the table used when initializing " + f"{type(self).__name__} (set `store_table=True`) or you must provide a `table_override` when " + f"calling {type(self).__name__}" + ) + raise RuntimeError(msg) + return self.table + + if table_override.shape != self.table.shape: + msg = f"The table_override shape, {table_override.shape}, must match the base table shape, {self.table.shape}" + raise ValueError(msg) + + return table_override.astype("float64") + + def _get_digitized_arrays(self, data_object): + return_arrays = [] + for dim in "xyzw"[: self._ndim]: + dim_name = getattr(self, f"{dim}_name") + dim_bins = getattr(self, f"{dim}_bins") + + dim_vals = data_object[dim_name].ravel().astype("float64") + dim_i = (np.digitize(dim_vals, dim_bins) - 1).astype("int32") + if np.any((dim_i == -1) | (dim_i == len(dim_bins) - 1)): + if not self.truncate: + self._raise_truncation_error(data_object) + else: + dim_i = np.minimum(np.maximum(dim_i, 0), len(dim_bins) - 2) + return_arrays.append(dim_i) + return_arrays.append(dim_vals) + + return return_arrays + + +class UnilinearFieldInterpolator(_LinearInterpolator): + _ndim = 1 + + def __init__( + self, table, boundaries, field_names, truncate=False, *, store_table=True + ): r"""Initialize a 1D interpolator for field data. table : array @@ -32,8 +93,7 @@ def __init__(self, table, boundaries, field_names, truncate=False): field_data = interp(ad) """ - self.table = table.astype("float64") - self.truncate = truncate + super().__init__(table, truncate=truncate, store_table=store_table) self.x_name = field_names if isinstance(boundaries, np.ndarray): if boundaries.size != table.shape[0]: @@ -44,30 +104,22 @@ def __init__(self, table, boundaries, field_names, truncate=False): x0, x1 = boundaries self.x_bins = np.linspace(x0, x1, table.shape[0], dtype="float64") - def __call__(self, data_object): + def __call__(self, data_object, *, table_override=None): + table = self._validate_table(table_override) orig_shape = data_object[self.x_name].shape - x_vals = data_object[self.x_name].ravel().astype("float64") - - x_i = (np.digitize(x_vals, self.x_bins) - 1).astype("int32") - if np.any((x_i == -1) | (x_i == len(self.x_bins) - 1)): - if not self.truncate: - mylog.error( - "Sorry, but your values are outside " - "the table! Dunno what to do, so dying." - ) - mylog.error("Error was in: %s", data_object) - raise ValueError - else: - x_i = np.minimum(np.maximum(x_i, 0), len(self.x_bins) - 2) - + x_vals, x_i = self._get_digitized_arrays(data_object) my_vals = np.zeros(x_vals.shape, dtype="float64") - lib.UnilinearlyInterpolate(self.table, x_vals, self.x_bins, x_i, my_vals) + lib.UnilinearlyInterpolate(table, x_vals, self.x_bins, x_i, my_vals) my_vals.shape = orig_shape return my_vals -class BilinearFieldInterpolator: - def __init__(self, table, boundaries, field_names, truncate=False): +class BilinearFieldInterpolator(_LinearInterpolator): + _ndim = 2 + + def __init__( + self, table, boundaries, field_names, truncate=False, *, store_table=True + ): r"""Initialize a 2D interpolator for field data. table : array @@ -94,8 +146,7 @@ def __init__(self, table, boundaries, field_names, truncate=False): field_data = interp(ad) """ - self.table = table.astype("float64") - self.truncate = truncate + super().__init__(table, truncate=truncate, store_table=store_table) self.x_name, self.y_name = field_names if len(boundaries) == 4: x0, x1, y0, y1 = boundaries @@ -116,37 +167,25 @@ def __init__(self, table, boundaries, field_names, truncate=False): ) raise ValueError - def __call__(self, data_object): - orig_shape = data_object[self.x_name].shape - x_vals = data_object[self.x_name].ravel().astype("float64") - y_vals = data_object[self.y_name].ravel().astype("float64") - - x_i = (np.digitize(x_vals, self.x_bins) - 1).astype("int32") - y_i = (np.digitize(y_vals, self.y_bins) - 1).astype("int32") - if np.any((x_i == -1) | (x_i == len(self.x_bins) - 1)) or np.any( - (y_i == -1) | (y_i == len(self.y_bins) - 1) - ): - if not self.truncate: - mylog.error( - "Sorry, but your values are outside " - "the table! Dunno what to do, so dying." - ) - mylog.error("Error was in: %s", data_object) - raise ValueError - else: - x_i = np.minimum(np.maximum(x_i, 0), len(self.x_bins) - 2) - y_i = np.minimum(np.maximum(y_i, 0), len(self.y_bins) - 2) + def __call__(self, data_object, *, table_override=None): + table = self._validate_table(table_override) + orig_shape = data_object[self.x_name].shape + x_vals, x_i, y_vals, y_i = self._get_digitized_arrays(data_object) my_vals = np.zeros(x_vals.shape, dtype="float64") lib.BilinearlyInterpolate( - self.table, x_vals, y_vals, self.x_bins, self.y_bins, x_i, y_i, my_vals + table, x_vals, y_vals, self.x_bins, self.y_bins, x_i, y_i, my_vals ) my_vals.shape = orig_shape return my_vals -class TrilinearFieldInterpolator: - def __init__(self, table, boundaries, field_names, truncate=False): +class TrilinearFieldInterpolator(_LinearInterpolator): + _ndim = 3 + + def __init__( + self, table, boundaries, field_names, truncate=False, *, store_table=True + ): r"""Initialize a 3D interpolator for field data. table : array @@ -174,8 +213,7 @@ def __init__(self, table, boundaries, field_names, truncate=False): field_data = interp(ad) """ - self.table = table.astype("float64") - self.truncate = truncate + super().__init__(table, truncate=truncate, store_table=store_table) self.x_name, self.y_name, self.z_name = field_names if len(boundaries) == 6: x0, x1, y0, y1, z0, z1 = boundaries @@ -202,35 +240,15 @@ def __init__(self, table, boundaries, field_names, truncate=False): ) raise ValueError - def __call__(self, data_object): + def __call__(self, data_object, *, table_override=None): + table = self._validate_table(table_override) + orig_shape = data_object[self.x_name].shape - x_vals = data_object[self.x_name].ravel().astype("float64") - y_vals = data_object[self.y_name].ravel().astype("float64") - z_vals = data_object[self.z_name].ravel().astype("float64") - - x_i = np.digitize(x_vals, self.x_bins).astype("int_") - 1 - y_i = np.digitize(y_vals, self.y_bins).astype("int_") - 1 - z_i = np.digitize(z_vals, self.z_bins).astype("int_") - 1 - if ( - np.any((x_i == -1) | (x_i == len(self.x_bins) - 1)) - or np.any((y_i == -1) | (y_i == len(self.y_bins) - 1)) - or np.any((z_i == -1) | (z_i == len(self.z_bins) - 1)) - ): - if not self.truncate: - mylog.error( - "Sorry, but your values are outside " - "the table! Dunno what to do, so dying." - ) - mylog.error("Error was in: %s", data_object) - raise ValueError - else: - x_i = np.minimum(np.maximum(x_i, 0), len(self.x_bins) - 2) - y_i = np.minimum(np.maximum(y_i, 0), len(self.y_bins) - 2) - z_i = np.minimum(np.maximum(z_i, 0), len(self.z_bins) - 2) + x_vals, x_i, y_vals, y_i, z_vals, z_i = self._get_digitized_arrays(data_object) my_vals = np.zeros(x_vals.shape, dtype="float64") lib.TrilinearlyInterpolate( - self.table, + table, x_vals, y_vals, z_vals, @@ -246,8 +264,12 @@ def __call__(self, data_object): return my_vals -class QuadrilinearFieldInterpolator: - def __init__(self, table, boundaries, field_names, truncate=False): +class QuadrilinearFieldInterpolator(_LinearInterpolator): + _ndim = 4 + + def __init__( + self, table, boundaries, field_names, truncate=False, *, store_table=True + ): r"""Initialize a 4D interpolator for field data. table : array @@ -274,8 +296,7 @@ def __init__(self, table, boundaries, field_names, truncate=False): field_data = interp(ad) """ - self.table = table.astype("float64") - self.truncate = truncate + super().__init__(table, truncate=truncate, store_table=store_table) self.x_name, self.y_name, self.z_name, self.w_name = field_names if len(boundaries) == 8: x0, x1, y0, y1, z0, z1, w0, w1 = boundaries @@ -307,39 +328,17 @@ def __init__(self, table, boundaries, field_names, truncate=False): ) raise ValueError - def __call__(self, data_object): + def __call__(self, data_object, *, table_override=None): + table = self._validate_table(table_override) + orig_shape = data_object[self.x_name].shape - x_vals = data_object[self.x_name].ravel().astype("float64") - y_vals = data_object[self.y_name].ravel().astype("float64") - z_vals = data_object[self.z_name].ravel().astype("float64") - w_vals = data_object[self.w_name].ravel().astype("float64") - - x_i = np.digitize(x_vals, self.x_bins).astype("int") - 1 - y_i = np.digitize(y_vals, self.y_bins).astype("int") - 1 - z_i = np.digitize(z_vals, self.z_bins).astype("int") - 1 - w_i = np.digitize(w_vals, self.w_bins).astype("int") - 1 - if ( - np.any((x_i == -1) | (x_i == len(self.x_bins) - 1)) - or np.any((y_i == -1) | (y_i == len(self.y_bins) - 1)) - or np.any((z_i == -1) | (z_i == len(self.z_bins) - 1)) - or np.any((w_i == -1) | (w_i == len(self.w_bins) - 1)) - ): - if not self.truncate: - mylog.error( - "Sorry, but your values are outside " - "the table! Dunno what to do, so dying." - ) - mylog.error("Error was in: %s", data_object) - raise ValueError - else: - x_i = np.minimum(np.maximum(x_i, 0), len(self.x_bins) - 2) - y_i = np.minimum(np.maximum(y_i, 0), len(self.y_bins) - 2) - z_i = np.minimum(np.maximum(z_i, 0), len(self.z_bins) - 2) - w_i = np.minimum(np.maximum(w_i, 0), len(self.w_bins) - 2) + x_vals, x_i, y_vals, y_i, z_vals, z_i, w_vals, w_i = self._get_digitized_arrays( + data_object + ) my_vals = np.zeros(x_vals.shape, dtype="float64") lib.QuadrilinearlyInterpolate( - self.table, + table, x_vals, y_vals, z_vals, From 833fcd33b9764a0500d5ee85ade0b4c9381cc2ac Mon Sep 17 00:00:00 2001 From: chavlin Date: Thu, 19 Oct 2023 12:06:29 -0500 Subject: [PATCH 02/11] fix type casting --- yt/utilities/linear_interpolators.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index fe2ec16c55b..02e22444711 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -8,6 +8,7 @@ class _LinearInterpolator(abc.ABC): _ndim: int + _dim_i_type = "int32" def __init__(self, table, truncate=False, *, store_table=True): if store_table: @@ -49,14 +50,14 @@ def _get_digitized_arrays(self, data_object): dim_bins = getattr(self, f"{dim}_bins") dim_vals = data_object[dim_name].ravel().astype("float64") - dim_i = (np.digitize(dim_vals, dim_bins) - 1).astype("int32") + dim_i = (np.digitize(dim_vals, dim_bins) - 1).astype(self._dim_i_type) if np.any((dim_i == -1) | (dim_i == len(dim_bins) - 1)): if not self.truncate: self._raise_truncation_error(data_object) else: dim_i = np.minimum(np.maximum(dim_i, 0), len(dim_bins) - 2) - return_arrays.append(dim_i) return_arrays.append(dim_vals) + return_arrays.append(dim_i) return return_arrays @@ -182,6 +183,7 @@ def __call__(self, data_object, *, table_override=None): class TrilinearFieldInterpolator(_LinearInterpolator): _ndim = 3 + _dim_i_type = "int_" def __init__( self, table, boundaries, field_names, truncate=False, *, store_table=True @@ -266,6 +268,7 @@ def __call__(self, data_object, *, table_override=None): class QuadrilinearFieldInterpolator(_LinearInterpolator): _ndim = 4 + _dim_i_type = "int_" def __init__( self, table, boundaries, field_names, truncate=False, *, store_table=True From d79994e34df5241c992c3a8e25a070fdb18813f1 Mon Sep 17 00:00:00 2001 From: chavlin Date: Fri, 20 Oct 2023 09:10:14 -0500 Subject: [PATCH 03/11] refactor boundary size check --- yt/utilities/linear_interpolators.py | 56 +++++++++--------------- yt/utilities/tests/test_interpolators.py | 28 ++++++++++++ 2 files changed, 49 insertions(+), 35 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index 02e22444711..b1113a7134c 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -10,13 +10,14 @@ class _LinearInterpolator(abc.ABC): _ndim: int _dim_i_type = "int32" - def __init__(self, table, truncate=False, *, store_table=True): + def __init__(self, table, field_names, truncate=False, *, store_table=True): if store_table: self.table = table.astype("float64") else: self.table = None self.table_shape = table.shape self.truncate = truncate + self._field_names = field_names def _raise_truncation_error(self, data_object): mylog.error( @@ -61,6 +62,13 @@ def _get_digitized_arrays(self, data_object): return return_arrays + def _validate_bin_boundaries(self, boundaries): + for idim in range(self._ndim): + if boundaries[idim].size != self.table_shape[idim]: + msg = f"{self._field_names[idim]} bins array not the same length as the data." + mylog.error(msg) + raise ValueError(msg) + class UnilinearFieldInterpolator(_LinearInterpolator): _ndim = 1 @@ -94,12 +102,14 @@ def __init__( field_data = interp(ad) """ - super().__init__(table, truncate=truncate, store_table=store_table) + super().__init__(table, field_names, truncate=truncate, store_table=store_table) self.x_name = field_names if isinstance(boundaries, np.ndarray): - if boundaries.size != table.shape[0]: - mylog.error("Bins array not the same length as the data.") - raise ValueError + self._validate_bin_boundaries( + [ + boundaries, + ] + ) self.x_bins = boundaries else: x0, x1 = boundaries @@ -147,19 +157,14 @@ def __init__( field_data = interp(ad) """ - super().__init__(table, truncate=truncate, store_table=store_table) + super().__init__(table, field_names, truncate=truncate, store_table=store_table) self.x_name, self.y_name = field_names if len(boundaries) == 4: x0, x1, y0, y1 = boundaries self.x_bins = np.linspace(x0, x1, table.shape[0], dtype="float64") self.y_bins = np.linspace(y0, y1, table.shape[1], dtype="float64") elif len(boundaries) == 2: - if boundaries[0].size != table.shape[0]: - mylog.error("X bins array not the same length as the data.") - raise ValueError - if boundaries[1].size != table.shape[1]: - mylog.error("Y bins array not the same length as the data.") - raise ValueError + self._validate_bin_boundaries(boundaries) self.x_bins = boundaries[0] self.y_bins = boundaries[1] else: @@ -215,7 +220,7 @@ def __init__( field_data = interp(ad) """ - super().__init__(table, truncate=truncate, store_table=store_table) + super().__init__(table, field_names, truncate=truncate, store_table=store_table) self.x_name, self.y_name, self.z_name = field_names if len(boundaries) == 6: x0, x1, y0, y1, z0, z1 = boundaries @@ -223,15 +228,7 @@ def __init__( self.y_bins = np.linspace(y0, y1, table.shape[1], dtype="float64") self.z_bins = np.linspace(z0, z1, table.shape[2], dtype="float64") elif len(boundaries) == 3: - if boundaries[0].size != table.shape[0]: - mylog.error("X bins array not the same length as the data.") - raise ValueError - if boundaries[1].size != table.shape[1]: - mylog.error("Y bins array not the same length as the data.") - raise ValueError - if boundaries[2].size != table.shape[2]: - mylog.error("Z bins array not the same length as the data.") - raise ValueError + self._validate_bin_boundaries(boundaries) self.x_bins = boundaries[0] self.y_bins = boundaries[1] self.z_bins = boundaries[2] @@ -299,7 +296,7 @@ def __init__( field_data = interp(ad) """ - super().__init__(table, truncate=truncate, store_table=store_table) + super().__init__(table, field_names, truncate=truncate, store_table=store_table) self.x_name, self.y_name, self.z_name, self.w_name = field_names if len(boundaries) == 8: x0, x1, y0, y1, z0, z1, w0, w1 = boundaries @@ -308,18 +305,7 @@ def __init__( self.z_bins = np.linspace(z0, z1, table.shape[2]).astype("float64") self.w_bins = np.linspace(w0, w1, table.shape[3]).astype("float64") elif len(boundaries) == 4: - if boundaries[0].size != table.shape[0]: - mylog.error("X bins array not the same length as the data.") - raise ValueError - if boundaries[1].size != table.shape[1]: - mylog.error("Y bins array not the same length as the data.") - raise ValueError - if boundaries[2].size != table.shape[2]: - mylog.error("Z bins array not the same length as the data.") - raise ValueError - if boundaries[3].size != table.shape[3]: - mylog.error("W bins array not the same length as the data.") - raise ValueError + self._validate_bin_boundaries(boundaries) self.x_bins = boundaries[0] self.y_bins = boundaries[1] self.z_bins = boundaries[2] diff --git a/yt/utilities/tests/test_interpolators.py b/yt/utilities/tests/test_interpolators.py index a83b2ca4e9d..152b0a2d6a1 100644 --- a/yt/utilities/tests/test_interpolators.py +++ b/yt/utilities/tests/test_interpolators.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal import yt.utilities.linear_interpolators as lin @@ -152,3 +153,30 @@ def test_get_vertex_centered_data(): ds = fake_random_ds(16) g = ds.index.grids[0] g.get_vertex_centered_data([("gas", "density")], no_ghost=True) + + +_lin_interpolators_by_dim = { + 1: lin.UnilinearFieldInterpolator, + 2: lin.BilinearFieldInterpolator, + 3: lin.TrilinearFieldInterpolator, + 4: lin.QuadrilinearFieldInterpolator, +} + + +@pytest.mark.parametrize("ndim", list(range(1, 5))) +def test_table_override(ndim): + sz = 8 + + random_data = np.random.random((sz,) * ndim) + # evenly spaced bins + + field_names = "xyzw"[:ndim] + slc = slice(0.0, 1.0, complex(0, sz)) + fv = dict(zip(field_names, np.mgrid[(slc,) * ndim])) + boundaries = (0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0)[: ndim * 2] + + interp_class = _lin_interpolators_by_dim[ndim] + tfi = interp_class(random_data, boundaries, field_names, True) + assert_array_almost_equal(tfi(fv), random_data) + table_2 = random_data * 2 + assert_array_almost_equal(tfi(fv, table_override=table_2), table_2) From b6fc8d96d61d8be3f73ba737447274f19d81f6bb Mon Sep 17 00:00:00 2001 From: chavlin Date: Fri, 20 Oct 2023 09:14:30 -0500 Subject: [PATCH 04/11] fix boundary check --- yt/utilities/linear_interpolators.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index b1113a7134c..e4d62d455b7 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -105,11 +105,7 @@ def __init__( super().__init__(table, field_names, truncate=truncate, store_table=store_table) self.x_name = field_names if isinstance(boundaries, np.ndarray): - self._validate_bin_boundaries( - [ - boundaries, - ] - ) + self._validate_bin_boundaries((boundaries,)) self.x_bins = boundaries else: x0, x1 = boundaries From 745e8fddefe86e800f28d67f9a2491cde2ebc9b5 Mon Sep 17 00:00:00 2001 From: chavlin Date: Fri, 20 Oct 2023 09:38:01 -0500 Subject: [PATCH 05/11] add bin validation test, add tests to nose ignore --- nose_ignores.txt | 1 + tests/tests.yaml | 1 + yt/utilities/linear_interpolators.py | 13 ++++------ yt/utilities/tests/test_interpolators.py | 30 ++++++++++++++++++++---- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/nose_ignores.txt b/nose_ignores.txt index eea1709491c..889b94f4252 100644 --- a/nose_ignores.txt +++ b/nose_ignores.txt @@ -36,3 +36,4 @@ --ignore-file=test_version\.py --ignore-file=test_gadget_pytest\.py --ignore-file=test_vr_orientation\.py +--ignore-file=test_interpolators\.py diff --git a/tests/tests.yaml b/tests/tests.yaml index 80145d63528..9e7cc615dfa 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -207,6 +207,7 @@ other_tests: - "--ignore-file=test_version\\.py" - "--ignore-file=test_gadget_pytest\\.py" - "--ignore-file=test_vr_orientation\\.py" + - "--ignore-file=test_interpolators\\.py" - "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF" - "--exclude-test=yt.frontends.adaptahop.tests.test_outputs" - "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles" diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index e4d62d455b7..f4facc1ea2a 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -63,6 +63,7 @@ def _get_digitized_arrays(self, data_object): return return_arrays def _validate_bin_boundaries(self, boundaries): + # boundaries: tuple of ndarrays for idim in range(self._ndim): if boundaries[idim].size != self.table_shape[idim]: msg = f"{self._field_names[idim]} bins array not the same length as the data." @@ -161,8 +162,7 @@ def __init__( self.y_bins = np.linspace(y0, y1, table.shape[1], dtype="float64") elif len(boundaries) == 2: self._validate_bin_boundaries(boundaries) - self.x_bins = boundaries[0] - self.y_bins = boundaries[1] + self.x_bins, self.y_bins = boundaries else: mylog.error( "Boundaries must be given as (x0, x1, y0, y1) or as (x_bins, y_bins)" @@ -225,9 +225,7 @@ def __init__( self.z_bins = np.linspace(z0, z1, table.shape[2], dtype="float64") elif len(boundaries) == 3: self._validate_bin_boundaries(boundaries) - self.x_bins = boundaries[0] - self.y_bins = boundaries[1] - self.z_bins = boundaries[2] + self.x_bins, self.y_bins, self.z_bins = boundaries else: mylog.error( "Boundaries must be given as (x0, x1, y0, y1, z0, z1) " @@ -302,10 +300,7 @@ def __init__( self.w_bins = np.linspace(w0, w1, table.shape[3]).astype("float64") elif len(boundaries) == 4: self._validate_bin_boundaries(boundaries) - self.x_bins = boundaries[0] - self.y_bins = boundaries[1] - self.z_bins = boundaries[2] - self.w_bins = boundaries[3] + self.x_bins, self.y_bins, self.z_bins, self.w_bins = boundaries else: mylog.error( "Boundaries must be given as (x0, x1, y0, y1, z0, z1, w0, w1) " diff --git a/yt/utilities/tests/test_interpolators.py b/yt/utilities/tests/test_interpolators.py index 152b0a2d6a1..a50ccb16e85 100644 --- a/yt/utilities/tests/test_interpolators.py +++ b/yt/utilities/tests/test_interpolators.py @@ -168,7 +168,6 @@ def test_table_override(ndim): sz = 8 random_data = np.random.random((sz,) * ndim) - # evenly spaced bins field_names = "xyzw"[:ndim] slc = slice(0.0, 1.0, complex(0, sz)) @@ -176,7 +175,30 @@ def test_table_override(ndim): boundaries = (0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0)[: ndim * 2] interp_class = _lin_interpolators_by_dim[ndim] - tfi = interp_class(random_data, boundaries, field_names, True) - assert_array_almost_equal(tfi(fv), random_data) + interpolator = interp_class(random_data, boundaries, field_names, True) + assert_array_almost_equal(interpolator(fv), random_data) table_2 = random_data * 2 - assert_array_almost_equal(tfi(fv, table_override=table_2), table_2) + assert_array_almost_equal(interpolator(fv, table_override=table_2), table_2) + + +@pytest.mark.parametrize("ndim", list(range(1, 5))) +def test_bin_validation(ndim): + interp_class = _lin_interpolators_by_dim[ndim] + + sz = 8 + random_data = np.random.random((sz,) * ndim) + field_names = "xyzw"[:ndim] + + bad_bounds = np.linspace(0.0, 1.0, sz - 1) + good_bounds = np.linspace(0.0, 1.0, sz) + if ndim == 1: + bounds = bad_bounds + else: + bounds = [ + good_bounds, + ] * ndim + bounds[0] = bad_bounds + bounds = tuple(bounds) + + with pytest.raises(ValueError, match=f"{field_names[0]} bins array not"): + _ = interp_class(random_data, bounds, field_names) From 0b87b8c6cde82e1970c89765df2c38ab56cc9b97 Mon Sep 17 00:00:00 2001 From: chavlin Date: Fri, 20 Oct 2023 09:59:14 -0500 Subject: [PATCH 06/11] add tests for not storing table --- yt/utilities/linear_interpolators.py | 116 +++++++++++++++-------- yt/utilities/tests/test_interpolators.py | 13 ++- 2 files changed, 90 insertions(+), 39 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index f4facc1ea2a..2b77ae873bf 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -31,15 +31,15 @@ def _validate_table(self, table_override): if table_override is None: if self.table is None: msg = ( - f"You must either store the table used when initializing " - f"{type(self).__name__} (set `store_table=True`) or you must provide a `table_override` when " + f"You must either store the table used when initializing a new " + f"{type(self).__name__} (set `store_table=True`) or you must provide a `table` when " f"calling {type(self).__name__}" ) - raise RuntimeError(msg) + raise ValueError(msg) return self.table - if table_override.shape != self.table.shape: - msg = f"The table_override shape, {table_override.shape}, must match the base table shape, {self.table.shape}" + if table_override.shape != self.table_shape: + msg = f"The table_override shape, {table_override.shape}, must match the base table shape, {self.table_shape}" raise ValueError(msg) return table_override.astype("float64") @@ -92,15 +92,25 @@ def __init__( If False, an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. + store_table: bool + If False, only the shape of the input table is stored and + a full table must be provided when calling the interpolator. Examples -------- - ad = ds.all_data() - table_data = np.random.random(64) - interp = UnilinearFieldInterpolator(table_data, (0.0, 1.0), "x", - truncate=True) - field_data = interp(ad) + >>> ad = ds.all_data() + >>> table_data = np.random.random(64) + >>> interp = UnilinearFieldInterpolator(table_data, (0.0, 1.0), "x", + truncate=True) + >>> field_data = interp(ad) + + If you want to re-use the interpolator with table_data of the same shape + but different values, you can also supply the `table` keyword argument when + calling the interpolator: + + >>> new_table_data = np.random.random(64) + >>> field_data = interp(ad, table=new_table_data) """ super().__init__(table, field_names, truncate=truncate, store_table=store_table) @@ -112,8 +122,8 @@ def __init__( x0, x1 = boundaries self.x_bins = np.linspace(x0, x1, table.shape[0], dtype="float64") - def __call__(self, data_object, *, table_override=None): - table = self._validate_table(table_override) + def __call__(self, data_object, *, table=None): + table = self._validate_table(table) orig_shape = data_object[self.x_name].shape x_vals, x_i = self._get_digitized_arrays(data_object) my_vals = np.zeros(x_vals.shape, dtype="float64") @@ -142,16 +152,26 @@ def __init__( If False, an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. + store_table: bool + If False, only the shape of the input table is stored and + a full table must be provided when calling the interpolator. Examples -------- - ad = ds.all_data() - table_data = np.random.random((64, 64)) - interp = BilinearFieldInterpolator(table_data, (0.0, 1.0, 0.0, 1.0), - ["x", "y"], - truncate=True) - field_data = interp(ad) + >>> ad = ds.all_data() + >>> table_data = np.random.random((64, 64)) + >>> interp = BilinearFieldInterpolator(table_data, (0.0, 1.0, 0.0, 1.0), + ["x", "y"], + truncate=True) + >>> field_data = interp(ad) + + If you want to re-use the interpolator with table_data of the same shape + but different values, you can also supply the `table` keyword argument when + calling the interpolator: + + >>> new_table_data = np.random.random((64, 64)) + >>> field_data = interp(ad, table=new_table_data) """ super().__init__(table, field_names, truncate=truncate, store_table=store_table) @@ -169,8 +189,8 @@ def __init__( ) raise ValueError - def __call__(self, data_object, *, table_override=None): - table = self._validate_table(table_override) + def __call__(self, data_object, *, table=None): + table = self._validate_table(table) orig_shape = data_object[self.x_name].shape x_vals, x_i, y_vals, y_i = self._get_digitized_arrays(data_object) @@ -203,17 +223,27 @@ def __init__( If False, an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. + store_table: bool + If False, only the shape of the input table is stored and + a full table must be provided when calling the interpolator. Examples -------- - ad = ds.all_data() - table_data = np.random.random((64, 64, 64)) - interp = TrilinearFieldInterpolator(table_data, - (0.0, 1.0, 0.0, 1.0, 0.0, 1.0), - ["x", "y", "z"], - truncate=True) - field_data = interp(ad) + >>> ad = ds.all_data() + >>> table_data = np.random.random((64, 64, 64)) + >>> interp = TrilinearFieldInterpolator(table_data, + (0.0, 1.0, 0.0, 1.0, 0.0, 1.0), + ["x", "y", "z"], + truncate=True) + >>> field_data = interp(ad) + + If you want to re-use the interpolator with table_data of the same shape + but different values, you can also supply the `table` keyword argument when + calling the interpolator: + + >>> new_table_data = np.random.random((64, 64, 64)) + >>> field_data = interp(ad, table=new_table_data) """ super().__init__(table, field_names, truncate=truncate, store_table=store_table) @@ -233,8 +263,8 @@ def __init__( ) raise ValueError - def __call__(self, data_object, *, table_override=None): - table = self._validate_table(table_override) + def __call__(self, data_object, *, table=None): + table = self._validate_table(table) orig_shape = data_object[self.x_name].shape x_vals, x_i, y_vals, y_i, z_vals, z_i = self._get_digitized_arrays(data_object) @@ -278,16 +308,26 @@ def __init__( If False, an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. + store_table: bool + If False, only the shape of the input table is stored and + a full table must be provided when calling the interpolator. Examples -------- - ad = ds.all_data() - table_data = np.random.random((64, 64, 64, 64)) - interp = BilinearFieldInterpolator(table_data, - (0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0), - ["x", "y", "z", "w"], - truncate=True) - field_data = interp(ad) + >>> ad = ds.all_data() + >>> table_data = np.random.random((64, 64, 64, 64)) + >>> interp = QuadrilinearFieldInterpolator(table_data, + (0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0), + ["x", "y", "z", "w"], + truncate=True) + >>> field_data = interp(ad) + + If you want to re-use the interpolator with table_data of the same shape + but different values, you can also supply the `table` keyword argument when + calling the interpolator: + + >>> new_table_data = np.random.random((64, 64, 64, 64)) + >>> field_data = interp(ad, table=new_table_data) """ super().__init__(table, field_names, truncate=truncate, store_table=store_table) @@ -308,8 +348,8 @@ def __init__( ) raise ValueError - def __call__(self, data_object, *, table_override=None): - table = self._validate_table(table_override) + def __call__(self, data_object, *, table=None): + table = self._validate_table(table) orig_shape = data_object[self.x_name].shape x_vals, x_i, y_vals, y_i, z_vals, z_i, w_vals, w_i = self._get_digitized_arrays( diff --git a/yt/utilities/tests/test_interpolators.py b/yt/utilities/tests/test_interpolators.py index a50ccb16e85..1b726b34709 100644 --- a/yt/utilities/tests/test_interpolators.py +++ b/yt/utilities/tests/test_interpolators.py @@ -178,7 +178,18 @@ def test_table_override(ndim): interpolator = interp_class(random_data, boundaries, field_names, True) assert_array_almost_equal(interpolator(fv), random_data) table_2 = random_data * 2 - assert_array_almost_equal(interpolator(fv, table_override=table_2), table_2) + assert_array_almost_equal(interpolator(fv, table=table_2), table_2) + + # check that we can do it without storing the initial table + interpolator = interp_class( + random_data, boundaries, field_names, True, store_table=False + ) + assert_array_almost_equal(interpolator(fv, table=table_2), table_2) + + with pytest.raises( + ValueError, match="You must either store the table used when initializing" + ): + _ = interpolator(fv) @pytest.mark.parametrize("ndim", list(range(1, 5))) From c7b238db7be7c969b15e3267a30cb459a328626f Mon Sep 17 00:00:00 2001 From: Chris Havlin Date: Mon, 13 Nov 2023 09:31:49 -0600 Subject: [PATCH 07/11] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clément Robert --- yt/utilities/linear_interpolators.py | 4 ++-- yt/utilities/tests/test_interpolators.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index 2b77ae873bf..7f741a06877 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -32,7 +32,7 @@ def _validate_table(self, table_override): if self.table is None: msg = ( f"You must either store the table used when initializing a new " - f"{type(self).__name__} (set `store_table=True`) or you must provide a `table` when " + f"{type(self).__name__} (set `store_table=True`) or you provide a `table` when " f"calling {type(self).__name__}" ) raise ValueError(msg) @@ -50,7 +50,7 @@ def _get_digitized_arrays(self, data_object): dim_name = getattr(self, f"{dim}_name") dim_bins = getattr(self, f"{dim}_bins") - dim_vals = data_object[dim_name].ravel().astype("float64") + dim_vals = data_object[dim_name].astype("float64").ravel() dim_i = (np.digitize(dim_vals, dim_bins) - 1).astype(self._dim_i_type) if np.any((dim_i == -1) | (dim_i == len(dim_bins) - 1)): if not self.truncate: diff --git a/yt/utilities/tests/test_interpolators.py b/yt/utilities/tests/test_interpolators.py index 1b726b34709..b35a31507df 100644 --- a/yt/utilities/tests/test_interpolators.py +++ b/yt/utilities/tests/test_interpolators.py @@ -163,7 +163,7 @@ def test_get_vertex_centered_data(): } -@pytest.mark.parametrize("ndim", list(range(1, 5))) +@pytest.mark.parametrize("ndim", list(_lin_interpolators_by_dim.keys())) def test_table_override(ndim): sz = 8 @@ -189,7 +189,7 @@ def test_table_override(ndim): with pytest.raises( ValueError, match="You must either store the table used when initializing" ): - _ = interpolator(fv) + interpolator(fv) @pytest.mark.parametrize("ndim", list(range(1, 5))) @@ -212,4 +212,4 @@ def test_bin_validation(ndim): bounds = tuple(bounds) with pytest.raises(ValueError, match=f"{field_names[0]} bins array not"): - _ = interp_class(random_data, bounds, field_names) + interp_class(random_data, bounds, field_names) From e14add4e1ce494533168bba551ead3655e49b4f3 Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 13 Nov 2023 09:52:17 -0600 Subject: [PATCH 08/11] avoid both logs and errors, note defaults in docstrings --- yt/utilities/linear_interpolators.py | 76 +++++++++++------------- yt/utilities/tests/test_interpolators.py | 9 ++- 2 files changed, 42 insertions(+), 43 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index 7f741a06877..fae41801036 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -3,7 +3,6 @@ import numpy as np import yt.utilities.lib.interpolators as lib -from yt.funcs import mylog class _LinearInterpolator(abc.ABC): @@ -19,14 +18,6 @@ def __init__(self, table, field_names, truncate=False, *, store_table=True): self.truncate = truncate self._field_names = field_names - def _raise_truncation_error(self, data_object): - mylog.error( - "Sorry, but your values are outside " - "the table! Dunno what to do, so dying." - ) - mylog.error("Error was in: %s", data_object) - raise ValueError - def _validate_table(self, table_override): if table_override is None: if self.table is None: @@ -54,7 +45,13 @@ def _get_digitized_arrays(self, data_object): dim_i = (np.digitize(dim_vals, dim_bins) - 1).astype(self._dim_i_type) if np.any((dim_i == -1) | (dim_i == len(dim_bins) - 1)): if not self.truncate: - self._raise_truncation_error(data_object) + msg = ( + f"The dimension values for {dim_name} and data object {data_object} are outside the bounds " + f"of the table! You can avoid this error by providing truncate=True to the interpolator or " + f"by adjusting the data object to remain inside the table bounds. But for now, dunno what " + f"to do, so dying." + ) + raise ValueError(msg) else: dim_i = np.minimum(np.maximum(dim_i, 0), len(dim_bins) - 2) return_arrays.append(dim_vals) @@ -67,7 +64,6 @@ def _validate_bin_boundaries(self, boundaries): for idim in range(self._ndim): if boundaries[idim].size != self.table_shape[idim]: msg = f"{self._field_names[idim]} bins array not the same length as the data." - mylog.error(msg) raise ValueError(msg) @@ -88,13 +84,15 @@ def __init__( explicitly. field_names: str Name of the field to be used as input data for interpolation. - truncate : bool - If False, an exception is raised if the input values are + truncate: bool + If False (default), an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If False, only the shape of the input table is stored and - a full table must be provided when calling the interpolator. + If True (default), a copy of the full table is stored in + the interpolator. If False, only the shape of the input table + is stored and a full table must be provided when calling the + interpolator. Examples -------- @@ -148,13 +146,15 @@ def __init__( x and y bins. field_names: list Names of the fields to be used as input data for interpolation. - truncate : bool + truncate: bool If False, an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If False, only the shape of the input table is stored and - a full table must be provided when calling the interpolator. + If True (the default), a copy of the full table is stored in + the interpolator. If False, only the shape of the input table + is stored and a full table must be provided when calling the + interpolator. Examples -------- @@ -184,10 +184,8 @@ def __init__( self._validate_bin_boundaries(boundaries) self.x_bins, self.y_bins = boundaries else: - mylog.error( - "Boundaries must be given as (x0, x1, y0, y1) or as (x_bins, y_bins)" - ) - raise ValueError + msg = "Boundaries must be given as (x0, x1, y0, y1) or as (x_bins, y_bins)" + raise ValueError(msg) def __call__(self, data_object, *, table=None): table = self._validate_table(table) @@ -219,13 +217,15 @@ def __init__( containing the x, y, and z bins. field_names: list Names of the fields to be used as input data for interpolation. - truncate : bool - If False, an exception is raised if the input values are + truncate: bool + If False (default), an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If False, only the shape of the input table is stored and - a full table must be provided when calling the interpolator. + If True (default), a copy of the full table is stored in + the interpolator. If False, only the shape of the input table + is stored and a full table must be provided when calling the + interpolator. Examples -------- @@ -257,11 +257,8 @@ def __init__( self._validate_bin_boundaries(boundaries) self.x_bins, self.y_bins, self.z_bins = boundaries else: - mylog.error( - "Boundaries must be given as (x0, x1, y0, y1, z0, z1) " - "or as (x_bins, y_bins, z_bins)" - ) - raise ValueError + msg = "Boundaries must be given as (x0, x1, y0, y1, z0, z1) or as (x_bins, y_bins, z_bins)" + raise ValueError(msg) def __call__(self, data_object, *, table=None): table = self._validate_table(table) @@ -304,13 +301,15 @@ def __init__( containing the x, y, z, and w bins. field_names: list Names of the fields to be used as input data for interpolation. - truncate : bool - If False, an exception is raised if the input values are + truncate: bool + If False (default), an exception is raised if the input values are outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If False, only the shape of the input table is stored and - a full table must be provided when calling the interpolator. + If True (default), a copy of the full table is stored in + the interpolator. If False, only the shape of the input table + is stored and a full table must be provided when calling the + interpolator. Examples -------- @@ -342,11 +341,8 @@ def __init__( self._validate_bin_boundaries(boundaries) self.x_bins, self.y_bins, self.z_bins, self.w_bins = boundaries else: - mylog.error( - "Boundaries must be given as (x0, x1, y0, y1, z0, z1, w0, w1) " - "or as (x_bins, y_bins, z_bins, w_bins)" - ) - raise ValueError + msg = "Boundaries must be given as (x0, x1, y0, y1, z0, z1, w0, w1) or as (x_bins, y_bins, z_bins, w_bins)" + raise ValueError(msg) def __call__(self, data_object, *, table=None): table = self._validate_table(table) diff --git a/yt/utilities/tests/test_interpolators.py b/yt/utilities/tests/test_interpolators.py index b35a31507df..9a08df366b3 100644 --- a/yt/utilities/tests/test_interpolators.py +++ b/yt/utilities/tests/test_interpolators.py @@ -175,14 +175,17 @@ def test_table_override(ndim): boundaries = (0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0)[: ndim * 2] interp_class = _lin_interpolators_by_dim[ndim] - interpolator = interp_class(random_data, boundaries, field_names, True) + # get interpolator with default behavior (a copy of the table is stored) + interpolator = interp_class(random_data, boundaries, field_names, truncate=True) + # evaluate at table grid nodes (should return the input table exactly) assert_array_almost_equal(interpolator(fv), random_data) + # check that we can change the table used after interpolator initialization table_2 = random_data * 2 assert_array_almost_equal(interpolator(fv, table=table_2), table_2) - # check that we can do it without storing the initial table + # check that we can use an interpolator without storing the initial table interpolator = interp_class( - random_data, boundaries, field_names, True, store_table=False + random_data, boundaries, field_names, truncate=True, store_table=False ) assert_array_almost_equal(interpolator(fv, table=table_2), table_2) From 9e876855d4f97c57023aa36aa1c6505367d20fdb Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 13 Nov 2023 10:31:03 -0600 Subject: [PATCH 09/11] avoid making a copy of the table if possible --- yt/utilities/linear_interpolators.py | 32 ++++++++++++---------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index fae41801036..4613c891782 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -11,7 +11,7 @@ class _LinearInterpolator(abc.ABC): def __init__(self, table, field_names, truncate=False, *, store_table=True): if store_table: - self.table = table.astype("float64") + self.table = table.astype("float64", copy=False) else: self.table = None self.table_shape = table.shape @@ -33,7 +33,7 @@ def _validate_table(self, table_override): msg = f"The table_override shape, {table_override.shape}, must match the base table shape, {self.table_shape}" raise ValueError(msg) - return table_override.astype("float64") + return table_override.astype("float64", copy=False) def _get_digitized_arrays(self, data_object): return_arrays = [] @@ -89,10 +89,9 @@ def __init__( outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If True (default), a copy of the full table is stored in - the interpolator. If False, only the shape of the input table - is stored and a full table must be provided when calling the - interpolator. + If True (default), the full table is stored in the interpolator. + If False, only the shape of the input table is stored and a full + table must be provided when calling the interpolator. Examples -------- @@ -151,10 +150,9 @@ def __init__( outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If True (the default), a copy of the full table is stored in - the interpolator. If False, only the shape of the input table - is stored and a full table must be provided when calling the - interpolator. + If True (default), the full table is stored in the interpolator. + If False, only the shape of the input table is stored and a full + table must be provided when calling the interpolator. Examples -------- @@ -222,10 +220,9 @@ def __init__( outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If True (default), a copy of the full table is stored in - the interpolator. If False, only the shape of the input table - is stored and a full table must be provided when calling the - interpolator. + If True (default), the full table is stored in the interpolator. + If False, only the shape of the input table is stored and a full + table must be provided when calling the interpolator. Examples -------- @@ -306,10 +303,9 @@ def __init__( outside the bounds of the table. If True, extrapolation is performed. store_table: bool - If True (default), a copy of the full table is stored in - the interpolator. If False, only the shape of the input table - is stored and a full table must be provided when calling the - interpolator. + If True (default), the full table is stored in the interpolator. + If False, only the shape of the input table is stored and a full + table must be provided when calling the interpolator. Examples -------- From 5f13122ed963508f65907a818bb8efe768416caa Mon Sep 17 00:00:00 2001 From: chavlin Date: Mon, 13 Nov 2023 10:38:24 -0600 Subject: [PATCH 10/11] add keywords for bool args, drop 4d interpolator test resolution --- yt/utilities/tests/test_interpolators.py | 25 ++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/yt/utilities/tests/test_interpolators.py b/yt/utilities/tests/test_interpolators.py index 9a08df366b3..f1ce2047b99 100644 --- a/yt/utilities/tests/test_interpolators.py +++ b/yt/utilities/tests/test_interpolators.py @@ -15,7 +15,7 @@ def test_linear_interpolator_1d(): random_data = np.random.random(64) fv = {"x": np.mgrid[0.0:1.0:64j]} # evenly spaced bins - ufi = lin.UnilinearFieldInterpolator(random_data, (0.0, 1.0), "x", True) + ufi = lin.UnilinearFieldInterpolator(random_data, (0.0, 1.0), "x", truncate=True) assert_array_equal(ufi(fv), random_data) # randomly spaced bins @@ -23,7 +23,7 @@ def test_linear_interpolator_1d(): shift = (1.0 / size) * np.random.random(size) - (0.5 / size) fv["x"] += shift ufi = lin.UnilinearFieldInterpolator( - random_data, np.linspace(0.0, 1.0, size) + shift, "x", True + random_data, np.linspace(0.0, 1.0, size) + shift, "x", truncate=True ) assert_array_almost_equal(ufi(fv), random_data, 15) @@ -32,7 +32,9 @@ def test_linear_interpolator_2d(): random_data = np.random.random((64, 64)) # evenly spaced bins fv = dict(zip("xyz", np.mgrid[0.0:1.0:64j, 0.0:1.0:64j])) - bfi = lin.BilinearFieldInterpolator(random_data, (0.0, 1.0, 0.0, 1.0), "xy", True) + bfi = lin.BilinearFieldInterpolator( + random_data, (0.0, 1.0, 0.0, 1.0), "xy", truncate=True + ) assert_array_equal(bfi(fv), random_data) # randomly spaced bins @@ -42,7 +44,7 @@ def test_linear_interpolator_2d(): fv["x"] += shifts["x"][:, np.newaxis] fv["y"] += shifts["y"] bfi = lin.BilinearFieldInterpolator( - random_data, (bins + shifts["x"], bins + shifts["y"]), "xy", True + random_data, (bins + shifts["x"], bins + shifts["y"]), "xy", truncate=True ) assert_array_almost_equal(bfi(fv), random_data, 15) @@ -52,7 +54,7 @@ def test_linear_interpolator_3d(): # evenly spaced bins fv = dict(zip("xyz", np.mgrid[0.0:1.0:64j, 0.0:1.0:64j, 0.0:1.0:64j])) tfi = lin.TrilinearFieldInterpolator( - random_data, (0.0, 1.0, 0.0, 1.0, 0.0, 1.0), "xyz", True + random_data, (0.0, 1.0, 0.0, 1.0, 0.0, 1.0), "xyz", truncate=True ) assert_array_almost_equal(tfi(fv), random_data) @@ -73,16 +75,19 @@ def test_linear_interpolator_3d(): def test_linear_interpolator_4d(): - random_data = np.random.random((64, 64, 64, 64)) + size = 32 + random_data = np.random.random((size,) * 4) # evenly spaced bins - fv = dict(zip("xyzw", np.mgrid[0.0:1.0:64j, 0.0:1.0:64j, 0.0:1.0:64j, 0.0:1.0:64j])) + step = complex(0, size) + fv = dict( + zip("xyzw", np.mgrid[0.0:1.0:step, 0.0:1.0:step, 0.0:1.0:step, 0.0:1.0:step]) + ) tfi = lin.QuadrilinearFieldInterpolator( - random_data, (0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0), "xyzw", True + random_data, (0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0), "xyzw", truncate=True ) assert_array_almost_equal(tfi(fv), random_data) # randomly spaced bins - size = 64 bins = np.linspace(0.0, 1.0, size) shifts = {ax: (1.0 / size) * np.random.random(size) - (0.5 / size) for ax in "xyzw"} fv["x"] += shifts["x"][:, np.newaxis, np.newaxis, np.newaxis] @@ -98,7 +103,7 @@ def test_linear_interpolator_4d(): bins + shifts["w"], ), "xyzw", - True, + truncate=True, ) assert_array_almost_equal(tfi(fv), random_data, 15) From 0de67f9d20d5f4fd6cf870f93ff38c9d9fc1d02e Mon Sep 17 00:00:00 2001 From: chrishavlin Date: Mon, 13 Nov 2023 14:13:38 -0600 Subject: [PATCH 11/11] use same type in all interpolators --- yt/utilities/lib/interpolators.pyx | 6 +++--- yt/utilities/linear_interpolators.py | 5 +---- yt/utilities/tests/test_interpolators.py | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/yt/utilities/lib/interpolators.pyx b/yt/utilities/lib/interpolators.pyx index 5ddb0c6cbf8..72d25b94ef3 100644 --- a/yt/utilities/lib/interpolators.pyx +++ b/yt/utilities/lib/interpolators.pyx @@ -22,7 +22,7 @@ from yt.utilities.lib.fp_utils cimport iclip def UnilinearlyInterpolate(np.ndarray[np.float64_t, ndim=1] table, np.ndarray[np.float64_t, ndim=1] x_vals, np.ndarray[np.float64_t, ndim=1] x_bins, - np.ndarray[np.int32_t, ndim=1] x_is, + np.ndarray[np.int_t, ndim=1] x_is, np.ndarray[np.float64_t, ndim=1] output): cdef double x, xp, xm cdef int i, x_i @@ -43,8 +43,8 @@ def BilinearlyInterpolate(np.ndarray[np.float64_t, ndim=2] table, np.ndarray[np.float64_t, ndim=1] y_vals, np.ndarray[np.float64_t, ndim=1] x_bins, np.ndarray[np.float64_t, ndim=1] y_bins, - np.ndarray[np.int32_t, ndim=1] x_is, - np.ndarray[np.int32_t, ndim=1] y_is, + np.ndarray[np.int_t, ndim=1] x_is, + np.ndarray[np.int_t, ndim=1] y_is, np.ndarray[np.float64_t, ndim=1] output): cdef double x, xp, xm cdef double y, yp, ym diff --git a/yt/utilities/linear_interpolators.py b/yt/utilities/linear_interpolators.py index 4613c891782..469811b6f0d 100644 --- a/yt/utilities/linear_interpolators.py +++ b/yt/utilities/linear_interpolators.py @@ -7,7 +7,6 @@ class _LinearInterpolator(abc.ABC): _ndim: int - _dim_i_type = "int32" def __init__(self, table, field_names, truncate=False, *, store_table=True): if store_table: @@ -42,7 +41,7 @@ def _get_digitized_arrays(self, data_object): dim_bins = getattr(self, f"{dim}_bins") dim_vals = data_object[dim_name].astype("float64").ravel() - dim_i = (np.digitize(dim_vals, dim_bins) - 1).astype(self._dim_i_type) + dim_i = (np.digitize(dim_vals, dim_bins) - 1).astype("int_") if np.any((dim_i == -1) | (dim_i == len(dim_bins) - 1)): if not self.truncate: msg = ( @@ -200,7 +199,6 @@ def __call__(self, data_object, *, table=None): class TrilinearFieldInterpolator(_LinearInterpolator): _ndim = 3 - _dim_i_type = "int_" def __init__( self, table, boundaries, field_names, truncate=False, *, store_table=True @@ -283,7 +281,6 @@ def __call__(self, data_object, *, table=None): class QuadrilinearFieldInterpolator(_LinearInterpolator): _ndim = 4 - _dim_i_type = "int_" def __init__( self, table, boundaries, field_names, truncate=False, *, store_table=True diff --git a/yt/utilities/tests/test_interpolators.py b/yt/utilities/tests/test_interpolators.py index f1ce2047b99..2413dc97a77 100644 --- a/yt/utilities/tests/test_interpolators.py +++ b/yt/utilities/tests/test_interpolators.py @@ -200,7 +200,7 @@ def test_table_override(ndim): interpolator(fv) -@pytest.mark.parametrize("ndim", list(range(1, 5))) +@pytest.mark.parametrize("ndim", list(_lin_interpolators_by_dim.keys())) def test_bin_validation(ndim): interp_class = _lin_interpolators_by_dim[ndim]