Skip to content

Commit

Permalink
Fix DINOProjectionHead BatchNorm Handling
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin committed Nov 8, 2024
1 parent 94050b3 commit 1c560e0
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions lightly/models/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,12 +673,20 @@ def __init__(
norm_last_layer: bool = True,
):
"""Initializes the DINOProjectionHead with the specified dimensions."""
bn = nn.BatchNorm1d(hidden_dim) if batch_norm else None

super().__init__(
[
(input_dim, hidden_dim, bn, nn.GELU()),
(hidden_dim, hidden_dim, bn, nn.GELU()),
(
input_dim,
hidden_dim,
nn.BatchNorm1d(hidden_dim) if batch_norm else None,
nn.GELU(),
),
(
hidden_dim,
hidden_dim,
nn.BatchNorm1d(hidden_dim) if batch_norm else None,
nn.GELU(),
),
(hidden_dim, bottleneck_dim, None, None),
]
)
Expand Down

0 comments on commit 1c560e0

Please sign in to comment.