Skip to content

Commit

Permalink
Merge pull request #9 from NVlabs/fix-lat-long-bug
Browse files Browse the repository at this point in the history
Fix gpu -> tensor bug
  • Loading branch information
nbren12 authored Aug 21, 2024
2 parents 3abe3ba + 3edcc6e commit 2e7af1b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
15 changes: 7 additions & 8 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _npix(self):

def _nest_ipix(self):
"""convert to nested index number"""
i = torch.arange(self._npix())
i = torch.arange(self._npix(), device="cpu")
if isinstance(self.pixel_order, XY):
i_xy = _convert_xyindex(nside=self._nside(), src=self.pixel_order, dest=XY(), i=i)
i = xy2nest(self._nside(), i_xy)
Expand All @@ -195,29 +195,28 @@ def _nest_ipix(self):
pass
else:
raise ValueError(self.pixel_order)
return i.numpy()
return i

def _nest2me(self, ipix: np.ndarray) -> np.ndarray:
def _nest2me(self, ipix: torch.Tensor) -> torch.Tensor:
"""return the index in my PIXELORDER corresponding to ipix in NEST ordering"""
if isinstance(self.pixel_order, XY):
i_xy = nest2xy(self._nside(), ipix)
i_me = _convert_xyindex(nside=self._nside(), src=XY(), dest=self.pixel_order, i=i_xy)
elif self.pixel_order == PixelOrder.RING:
ipix_t = torch.from_numpy(ipix)
i_me = healpix_bare.nest2ring(self._nside(), ipix_t).numpy()
i_me = healpix_bare.nest2ring(self._nside(), ipix)
elif self.pixel_order == PixelOrder.NEST:
i_me = ipix
return i_me

@property
def lat(self):
ipix = torch.from_numpy(self._nest_ipix())
ipix = self._nest_ipix()
_, lat = healpix_bare.pix2ang(self._nside(), ipix, lonlat=True, nest=True)
return lat.numpy()

@property
def lon(self):
ipix = torch.from_numpy(self._nest_ipix())
ipix = self._nest_ipix()
lon, _ = healpix_bare.pix2ang(self._nside(), ipix, lonlat=True, nest=True)
return lon.numpy()

Expand Down Expand Up @@ -256,7 +255,7 @@ def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray):
lat, lon = np.broadcast_arrays(lat, lon)
i_ring, weights = healpix_bare.get_interp_weights(self._nside(), torch.tensor(lon), torch.tensor(lat))
i_nest = healpix_bare.ring2nest(self._nside(), i_ring.ravel())
i_me = torch.from_numpy(self._nest2me(i_nest.numpy())).view(i_ring.shape)
i_me = self._nest2me(i_nest).reshape(i_ring.shape)
return ApplyWeights(i_me, weights)

def approximate_grid_length_meters(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,18 @@ def test_conv2d():
weight = torch.zeros(cout, cin, 3, 3)
out = healpix.conv2d(x, weight, padding=(1, 1))
assert out.shape == (n, cout, 1, npix)


def test_latlon_cuda_set_device_regression():
"""See https://github.com/NVlabs/earth2grid/issues/6"""

if torch.cuda.device_count() == 0:
pytest.skip()

default = torch.get_default_device()
try:
torch.set_default_device("cuda")
grid = healpix.Grid(4)
grid.lat
finally:
torch.set_default_device(default)

0 comments on commit 2e7af1b

Please sign in to comment.