Skip to content

Commit

Permalink
streamlined multi-zarr workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Jun 1, 2024
1 parent 4e457ed commit adc592f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 108 deletions.
110 changes: 37 additions & 73 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,11 @@


class Config:
"""
Class for loading configuration files.
This class loads a configuration file and provides a way to access its
values as attributes.
"""

def __init__(self, values):
self.values = values

@classmethod
def from_file(cls, filepath):
"""Load a configuration file."""
if filepath.endswith(".yaml"):
with open(filepath, encoding="utf-8", mode="r") as file:
return cls(values=yaml.safe_load(file))
Expand Down Expand Up @@ -53,7 +45,6 @@ def __contains__(self, key):

@functools.cached_property
def coords_projection(self):
"""Return the projection."""
proj_config = self.values["projection"]
proj_class_name = proj_config["class"]
proj_class = getattr(ccrs, proj_class_name)
Expand All @@ -62,7 +53,6 @@ def coords_projection(self):

@functools.cached_property
def param_names(self):
"""Return parameter names."""
surface_vars_names = self.values["state"]["surface_vars"]
atmosphere_vars_names = [
f"{var}_{level}"
Expand All @@ -73,18 +63,16 @@ def param_names(self):

@functools.cached_property
def param_units(self):
"""Return parameter units."""
surface_vars_units = self.values["state"]["surface_vars_units"]
surface_vars_units = self.values["state"]["surface_units"]
atmosphere_vars_units = [
unit
for unit in self.values["state"]["atmosphere_vars_units"]
for unit in self.values["state"]["atmosphere_units"]
for _ in self.values["state"]["levels"]
]
return surface_vars_units + atmosphere_vars_units

@functools.lru_cache()
def num_data_vars(self, category):
"""Return the number of data variables for a given category."""
surface_vars = self.values[category].get("surface_vars", [])
atmosphere_vars = self.values[category].get("atmosphere_vars", [])
levels = self.values[category].get("levels", [])
Expand All @@ -101,31 +89,20 @@ def num_data_vars(self, category):

@functools.lru_cache(maxsize=None)
def open_zarr(self, category):
"""Open a dataset specified by the category."""
zarr_config = self.zarrs[category]

if isinstance(zarr_config, list):
try:
datasets = []
for config in zarr_config:
dataset_path = config["path"]
dataset = xr.open_zarr(dataset_path, consolidated=True)
datasets.append(dataset)
return xr.merge(datasets)
except Exception:
print(f"Invalid zarr configuration for category: {category}")
return None

else:
try:
dataset_path = zarr_config["path"]
return xr.open_zarr(dataset_path, consolidated=True)
except Exception:
print(f"Invalid zarr configuration for category: {category}")
return None
zarr_configs = self.values[category]["zarrs"]

try:
datasets = []
for config in zarr_configs:
dataset_path = config["path"]
dataset = xr.open_zarr(dataset_path, consolidated=True)
datasets.append(dataset)
return xr.merge(datasets)
except Exception:
print(f"Invalid zarr configuration for category: {category}")
return None

def stack_grid(self, dataset):
"""Stack grid dimensions."""
dims = dataset.to_array().dims

if "grid" not in dims and "x" in dims and "y" in dims:
Expand All @@ -145,10 +122,9 @@ def stack_grid(self, dataset):

@functools.lru_cache()
def get_nwp_xy(self, category):
"""Get the x and y coordinates for the NWP grid."""
dataset = self.open_zarr(category)
lon_name = self.zarrs[category].lat_lon_names.lon
lat_name = self.zarrs[category].lat_lon_names.lat
lon_name = self.values[category]["zarrs"][0]["lat_lon_names"]["lon"]
lat_name = self.values[category]["zarrs"][0]["lat_lon_names"]["lat"]
if lon_name in dataset and lat_name in dataset:
lon = dataset[lon_name].values
lat = dataset[lat_name].values
Expand All @@ -158,46 +134,42 @@ def get_nwp_xy(self, category):
)
if lon.ndim == 1:
lon, lat = np.meshgrid(lat, lon)
lonlat = np.stack((lon, lat), axis=0)
lonlat = np.stack((lon.T, lat.T), axis=0)

return lonlat

@functools.cached_property
def load_normalization_stats(self):
"""Load normalization statistics from Zarr archive."""
normalization_path = self.normalization.zarr
if not os.path.exists(normalization_path):
print(
f"Normalization statistics not found at "
f"path: {normalization_path}"
)
return None
normalization_stats = xr.open_zarr(
normalization_path, consolidated=True
)
normalization_stats = {}
for zarr_config in self.values["normalization"]["zarrs"]:
normalization_path = zarr_config["path"]
if not os.path.exists(normalization_path):
print(
f"Normalization statistics not found at path: "
f"{normalization_path}"
)
return None
stats = xr.open_zarr(normalization_path, consolidated=True)
for var_name, var_path in zarr_config["stats_vars"].items():
normalization_stats[var_name] = stats[var_path]
return normalization_stats

@functools.lru_cache(maxsize=None)
def process_dataset(self, category, split="train"):
"""Process a single dataset specified by the dataset name."""

dataset = self.open_zarr(category)
if dataset is None:
return None

start, end = (
self.splits[split].start,
self.splits[split].end,
self.values["splits"][split]["start"],
self.values["splits"][split]["end"],
)
dataset = dataset.sel(time=slice(start, end))

dims_mapping = {}
zarr_configs = self.zarrs[category]
if isinstance(zarr_configs, list):
for zarr_config in zarr_configs:
dims_mapping.update(zarr_config["dims"])
else:
dims_mapping.update(zarr_configs["dims"].values)
zarr_configs = self.values[category]["zarrs"]
for zarr_config in zarr_configs:
dims_mapping.update(zarr_config["dims"])

dataset = dataset.rename_dims(
{
Expand Down Expand Up @@ -236,18 +208,14 @@ def process_dataset(self, category, split="train"):
print(f"No variables found in dataset {category}")
return None

zarr_configs = self.zarrs[category]
lat_lon_names = {}
if isinstance(self.zarrs[category], list):
for zarr_configs in self.zarrs[category]:
lat_lon_names.update(zarr_configs["lat_lon_names"])
else:
lat_lon_names.update(self.zarrs[category]["lat_lon_names"].values)
for zarr_config in self.values[category]["zarrs"]:
lat_lon_names.update(zarr_config["lat_lon_names"])

if not all(
lat_lon in lat_lon_names.values() for lat_lon in lat_lon_names
):
lat_name, lon_name = lat_lon_names[:2]
lat_name, lon_name = list(lat_lon_names.values())[:2]
if dataset[lat_name].ndim == 2:
dataset[lat_name] = dataset[lat_name].isel(x=0, drop=True)
if dataset[lon_name].ndim == 2:
Expand All @@ -262,9 +230,5 @@ def process_dataset(self, category, split="train"):
dataset = self.stack_grid(dataset)
return dataset

dataset = self.stack_grid(dataset)

return dataset


config = Config.from_file("neural_lam/data_config.yaml")
71 changes: 36 additions & 35 deletions neural_lam/data_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: danra
zarrs:
state:
state:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
time: time
Expand All @@ -21,28 +21,6 @@ zarrs:
lat_lon_names:
lon: lon
lat: lat
static:
path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
level: null
x: x
y: y
grid: null
lat_lon_names:
lon: lon
lat: lat
forcing:
path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
time: time
level: null
x: x
y: y
grid: null
lat_lon_names:
lon: lon
lat: lat
state:
surface_vars:
- u10m
- v10m
Expand All @@ -62,16 +40,50 @@ state:
levels:
- 100
static:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
level: null
x: x
y: y
grid: null
lat_lon_names:
lon: lon
lat: lat
surface_vars:
- pres0m # just as a technical test
atmosphere_vars: null
levels: null
forcing:
zarrs:
- path: "https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr"
dims:
time: time
level: null
x: x
y: y
grid: null
lat_lon_names:
lon: lon
lat: lat
surface_vars:
- cape_column # just as a technical test
atmosphere_vars: null
levels: null
window: 3 # Number of time steps to use for forcing (odd)
normalization:
zarrs:
- path: "normalization.zarr"
stats_vars:
data_mean: data_mean
data_std: data_std
forcing_mean: forcing_mean
forcing_std: forcing_std
boundary_mean: boundary_mean
boundary_std: boundary_std
diff_mean: diff_mean
diff_std: diff_std

grid_shape_state:
x: 582
y: 390
Expand All @@ -91,14 +103,3 @@ projection:
central_longitude: 6.22
central_latitude: 56.0
standard_parallels: [47.6, 64.4]
normalization:
zarr: normalization.zarr
vars:
data_mean: data_mean
data_std: data_std
forcing_mean: forcing_mean
forcing_std: forcing_std
boundary_mean: boundary_mean
boundary_std: boundary_std
diff_mean: diff_mean
diff_std: diff_std

0 comments on commit adc592f

Please sign in to comment.