Skip to content

Commit

Permalink
Rename backbone_distribution arg in unimodal constraint layer
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent 0928a4b commit d019bf4
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions didactic/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
self,
in_features: int,
num_logits: int,
backbone_distribution: Literal["poisson", "binomial"] = "binomial",
distribution: Literal["poisson", "binomial"] = "binomial",
tau: float = 1.0,
tau_mode: Literal["fixed", "learn", "learn_sigm", "learn_fn"] = "learn_fn",
eps: float = 1e-6,
Expand All @@ -112,8 +112,8 @@ def __init__(
Args:
in_features: Number of features in the input feature vector.
num_logits: Number of logits to output.
backbone_distribution: Distribution whose probability mass function (PMF) is used to enforce an unimodal
distribution of the logits.
distribution: Distribution whose probability mass function (PMF) is used to enforce an unimodal distribution
of the logits.
tau: Temperature parameter to control the sharpness of the distribution.
- If `tau_mode` is 'fixed', this is the fixed value of tau.
- If `tau_mode` is 'learn' or 'learn_sigm', this is the initial value of tau.
Expand All @@ -128,18 +128,18 @@ def __init__(
"""
super().__init__()
self.num_logits = num_logits
self.backbone_distribution = backbone_distribution
self.distribution = distribution
self.tau_mode = tau_mode
self.eps = eps

self.register_buffer("logits_idx", torch.arange(self.num_logits))
match self.backbone_distribution:
match self.distribution:
case "poisson":
self.register_buffer("logits_factorial", torch.from_numpy(factorial(self.logits_idx)))
case "binomial":
self.register_buffer("binom_coef", binom(self.num_logits - 1, self.logits_idx))
case _:
raise ValueError(f"Unsupported distribution '{backbone_distribution}'.")
raise ValueError(f"Unsupported distribution '{distribution}'.")

self.param_head = nn.Sequential(nn.Linear(in_features, 1), nn.Sigmoid())

Expand All @@ -156,7 +156,7 @@ def __init__(
def __repr__(self):
"""Overrides the default repr to display the important parameters of the layer."""
vars = {"in_features": self.param_head[0].in_features}
vars.update({var: getattr(self, var) for var in ["num_logits", "backbone_distribution", "tau_mode"]})
vars.update({var: getattr(self, var) for var in ["num_logits", "distribution", "tau_mode"]})
vars_str = [f"{var}={val}" for var, val in vars.items()]
return f"{self.__class__.__name__}({', '.join(vars_str)})"

Expand All @@ -169,15 +169,15 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
Returns:
- (N, `num_logits`), Output tensor of unimodal logits. The logits are unnormalized, but the temperature
to control the sharpness of the distribution as already been applied.
- (N, 1), Predicted parameter of the backbone distribution, in the range [0, 1].
- (N, 1), Predicted parameter of the distribution, in the range [0, 1].
- (N, 1), Temperature parameter tau, in the range [0, inf) or [0, 1], depending on `tau_mode`.
"""
# Forward through the linear layer to get a scalar param in [0,1] for the backbone distribution
# Forward through the linear layer to get a scalar param in [0,1] for the distribution
f_x = param = self.param_head(x)

# Compute the probability mass function (PMF) of the backbone distribution
# Compute the probability mass function (PMF) of the distribution
# For technical reasons, use the log instead of the direct value
match self.backbone_distribution:
match self.distribution:
case "poisson":
f_x = (self.num_logits + 1) * f_x # Rescale f(x) to [0, num_logits+1]
log_f = (self.logits_idx * torch.log(f_x + self.eps)) - f_x - torch.log(self.logits_factorial)
Expand Down

0 comments on commit d019bf4

Please sign in to comment.