diff --git a/didactic/models/layers.py b/didactic/models/layers.py index 91bedd19..68c56210 100644 --- a/didactic/models/layers.py +++ b/didactic/models/layers.py @@ -77,9 +77,9 @@ def __init__( self, in_features: int, out_features: int, - backbone_distribution: Literal["poisson", "binomial"] = "poisson", + backbone_distribution: Literal["poisson", "binomial"] = "binomial", tau: float = 1.0, - tau_mode: Literal["fixed", "learn", "learn_sigm", "learn_fn"] = "learn_sigm", + tau_mode: Literal["fixed", "learn", "learn_sigm", "learn_fn"] = "learn_fn", eps: float = 1e-6, ): """Initializes class instance. @@ -148,13 +148,13 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - (N, 1), Temperature parameter tau, in the range [0, inf) or [0, 1], depending on `tau_mode`. """ # Forward through the linear layer to get a scalar param in [0,1] for the backbone distribution - param = self.param_head(x) - f_x = (self.num_logits + 1) * param # Rescale the parameter to [0, num_logits+1] + f_x = param = self.param_head(x) # Compute the probability mass function (PMF) of the backbone distribution # For technical reasons, use the log instead of the direct value match self.backbone_distribution: case "poisson": + f_x = (self.num_logits + 1) * f_x # Rescale f(x) to [0, num_logits+1] log_f = (self.logits_idx * torch.log(f_x + self.eps)) - f_x - torch.log(self.logits_factorial) case "binomial": log_f = (