Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualization utilties #13

Open
nbren12 opened this issue Sep 11, 2024 · 0 comments
Open

Visualization utilties #13

nbren12 opened this issue Sep 11, 2024 · 0 comments

Comments

@nbren12
Copy link
Collaborator

nbren12 commented Sep 11, 2024

earth2grid can be used to accelerate plotting workflows. It might be worth adding some visualization utilties.

from earth2grid import healpix
import generate
import torch
import matplotlib.pyplot as plt
import cartopy.crs
import numpy as np


def create_regular_grid_in_projection(projection, nx, ny):
    """
    Create a regular grid of lat-lon coordinates in a given Cartopy projection.

    Parameters:
    projection (cartopy.crs.Projection): The desired Cartopy projection
    resolution (float): The grid resolution in projection units

    Returns:
    tuple: Two 2D arrays, one for latitudes and one for longitudes
    """
    # Get the projection's limits
    x_min, x_max, y_min, y_max = projection.x_limits + projection.y_limits

    # Create a regular grid in the projection coordinates
    x = np.linspace(x_min, x_max, nx)
    y = np.linspace(y_min, y_max, ny)
    xx, yy = np.meshgrid(x, y)

    # Transform the gridded coordinates back to lat-lon
    geodetic = cartopy.crs.Geodetic()
    transformed = geodetic.transform_points(projection, xx, yy)

    lons = transformed[..., 0]
    lats = transformed[..., 1]

    # Filter out invalid points (those outside the projection's valid domain)
    valid = np.logical_and(np.isfinite(lons), np.isfinite(lats))
    lons[~valid] = np.nan
    lats[~valid] = np.nan

    return lats, lons, xx, yy


def visualize(x):
    hpx = healpix.Grid(healpix.npix2level(x.shape[-1]))
    crs = cartopy.crs.Robinson()
    lat, lon, xx, yy = create_regular_grid_in_projection(crs, 256, 512)
    mask = ~np.isnan(lat)
    latm = lat[mask]
    lonm = lon[mask]
    regrid = hpx.get_bilinear_regridder_to(latm, lonm)
    regrid.to(x)
    out = torch.zeros_like(torch.tensor(lat)).to(x)
    out[mask] = regrid(x)
    out[~mask] = torch.nan
    ax = plt.subplot(projection=crs)
    im = ax.pcolormesh(xx, yy, out.cpu(), transform=crs)
    ax.coastlines()
    plt.colorbar(im, orientation="horizontal")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant