Skip to content

Commit

Permalink
Add custom layer to enforce unimodal distribution of output logits
Browse files Browse the repository at this point in the history
Implements a 2017 ICML paper by Beckham and Pal: https://proceedings.mlr.press/v70/beckham17a.html
  • Loading branch information
nathanpainchaud committed Oct 31, 2023
1 parent 8f7f6f7 commit 5fd28b4
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions didactic/models/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Literal, Tuple

import torch
from scipy.special import binom, factorial
from torch import Tensor, nn
from torch.nn import Parameter
from torch.nn import functional as F
Expand Down Expand Up @@ -57,3 +60,123 @@ def forward(self, x: Tensor) -> Tensor:
broadcast_attn_vector = attn_vector.transpose(2, 1) # (N, S, 1) -> (N, 1, S)
pooled_x = (broadcast_attn_vector @ x).squeeze(1) # (N, 1, S) @ (N, S, E) -> (N, E)
return pooled_x


class UnimodalLogitsHead(nn.Module):
"""Layer to output (enforced) unimodal logits from an input feature vector.
This is a re-implementation of a 2017 ICML paper by Beckham and Pal, which proposes to use either a Poisson or
binomial distribution to output unimodal logits (because they are constrained as such by the distribution) from a
scalar value.
References:
- ICML 2017 paper: https://proceedings.mlr.press/v70/beckham17a.html
"""

def __init__(
self,
in_features: int,
num_logits: int,
backbone_distribution: Literal["poisson", "binomial"] = "poisson",
tau: float = 1.0,
tau_mode: Literal["fixed", "learn", "learn_sigm", "learn_fn"] = "learn_sigm",
eps: float = 1e-6,
):
"""Initializes class instance.
Args:
in_features: Number of features in the input feature vector.
num_logits: Number of (tempered) logits to output.
backbone_distribution: Distribution whose probability mass function (PMF) is used to enforce an unimodal
distribution of the logits.
tau: Temperature parameter to control the sharpness of the distribution.
- If `tau_mode` is 'fixed', this is the fixed value of tau.
- If `tau_mode` is 'learn' or 'learn_sigm', this is the initial value of tau.
- If `tau_mode` is 'learn_fn', this argument is ignored.
tau_mode: Method to use to set or learn the temperature parameter:
- 'fixed': Use a fixed value of tau.
- 'learn': Learn tau.
- 'learn_sigm': Learn tau through a sigmoid function.
- 'learn_fn': Learn tau through a function of the input, i.e. a tau that varies for each input.
The function is 1 / (1 + g(L(x))), where g is the softplus function. and L is a linear layer.
eps: Epsilon value to use in probabilities' log to avoid numerical instability.
"""
super().__init__()
self.num_logits = num_logits
self.backbone_distribution = backbone_distribution
self.tau_mode = tau_mode
self.eps = eps

self.register_buffer("logits_idx", torch.arange(self.num_logits))
match self.backbone_distribution:
case "poisson":
self.register_buffer("logits_factorial", torch.from_numpy(factorial(self.logits_idx)))
case "binomial":
self.register_buffer("binom_coef", binom(self.num_logits - 1, self.logits_idx))
case _:
raise ValueError(f"Unsupported distribution '{backbone_distribution}'.")

self.param_head = nn.Sequential(nn.Linear(in_features, 1), nn.Sigmoid())

match self.tau_mode:
case "fixed":
self.tau = tau
case "learn" | "learn_sigm":
self.tau = nn.Parameter(torch.tensor(float(tau)))
case "learn_fn":
self.tau_head = nn.Sequential(nn.Linear(in_features, 1), nn.Softplus())
case _:
raise ValueError(f"Unsupported tau mode '{tau_mode}'.")

def __repr__(self):
"""Overrides the default repr to display the important parameters of the layer."""
vars = {"in_features": self.param_head[0].in_features}
vars.update({var: getattr(self, var) for var in ["num_logits", "backbone_distribution", "tau_mode"]})
vars_str = [f"{var}={val}" for var, val in vars.items()]
return f"{self.__class__.__name__}({', '.join(vars_str)})"

def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Predicts unnormalized, unimodal logits from a feature vector input.
Args:
x: (N, `in_features`), Batch of feature vectors.
Returns:
- (N, `num_logits`), Output tensor of unimodal logits, where `C` is the number of classes. The logits are
unnormalized, but the temperature to control the sharpness of the distribution as already been applied.
- (N, 1), Predicted parameter of the backbone distribution, in the range [0, 1].
- (N, 1), Temperature parameter tau, in the range [0, inf) or [0, 1], depending on `tau_mode`.
"""
# Forward through the linear layer to get a scalar param in [0,1] for the backbone distribution
param = self.param_head(x)
f_x = (self.num_logits + 1) * param # Rescale the parameter to [0, num_logits+1]

# Compute the probability mass function (PMF) of the backbone distribution
# For technical reasons, use the log instead of the direct value
match self.backbone_distribution:
case "poisson":
log_f = (self.logits_idx * torch.log(f_x + self.eps)) - f_x - torch.log(self.logits_factorial)
case "binomial":
log_f = (
torch.log(self.binom_coef)
+ (self.logits_idx * torch.log(f_x + self.eps))
+ ((self.num_logits - 1 - self.logits_idx) * torch.log(1 - f_x + self.eps))
)

# Compute the temperature parameter tau
# In cases where tau is a scalar, manually broadcast it to a tensor w/ one value for each item in the batch
# This is done to keep a consistent API for the different tau modes, with tau having a different value for each
# item in the batch when `tau_mode` is 'learn_fn'
match self.tau_mode:
case "fixed":
tau = torch.full_like(param, self.tau) # Manual broadcast
case "learn":
tau = self.tau.expand(*param.shape) # Manual broadcast
case "learn_sigm":
tau = torch.sigmoid(self.tau).expand(*param.shape) # Sigmoid + manual broadcast
case "learn_fn":
tau = 1 / (1 + self.tau_head(x))
case _:
raise ValueError(f"Unsupported 'tau_mode': '{self.tau_mode}'.")

return log_f / tau, param, tau

0 comments on commit 5fd28b4

Please sign in to comment.