From da1480c30403a3fe58c0597e997fb52ff9f25105 Mon Sep 17 00:00:00 2001 From: Simon Kamuk Christiansen Date: Wed, 4 Dec 2024 10:40:03 +0000 Subject: [PATCH] prevent inverse sigmoid and softplus from returning +/- inf --- neural_lam/models/base_graph_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 05c8b176..342a570c 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -201,11 +201,17 @@ def inverse_softplus(x, beta=1, threshold=20): # If x*beta is above threshold, returns linear function # for numerical stability under_lim = x * beta <= threshold - x[under_lim] = torch.log(torch.expm1(x[under_lim] * beta)) / beta + x[under_lim] = ( + torch.log( + torch.clamp_min(torch.expm1(x[under_lim] * beta), 1e-6) + ) + / beta + ) return x def inverse_sigmoid(x): - return torch.log(x / (1 - x)) + x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6) + return torch.log(x_clamped / (1 - x_clamped)) self.inverse_clamp_lower_upper = lambda x: ( sigmoid_center