Skip to content

Commit

Permalink
respond to reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
nbren12 committed Sep 16, 2024
1 parent 8248763 commit e2f2392
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
4 changes: 2 additions & 2 deletions earth2grid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import torch

from earth2grid import base, healpix, latlon
from earth2grid._regrid import BilinearInterpolator, Identity, Regridder, S2NearestNeighborInterpolator
from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder

__all__ = [
"base",
"healpix",
"latlon",
"get_regridder",
"BilinearInterpolator",
"S2NearestNeighborInterpolator",
"KNNS2Interpolator",
"Regridder",
]

Expand Down
5 changes: 4 additions & 1 deletion earth2grid/_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def forward(self, z: torch.Tensor):
return interpolated


def S2NearestNeighborInterpolator(
def KNNS2Interpolator(
src_lon: torch.Tensor,
src_lat: torch.Tensor,
dest_lon: torch.Tensor,
Expand All @@ -202,6 +202,9 @@ def S2NearestNeighborInterpolator(
k > 1.
"""
if (src_lat.ndim != 1) or (src_lon.ndim != 1) or (dest_lat.ndim != 1) or (dest_lon.ndim != 1):
raise ValueError("All input coordinates must be 1 dimensional.")

src_lon = torch.deg2rad(src_lon.cpu())
src_lat = torch.deg2rad(src_lat.cpu())

Expand Down
13 changes: 10 additions & 3 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,25 @@ def test_out_of_bounds():
@pytest.mark.parametrize("k", [1, 2, 3])
def test_NearestNeighborInterpolator(k):
n = 10000
m = 887
torch.manual_seed(0)
lon = torch.rand(n) * 360
lat = torch.rand(n) * 180 - 90

lond = torch.rand(n) * 360
latd = torch.rand(n) * 180 - 90
lond = torch.rand(m) * 360
latd = torch.rand(m) * 180 - 90

interpolate = earth2grid.S2NearestNeighborInterpolator(lon, lat, lond, latd, k=k)
interpolate = earth2grid.KNNS2Interpolator(lon, lat, lond, latd, k=k)
out = interpolate(torch.cos(torch.deg2rad(lon)))
expected = torch.cos(torch.deg2rad(lond))
mae = torch.mean(torch.abs(out - expected))
assert mae.item() < 0.02

# load-reload
earth2grid.Regridder.from_state_dict(interpolate.state_dict())

# try batched interpolation
x = torch.cos(torch.deg2rad(lon))
x = x.unsqueeze(0)
out = interpolate(x)
assert out.shape == (1, m)

0 comments on commit e2f2392

Please sign in to comment.