Skip to content

Commit

Permalink
BaseCartesianDatastore -> BaseRegularGridDatastore
Browse files Browse the repository at this point in the history
Rename base class for datastores representating data on a regular grid. Also introduce DummyDatastore in tests that represent data on an irregular grid
  • Loading branch information
leifdenby committed Oct 3, 2024
1 parent e0300fb commit d1b4ca7
Show file tree
Hide file tree
Showing 11 changed files with 557 additions and 91 deletions.
12 changes: 9 additions & 3 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Local
from .config import load_config_and_datastore
from .datastore.base import BaseCartesianDatastore
from .datastore.base import BaseRegularGridDatastore


def plot_graph(graph, title=None):
Expand Down Expand Up @@ -532,13 +532,19 @@ def create_graph(


def create_graph_from_datastore(
datastore: BaseCartesianDatastore,
datastore: BaseRegularGridDatastore,
output_root_path: str,
n_max_levels: int = None,
hierarchical: bool = False,
create_plot: bool = False,
):
xy = datastore.get_xy(category="state", stacked=False)
if isinstance(datastore, BaseRegularGridDatastore):
xy = datastore.get_xy(category="state", stacked=False)
else:
raise NotImplementedError(
"Only graph creation for BaseRegularGridDatastore is supported"
)

create_graph(
graph_dir_path=output_root_path,
xy=xy,
Expand Down
159 changes: 109 additions & 50 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class BaseDatastore(abc.ABC):
"""
Base class for weather data used in the neural- lam package. A datastore
Base class for weather data used in the neural-lam package. A datastore
defines the interface for accessing weather data by providing methods to
access the data in a processed format that can be used for training and
evaluation of neural networks.
Expand All @@ -37,6 +37,13 @@ class BaseDatastore(abc.ABC):
attribute should be set to True, and returned data from `get_dataarray` is
assumed to have an `ensemble_member` dimension.
# Grid index
All methods that return data specific to a grid point (like
`get_dataarray`) should have a single dimension named `grid_index` that
represents the spatial grid index of the data. The actual x, y coordinates
of the grid points should be stored in the `x` and `y` coordinates of the
dataarray or dataset with the `grid_index` dimension as the coordinate for
each of the `x` and `y` coordinates.
"""

is_ensemble: bool = False
Expand Down Expand Up @@ -237,34 +244,22 @@ def boundary_mask(self) -> xr.DataArray:
"""
pass

@abc.abstractmethod
def get_xy(self, category: str) -> np.ndarray:
"""
Return the x, y coordinates of the dataset as a numpy arrays for a
given category of data.
@dataclasses.dataclass
class CartesianGridShape:
"""Dataclass to store the shape of a grid."""

x: int
y: int


class BaseCartesianDatastore(BaseDatastore):
"""
Base class for weather data stored on a Cartesian grid. In addition to the
methods and attributes required for weather data in general (see
`BaseDatastore`) for Cartesian gridded source data each `grid_index`
coordinate value is assume to have an associated `x` and `y`-value so that
the processed data-arrays can be reshaped back into into 2D xy-gridded
arrays.
In addition the following attributes and methods are required:
- `coords_projection` (property): Projection object for the coordinates.
- `grid_shape_state` (property): Shape of the grid for the state variables.
- `get_xy_extent` (method): Return the extent of the x, y coordinates for a
given category of data.
- `get_xy` (method): Return the x, y coordinates of the dataset.
"""
Parameters
----------
category : str
The category of the dataset (state/forcing/static).
CARTESIAN_COORDS = ["x", "y"]
Returns
-------
np.ndarray
The x, y coordinates of the dataset with shape `[2, n_grid_points]`.
"""

@property
@abc.abstractmethod
Expand All @@ -281,6 +276,78 @@ def coords_projection(self) -> ccrs.Projection:
"""
pass

@functools.lru_cache
def get_xy_extent(self, category: str) -> List[float]:
"""
Return the extent of the x, y coordinates for a given category of data.
The extent should be returned as a list of 4 floats with `[xmin, xmax,
ymin, ymax]` which can then be used to set the extent of a plot.
Parameters
----------
category : str
The category of the dataset (state/forcing/static).
Returns
-------
List[float]
The extent of the x, y coordinates.
"""
xy = self.get_xy(category, stacked=False)
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
return [float(v) for v in extent]

@property
@abc.abstractmethod
def num_grid_points(self) -> int:
"""Return the number of grid points in the dataset.
Returns
-------
int
The number of grid points in the dataset.
"""
pass


@dataclasses.dataclass
class CartesianGridShape:
"""Dataclass to store the shape of a grid."""

x: int
y: int


class BaseRegularGridDatastore(BaseDatastore):
"""
Base class for weather data stored on a regular grid (like a chess-board,
as opposed to a irregular grid where each cell cannot be indexed by just
two integers, see https://en.wikipedia.org/wiki/Regular_grid). In addition
to the methods and attributes required for weather data in general (see
`BaseDatastore`) for regular-gridded source data each `grid_index`
coordinate value is assumed to be associated with `x` and `y`-values that
allow the processed data-arrays can be reshaped back into into 2D
xy-gridded arrays.
The following methods and attributes must be implemented for datastore that
represents regular-gridded data:
- `grid_shape_state` (property): 2D shape of the grid for the state
variables.
- `get_xy` (method): Return the x, y coordinates of the dataset, with the
option to not stack the coordinates (so that they are returned as a 2D
grid).
The operation of going from (x,y)-indexed regular grid
to `grid_index`-indexed data-array is called "stacking" and the reverse
operation is called "unstacking". This class provides methods to stack and
unstack the spatial grid coordinates of the data-arrays (called
`stack_grid_coords` and `unstack_grid_coords` respectively).
"""

CARTESIAN_COORDS = ["x", "y"]

@property
@abc.abstractmethod
def grid_shape_state(self) -> CartesianGridShape:
Expand Down Expand Up @@ -314,32 +381,11 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
- `stacked==True`: shape `(2, n_grid_points)` where
n_grid_points=N_x*N_y.
- `stacked==False`: shape `(2, N_y, N_x)`
The values for the x-coordinates are in the first row and the
values for the y-coordinates are in the second row.
"""
pass

@functools.lru_cache
def get_xy_extent(self, category: str) -> List[float]:
"""
Return the extent of the x, y coordinates for a given category of data.
The extent should be returned as a list of 4 floats with `[xmin, xmax,
ymin, ymax]` which can then be used to set the extent of a plot.
Parameters
----------
category : str
The category of the dataset (state/forcing/static).
Returns
-------
List[float]
The extent of the x, y coordinates.
"""
xy = self.get_xy(category, stacked=False)
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
return [float(v) for v in extent]

def unstack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
) -> Union[xr.DataArray, xr.Dataset]:
Expand Down Expand Up @@ -394,3 +440,16 @@ def stack_grid_coords(
"""
return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)

@property
@functools.lru_cache
def num_grid_points(self) -> int:
"""Return the number of grid points in the dataset.
Returns
-------
int
The number of grid points in the dataset.
"""
return self.grid_shape_state.x * self.grid_shape_state.y
4 changes: 2 additions & 2 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from numpy import ndarray

# Local
from .base import BaseCartesianDatastore, CartesianGridShape
from .base import BaseRegularGridDatastore, CartesianGridShape


class MDPDatastore(BaseCartesianDatastore):
class MDPDatastore(BaseRegularGridDatastore):
"""
Datastore class for datasets made with the mllam_data_prep library
(https://github.com/mllam/mllam-data-prep). This class wraps the
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from xarray.core.dataarray import DataArray

# Local
from ..base import BaseCartesianDatastore, CartesianGridShape
from ..base import BaseRegularGridDatastore, CartesianGridShape
from .config import NpyDatastoreConfig

STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy"
Expand All @@ -37,7 +37,7 @@ def _load_np(fp, add_feature_dim):
return arr


class NpyFilesDatastoreMEPS(BaseCartesianDatastore):
class NpyFilesDatastoreMEPS(BaseRegularGridDatastore):
__doc__ = f"""
Represents a dataset stored as numpy files on disk. The dataset is assumed
to be stored in a directory structure where each sample is stored in a
Expand Down
8 changes: 4 additions & 4 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

# Local
from . import utils
from .datastore.base import BaseCartesianDatastore
from .datastore.base import BaseRegularGridDatastore


@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_map(errors, datastore: BaseCartesianDatastore, title=None):
def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
Expand Down Expand Up @@ -67,7 +67,7 @@ def plot_error_map(errors, datastore: BaseCartesianDatastore, title=None):
def plot_prediction(
pred,
target,
datastore: BaseCartesianDatastore,
datastore: BaseRegularGridDatastore,
title=None,
vrange=None,
):
Expand Down Expand Up @@ -132,7 +132,7 @@ def plot_prediction(

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_spatial_error(
error, datastore: BaseCartesianDatastore, title=None, vrange=None
error, datastore: BaseRegularGridDatastore, title=None, vrange=None
):
"""
Plot errors over spatial map
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# Third-party
import pooch
import yaml
from dummy_datastore import DummyDatastore

# First-party
from neural_lam.datastore import init_datastore
from neural_lam.datastore import DATASTORES, init_datastore

# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
Expand Down Expand Up @@ -61,8 +62,11 @@ def download_meps_example_reduced_dataset():
DATASTORES_EXAMPLES = dict(
mdp=(DATASTORE_EXAMPLES_ROOT_PATH / "mdp" / "danra.example.yaml"),
npyfilesmeps=download_meps_example_reduced_dataset(),
dummydata=None,
)

DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore


def init_datastore_example(datastore_kind):
datastore = init_datastore(
Expand Down
Loading

0 comments on commit d1b4ca7

Please sign in to comment.