diff --git a/didactic/config/experiment/cardinal/multimodal-xformer.yaml b/didactic/config/experiment/cardinal/multimodal-xformer.yaml index 0dfde958..a14a93ec 100644 --- a/didactic/config/experiment/cardinal/multimodal-xformer.yaml +++ b/didactic/config/experiment/cardinal/multimodal-xformer.yaml @@ -3,6 +3,7 @@ defaults: - /task/img_tokenizer/model: linear-embedding - /task/model/encoder: ??? + - /task/model/contrastive_head: mlp - override /task/model: null # Set this to null because we specify multiple submodels instead of a singleton model - override /task/optim: adamw - override /data: cardinal diff --git a/didactic/config/task/model/contrastive_head/mlp.yaml b/didactic/config/task/model/contrastive_head/mlp.yaml new file mode 100644 index 00000000..3e4dbe05 --- /dev/null +++ b/didactic/config/task/model/contrastive_head/mlp.yaml @@ -0,0 +1,7 @@ +# Build the projection head as an MLP with a single hidden layer and constant width, as proposed in +# https://arxiv.org/abs/2106.15147 +_target_: vital.models.classification.mlp.MLP +input_shape: [task.embed_dim] +output_shape: [task.embed_dim] +hidden: [task.embed_dim] +dropout: 0 diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index 2ef3b810..58229a54 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -17,7 +17,6 @@ from vital.data.cardinal.config import View as ViewEnum from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS -from vital.models.classification.mlp import MLP from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data @@ -293,13 +292,10 @@ def configure_model( # Build the transformer encoder encoder = hydra.utils.instantiate(self.hparams.model.encoder) - # Build the projection head as an MLP with a single hidden layer and constant width, as proposed in - # https://arxiv.org/abs/2106.15147 + # Build the projection head for contrastive learning, if contrastive learning is enabled contrastive_head = None if self.contrastive_loss: - contrastive_head = MLP( - (self.hparams.embed_dim,), (self.hparams.embed_dim,), hidden=(self.hparams.embed_dim,), dropout=0 - ) + contrastive_head = hydra.utils.instantiate(self.hparams.model.contrastive_head) # Build the prediction heads (one by clinical attribute to predict) following the architecture proposed in # https://arxiv.org/pdf/2106.11959