Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add writing to zarr dataset for eval-mode of trained models #104

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 77 additions & 3 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import torch
import wandb
import xarray as xr
from loguru import logger

# Local
from .. import metrics, vis
from ..config import NeuralLAMConfig
from ..datastore import BaseDatastore
from ..datastore.base import BaseRegularGridDatastore
from ..loss_weighting import get_state_feature_weighting
from ..weather_dataset import WeatherDataset

Expand Down Expand Up @@ -184,7 +186,7 @@ def _create_dataarray_from_tensor(
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
time = np.array(time.cpu(), dtype="datetime64[ns]")
da = weather_dataset.create_dataarray_from_tensor(
tensor=tensor.cpu().numpy(), time=time, category=category
tensor=tensor, time=time, category=category
)
return da

Expand Down Expand Up @@ -371,15 +373,79 @@ def on_validation_epoch_end(self):
for metric_list in self.val_metrics.values():
metric_list.clear()

def _save_predictions_to_zarr(
self,
batch_times: torch.Tensor,
batch_predictions: torch.Tensor,
batch_idx: int,
zarr_output_path: str,
):
"""
Save state predictions for single batch to zarr dataset. Will append to
existing dataset for batch_idx > 0. Resulting dataset will contain a
variable named `state` with coordinates (start_time,
elapsed_forecast_duration, grid_index, state_feature).

Parameters
----------
batch_times : torch.Tensor[int]
The times for the batch, given as epoch time in nanoseconds. Shape
is (B, args.pred_steps) where B is the batch size and
args.pred_steps is the number of prediction steps.
batch_predictions : torch.Tensor[float]
The predictions for the batch, given as (B, args.pred_steps,
num_grid_nodes, d_f) where B is the batch size, args.pred_steps is
the number of prediction steps, num_grid_nodes is the number of
grid nodes, and d_f is the number of state features.
batch_idx : int
The index of the batch in the current epoch.
"""
batch_size = batch_predictions.shape[0]
# Convert predictions to DataArray using _create_dataarray_from_tensor
das_pred = []
for i in range(len(batch_times)):
da_pred = self._create_dataarray_from_tensor(
tensor=batch_predictions[i],
time=batch_times[i],
split="test",
category="state",
)
# Unstack grid coords if necessary, this also avoids the need to
# try to store a MultiIndex zarr dataset which is not supported by
# xarray
if isinstance(self._datastore, BaseRegularGridDatastore):
da_pred = self._datastore.unstack_grid_coords(da_pred)

t0 = da_pred.coords["time"].values[0]
da_pred.coords["start_time"] = t0
da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0
da_pred = da_pred.swap_dims({"time": "elapsed_forecast_duration"})
da_pred.name = "state"
das_pred.append(da_pred)

da_pred_batch = xr.concat(das_pred, dim="start_time")

# Apply chunking along analysis_time so that each batch is saved as a
# separate chunk
da_pred_batch = da_pred_batch.chunk({"start_time": batch_size})

if batch_idx == 0:
logger.info(f"Saving predictions to {zarr_output_path}")
da_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True)
else:
da_pred_batch.to_zarr(
zarr_output_path, mode="a", append_dim="start_time"
)

# pylint: disable-next=unused-argument
def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
# TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)

time_step_loss = torch.mean(
self.loss(
Expand Down Expand Up @@ -435,6 +501,14 @@ def test_step(self, batch, batch_idx):
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)

if self.args.save_eval_to_zarr_path:
self._save_predictions_to_zarr(
batch_times=batch_times,
batch_predictions=prediction,
batch_idx=batch_idx,
zarr_output_path=self.args.save_eval_to_zarr_path,
)

# Plot example predictions (on rank 0 only)
if (
self.trainer.is_global_zero
Expand Down
5 changes: 5 additions & 0 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ def main(input_args=None):
help="Eval model on given data split (val/test) "
"(default: None (train model))",
)
parser.add_argument(
"--save-eval-to-zarr-path",
type=str,
help="Save evaluation results to zarr dataset at given path ",
)
parser.add_argument(
"--ar_steps_eval",
type=int,
Expand Down
Loading