Skip to content

Commit

Permalink
Extract default FT-Transformer prediction head into its own module
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent 46330b1 commit c35ce65
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
25 changes: 25 additions & 0 deletions didactic/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,31 @@ def forward(self, x: Tensor) -> Tensor:
return pooled_x


class FTPredictionHead(nn.Module):
"""Prediction head architecture described in the Feature Tokenizer transformer (FT-Transformer) paper."""

def __init__(self, in_features: int, out_features: int):
"""Initializes class instance.
Args:
in_features: Number of features in the input feature vector.
out_features: Number of features to output.
"""
super().__init__()
self.head = nn.Sequential(nn.LayerNorm(in_features), nn.ReLU(), nn.Linear(in_features, out_features))

def forward(self, x: Tensor) -> Tensor:
"""Predicts unnormalized features from a feature vector input.
Args:
x: (N, `in_features`), Batch of feature vectors.
Returns:
- (N, `out_features`), Batch of output features.
"""
return self.head(x)


class UnimodalLogitsHead(nn.Module):
"""Layer to output (enforced) unimodal logits from an input feature vector.
Expand Down
6 changes: 2 additions & 4 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data

from didactic.models.layers import PositionalEncoding, SequentialPooling, UnimodalLogitsHead
from didactic.models.layers import FTPredictionHead, PositionalEncoding, SequentialPooling, UnimodalLogitsHead

CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute]

Expand Down Expand Up @@ -319,9 +319,7 @@ def configure_model(
self.hparams.embed_dim, output_size, **self.hparams.unimodal_head_kwargs
)
else:
prediction_heads[target_clinical_attr] = nn.Sequential(
nn.LayerNorm(self.hparams.embed_dim), nn.ReLU(), nn.Linear(self.hparams.embed_dim, output_size)
)
prediction_heads[target_clinical_attr] = FTPredictionHead(self.hparams.embed_dim, output_size)

return encoder, contrastive_head, prediction_heads

Expand Down

0 comments on commit c35ce65

Please sign in to comment.