Skip to content

Commit

Permalink
implement fill value for points outside of data bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
nbren12 committed Aug 24, 2024
1 parent 44297de commit e016620
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
47 changes: 36 additions & 11 deletions earth2grid/latlon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import numpy as np
import torch

Expand All @@ -27,7 +29,12 @@ class BilinearInterpolator(torch.nn.Module):
"""Bilinear interpolation for a non-uniform grid"""

def __init__(
self, x_coords: torch.Tensor, y_coords: torch.Tensor, x_query: torch.Tensor, y_query: torch.Tensor
self,
x_coords: torch.Tensor,
y_coords: torch.Tensor,
x_query: torch.Tensor,
y_query: torch.Tensor,
fill_value=math.nan,
) -> None:
"""
Expand All @@ -38,9 +45,12 @@ def __init__(
y_query (Tensor): Y-coordinates for query points, shape [N].
"""
super().__init__()
self.fill_value = fill_value

# Ensure input coordinates are float for interpolation
x_coords, y_coords = x_coords.float(), y_coords.float()
x_coords, y_coords = x_coords.double(), y_coords.double()
x_query = x_query.double()
y_query = y_query.double()

if torch.any(x_coords[1:] < x_coords[:-1]):
raise ValueError("x_coords must be in non-decreasing order.")
Expand All @@ -54,11 +64,22 @@ def __init__(
y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1
y_u_idx = y_l_idx + 1

# Clip indices to ensure they are within the bounds of the input grid
x_l_idx = x_l_idx.clamp(0, x_coords.size(0) - 2)
x_u_idx = x_u_idx.clamp(1, x_coords.size(0) - 1)
y_l_idx = y_l_idx.clamp(0, y_coords.size(0) - 2)
y_u_idx = y_u_idx.clamp(1, y_coords.size(0) - 1)
# fill in nan outside mask
def isin(x, a, b):
return (x <= b) & (x >= a)

mask = (
isin(x_l_idx, 0, x_coords.size(0) - 2)
& isin(x_u_idx, 1, x_coords.size(0) - 1)
& isin(y_l_idx, 0, y_coords.size(0) - 2)
& isin(y_u_idx, 1, y_coords.size(0) - 1)
)
x_u_idx = x_u_idx[mask]
x_l_idx = x_l_idx[mask]
y_u_idx = y_u_idx[mask]
y_l_idx = y_l_idx[mask]
x_query = x_query[mask]
y_query = y_query[mask]

# Compute weights
x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx])
Expand All @@ -69,8 +90,6 @@ def __init__(
[x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1
)

self.register_buffer("weights", weights)

stride = x_coords.size(-1)
index = torch.stack(
[
Expand All @@ -81,6 +100,8 @@ def __init__(
],
dim=-1,
)
self.register_buffer("weights", weights)
self.register_buffer("mask", mask)
self.register_buffer("index", index)

def forward(self, z: torch.Tensor):
Expand All @@ -93,8 +114,12 @@ def forward(self, z: torch.Tensor):
*shape, y, x = z.shape
zrs = z.view(-1, y * x).T
# using embedding bag is 2x faster on cpu and 4x on gpu.
interpolated = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
interpolated = interpolated.T.view(*shape, self.weights.size(0))
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
interpolated = torch.full(
[self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device
)
interpolated.masked_scatter_(self.mask.unsqueeze(-1), output)
interpolated = interpolated.T.view(*shape, self.mask.numel())
return interpolated


Expand Down
21 changes: 19 additions & 2 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_interpolation(self):

# Execute
interpolator = BilinearInterpolator(x_coords, y_coords, x_query, y_query)
interpolator.to(input_tensor)
result = interpolator(input_tensor)

# Verify
Expand All @@ -142,13 +143,13 @@ def test_raises_error_when_coordinates_not_increasing_x(self):
x_coords = torch.linspace(1, -1, steps=32) # Example non-uniform x-coordinates
y_coords = torch.linspace(-1, 1, steps=32) # Example non-uniform y-coordinates
with self.assertRaises(ValueError):
BilinearInterpolator(x_coords, y_coords, [0], [0])
BilinearInterpolator(x_coords, y_coords, torch.tensor([0]), torch.tensor([0]))

def test_raises_error_when_coordinates_not_increasing_y(self):
x_coords = torch.linspace(-1, 1, steps=32) # Example non-uniform x-coordinates
y_coords = torch.linspace(1, -1, steps=32) # Example non-uniform y-coordinates
with self.assertRaises(ValueError):
BilinearInterpolator(x_coords, y_coords, [0], [0])
BilinearInterpolator(x_coords, y_coords, torch.tensor([0]), torch.tensor([0]))

def test_interpolation_func(self):
# Setup
Expand All @@ -170,6 +171,7 @@ def func(x, y):

# Execute
interpolator = BilinearInterpolator(x_coords, y_coords, x_query, y_query)
interpolator.to(input_tensor)
result = interpolator(input_tensor)

# Verify
Expand All @@ -178,3 +180,18 @@ def func(x, y):
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
interpolator.cuda()
interpolator(input_tensor.cuda())


def test_out_of_bounds():
x_coords = torch.tensor([0, 1, 2]).float()
y_coords = torch.tensor([0, 1, 2]).float()

x_query = torch.tensor([-1, 3]).float()
y_query = torch.tensor([-1, 3]).float()
regrid = BilinearInterpolator(x_coords, y_coords, x_query, y_query)

data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).float()
regrid.to(data)
output = regrid(data)

assert torch.all(torch.isnan(output))

0 comments on commit e016620

Please sign in to comment.