Skip to content

Commit

Permalink
Extract contrastive head configuration from code into its own Hydra c…
Browse files Browse the repository at this point in the history
…onfig
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent 47f2902 commit efb1330
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
defaults:
- /task/img_tokenizer/model: linear-embedding
- /task/model/encoder: ???
- /task/model/contrastive_head: mlp
- override /task/model: null # Set this to null because we specify multiple submodels instead of a singleton model
- override /task/optim: adamw
- override /data: cardinal
Expand Down
7 changes: 7 additions & 0 deletions didactic/config/task/model/contrastive_head/mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Build the projection head as an MLP with a single hidden layer and constant width, as proposed in
# https://arxiv.org/abs/2106.15147
_target_: vital.models.classification.mlp.MLP
input_shape: [task.embed_dim]
output_shape: [task.embed_dim]
hidden: [task.embed_dim]
dropout: 0
8 changes: 2 additions & 6 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from vital.data.cardinal.config import View as ViewEnum
from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes
from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS
from vital.models.classification.mlp import MLP
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data

Expand Down Expand Up @@ -293,13 +292,10 @@ def configure_model(
# Build the transformer encoder
encoder = hydra.utils.instantiate(self.hparams.model.encoder)

# Build the projection head as an MLP with a single hidden layer and constant width, as proposed in
# https://arxiv.org/abs/2106.15147
# Build the projection head for contrastive learning, if contrastive learning is enabled
contrastive_head = None
if self.contrastive_loss:
contrastive_head = MLP(
(self.hparams.embed_dim,), (self.hparams.embed_dim,), hidden=(self.hparams.embed_dim,), dropout=0
)
contrastive_head = hydra.utils.instantiate(self.hparams.model.contrastive_head)

# Build the prediction heads (one by clinical attribute to predict) following the architecture proposed in
# https://arxiv.org/pdf/2106.11959
Expand Down

0 comments on commit efb1330

Please sign in to comment.