Skip to content

Commit

Permalink
npyfiles datastore complete
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Jul 26, 2024
1 parent d1b6fc1 commit 6fe19ac
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 265 deletions.
8 changes: 5 additions & 3 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def create_graph(
"""
os.makedirs(graph_dir_path, exist_ok=True)

print(f"Writing graph components to {graph_dir_path}")

grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))

Expand Down Expand Up @@ -562,7 +564,7 @@ def cli(input_args=None):
help="kind of data store to use (default: multizarr)",
)
parser.add_argument(
"datastore-path",
"datastore_path",
type=str,
help="path to the data store",
)
Expand Down Expand Up @@ -594,11 +596,11 @@ def cli(input_args=None):
args = parser.parse_args(input_args)

DatastoreClass = DATASTORES[args.datastore]
datastore = DatastoreClass(args.datastore_path)
datastore = DatastoreClass(root_path=args.datastore_path)

create_graph_from_datastore(
datastore=datastore,
output_root_path=os.path.join(datastore.root_path, "graphs", args.name),
output_root_path=os.path.join(datastore.root_path, "graph", args.name),
n_max_levels=args.levels,
hierarchical=args.hierarchical,
create_plot=args.plot,
Expand Down
3 changes: 2 additions & 1 deletion neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.
"""
xy = self.get_xy(category, stacked=False)
return [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
return [float(v) for v in extent]

def unstack_grid_coords(
self, da_or_ds: Union[xr.DataArray, xr.Dataset]
Expand Down
8 changes: 5 additions & 3 deletions neural_lam/datastore/mllam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class MLLAMDatastore(BaseCartesianDatastore):
"""Datastore class for the MLLAM dataset."""

def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
def __init__(self, root_path, n_boundary_points=30, reuse_existing=True):
"""Construct a new MLLAMDatastore from the configuration file at
`config_path`. A boundary mask is created with `n_boundary_points`
boundary points. If `reuse_existing` is True, the dataset is loaded
Expand All @@ -33,7 +33,9 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
reuse_existing : bool
Whether to reuse an existing dataset zarr file if it exists.
"""
self._config_path = Path(config_path)
config_filename = "data_config.yaml"
self._root_path = Path(root_path)
config_path = self._root_path / config_filename
self._config = mdp.Config.from_yaml_file(config_path)
fp_ds = self._config_path.parent / self._config_path.name.replace(
".yaml", ".zarr"
Expand All @@ -48,7 +50,7 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):

@property
def root_path(self) -> Path:
return Path(self._config_path.parent)
return self._root_path

@property
def step_length(self) -> int:
Expand Down
10 changes: 8 additions & 2 deletions neural_lam/datastore/multizarr/store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard library
import functools
import os
from pathlib import Path

# Third-party
import cartopy.crs as ccrs
Expand All @@ -16,11 +17,16 @@
class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable"}

def __init__(self, config_path):
self.config_path = config_path
def __init__(self, root_path):
self._root_path = Path(root_path)
config_path = self._root_path / "data_config.yaml"
with open(config_path, encoding="utf-8", mode="r") as file:
self._config = yaml.safe_load(file)

@property
def root_path(self):
return self._root_path

def _normalize_path(self, path):
# try to parse path to see if it defines a protocol, e.g. s3://
if "://" in path or path.startswith("/"):
Expand Down
101 changes: 55 additions & 46 deletions neural_lam/datastore/npyfiles/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,61 @@
# Standard library
import functools
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List

# Third-party
import cartopy.crs as ccrs
import yaml
import dataclass_wizard


@dataclass
class Projection:
"""Represents the projection information for a dataset, including the type
of projection and its parameters. Capable of creating a cartopy.crs
projection object.
Attributes:
class_name: The class name of the projection, this should be a valid
cartopy.crs class.
kwargs: A dictionary of keyword arguments specific to the projection type.
"""

class_name: str # = field(metadata={'data_key': 'class'})
kwargs: Dict[str, Any]


@dataclass
class Dataset:
"""Contains information about the dataset, including variable names, units,
and descriptions.
Attributes:
name: The name of the dataset.
var_names: A list of variable names in the dataset.
var_units: A list of units for each variable.
var_longnames: A list of long, descriptive names for each variable.
num_forcing_features: The number of forcing features in the dataset.
"""

name: str
var_names: List[str]
var_units: List[str]
var_longnames: List[str]
num_forcing_features: int


@dataclass
class NpyDatastoreConfig(dataclass_wizard.YAMLWizard):
"""Configuration for loading and processing a dataset, including dataset
details, grid shape, and projection information.
Attributes:
dataset: An instance of Dataset containing details about the dataset.
grid_shape_state: A list representing the shape of the grid state.
projection: An instance of Projection containing projection details.
"""

dataset: Dataset
grid_shape_state: List[int]
projection: Projection


class NpyConfig:
Expand All @@ -14,48 +65,6 @@ class NpyConfig:
its values as attributes.
"""

def __init__(self, values):
self.values = values

@classmethod
def from_file(cls, filepath):
"""Load a configuration file."""
if str(filepath).endswith(".yaml"):
with open(filepath, encoding="utf-8", mode="r") as file:
return cls(values=yaml.safe_load(file))
else:
raise NotImplementedError(Path(filepath).suffix)

def __getattr__(self, name):
child, *children = name.split(".")

value = self.values[child]
if len(children) > 0:
return self.__class__(values=value).get(".".join(children))
else:
if isinstance(value, dict):
return self.__class__(values=value)
else:
return value

def __getitem__(self, key):
value = self.values[key]
if isinstance(value, dict):
return self.__class__(values=value)
return value

def __contains__(self, key):
return key in self.values

def num_data_vars(self):
"""Return the number of data variables for a given key."""
return len(self.dataset.var_names)

@functools.cached_property
def coords_projection(self):
"""Return the projection."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
proj_params = proj_config.get("kwargs", {})
return proj_class(**proj_params)
64 changes: 37 additions & 27 deletions neural_lam/datastore/npyfiles/store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Standard library
import functools
import re
from pathlib import Path
from typing import List

# Third-party
import cartopy.crs as ccrs
import dask
import dask.array
import dask.delayed
Expand All @@ -15,7 +17,7 @@

# Local
from ..base import BaseCartesianDatastore, CartesianGridShape
from .config import NpyConfig
from .config import NpyDatastoreConfig

STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy"
TOA_SW_DOWN_FLUX_FILENAME_FORMAT = (
Expand All @@ -24,6 +26,13 @@
COLUMN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"


def _load_np(fp, add_feature_dim):
arr = np.load(fp)
if add_feature_dim:
arr = arr[..., np.newaxis]
return arr


class NpyFilesDatastore(BaseCartesianDatastore):
__doc__ = f"""
Represents a dataset stored as numpy files on disk. The dataset is assumed
Expand Down Expand Up @@ -133,7 +142,9 @@ def __init__(
self._num_ensemble_members = 2

self._root_path = Path(root_path)
self.config = NpyConfig.from_file(self.root_path / "data_config.yaml")
self.config = NpyDatastoreConfig.from_yaml_file(
self.root_path / "data_config.yaml"
)

@property
def root_path(self):
Expand All @@ -157,9 +168,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
xr.DataArray
The data array for the given category and split, with dimensions
per category:
state: `[elapsed_forecast_time, analysis_time, grid_index, feature,
state: `[elapsed_forecast_duration, analysis_time, grid_index, feature,
ensemble_member]`
forcing: `[elapsed_forecast_time, analysis_time, grid_index, feature]`
forcing: `[elapsed_forecast_duration, analysis_time, grid_index, feature]`
static: `[grid_index, feature]`
"""
if category == "state":
Expand Down Expand Up @@ -188,14 +199,14 @@ def get_dataarray(self, category: str, split: str) -> DataArray:

# add datetime forcing as a feature
# to do this we create a forecast time variable which has the
# dimensions of (analysis_time, elapsed_forecast_time) with values
# dimensions of (analysis_time, elapsed_forecast_duration) with values
# that are the actual forecast time of each time step. By calling
# .chunk({"elapsed_forecast_time": 1}) this time variable is turned
# .chunk({"elapsed_forecast_duration": 1}) this time variable is turned
# into a dask array and so execution of the calculation is delayed
# until the feature values are actually used.
da_forecast_time = (
da.analysis_time + da.elapsed_forecast_time
).chunk({"elapsed_forecast_time": 1})
da.analysis_time + da.elapsed_forecast_duration
).chunk({"elapsed_forecast_duration": 1})
da_datetime_forcing_features = self._calc_datetime_forcing_features(
da_time=da_forecast_time
)
Expand Down Expand Up @@ -262,7 +273,7 @@ def _get_single_timeseries_dataarray(
-------
xr.DataArray
The data array for the given category and split, with dimensions
`[elapsed_forecast_time, analysis_time, grid_index, feature]` for
`[elapsed_forecast_duration, analysis_time, grid_index, feature]` for
all categories of data
"""
assert split in ("train", "val", "test"), "Unknown dataset split"
Expand All @@ -284,12 +295,12 @@ def _get_single_timeseries_dataarray(
features_vary_with_analysis_time = True
if features == self.get_vars_names(category="state"):
filename_format = STATE_FILENAME_FORMAT
file_dims = ["elapsed_forecast_time", "y", "x", "feature"]
file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
# only select one member for now
file_params["member_id"] = member
elif features == ["toa_downwelling_shortwave_flux"]:
filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
file_dims = ["elapsed_forecast_time", "y", "x", "feature"]
file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
add_feature_dim = True
elif features == ["column_water"]:
filename_format = COLUMN_WATER_FILENAME_FORMAT
Expand Down Expand Up @@ -329,7 +340,7 @@ def _get_single_timeseries_dataarray(
coords = {}
arr_shape = []
for d in dims:
if d == "elapsed_forecast_time":
if d == "elapsed_forecast_duration":
coord_values = (
self.step_length
* np.arange(self._num_timesteps)
Expand All @@ -346,16 +357,12 @@ def _get_single_timeseries_dataarray(
else:
raise NotImplementedError(f"Dimension {d} not supported")

print(f"{d}: {len(coord_values)}")

coords[d] = coord_values
if d != "analysis_time":
# analysis_time varies across the different files, but not
# within a single file
arr_shape.append(len(coord_values))

print(f"{features}: {dims=} {file_dims=} {arr_shape=}")

if features_vary_with_analysis_time:
filepaths = [
fp_samples
Expand All @@ -369,16 +376,11 @@ def _get_single_timeseries_dataarray(

# use dask.delayed to load the numpy files, so that loading isn't
# done until the data is actually needed
@dask.delayed
def _load_np(fp):
arr = np.load(fp)
if add_feature_dim:
arr = arr[..., np.newaxis]
return arr

arrays = [
dask.array.from_delayed(
_load_np(fp), shape=arr_shape, dtype=np.float32
dask.delayed(_load_np)(fp=fp, add_feature_dim=add_feature_dim),
shape=arr_shape,
dtype=np.float32,
)
for fp in filepaths
]
Expand Down Expand Up @@ -457,7 +459,7 @@ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):

def get_vars_units(self, category: str) -> torch.List[str]:
if category == "state":
return self.config["dataset"]["var_units"]
return self.config.dataset.var_units
elif category == "forcing":
return [
"W/m^2",
Expand All @@ -474,7 +476,7 @@ def get_vars_units(self, category: str) -> torch.List[str]:

def get_vars_names(self, category: str) -> torch.List[str]:
if category == "state":
return self.config["dataset"]["var_names"]
return self.config.dataset.var_names
elif category == "forcing":
# XXX: this really shouldn't be hard-coded here, this should be in
# the config
Expand Down Expand Up @@ -557,7 +559,7 @@ def boundary_mask(self):
da_mask = xr.DataArray(
values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
)
da_mask_stacked_xy = self.stack_grid_coords(da_mask)
da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int)
return da_mask_stacked_xy

def get_normalization_dataarray(self, category: str) -> xr.Dataset:
Expand Down Expand Up @@ -623,3 +625,11 @@ def load_pickled_tensor(fn):
)

return ds_norm

@functools.cached_property
def coords_projection(self):
"""Return the projection."""
proj_class_name = self.config.projection.class_name
ProjectionClass = getattr(ccrs, proj_class_name)
proj_params = self.config.projection.kwargs
return ProjectionClass(**proj_params)
Loading

0 comments on commit 6fe19ac

Please sign in to comment.