Skip to content

Commit

Permalink
added extreme spatial value mae loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Jun 29, 2023
1 parent 98aa00b commit 24f903f
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,42 @@ def __call__(self, x1, x2):
return self.MSE_LOSS(x1_coarse, x2_coarse)


class SpatialExtremesLoss(tf.keras.losses.Loss):
"""Loss class that encourages accuracy of the min/max values in the
spatial domain"""

MAE_LOSS = MeanAbsoluteError()

def __call__(self, x1, x2):
"""Custom content loss that encourages temporal min/max accuracy
Parameters
----------
x1 : tf.tensor
synthetic generator output
(n_observations, spatial_1, spatial_2, features)
x2 : tf.tensor
high resolution data
(n_observations, spatial_1, spatial_2, features)
Returns
-------
tf.tensor
0D tensor with loss value
"""
x1_min = tf.reduce_min(x1, axis=(1, 2))
x2_min = tf.reduce_min(x2, axis=(1, 2))

x1_max = tf.reduce_max(x1, axis=(1, 2))
x2_max = tf.reduce_max(x2, axis=(1, 2))

mae = self.MAE_LOSS(x1, x2)
mae_min = self.MAE_LOSS(x1_min, x2_min)
mae_max = self.MAE_LOSS(x1_max, x2_max)

return mae + mae_min + mae_max


class TemporalExtremesLoss(tf.keras.losses.Loss):
"""Loss class that encourages accuracy of the min/max values in the
timeseries"""
Expand Down

0 comments on commit 24f903f

Please sign in to comment.