Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference step #45

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 0 additions & 64 deletions neural_lam/data_config.yaml

This file was deleted.

186 changes: 176 additions & 10 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Standard library
import glob
import os

# Third-party
import imageio
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
import torch
import wandb

Expand Down Expand Up @@ -93,6 +96,20 @@ def __init__(self, args):
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []

self.inference_output = []
"Storage for the output of individual inference steps"

self.variable_indices = self.pre_compute_variable_indices()
"Index mapping of variable names to their levels in the array."
self.selected_vars_units = [
(var_name, var_unit)
for var_name, var_unit in zip(
self.config_loader.dataset.var_names,
self.config_loader.dataset.var_units,
)
if var_name in self.config_loader.dataset.eval_plot_vars
]

def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
Expand All @@ -106,14 +123,42 @@ def interior_mask_bool(self):
"""
return self.interior_mask[:, 0].to(torch.bool)

def pre_compute_variable_indices(self):
"""
Pre-compute indices for each variable in the input tensor
"""
variable_indices = {}
all_vars = []
index = 0
# Create a list of tuples for all variables, using level 0 for 2D
# variables
for var_name in self.config_loader.dataset.var_names:
if self.config_loader.dataset.var_is_3d:
for level in self.config_loader.dataset.vertical_levels:
all_vars.append((var_name, level))
else:
all_vars.append((var_name, 0)) # Use level 0 for 2D variables

# Sort the variables based on the tuples
sorted_vars = sorted(all_vars)

for var in sorted_vars:
var_name, level = var
if var_name not in variable_indices:
variable_indices[var_name] = []
variable_indices[var_name].append(index)
index += 1

return variable_indices

@staticmethod
def expand_to_batch(x, batch_size):
"""
Expand tensor with initial batch dimension
"""
return x.unsqueeze(0).expand(batch_size, -1, -1)

def predict_step(self, prev_state, prev_prev_state, forcing):
def single_prediction(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
Expand All @@ -122,6 +167,48 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
"""
raise NotImplementedError("No prediction step implemented")

def predict_step(self, batch, batch_idx):
"""
Run the inference on batch.
"""
prediction, target, pred_std = self.common_step(batch)

# Compute all evaluation metrics for error maps
# Note: explicitly list metrics here, as test_metrics can contain
# additional ones, computed differently, but that should be aggregated
# on_predict_epoch_end
for metric_name in ("mse", "mae"):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.test_metrics[metric_name].append(batch_metric_vals)

if self.output_std:
# Store output std. per variable, spatially averaged
mean_pred_std = torch.mean(
pred_std[..., self.interior_mask_bool, :], dim=-2
) # (B, pred_steps, d_f)
self.test_metrics["output_std"].append(mean_pred_std)

# Save per-sample spatial loss for specific times
spatial_loss = self.loss(
prediction, target, pred_std, average_grid=False
) # (B, pred_steps, num_grid_nodes)
log_spatial_losses = spatial_loss[
:, [step - 1 for step in self.args.val_steps_to_log]
]
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)

if self.trainer.global_rank == 0:
self.plot_examples(batch, batch_idx, prediction=prediction)
self.inference_output.append(prediction)

def unroll_prediction(self, init_states, forcing_features, true_states):
"""
Roll out prediction taking multiple autoregressive steps with model
Expand All @@ -139,7 +226,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
forcing = forcing_features[:, i]
border_state = true_states[:, i]

pred_state, pred_std = self.predict_step(
pred_state, pred_std = self.single_prediction(
prev_state, prev_prev_state, forcing
)
# state: (B, num_grid_nodes, d_f)
Expand Down Expand Up @@ -345,20 +432,50 @@ def test_step(self, batch, batch_idx):
batch, n_additional_examples, prediction=prediction
)

def plot_examples(self, batch, n_examples, prediction=None):
@rank_zero_only
def plot_examples(self, batch, n_examples, batch_idx: int, prediction=None):
"""
Plot the first n_examples forecasts from batch

batch: batch with data to plot corresponding forecasts for
n_examples: number of forecasts to plot
prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction.
Generate if None.
Plot the first n_examples forecasts from batch.

The function checks for the presence of test_dataset or
predict_dataset within the trainer's data module,
handles indexing within the batch for targeted analysis,
performs prediction rescaling, and plots results.

Parameters:
- batch: batch with data to plot corresponding forecasts for
- n_examples: number of forecasts to plot
- batch_idx (int): index of the batch being processed
- prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction.
Generate if None.
"""
if prediction is None:
prediction, target = self.common_step(batch)

target = batch[1]

# Determine the dataset to work with (test_dataset or predict_dataset)
dataset = None
if (
hasattr(self.trainer.datamodule, "test_dataset")
and self.trainer.datamodule.test_dataset
):
dataset = self.trainer.datamodule.test_dataset
plot_name = "test"
elif (
hasattr(self.trainer.datamodule, "predict_dataset")
and self.trainer.datamodule.predict_dataset
):
dataset = self.trainer.datamodule.predict_dataset
plot_name = "prediction"

if (
dataset
and self.trainer.global_rank == 0
and dataset.batch_index == batch_idx
):
index_within_batch = dataset.index_within_batch

# Rescale to original data scale
prediction_rescaled = prediction * self.data_std + self.data_mean
target_rescaled = target * self.data_std + self.data_mean
Expand Down Expand Up @@ -415,7 +532,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
example_i = self.plotted_examples
wandb.log(
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
f"{var_name}_{plot_name}_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
self.config_loader.dataset.var_names, var_figs
)
Expand Down Expand Up @@ -573,6 +690,55 @@ def on_test_epoch_end(self):

self.spatial_loss_maps.clear()

@rank_zero_only
def on_predict_epoch_end(self):
"""
Compute test metrics and make plots at the end of test epoch.
Will gather stored tensors and perform plotting and logging on rank 0.
"""

plot_dir_path = f"{wandb.run.dir}/media/images"
value_dir_path = f"{wandb.run.dir}/results/inference"
# Ensure the directory for saving numpy arrays exists
os.makedirs(plot_dir_path, exist_ok=True)
os.makedirs(value_dir_path, exist_ok=True)

# For values
for i, prediction in enumerate(self.inference_output):

# Rescale to original data scale
prediction_rescaled = prediction * self.data_std + self.data_mean

# Process and save the prediction
prediction_array = prediction_rescaled.cpu().numpy()
file_path = os.path.join(value_dir_path, f"prediction_{i}.npy")
np.save(file_path, prediction_array)

dir_path = f"{wandb.run.dir}/media/images"
for var_name, _ in self.selected_vars_units:
var_indices = self.variable_indices[var_name]
for lvl_i, _ in enumerate(var_indices):
# Calculate var_vrange for each index
lvl = self.config_loader.dataset.vertical_levels[lvl_i]

# Get all the images for the current variable and index
images = sorted(
glob.glob(
f"{dir_path}/{var_name}_test_lvl_{lvl:02}_t_*.png"
)
)
# Generate the GIF
with imageio.get_writer(
f"{dir_path}/{var_name}_lvl_{lvl:02}.gif",
mode="I",
fps=1,
) as writer:
for filename in images:
image = imageio.imread(filename)
writer.append_data(image)

self.spatial_loss_maps.clear()

def on_load_checkpoint(self, checkpoint):
"""
Perform any changes to state dict before loading checkpoint
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def process_step(self, mesh_rep):
"""
raise NotImplementedError("process_step not implemented")

def predict_step(self, prev_state, prev_prev_state, forcing):
def single_prediction(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
prev_state: (B, num_grid_nodes, feature_dim), X_t
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
):
super().__init__()

assert split in ("train", "val", "test"), "Unknown dataset split"
assert split in ("train", "val", "test", "predict"), "Unknown dataset split"
self.sample_dir_path = os.path.join(
"data", dataset_name, "samples", split
)
Expand Down
Loading