From 1520968e645ab614d37baddb5fa854403a515e9f Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Wed, 1 Nov 2023 16:24:15 +0100 Subject: [PATCH] Extract default FT-Transformer prediction head into its own module --- didactic/models/layers.py | 25 +++++++++++++++++++ .../cardiac_multimodal_representation.py | 6 ++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/didactic/models/layers.py b/didactic/models/layers.py index a5e2e64f..17bd2868 100644 --- a/didactic/models/layers.py +++ b/didactic/models/layers.py @@ -62,6 +62,31 @@ def forward(self, x: Tensor) -> Tensor: return pooled_x +class FTPredictionHead(nn.Module): + """Prediction head architecture described in the Feature Tokenizer transformer (FT-Transformer) paper.""" + + def __init__(self, in_features: int, out_features: int): + """Initializes class instance. + + Args: + in_features: Number of features in the input feature vector. + out_features: Number of features to output. + """ + super().__init__() + self.head = nn.Sequential(nn.LayerNorm(in_features), nn.ReLU(), nn.Linear(in_features, out_features)) + + def forward(self, x: Tensor) -> Tensor: + """Predicts unnormalized features from a feature vector input. + + Args: + x: (N, `in_features`), Batch of feature vectors. + + Returns: + - (N, `out_features`), Batch of output features. + """ + return self.head(x) + + class UnimodalLogitsHead(nn.Module): """Layer to output (enforced) unimodal logits from an input feature vector. diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index ea43017b..ffa01696 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -20,7 +20,7 @@ from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data -from didactic.models.layers import PositionalEncoding, SequentialPooling, UnimodalLogitsHead +from didactic.models.layers import FTPredictionHead, PositionalEncoding, SequentialPooling, UnimodalLogitsHead CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute] @@ -319,9 +319,7 @@ def configure_model( self.hparams.embed_dim, output_size, **self.hparams.unimodal_head_kwargs ) else: - prediction_heads[target_clinical_attr] = nn.Sequential( - nn.LayerNorm(self.hparams.embed_dim), nn.ReLU(), nn.Linear(self.hparams.embed_dim, output_size) - ) + prediction_heads[target_clinical_attr] = FTPredictionHead(self.hparams.embed_dim, output_size) return encoder, contrastive_head, prediction_heads