Skip to content

Commit

Permalink
Re-define RMSE metric to take sqrt after sample averaging (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson authored Feb 29, 2024
1 parent 1cddf09 commit 0669ff4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 48 deletions.
37 changes: 0 additions & 37 deletions neural_lam/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,42 +108,6 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def rmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
Root Mean Squared Error
Note: here take sqrt only after spatial averaging, averaging the RMSE
of forecasts.
This is consistent with Weatherbench and others.
Because of this, averaging over grid must be set to true.
(...,) is any number of batch dimensions, potentially different
but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum
over d_state)
Returns:
metric_val: One of (...,), (..., d_state), depending on reduction arguments
"""
assert average_grid, "Can not compute RMSE without averaging grid"

# Spatially averaged mse, masking is also performed here
averaged_mse = mse(
pred, target, pred_std, mask, average_grid=True, sum_vars=False
) # (..., d_state)
entry_rmse = torch.sqrt(averaged_mse) # (..., d_state)

# Optionally sum over variables here manually
if sum_vars:
return torch.sum(entry_rmse, dim=-1) # (...,)

return entry_rmse # (..., d_state)


def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
Weighted Mean Absolute Error
Expand Down Expand Up @@ -266,7 +230,6 @@ def crps_gauss(
DEFINED_METRICS = {
"mse": mse,
"mae": mae,
"rmse": rmse,
"wmse": wmse,
"wmae": wmae,
"nll": nll,
Expand Down
31 changes: 20 additions & 11 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __init__(self, args):

self.step_length = args.step_length # Number of hours per pred. step
self.val_metrics = {
"rmse": [],
"mse": [],
}
self.test_metrics = {
"rmse": [],
"mse": [],
"mae": [],
}
if self.output_std:
Expand Down Expand Up @@ -238,7 +238,9 @@ def all_gather_cat(self, tensor_to_gather):
"""
return self.all_gather(tensor_to_gather).flatten(0, 1)

def validation_step(self, batch):
# newer lightning versions requires batch_idx argument, even if unused
# pylint: disable-next=unused-argument
def validation_step(self, batch, batch_idx):
"""
Run validation on single batch
"""
Expand All @@ -262,15 +264,15 @@ def validation_step(self, batch):
val_log_dict, on_step=False, on_epoch=True, sync_dist=True
)

# Store RMSEs
entry_rmses = metrics.rmse(
# Store MSEs
entry_mses = metrics.mse(
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.val_metrics["rmse"].append(entry_rmses)
self.val_metrics["mse"].append(entry_mses)

def on_validation_epoch_end(self):
"""
Expand All @@ -283,7 +285,8 @@ def on_validation_epoch_end(self):
for metric_list in self.val_metrics.values():
metric_list.clear()

def test_step(self, batch):
# pylint: disable-next=unused-argument
def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
Expand Down Expand Up @@ -314,7 +317,7 @@ def test_step(self, batch):
# Note: explicitly list metrics here, as test_metrics can contain
# additional ones, computed differently, but that should be aggregated
# on_test_epoch_end
for metric_name in ("rmse", "mae"):
for metric_name in ("mse", "mae"):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
prediction,
Expand Down Expand Up @@ -508,10 +511,16 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
) # (N_eval, pred_steps, d_f)

if self.trainer.is_global_zero:
metric_tensor_averaged = torch.mean(metric_tensor, dim=0)
# (pred_steps, d_f)

# Take square root after all averaging to change MSE to RMSE
if "mse" in metric_name:
metric_tensor_averaged = torch.sqrt(metric_tensor_averaged)
metric_name = metric_name.replace("mse", "rmse")

# Note: we here assume rescaling for all metrics is linear
metric_rescaled = (
torch.mean(metric_tensor, dim=0) * self.data_std
)
metric_rescaled = metric_tensor_averaged * self.data_std
# (pred_steps, d_f)
log_dict.update(
self.create_metric_log_dict(
Expand Down

0 comments on commit 0669ff4

Please sign in to comment.