Skip to content

Commit

Permalink
Move cuda padding routines to healpix.py, check for cuda before attem…
Browse files Browse the repository at this point in the history
…pting cuda install
  • Loading branch information
daviddpruitt committed Aug 23, 2024
1 parent f5ad041 commit 271b5a1
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 119 deletions.
87 changes: 81 additions & 6 deletions earth2grid/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"""

import math
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Union
Expand All @@ -48,8 +49,17 @@
except ImportError:
pv = None

from earth2grid import base, healpixpad
from earth2grid.third_party.zephyr.healpix import healpix_pad
try:
import healpixpad_cuda

healpixpad_cuda_avail = True
except ImportError:
healpixpad_cuda_avail = False
warnings.warn("healpixpad_cuda module not available, reverting to CPU for all padding routines")


from earth2grid import base
from earth2grid.third_party.zephyr.healpix import healpix_pad as heapixpad_cpu

__all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d"]

Expand All @@ -59,7 +69,7 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
Pad each face consistently with its according neighbors in the HEALPix
Args:
x: The input tensor of shape [N, F, H, W]
x: The input tensor of shape [N, F, H, W] or [N, F, C, H, W]
padding: the amount of padding
Returns:
Expand All @@ -80,10 +90,12 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor:
torch.Size([1, 12, 18, 18])
"""
if x.device.type != 'cuda':
return healpix_pad(x, padding)
if x.device.type != 'cuda' or not healpixpad_cuda_avail:
return heapixpad_cpu(x, padding)
elif x.ndim == 5:
return HEALPixPadFunction.apply(x, padding)
else:
return healpixpad.HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2)
return HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2)


class PixelOrder(Enum):
Expand Down Expand Up @@ -293,6 +305,69 @@ def to_image(self, x: torch.Tensor, fill_value=torch.nan) -> torch.Tensor:
return output


class HEALPixPadFunction(torch.autograd.Function):
"""
A torch autograd class that pads a healpixpad xy tensor
"""

@staticmethod
def forward(ctx, input, pad):
"""
The forward pass of the padding class
Parameters
----------
input: torch.tensor
The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format
where F == 12 and H == W
pad: int
The amount to pad each face of the tensor
Returns
-------
torch.tensor: The padded tensor
"""
ctx.pad = pad
if input.ndim != 5:
raise ValueError(
f"Input tensor must be have 5 dimensions (B, F, C, H, W), got {len(input.shape)} dimensions instead"
)
if input.shape[1] != 12:
raise ValueError(
f"Input tensor must be have 5 dimensions (B, F, C, H, W), with F == 12, got {input.shape[1]}"
)
if input.shape[3] != input.shape[4]:
raise ValueError(
f"Input tensor must be have 5 dimensions (B, F, C, H, W), with H == @, got {input.shape[3]}, {input.shape[4]}"
)
# make contiguous
input = input.contiguous()
out = healpixpad_cuda.forward(input, pad)[0]
return out

@staticmethod
def backward(ctx, grad):
"""
The forward pass of the padding class
Parameters
----------
input: torch.tensor
The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format
where F == 12 and H == W
pad: int
The amount to pad each face of the tensor
Returns
-------
torch.tensor: The padded tensor
"""
pad = ctx.pad
grad = grad.contiguous()
out = healpixpad_cuda.backward(grad, pad)[0]
return out, None


# nside = 2^ZOOM_LEVELS
ZOOM_LEVELS = 20

Expand Down
91 changes: 0 additions & 91 deletions earth2grid/healpixpad.py

This file was deleted.

11 changes: 0 additions & 11 deletions earth2grid/third_party/zephyr/healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,10 @@
"""

import sys

import torch
import torch as th

sys.path.append('/home/disk/quicksilver/nacc/dlesm/HealPixPad')
have_healpixpad = False
try:
from healpixpad import HEALPixPad # noqa

have_healpixpad = True
except ImportError:
print("Warning, cannot find healpixpad module")
have_healpixpad = False


def healpix_pad(x: torch.Tensor, padding: int, enable_nhwc: bool = False) -> torch.Tensor:
"""
Expand Down
35 changes: 24 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import os
import subprocess
import warnings
from typing import List

from setuptools import setup
Expand Down Expand Up @@ -59,20 +60,32 @@ def get_compiler():
"earth2grid/csrc/healpixpad/healpixpad_cuda_fwd.cu",
"earth2grid/csrc/healpixpad/healpixpad_cuda_bwd.cu",
]
setup(
name='earth2grid',
ext_modules=[
cpp_extension.CppExtension(
'earth2grid._healpix_bare',
src_files,
extra_compile_args=extra_compile_args,
include_dirs=[os.path.abspath("earth2grid/csrc"), os.path.abspath("earth2grid/third_party/healpix_bare")],
),
cpp_extension.CUDAExtension(

ext_modules = [
cpp_extension.CppExtension(
'earth2grid._healpix_bare',
src_files,
extra_compile_args=extra_compile_args,
include_dirs=[os.path.abspath("earth2grid/csrc"), os.path.abspath("earth2grid/third_party/healpix_bare")],
),
]

try:
from torch.utils.cpp_extension import CUDAExtension

ext_modules.append(
CUDAExtension(
name='healpixpad_cuda',
sources=cuda_src_files,
extra_compile_args={'nvcc': ['-O2']},
),
],
)
except ImportError:
warnings.warn("Cuda extensions for torch not found, skipping cuda healpix padding module")


setup(
name='earth2grid',
ext_modules=ext_modules,
cmdclass={'build_ext': cpp_extension.BuildExtension},
)

0 comments on commit 271b5a1

Please sign in to comment.