diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 05c8b176..342a570c 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -201,11 +201,17 @@ def inverse_softplus(x, beta=1, threshold=20): # If x*beta is above threshold, returns linear function # for numerical stability under_lim = x * beta <= threshold - x[under_lim] = torch.log(torch.expm1(x[under_lim] * beta)) / beta + x[under_lim] = ( + torch.log( + torch.clamp_min(torch.expm1(x[under_lim] * beta), 1e-6) + ) + / beta + ) return x def inverse_sigmoid(x): - return torch.log(x / (1 - x)) + x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6) + return torch.log(x_clamped / (1 - x_clamped)) self.inverse_clamp_lower_upper = lambda x: ( sigmoid_center