Skip to content

Commit

Permalink
Fix binomial backbone distribution in UnimodalLogitsHead
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent 30b7dcf commit 7e63965
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions didactic/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def __init__(
self,
in_features: int,
out_features: int,
backbone_distribution: Literal["poisson", "binomial"] = "poisson",
backbone_distribution: Literal["poisson", "binomial"] = "binomial",
tau: float = 1.0,
tau_mode: Literal["fixed", "learn", "learn_sigm", "learn_fn"] = "learn_sigm",
tau_mode: Literal["fixed", "learn", "learn_sigm", "learn_fn"] = "learn_fn",
eps: float = 1e-6,
):
"""Initializes class instance.
Expand Down Expand Up @@ -148,13 +148,13 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
- (N, 1), Temperature parameter tau, in the range [0, inf) or [0, 1], depending on `tau_mode`.
"""
# Forward through the linear layer to get a scalar param in [0,1] for the backbone distribution
param = self.param_head(x)
f_x = (self.num_logits + 1) * param # Rescale the parameter to [0, num_logits+1]
f_x = param = self.param_head(x)

# Compute the probability mass function (PMF) of the backbone distribution
# For technical reasons, use the log instead of the direct value
match self.backbone_distribution:
case "poisson":
f_x = (self.num_logits + 1) * f_x # Rescale f(x) to [0, num_logits+1]
log_f = (self.logits_idx * torch.log(f_x + self.eps)) - f_x - torch.log(self.logits_factorial)
case "binomial":
log_f = (
Expand Down

0 comments on commit 7e63965

Please sign in to comment.