Skip to content

Commit

Permalink
Standardize name of out_features in UnimodalLogitsHead
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent efb1330 commit 30b7dcf
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions didactic/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class UnimodalLogitsHead(nn.Module):
def __init__(
self,
in_features: int,
num_logits: int,
out_features: int,
backbone_distribution: Literal["poisson", "binomial"] = "poisson",
tau: float = 1.0,
tau_mode: Literal["fixed", "learn", "learn_sigm", "learn_fn"] = "learn_sigm",
Expand All @@ -86,7 +86,7 @@ def __init__(
Args:
in_features: Number of features in the input feature vector.
num_logits: Number of (tempered) logits to output.
out_features: Number of logits to output.
backbone_distribution: Distribution whose probability mass function (PMF) is used to enforce an unimodal
distribution of the logits.
tau: Temperature parameter to control the sharpness of the distribution.
Expand All @@ -102,7 +102,7 @@ def __init__(
eps: Epsilon value to use in probabilities' log to avoid numerical instability.
"""
super().__init__()
self.num_logits = num_logits
self.num_logits = out_features
self.backbone_distribution = backbone_distribution
self.tau_mode = tau_mode
self.eps = eps
Expand Down Expand Up @@ -142,7 +142,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
x: (N, `in_features`), Batch of feature vectors.
Returns:
- (N, `num_logits`), Output tensor of unimodal logits, where `C` is the number of classes. The logits are
- (N, `out_features`), Output tensor of unimodal logits, where `C` is the number of classes. The logits are
unnormalized, but the temperature to control the sharpness of the distribution as already been applied.
- (N, 1), Predicted parameter of the backbone distribution, in the range [0, 1].
- (N, 1), Temperature parameter tau, in the range [0, inf) or [0, 1], depending on `tau_mode`.
Expand Down

0 comments on commit 30b7dcf

Please sign in to comment.