From f606d57b0a9fcacf9324c95d2ed7053ee48102c2 Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Wed, 1 Nov 2023 16:42:06 +0100 Subject: [PATCH] Change instantiation of prediction heads to make them configurable via Hydra --- .../cardinal/multimodal-xformer.yaml | 2 ++ .../model/prediction_head/ft-prediction.yaml | 2 ++ .../model/prediction_head/unimodal-logits.yaml | 5 +++++ .../tasks/cardiac_multimodal_representation.py | 18 +++++++----------- 4 files changed, 16 insertions(+), 11 deletions(-) create mode 100644 didactic/config/task/model/prediction_head/ft-prediction.yaml create mode 100644 didactic/config/task/model/prediction_head/unimodal-logits.yaml diff --git a/didactic/config/experiment/cardinal/multimodal-xformer.yaml b/didactic/config/experiment/cardinal/multimodal-xformer.yaml index a14a93ec..f9af13b6 100644 --- a/didactic/config/experiment/cardinal/multimodal-xformer.yaml +++ b/didactic/config/experiment/cardinal/multimodal-xformer.yaml @@ -4,6 +4,8 @@ defaults: - /task/img_tokenizer/model: linear-embedding - /task/model/encoder: ??? - /task/model/contrastive_head: mlp + - /task/model/prediction_head: ft-prediction + - /task/model/prediction_head@task.model.ordinal_head: unimodal-logits - override /task/model: null # Set this to null because we specify multiple submodels instead of a singleton model - override /task/optim: adamw - override /data: cardinal diff --git a/didactic/config/task/model/prediction_head/ft-prediction.yaml b/didactic/config/task/model/prediction_head/ft-prediction.yaml new file mode 100644 index 00000000..89b1ff58 --- /dev/null +++ b/didactic/config/task/model/prediction_head/ft-prediction.yaml @@ -0,0 +1,2 @@ +_target_: didactic.models.layers.FTPredictionHead +in_features: ${task.embed_dim} diff --git a/didactic/config/task/model/prediction_head/unimodal-logits.yaml b/didactic/config/task/model/prediction_head/unimodal-logits.yaml new file mode 100644 index 00000000..b747ac42 --- /dev/null +++ b/didactic/config/task/model/prediction_head/unimodal-logits.yaml @@ -0,0 +1,5 @@ +_target_: didactic.models.layers.UnimodalLogitsHead +in_features: ${task.embed_dim} +backbone_distribution: binomial +tau: 1 +tau_mode: learn_fn diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index 01027441..787907ba 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -20,7 +20,7 @@ from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data -from didactic.models.layers import FTPredictionHead, PositionalEncoding, SequentialPooling, UnimodalLogitsHead +from didactic.models.layers import PositionalEncoding, SequentialPooling CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute] @@ -36,7 +36,6 @@ def __init__( views: Sequence[ViewEnum] = tuple(ViewEnum), predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None, ordinal_mode: bool = True, - unimodal_head_kwargs: Dict[str, Any] | DictConfig = None, contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None, contrastive_loss_weight: float = 0, clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None, @@ -57,7 +56,6 @@ def __init__( ordinal_mode: Whether to consider applicable targets as ordinal variables, which means: - Applying a constraint to enforce an unimodal softmax output from the prediction heads; - Predicting a new output for each ordinal target, namely the parameter of the unimodal softmax. - unimodal_head_kwargs: Keyword arguments to forward to the initialization of the unimodal prediction heads. contrastive_loss: Self-supervised criterion to use as contrastive loss between pairs of (N, E) collections of feature vectors, in a contrastive learning step that follows the SCARF pretraining. (see ref: https://arxiv.org/abs/2106.15147) @@ -89,10 +87,6 @@ def __init__( if not isinstance(mtr_p, (int, float)): mtr_p = tuple(mtr_p) - # If kwargs are null, set them to empty dict - if unimodal_head_kwargs is None: - unimodal_head_kwargs = {} - if contrastive_loss is None and predict_losses is None: raise ValueError( "You should provide at least one of `contrastive_loss` or `predict_losses`. Providing only " @@ -314,12 +308,14 @@ def configure_model( output_size = 1 if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs(): - # For ordinal targets, use a custom prediction head to constraint the distribution of logits - prediction_heads[target_clinical_attr] = UnimodalLogitsHead( - self.hparams.embed_dim, output_size, **self.hparams.unimodal_head_kwargs + # For ordinal targets, use a separate prediction head config + prediction_heads[target_clinical_attr] = hydra.utils.instantiate( + self.hparams.model.ordinal_head, num_logits=output_size ) else: - prediction_heads[target_clinical_attr] = FTPredictionHead(self.hparams.embed_dim, output_size) + prediction_heads[target_clinical_attr] = hydra.utils.instantiate( + self.hparams.model.prediction_head, out_features=output_size + ) return encoder, contrastive_head, prediction_heads