diff --git a/earth2grid/healpix.py b/earth2grid/healpix.py index 8d80737..44a86c6 100644 --- a/earth2grid/healpix.py +++ b/earth2grid/healpix.py @@ -33,6 +33,7 @@ """ import math +import warnings from dataclasses import dataclass from enum import Enum from typing import Union @@ -48,8 +49,17 @@ except ImportError: pv = None -from earth2grid import base, healpixpad -from earth2grid.third_party.zephyr.healpix import healpix_pad +try: + import healpixpad_cuda + + healpixpad_cuda_avail = True +except ImportError: + healpixpad_cuda_avail = False + warnings.warn("healpixpad_cuda module not available, reverting to CPU for all padding routines") + + +from earth2grid import base +from earth2grid.third_party.zephyr.healpix import healpix_pad as heapixpad_cpu __all__ = ["pad", "PixelOrder", "XY", "Compass", "Grid", "HEALPIX_PAD_XY", "conv2d"] @@ -59,7 +69,7 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor: Pad each face consistently with its according neighbors in the HEALPix Args: - x: The input tensor of shape [N, F, H, W] + x: The input tensor of shape [N, F, H, W] or [N, F, C, H, W] padding: the amount of padding Returns: @@ -80,10 +90,12 @@ def pad(x: torch.Tensor, padding: int) -> torch.Tensor: torch.Size([1, 12, 18, 18]) """ - if x.device.type != 'cuda': - return healpix_pad(x, padding) + if x.device.type != 'cuda' or not healpixpad_cuda_avail: + return heapixpad_cpu(x, padding) + elif x.ndim == 5: + return HEALPixPadFunction.apply(x, padding) else: - return healpixpad.HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2) + return HEALPixPadFunction.apply(x.unsqueeze(2), padding).squeeze(2) class PixelOrder(Enum): @@ -293,6 +305,69 @@ def to_image(self, x: torch.Tensor, fill_value=torch.nan) -> torch.Tensor: return output +class HEALPixPadFunction(torch.autograd.Function): + """ + A torch autograd class that pads a healpixpad xy tensor + """ + + @staticmethod + def forward(ctx, input, pad): + """ + The forward pass of the padding class + + Parameters + ---------- + input: torch.tensor + The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format + where F == 12 and H == W + pad: int + The amount to pad each face of the tensor + + Returns + ------- + torch.tensor: The padded tensor + """ + ctx.pad = pad + if input.ndim != 5: + raise ValueError( + f"Input tensor must be have 5 dimensions (B, F, C, H, W), got {len(input.shape)} dimensions instead" + ) + if input.shape[1] != 12: + raise ValueError( + f"Input tensor must be have 5 dimensions (B, F, C, H, W), with F == 12, got {input.shape[1]}" + ) + if input.shape[3] != input.shape[4]: + raise ValueError( + f"Input tensor must be have 5 dimensions (B, F, C, H, W), with H == @, got {input.shape[3]}, {input.shape[4]}" + ) + # make contiguous + input = input.contiguous() + out = healpixpad_cuda.forward(input, pad)[0] + return out + + @staticmethod + def backward(ctx, grad): + """ + The forward pass of the padding class + + Parameters + ---------- + input: torch.tensor + The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format + where F == 12 and H == W + pad: int + The amount to pad each face of the tensor + + Returns + ------- + torch.tensor: The padded tensor + """ + pad = ctx.pad + grad = grad.contiguous() + out = healpixpad_cuda.backward(grad, pad)[0] + return out, None + + # nside = 2^ZOOM_LEVELS ZOOM_LEVELS = 20 diff --git a/earth2grid/healpixpad.py b/earth2grid/healpixpad.py deleted file mode 100644 index 60df9b7..0000000 --- a/earth2grid/healpixpad.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -# -# Written by Mauro Bisson and Thorsten Kurth . - - -import healpixpad_cuda -import torch - - -class HEALPixPadFunction(torch.autograd.Function): - """ - A torch autograd class that pads a healpixpad xy tensor - """ - - @staticmethod - def forward(ctx, input, pad): - """ - The forward pass of the padding class - - Parameters - ---------- - input: torch.tensor - The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format - where F == 12 and H == W - pad: int - The amount to pad each face of the tensor - - Returns - ------- - torch.tensor: The padded tensor - """ - ctx.pad = pad - if len(input.shape) != 5: - raise ValueError( - f"Input tensor must be have 4 dimensions (B, F, C, H, W), got {len(input.shape)} dimensions instead" - ) - # make contiguous - input = input.contiguous() - out = healpixpad_cuda.forward(input, pad)[0] - return out - - @staticmethod - def backward(ctx, grad): - pad = ctx.pad - grad = grad.contiguous() - out = healpixpad_cuda.backward(grad, pad)[0] - return out, None - - -class HEALPixPad(torch.nn.Module): - """ - A torch module that handles padding of healpixpad xy tensors - - Paramaeters - ----------- - padding: int - The amount to pad the tensors - """ - - def __init__(self, padding: int): - super(HEALPixPad, self).__init__() - self.padding = padding - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - The forward pass of the padding class - - Parameters - ---------- - input: torch.tensor - The tensor to pad, must have 5 dimensions and be in (B, F, C, H, W) format - where F == 12 and H == W - - Returns - ------- - torch.tensor: The padded tensor - """ - return HEALPixPadFunction.apply(input, self.padding) diff --git a/earth2grid/third_party/zephyr/healpix.py b/earth2grid/third_party/zephyr/healpix.py index fe70810..fdcd739 100644 --- a/earth2grid/third_party/zephyr/healpix.py +++ b/earth2grid/third_party/zephyr/healpix.py @@ -23,21 +23,10 @@ """ -import sys import torch import torch as th -sys.path.append('/home/disk/quicksilver/nacc/dlesm/HealPixPad') -have_healpixpad = False -try: - from healpixpad import HEALPixPad # noqa - - have_healpixpad = True -except ImportError: - print("Warning, cannot find healpixpad module") - have_healpixpad = False - def healpix_pad(x: torch.Tensor, padding: int, enable_nhwc: bool = False) -> torch.Tensor: """ diff --git a/setup.py b/setup.py index 9bcfce0..dc5d496 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ # limitations under the License. import os import subprocess +import warnings from typing import List from setuptools import setup @@ -59,20 +60,32 @@ def get_compiler(): "earth2grid/csrc/healpixpad/healpixpad_cuda_fwd.cu", "earth2grid/csrc/healpixpad/healpixpad_cuda_bwd.cu", ] -setup( - name='earth2grid', - ext_modules=[ - cpp_extension.CppExtension( - 'earth2grid._healpix_bare', - src_files, - extra_compile_args=extra_compile_args, - include_dirs=[os.path.abspath("earth2grid/csrc"), os.path.abspath("earth2grid/third_party/healpix_bare")], - ), - cpp_extension.CUDAExtension( + +ext_modules = [ + cpp_extension.CppExtension( + 'earth2grid._healpix_bare', + src_files, + extra_compile_args=extra_compile_args, + include_dirs=[os.path.abspath("earth2grid/csrc"), os.path.abspath("earth2grid/third_party/healpix_bare")], + ), +] + +try: + from torch.utils.cpp_extension import CUDAExtension + + ext_modules.append( + CUDAExtension( name='healpixpad_cuda', sources=cuda_src_files, extra_compile_args={'nvcc': ['-O2']}, ), - ], + ) +except ImportError: + warnings.warn("Cuda extensions for torch not found, skipping cuda healpix padding module") + + +setup( + name='earth2grid', + ext_modules=ext_modules, cmdclass={'build_ext': cpp_extension.BuildExtension}, )