Skip to content

Commit

Permalink
Change instantiation of prediction heads to make them configurable vi…
Browse files Browse the repository at this point in the history
…a Hydra
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent 1520968 commit 8841a46
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
2 changes: 2 additions & 0 deletions didactic/config/experiment/cardinal/multimodal-xformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ defaults:
- /task/img_tokenizer/model: linear-embedding
- /task/model/encoder: ???
- /task/model/contrastive_head: mlp
- /task/model/prediction_head: ft-prediction
- /task/model/[email protected]_head: unimodal-logits
- override /task/model: null # Set this to null because we specify multiple submodels instead of a singleton model
- override /task/optim: adamw
- override /data: cardinal
Expand Down
2 changes: 2 additions & 0 deletions didactic/config/task/model/prediction_head/ft-prediction.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: didactic.models.layers.FTPredictionHead
in_features: ${task.embed_dim}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: didactic.models.layers.UnimodalLogitsHead
in_features: ${task.embed_dim}
backbone_distribution: binomial
tau: 1
tau_mode: learn_fn
18 changes: 7 additions & 11 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data

from didactic.models.layers import FTPredictionHead, PositionalEncoding, SequentialPooling, UnimodalLogitsHead
from didactic.models.layers import PositionalEncoding, SequentialPooling

CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute]

Expand All @@ -36,7 +36,6 @@ def __init__(
views: Sequence[ViewEnum] = tuple(ViewEnum),
predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None,
ordinal_mode: bool = True,
unimodal_head_kwargs: Dict[str, Any] | DictConfig = None,
contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None,
contrastive_loss_weight: float = 0,
clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None,
Expand All @@ -57,7 +56,6 @@ def __init__(
ordinal_mode: Whether to consider applicable targets as ordinal variables, which means:
- Applying a constraint to enforce an unimodal softmax output from the prediction heads;
- Predicting a new output for each ordinal target, namely the parameter of the unimodal softmax.
unimodal_head_kwargs: Keyword arguments to forward to the initialization of the unimodal prediction heads.
contrastive_loss: Self-supervised criterion to use as contrastive loss between pairs of (N, E) collections
of feature vectors, in a contrastive learning step that follows the SCARF pretraining.
(see ref: https://arxiv.org/abs/2106.15147)
Expand Down Expand Up @@ -89,10 +87,6 @@ def __init__(
if not isinstance(mtr_p, (int, float)):
mtr_p = tuple(mtr_p)

# If kwargs are null, set them to empty dict
if unimodal_head_kwargs is None:
unimodal_head_kwargs = {}

if contrastive_loss is None and predict_losses is None:
raise ValueError(
"You should provide at least one of `contrastive_loss` or `predict_losses`. Providing only "
Expand Down Expand Up @@ -314,12 +308,14 @@ def configure_model(
output_size = 1

if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs():
# For ordinal targets, use a custom prediction head to constraint the distribution of logits
prediction_heads[target_clinical_attr] = UnimodalLogitsHead(
self.hparams.embed_dim, output_size, **self.hparams.unimodal_head_kwargs
# For ordinal targets, use a separate prediction head config
prediction_heads[target_clinical_attr] = hydra.utils.instantiate(
self.hparams.model.unimodal_head, num_logits=output_size
)
else:
prediction_heads[target_clinical_attr] = FTPredictionHead(self.hparams.embed_dim, output_size)
prediction_heads[target_clinical_attr] = hydra.utils.instantiate(
self.hparams.model.prediction_head, out_features=output_size
)

return encoder, contrastive_head, prediction_heads

Expand Down

0 comments on commit 8841a46

Please sign in to comment.