From 9d509362023bcf2aa76e9a7378baab25e7ddde4b Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Fri, 6 Sep 2024 14:10:25 -0700 Subject: [PATCH 01/14] Move BilinearInterpolator into earth2grid._regrid --- CHANGELOG.md | 2 +- earth2grid/__init__.py | 22 +++++++- earth2grid/_regrid.py | 119 ++++++++++++++++++++++++++++++++++------- earth2grid/latlon.py | 101 +--------------------------------- tests/test_regrid.py | 2 +- 5 files changed, 123 insertions(+), 123 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a82817f..b52e7ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # Changelog ## latest - +- `earth2grid.latlon.BilinearInterpolator` moved to `earth2grid.BilinearInterpolator` ## 2024.8.1 diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index 3143b70..d976a1f 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -12,7 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + from earth2grid import base, healpix, latlon -from earth2grid._regrid import get_regridder +from earth2grid._regrid import BilinearInterpolator, Identity + +__all__ = ["base", "healpix", "latlon", "get_regridder", "BilinearInterpolator"] + + +def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: + """Get a regridder from `src` to `dest`""" + if src == dest: + return Identity() + elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, latlon.LatLonGrid): + return src.get_bilinear_regridder_to(dest.lat, dest.lon) + elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid): + return src.get_bilinear_regridder_to(dest.lat, dest.lon) + elif isinstance(src, healpix.Grid): + return src.get_bilinear_regridder_to(dest.lat, dest.lon) + elif isinstance(dest, healpix.Grid): + return src.get_healpix_regridder(dest) # type: ignore -__all__ = ["base", "healpix", "latlon", "get_regridder"] + raise ValueError(src, dest, "not supported.") diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 4230026..75970f5 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math + import einops import netCDF4 as nc import torch -from earth2grid import base, healpix -from earth2grid.latlon import LatLonGrid - class TempestRegridder(torch.nn.Module): def __init__(self, file_path): @@ -48,22 +47,104 @@ def forward(self, x): return y +class BilinearInterpolator(torch.nn.Module): + """Bilinear interpolation for a non-uniform grid""" + + def __init__( + self, + x_coords: torch.Tensor, + y_coords: torch.Tensor, + x_query: torch.Tensor, + y_query: torch.Tensor, + fill_value=math.nan, + ) -> None: + """ + + Args: + x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order. + y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order. + x_query (Tensor): X-coordinates for query points, shape [N]. + y_query (Tensor): Y-coordinates for query points, shape [N]. + """ + super().__init__() + self.fill_value = fill_value + + # Ensure input coordinates are float for interpolation + x_coords, y_coords = x_coords.double(), y_coords.double() + x_query = x_query.double() + y_query = y_query.double() + + if torch.any(x_coords[1:] < x_coords[:-1]): + raise ValueError("x_coords must be in non-decreasing order.") + + if torch.any(y_coords[1:] < y_coords[:-1]): + raise ValueError("y_coords must be in non-decreasing order.") + + # Find indices for the closest lower and upper bounds in x and y directions + x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1 + x_u_idx = x_l_idx + 1 + y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1 + y_u_idx = y_l_idx + 1 + + # fill in nan outside mask + def isin(x, a, b): + return (x <= b) & (x >= a) + + mask = ( + isin(x_l_idx, 0, x_coords.size(0) - 2) + & isin(x_u_idx, 1, x_coords.size(0) - 1) + & isin(y_l_idx, 0, y_coords.size(0) - 2) + & isin(y_u_idx, 1, y_coords.size(0) - 1) + ) + x_u_idx = x_u_idx[mask] + x_l_idx = x_l_idx[mask] + y_u_idx = y_u_idx[mask] + y_l_idx = y_l_idx[mask] + x_query = x_query[mask] + y_query = y_query[mask] + + # Compute weights + x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx]) + x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx]) + y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx]) + y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx]) + weights = torch.stack( + [x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1 + ) + + stride = x_coords.size(-1) + index = torch.stack( + [ + x_l_idx + stride * y_l_idx, + x_u_idx + stride * y_l_idx, + x_l_idx + stride * y_u_idx, + x_u_idx + stride * y_u_idx, + ], + dim=-1, + ) + self.register_buffer("weights", weights) + self.register_buffer("mask", mask) + self.register_buffer("index", index) + + def forward(self, z: torch.Tensor): + """ + Interpolate the field + + Args: + z: shape [Y, X] + """ + *shape, y, x = z.shape + zrs = z.view(-1, y * x).T + # using embedding bag is 2x faster on cpu and 4x on gpu. + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') + interpolated = torch.full( + [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device + ) + interpolated.masked_scatter_(self.mask.unsqueeze(-1), output) + interpolated = interpolated.T.view(*shape, self.mask.numel()) + return interpolated + + class Identity(torch.nn.Module): def forward(self, x): return x - - -def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: - """Get a regridder from `src` to `dest`""" - if src == dest: - return Identity() - elif isinstance(src, LatLonGrid) and isinstance(dest, LatLonGrid): - return src.get_bilinear_regridder_to(dest.lat, dest.lon) - elif isinstance(src, LatLonGrid) and isinstance(dest, healpix.Grid): - return src.get_bilinear_regridder_to(dest.lat, dest.lon) - elif isinstance(src, healpix.Grid): - return src.get_bilinear_regridder_to(dest.lat, dest.lon) - elif isinstance(dest, healpix.Grid): - return src.get_healpix_regridder(dest) # type: ignore - - raise ValueError(src, dest, "not supported.") diff --git a/earth2grid/latlon.py b/earth2grid/latlon.py index 5ccd45d..4b52669 100644 --- a/earth2grid/latlon.py +++ b/earth2grid/latlon.py @@ -12,12 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math - import numpy as np import torch from earth2grid import base +from earth2grid._regrid import BilinearInterpolator try: import pyvista as pv @@ -25,104 +24,6 @@ pv = None -class BilinearInterpolator(torch.nn.Module): - """Bilinear interpolation for a non-uniform grid""" - - def __init__( - self, - x_coords: torch.Tensor, - y_coords: torch.Tensor, - x_query: torch.Tensor, - y_query: torch.Tensor, - fill_value=math.nan, - ) -> None: - """ - - Args: - x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order. - y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order. - x_query (Tensor): X-coordinates for query points, shape [N]. - y_query (Tensor): Y-coordinates for query points, shape [N]. - """ - super().__init__() - self.fill_value = fill_value - - # Ensure input coordinates are float for interpolation - x_coords, y_coords = x_coords.double(), y_coords.double() - x_query = x_query.double() - y_query = y_query.double() - - if torch.any(x_coords[1:] < x_coords[:-1]): - raise ValueError("x_coords must be in non-decreasing order.") - - if torch.any(y_coords[1:] < y_coords[:-1]): - raise ValueError("y_coords must be in non-decreasing order.") - - # Find indices for the closest lower and upper bounds in x and y directions - x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1 - x_u_idx = x_l_idx + 1 - y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1 - y_u_idx = y_l_idx + 1 - - # fill in nan outside mask - def isin(x, a, b): - return (x <= b) & (x >= a) - - mask = ( - isin(x_l_idx, 0, x_coords.size(0) - 2) - & isin(x_u_idx, 1, x_coords.size(0) - 1) - & isin(y_l_idx, 0, y_coords.size(0) - 2) - & isin(y_u_idx, 1, y_coords.size(0) - 1) - ) - x_u_idx = x_u_idx[mask] - x_l_idx = x_l_idx[mask] - y_u_idx = y_u_idx[mask] - y_l_idx = y_l_idx[mask] - x_query = x_query[mask] - y_query = y_query[mask] - - # Compute weights - x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx]) - x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx]) - y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx]) - y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx]) - weights = torch.stack( - [x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1 - ) - - stride = x_coords.size(-1) - index = torch.stack( - [ - x_l_idx + stride * y_l_idx, - x_u_idx + stride * y_l_idx, - x_l_idx + stride * y_u_idx, - x_u_idx + stride * y_u_idx, - ], - dim=-1, - ) - self.register_buffer("weights", weights) - self.register_buffer("mask", mask) - self.register_buffer("index", index) - - def forward(self, z: torch.Tensor): - """ - Interpolate the field - - Args: - z: shape [Y, X] - """ - *shape, y, x = z.shape - zrs = z.view(-1, y * x).T - # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') - interpolated = torch.full( - [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device - ) - interpolated.masked_scatter_(self.mask.unsqueeze(-1), output) - interpolated = interpolated.T.view(*shape, self.mask.numel()) - return interpolated - - class LatLonGrid(base.Grid): def __init__(self, lat: list[float], lon: list[float]): """ diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 29bf264..18b28a2 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -20,7 +20,7 @@ import torch import earth2grid -from earth2grid.latlon import BilinearInterpolator +from earth2grid import BilinearInterpolator @pytest.mark.parametrize("with_channels", [True, False]) From a8815d554ce891563c17eb64a453766622545e31 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Fri, 6 Sep 2024 15:23:36 -0700 Subject: [PATCH 02/14] add earth2grid.NearestNeighborInterpolator --- earth2grid/__init__.py | 4 ++-- earth2grid/_regrid.py | 52 ++++++++++++++++++++++++++++++++++++++++++ earth2grid/spatial.py | 46 +++++++++++++++++++++++++++++++++++++ tests/test_regrid.py | 16 +++++++++++++ 4 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 earth2grid/spatial.py diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index d976a1f..b505ce9 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -15,9 +15,9 @@ import torch from earth2grid import base, healpix, latlon -from earth2grid._regrid import BilinearInterpolator, Identity +from earth2grid._regrid import BilinearInterpolator, Identity, S2NearestNeighborInterpolator -__all__ = ["base", "healpix", "latlon", "get_regridder", "BilinearInterpolator"] +__all__ = ["base", "healpix", "latlon", "get_regridder", "BilinearInterpolator", "S2NearestNeighborInterpolator"] def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 75970f5..25aed2b 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -17,6 +17,9 @@ import einops import netCDF4 as nc import torch +from scipy import spatial + +from earth2grid.spatial import ang2vec class TempestRegridder(torch.nn.Module): @@ -145,6 +148,55 @@ def forward(self, z: torch.Tensor): return interpolated +class S2NearestNeighborInterpolator(torch.nn.Module): + """Bilinear interpolation for a non-uniform grid""" + + def __init__( + self, + src_lon: torch.Tensor, + src_lat: torch.Tensor, + dest_lon: torch.Tensor, + dest_lat: torch.Tensor, + ) -> None: + """ + + Args: + src_lon: (m,) source longitude in degrees E + src_lat: (m,) source latitude in degrees N + dest_lon: (n,) output longitude in degrees E + dest_lat: (n,) output latitude in degrees N + + """ + super().__init__() + src_lon = torch.deg2rad(src_lon.cpu()) + src_lat = torch.deg2rad(src_lat.cpu()) + + dest_lon = torch.deg2rad(dest_lon.cpu()) + dest_lat = torch.deg2rad(dest_lat.cpu()) + + vec = torch.stack(ang2vec(src_lon, src_lat), -1) + + # havesign distance and euclidean are monotone for points on S2 so can use 3d lookups. + self.tree = spatial.KDTree(vec) + vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) + _, neighbors = self.tree.query(vec, k=1) + self.register_buffer("index", torch.as_tensor(neighbors).view(-1, 1)) + + def forward(self, z: torch.Tensor): + """ + Interpolate the field + + Args: + z: shape [*, X] + """ + *shape, x = z.shape + zrs = z.view(-1, x).T + # using embedding bag is 2x faster on cpu and 4x on gpu. + output = torch.nn.functional.embedding_bag(self.index, zrs, mode='sum') + output = output.T.view(*shape, -1) + return output + + class Identity(torch.nn.Module): def forward(self, x): return x diff --git a/earth2grid/spatial.py b/earth2grid/spatial.py new file mode 100644 index 0000000..974108e --- /dev/null +++ b/earth2grid/spatial.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def haversine_distance(lon1, lat1, lon2, lat2): + """ + Calculate the Haversine distance between two points on unit sphere + + Args: + lon1 (float): Longitude of the first point in radians. + lat1 (float): Latitude of the first point in radians. + lon2 (float): Longitude of the second point in radians. + lat2 (float): Latitude of the second point in radians. + + Returns: + float: Distance between the two points in kilometers. + """ + # Differences in coordinates + dlon = lon2 - lon1 + dlat = lat2 - lat1 + + # Haversine formula + a = torch.sin(dlat / 2) ** 2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon / 2) ** 2 + c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1 - a)) + return c + + +def ang2vec(lon, lat): + """convert lon,lat in radians to cartesian coordinates""" + x = torch.cos(lat) * torch.cos(lon) + y = torch.cos(lat) * torch.sin(lon) + z = torch.sin(lat) + return (x, y, z) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 18b28a2..06d038e 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -195,3 +195,19 @@ def test_out_of_bounds(): output = regrid(data) assert torch.all(torch.isnan(output)) + + +def test_NearestNeighborInterpolator(): + n = 10000 + torch.manual_seed(0) + lon = torch.rand(n) * 360 + lat = torch.rand(n) * 180 - 90 + + lond = torch.rand(n) * 360 + latd = torch.rand(n) * 180 - 90 + + interpolate = earth2grid.S2NearestNeighborInterpolator(lon, lat, lond, latd) + out = interpolate(torch.cos(torch.deg2rad(lon))) + expected = torch.cos(torch.deg2rad(lond)) + mae = torch.mean(torch.abs(out - expected)) + assert mae.item() < 0.02 From 9ed6d48055f100028467cae8cf79ae034b20ab4e Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Fri, 6 Sep 2024 15:26:52 -0700 Subject: [PATCH 03/14] improve docstring --- earth2grid/_regrid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 25aed2b..09929da 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -187,7 +187,10 @@ def forward(self, z: torch.Tensor): Interpolate the field Args: - z: shape [*, X] + z: shape [*, m] + + Returns: + shape [*, n] """ *shape, x = z.shape zrs = z.view(-1, x).T From 5bd4f2cf4fac4e36a0269b51aa11943c71c60629 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Fri, 6 Sep 2024 15:40:39 -0700 Subject: [PATCH 04/14] Add Linear Barycentric interopltor --- earth2grid/__init__.py | 17 ++++++++++-- earth2grid/_regrid.py | 62 +++++++++++++++++++++++++++++++++++++++++- tests/test_regrid.py | 16 +++++++++++ 3 files changed, 92 insertions(+), 3 deletions(-) diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index b505ce9..98e309e 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -15,9 +15,22 @@ import torch from earth2grid import base, healpix, latlon -from earth2grid._regrid import BilinearInterpolator, Identity, S2NearestNeighborInterpolator +from earth2grid._regrid import ( + BilinearInterpolator, + Identity, + S2LinearBarycentricInterpolator, + S2NearestNeighborInterpolator, +) -__all__ = ["base", "healpix", "latlon", "get_regridder", "BilinearInterpolator", "S2NearestNeighborInterpolator"] +__all__ = [ + "base", + "healpix", + "latlon", + "get_regridder", + "BilinearInterpolator", + "S2NearestNeighborInterpolator", + "S2LinearBarycentricInterpolator", +] def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 09929da..3bf60f4 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -19,7 +19,7 @@ import torch from scipy import spatial -from earth2grid.spatial import ang2vec +from earth2grid.spatial import ang2vec, haversine_distance class TempestRegridder(torch.nn.Module): @@ -200,6 +200,66 @@ def forward(self, z: torch.Tensor): return output +class S2LinearBarycentricInterpolator(torch.nn.Module): + """Linear Barycentric Interpolator for unstructured data + + This is equivalent to inverse square weighting + + """ + + def __init__( + self, + src_lon: torch.Tensor, + src_lat: torch.Tensor, + dest_lon: torch.Tensor, + dest_lat: torch.Tensor, + ) -> None: + """ + + Args: + src_lon: (m,) source longitude in degrees E + src_lat: (m,) source latitude in degrees N + dest_lon: (n,) output longitude in degrees E + dest_lat: (n,) output latitude in degrees N + + """ + super().__init__() + src_lon = torch.deg2rad(src_lon.cpu()) + src_lat = torch.deg2rad(src_lat.cpu()) + + dest_lon = torch.deg2rad(dest_lon.cpu()) + dest_lat = torch.deg2rad(dest_lat.cpu()) + + vec = torch.stack(ang2vec(src_lon, src_lat), -1) + + # havesign distance and euclidean are monotone for points on S2 so can use 3d lookups. + self.tree = spatial.KDTree(vec) + vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) + _, neighbors = self.tree.query(vec, k=3) + d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) + lam = 1 / d + lam = lam / lam.sum(-1, keepdim=True) + self.register_buffer("index", torch.as_tensor(neighbors).view(-1, 3)) + self.register_buffer("weight", lam) + + def forward(self, z: torch.Tensor): + """ + Interpolate the field + + Args: + z: shape [*, m] + + Returns: + shape [*, n] + """ + *shape, x = z.shape + zrs = z.view(-1, x).T + # using embedding bag is 2x faster on cpu and 4x on gpu. + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weight, mode='sum') + output = output.T.view(*shape, -1) + return output + + class Identity(torch.nn.Module): def forward(self, x): return x diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 06d038e..361ba4a 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -211,3 +211,19 @@ def test_NearestNeighborInterpolator(): expected = torch.cos(torch.deg2rad(lond)) mae = torch.mean(torch.abs(out - expected)) assert mae.item() < 0.02 + + +def test_BaryCentric(): + n = 10000 + torch.manual_seed(0) + lon = torch.rand(n) * 360 + lat = torch.rand(n) * 180 - 90 + + lond = torch.rand(n) * 360 + latd = torch.rand(n) * 180 - 90 + + interpolate = earth2grid.S2LinearBarycentricInterpolator(lon, lat, lond, latd) + out = interpolate(torch.cos(torch.deg2rad(lon))) + expected = torch.cos(torch.deg2rad(lond)) + mae = torch.mean(torch.abs(out - expected)) + assert mae.item() < 0.011 From fd9aa9f4ee9e23aedda148941422247a5c584b6c Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Fri, 6 Sep 2024 15:41:48 -0700 Subject: [PATCH 05/14] add scipy --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f8bc513..0b54520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "netCDF4>=1.6.5", "numpy>=1.23.3", "torch>=2.0.1", + "scipy" ] [project.urls] From 97a3fda5c57986a97dd6184520fdad174fdf5720 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Sat, 7 Sep 2024 13:20:03 -0700 Subject: [PATCH 06/14] consolidate nearest neighbor implementation --- earth2grid/__init__.py | 1 - earth2grid/_regrid.py | 78 ++++++++---------------------------------- tests/test_regrid.py | 21 ++---------- 3 files changed, 18 insertions(+), 82 deletions(-) diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index 98e309e..410c6a4 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -18,7 +18,6 @@ from earth2grid._regrid import ( BilinearInterpolator, Identity, - S2LinearBarycentricInterpolator, S2NearestNeighborInterpolator, ) diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 3bf60f4..a1122b1 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -149,63 +149,7 @@ def forward(self, z: torch.Tensor): class S2NearestNeighborInterpolator(torch.nn.Module): - """Bilinear interpolation for a non-uniform grid""" - - def __init__( - self, - src_lon: torch.Tensor, - src_lat: torch.Tensor, - dest_lon: torch.Tensor, - dest_lat: torch.Tensor, - ) -> None: - """ - - Args: - src_lon: (m,) source longitude in degrees E - src_lat: (m,) source latitude in degrees N - dest_lon: (n,) output longitude in degrees E - dest_lat: (n,) output latitude in degrees N - - """ - super().__init__() - src_lon = torch.deg2rad(src_lon.cpu()) - src_lat = torch.deg2rad(src_lat.cpu()) - - dest_lon = torch.deg2rad(dest_lon.cpu()) - dest_lat = torch.deg2rad(dest_lat.cpu()) - - vec = torch.stack(ang2vec(src_lon, src_lat), -1) - - # havesign distance and euclidean are monotone for points on S2 so can use 3d lookups. - self.tree = spatial.KDTree(vec) - vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) - _, neighbors = self.tree.query(vec, k=1) - self.register_buffer("index", torch.as_tensor(neighbors).view(-1, 1)) - - def forward(self, z: torch.Tensor): - """ - Interpolate the field - - Args: - z: shape [*, m] - - Returns: - shape [*, n] - """ - *shape, x = z.shape - zrs = z.view(-1, x).T - # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(self.index, zrs, mode='sum') - output = output.T.view(*shape, -1) - return output - - -class S2LinearBarycentricInterpolator(torch.nn.Module): - """Linear Barycentric Interpolator for unstructured data - - This is equivalent to inverse square weighting - - """ + """K-nearest neighbor interpolator with inverse distance weighting""" def __init__( self, @@ -213,6 +157,7 @@ def __init__( src_lat: torch.Tensor, dest_lon: torch.Tensor, dest_lat: torch.Tensor, + k: int = 1, ) -> None: """ @@ -221,6 +166,7 @@ def __init__( src_lat: (m,) source latitude in degrees N dest_lon: (n,) output longitude in degrees E dest_lat: (n,) output latitude in degrees N + k: number of neighbors """ super().__init__() @@ -235,12 +181,18 @@ def __init__( # havesign distance and euclidean are monotone for points on S2 so can use 3d lookups. self.tree = spatial.KDTree(vec) vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) - _, neighbors = self.tree.query(vec, k=3) - d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) - lam = 1 / d - lam = lam / lam.sum(-1, keepdim=True) - self.register_buffer("index", torch.as_tensor(neighbors).view(-1, 3)) - self.register_buffer("weight", lam) + _, neighbors = self.tree.query(vec, k=k) + self.register_buffer("index", torch.as_tensor(neighbors).view(-1, k)) + + self.k = k + + if k > 1: + d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) + lam = 1 / d + lam = lam / lam.sum(-1, keepdim=True) + self.register_buffer("weight", lam) + else: + self.weight = None def forward(self, z: torch.Tensor): """ diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 361ba4a..b7d0776 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -197,7 +197,8 @@ def test_out_of_bounds(): assert torch.all(torch.isnan(output)) -def test_NearestNeighborInterpolator(): +@pytest.mark.parametrize("k", [1, 2, 3]) +def test_NearestNeighborInterpolator(k): n = 10000 torch.manual_seed(0) lon = torch.rand(n) * 360 @@ -206,24 +207,8 @@ def test_NearestNeighborInterpolator(): lond = torch.rand(n) * 360 latd = torch.rand(n) * 180 - 90 - interpolate = earth2grid.S2NearestNeighborInterpolator(lon, lat, lond, latd) + interpolate = earth2grid.S2NearestNeighborInterpolator(lon, lat, lond, latd, k=k) out = interpolate(torch.cos(torch.deg2rad(lon))) expected = torch.cos(torch.deg2rad(lond)) mae = torch.mean(torch.abs(out - expected)) assert mae.item() < 0.02 - - -def test_BaryCentric(): - n = 10000 - torch.manual_seed(0) - lon = torch.rand(n) * 360 - lat = torch.rand(n) * 180 - 90 - - lond = torch.rand(n) * 360 - latd = torch.rand(n) * 180 - 90 - - interpolate = earth2grid.S2LinearBarycentricInterpolator(lon, lat, lond, latd) - out = interpolate(torch.cos(torch.deg2rad(lon))) - expected = torch.cos(torch.deg2rad(lond)) - mae = torch.mean(torch.abs(out - expected)) - assert mae.item() < 0.011 From 6729561915400976da9dd284ed246901bf3305c7 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Sat, 7 Sep 2024 13:32:33 -0700 Subject: [PATCH 07/14] refactor to serializable design --- earth2grid/_regrid.py | 121 ++++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 62 deletions(-) diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index a1122b1..5f3f31c 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -22,6 +22,27 @@ from earth2grid.spatial import ang2vec, haversine_distance +class Regridder(torch.nn.Module): + """Regridder to n points, with p nonzero maps weights + + Forward: + (*, m) -> (*, n) + """ + + def __init__(self, n: int, p: int): + super().__init__() + self.register_buffer("index", torch.empty(n, p, dtype=torch.long)) + self.register_buffer("weight", torch.ones(n, p)) + + def forward(self, z): + *shape, x = z.shape + zrs = z.view(-1, x).T + # using embedding bag is 2x faster on cpu and 4x on gpu. + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weight, mode='sum') + output = output.T.view(*shape, -1) + return output + + class TempestRegridder(torch.nn.Module): def __init__(self, file_path): super().__init__() @@ -148,68 +169,44 @@ def forward(self, z: torch.Tensor): return interpolated -class S2NearestNeighborInterpolator(torch.nn.Module): - """K-nearest neighbor interpolator with inverse distance weighting""" - - def __init__( - self, - src_lon: torch.Tensor, - src_lat: torch.Tensor, - dest_lon: torch.Tensor, - dest_lat: torch.Tensor, - k: int = 1, - ) -> None: - """ - - Args: - src_lon: (m,) source longitude in degrees E - src_lat: (m,) source latitude in degrees N - dest_lon: (n,) output longitude in degrees E - dest_lat: (n,) output latitude in degrees N - k: number of neighbors - - """ - super().__init__() - src_lon = torch.deg2rad(src_lon.cpu()) - src_lat = torch.deg2rad(src_lat.cpu()) - - dest_lon = torch.deg2rad(dest_lon.cpu()) - dest_lat = torch.deg2rad(dest_lat.cpu()) - - vec = torch.stack(ang2vec(src_lon, src_lat), -1) - - # havesign distance and euclidean are monotone for points on S2 so can use 3d lookups. - self.tree = spatial.KDTree(vec) - vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) - _, neighbors = self.tree.query(vec, k=k) - self.register_buffer("index", torch.as_tensor(neighbors).view(-1, k)) - - self.k = k - - if k > 1: - d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) - lam = 1 / d - lam = lam / lam.sum(-1, keepdim=True) - self.register_buffer("weight", lam) - else: - self.weight = None - - def forward(self, z: torch.Tensor): - """ - Interpolate the field - - Args: - z: shape [*, m] - - Returns: - shape [*, n] - """ - *shape, x = z.shape - zrs = z.view(-1, x).T - # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weight, mode='sum') - output = output.T.view(*shape, -1) - return output +def S2NearestNeighborInterpolator( + src_lon: torch.Tensor, + src_lat: torch.Tensor, + dest_lon: torch.Tensor, + dest_lat: torch.Tensor, + k: int = 1, +) -> Regridder: + """K-nearest neighbor interpolator with inverse distance weighting + + Args: + src_lon: (m,) source longitude in degrees E + src_lat: (m,) source latitude in degrees N + dest_lon: (n,) output longitude in degrees E + dest_lat: (n,) output latitude in degrees N + k: number of neighbors + + """ + src_lon = torch.deg2rad(src_lon.cpu()) + src_lat = torch.deg2rad(src_lat.cpu()) + + dest_lon = torch.deg2rad(dest_lon.cpu()) + dest_lat = torch.deg2rad(dest_lat.cpu()) + + vec = torch.stack(ang2vec(src_lon, src_lat), -1) + + # havesign distance and euclidean are monotone for points on S2 so can use 3d lookups. + tree = spatial.KDTree(vec) + vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) + _, neighbors = tree.query(vec, k=k) + regridder = Regridder(dest_lon.shape[0], k) + regridder.index.copy_(torch.as_tensor(neighbors).view(-1, k)) + if k > 1: + d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) + lam = 1 / d + lam = lam / lam.sum(-1, keepdim=True) + regridder.weight.copy_(lam) + + return regridder class Identity(torch.nn.Module): From c4c47a94768844c99a76616a3c0ff41a8d83c45b Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Sun, 8 Sep 2024 09:21:07 -0700 Subject: [PATCH 08/14] add from_state_dict to regridder --- earth2grid/__init__.py | 8 ++------ earth2grid/_regrid.py | 7 +++++++ tests/test_regrid.py | 3 +++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index 410c6a4..2747948 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -15,11 +15,7 @@ import torch from earth2grid import base, healpix, latlon -from earth2grid._regrid import ( - BilinearInterpolator, - Identity, - S2NearestNeighborInterpolator, -) +from earth2grid._regrid import BilinearInterpolator, Identity, Regridder, S2NearestNeighborInterpolator __all__ = [ "base", @@ -28,7 +24,7 @@ "get_regridder", "BilinearInterpolator", "S2NearestNeighborInterpolator", - "S2LinearBarycentricInterpolator", + "Regridder", ] diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 5f3f31c..dd58deb 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -42,6 +42,13 @@ def forward(self, z): output = output.T.view(*shape, -1) return output + @staticmethod + def from_state_dict(d): + n, p = d["index"].shape + regridder = Regridder(n, p) + regridder.load_state_dict(d) + return regridder + class TempestRegridder(torch.nn.Module): def __init__(self, file_path): diff --git a/tests/test_regrid.py b/tests/test_regrid.py index b7d0776..6d3004a 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -212,3 +212,6 @@ def test_NearestNeighborInterpolator(k): expected = torch.cos(torch.deg2rad(lond)) mae = torch.mean(torch.abs(out - expected)) assert mae.item() < 0.02 + + # load-reload + earth2grid.Regridder.from_state_dict(interpolate.state_dict()) From 397a460890894735a3ec37700edca1b2fb94d0d7 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Sun, 8 Sep 2024 09:23:55 -0700 Subject: [PATCH 09/14] add type hints --- earth2grid/_regrid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index dd58deb..d78f3ce 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Dict import einops import netCDF4 as nc @@ -43,7 +44,7 @@ def forward(self, z): return output @staticmethod - def from_state_dict(d): + def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder": n, p = d["index"].shape regridder = Regridder(n, p) regridder.load_state_dict(d) From 481e264e291bcb9ea1a25712daa323039eec8b2c Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Mon, 9 Sep 2024 09:08:36 -0700 Subject: [PATCH 10/14] support non flat output shapes --- earth2grid/_regrid.py | 23 ++++++++++++++--------- earth2grid/healpix.py | 39 ++++++++++++++++++++++----------------- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index d78f3ce..184b801 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Dict +from typing import Dict, Sequence import einops import netCDF4 as nc @@ -27,26 +27,31 @@ class Regridder(torch.nn.Module): """Regridder to n points, with p nonzero maps weights Forward: - (*, m) -> (*, n) + (*, m) -> (*,) + shape """ - def __init__(self, n: int, p: int): + def __init__(self, shape: Sequence[int], p: int): super().__init__() - self.register_buffer("index", torch.empty(n, p, dtype=torch.long)) - self.register_buffer("weight", torch.ones(n, p)) + self.register_buffer("index", torch.empty(*shape, p, dtype=torch.long)) + self.register_buffer("weight", torch.ones(*shape, p)) def forward(self, z): *shape, x = z.shape zrs = z.view(-1, x).T + + *output_shape, p = self.index.shape + index = self.index.view(-1, p) + weight = self.weight.view(-1, p) + # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weight, mode='sum') + output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum') output = output.T.view(*shape, -1) - return output + return output.reshape(list(shape) + output_shape) @staticmethod def from_state_dict(d: Dict[str, torch.Tensor]) -> "Regridder": n, p = d["index"].shape - regridder = Regridder(n, p) + regridder = Regridder((n,), p) regridder.load_state_dict(d) return regridder @@ -206,7 +211,7 @@ def S2NearestNeighborInterpolator( tree = spatial.KDTree(vec) vec = torch.stack(ang2vec(dest_lon.cpu(), dest_lat.cpu()), -1) _, neighbors = tree.query(vec, k=k) - regridder = Regridder(dest_lon.shape[0], k) + regridder = Regridder(dest_lon.shape, k) regridder.index.copy_(torch.as_tensor(neighbors).view(-1, k)) if k > 1: d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index 6b93176..99cdd08 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -43,6 +43,7 @@ import torch from earth2grid import healpix_bare +from earth2grid._regrid import Regridder try: import pyvista as pv @@ -230,24 +231,19 @@ def _convert_xyindex(nside: int, src: XY, dest: XY, i): return i -class ApplyWeights(torch.nn.Module): - def __init__(self, pix: torch.Tensor, weight: torch.Tensor): - super().__init__() +def ApplyWeights(pix: torch.Tensor, weight: torch.Tensor): + # the first dim is the 4 point stencil + # TODO delete + p, *shape = pix.shape - # the first dim is the 4 point stencil - n, *self.shape = pix.shape + pix = pix.view(p, -1).T + weight = weight.view(p, -1).T - pix = pix.view(n, -1).T - weight = weight.view(n, -1).T - - self.register_buffer("index", pix) - self.register_buffer("weight", weight) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - *shape, npix = x.shape - x = x.view(-1, npix).T - interpolated = torch.nn.functional.embedding_bag(self.index, x, per_sample_weights=self.weight, mode="sum").T - return interpolated.view(shape + self.shape) + regridder = Regridder(pix.shape[:-1], p=pix.shape[1]) + regridder.to(weight) + regridder.index.copy_(pix) + regridder.weight.copy_(weight) + return regridder @dataclass @@ -345,7 +341,16 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): i_ring, weights = healpix_bare.get_interp_weights(self._nside(), torch.tensor(lon), torch.tensor(lat)) i_nest = healpix_bare.ring2nest(self._nside(), i_ring.ravel()) i_me = self._nest2me(i_nest).reshape(i_ring.shape) - return ApplyWeights(i_me, weights) + + # reshape to (*, p) + weights = weights.movedim(0, -1) + index = i_me.movedim(0, -1) + + regridder = Regridder(weights.shape[:-1], p=weights.shape[-1]) + regridder.to(weights) + regridder.index.copy_(index) + regridder.weight.copy_(weights) + return regridder def approximate_grid_length_meters(self): return approx_grid_length_meters(self._nside()) From a9185bdb19eb7b1fd8715423246015e782866e1e Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Mon, 9 Sep 2024 12:25:00 -0700 Subject: [PATCH 11/14] regularize distances --- earth2grid/_regrid.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 184b801..8eb23fc 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -188,6 +188,7 @@ def S2NearestNeighborInterpolator( dest_lon: torch.Tensor, dest_lat: torch.Tensor, k: int = 1, + eps=1e-7, ) -> Regridder: """K-nearest neighbor interpolator with inverse distance weighting @@ -196,7 +197,9 @@ def S2NearestNeighborInterpolator( src_lat: (m,) source latitude in degrees N dest_lon: (n,) output longitude in degrees E dest_lat: (n,) output latitude in degrees N - k: number of neighbors + k: number of neighbors, default: 1 + eps: regularization factor for inverse distance weighting. Only used if + k > 1. """ src_lon = torch.deg2rad(src_lon.cpu()) @@ -215,7 +218,7 @@ def S2NearestNeighborInterpolator( regridder.index.copy_(torch.as_tensor(neighbors).view(-1, k)) if k > 1: d = haversine_distance(dest_lon[:, None], dest_lat[:, None], src_lon[neighbors], src_lat[neighbors]) - lam = 1 / d + lam = 1 / (d + eps) lam = lam / lam.sum(-1, keepdim=True) regridder.weight.copy_(lam) From 8248763f39045b24ee5d60a260b8e862a1d602bd Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Mon, 9 Sep 2024 13:52:26 -0700 Subject: [PATCH 12/14] Delete applyWeights --- earth2grid/healpix.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index 99cdd08..fb8887c 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -231,21 +231,6 @@ def _convert_xyindex(nside: int, src: XY, dest: XY, i): return i -def ApplyWeights(pix: torch.Tensor, weight: torch.Tensor): - # the first dim is the 4 point stencil - # TODO delete - p, *shape = pix.shape - - pix = pix.view(p, -1).T - weight = weight.view(p, -1).T - - regridder = Regridder(pix.shape[:-1], p=pix.shape[1]) - regridder.to(weight) - regridder.index.copy_(pix) - regridder.weight.copy_(weight) - return regridder - - @dataclass class Grid(base.Grid): """A Healpix Grid From e2f2392f6c5d79001a0f5cbbc7cad3c6f374b0c2 Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Mon, 16 Sep 2024 12:43:28 -0700 Subject: [PATCH 13/14] respond to reviews --- earth2grid/__init__.py | 4 ++-- earth2grid/_regrid.py | 5 ++++- tests/test_regrid.py | 13 ++++++++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index 2747948..6d52f27 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -15,7 +15,7 @@ import torch from earth2grid import base, healpix, latlon -from earth2grid._regrid import BilinearInterpolator, Identity, Regridder, S2NearestNeighborInterpolator +from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder __all__ = [ "base", @@ -23,7 +23,7 @@ "latlon", "get_regridder", "BilinearInterpolator", - "S2NearestNeighborInterpolator", + "KNNS2Interpolator", "Regridder", ] diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 8eb23fc..5e9e5e8 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -182,7 +182,7 @@ def forward(self, z: torch.Tensor): return interpolated -def S2NearestNeighborInterpolator( +def KNNS2Interpolator( src_lon: torch.Tensor, src_lat: torch.Tensor, dest_lon: torch.Tensor, @@ -202,6 +202,9 @@ def S2NearestNeighborInterpolator( k > 1. """ + if (src_lat.ndim != 1) or (src_lon.ndim != 1) or (dest_lat.ndim != 1) or (dest_lon.ndim != 1): + raise ValueError("All input coordinates must be 1 dimensional.") + src_lon = torch.deg2rad(src_lon.cpu()) src_lat = torch.deg2rad(src_lat.cpu()) diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 6d3004a..73a6b9e 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -200,14 +200,15 @@ def test_out_of_bounds(): @pytest.mark.parametrize("k", [1, 2, 3]) def test_NearestNeighborInterpolator(k): n = 10000 + m = 887 torch.manual_seed(0) lon = torch.rand(n) * 360 lat = torch.rand(n) * 180 - 90 - lond = torch.rand(n) * 360 - latd = torch.rand(n) * 180 - 90 + lond = torch.rand(m) * 360 + latd = torch.rand(m) * 180 - 90 - interpolate = earth2grid.S2NearestNeighborInterpolator(lon, lat, lond, latd, k=k) + interpolate = earth2grid.KNNS2Interpolator(lon, lat, lond, latd, k=k) out = interpolate(torch.cos(torch.deg2rad(lon))) expected = torch.cos(torch.deg2rad(lond)) mae = torch.mean(torch.abs(out - expected)) @@ -215,3 +216,9 @@ def test_NearestNeighborInterpolator(k): # load-reload earth2grid.Regridder.from_state_dict(interpolate.state_dict()) + + # try batched interpolation + x = torch.cos(torch.deg2rad(lon)) + x = x.unsqueeze(0) + out = interpolate(x) + assert out.shape == (1, m) From 9cb3cd7f4781d885c7b17658dfc1efbab46f8e7f Mon Sep 17 00:00:00 2001 From: "Noah D. Brenowitz" Date: Mon, 16 Sep 2024 12:49:56 -0700 Subject: [PATCH 14/14] update api docs --- docs/api.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index a272fc2..d57f5e8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,7 +19,12 @@ Regridding .. autofunction:: earth2grid.get_regridder +.. autofunction:: earth2grid.KNNS2Interpolator + +.. autofunction:: earth2grid.BilinearInterpolator + Other utilities --------------- +.. autofunction:: earth2grid.healpix.reorder .. autofunction:: earth2grid.healpix.pad