From 14f6903f206f0c9aae1bda17cd52041651d1e8cc Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Wed, 1 Nov 2023 00:44:50 +0100 Subject: [PATCH] Add optional custom prediction head for ordinal targets to enforce unimodal logits distribution Custom prediction heads also output parameter of the unimodal distribution and softmax temperature, on top of the logits. Thus, adapted downstream prediction logging pipeline to also log these parameters. --- didactic/data/cardinal/predict.py | 27 +++- .../cardiac_multimodal_representation.py | 139 ++++++++++++------ 2 files changed, 118 insertions(+), 48 deletions(-) diff --git a/didactic/data/cardinal/predict.py b/didactic/data/cardinal/predict.py index bd9aded5..204bb4c7 100644 --- a/didactic/data/cardinal/predict.py +++ b/didactic/data/cardinal/predict.py @@ -217,8 +217,17 @@ def _write_features_plots( predictions: Sequences of encoder output features and predicted clinical attribute for each patient. There is one sublist for each prediction dataloader provided. """ + prediction_example = predictions[0][0] # 1st: subset, 2nd: batch + # Pre-compute the list of attributes for which we have a unimodal parameter, since this output might be None + # and we don't want to access it in that case + ordinal_attrs = list(prediction_example[2]) if prediction_example[2] else [] features = { - (subset, patient.id, *[patient.attrs.get(attr) for attr in self._hue_attrs]): patient_prediction[0] + ( + subset, + patient.id, + *[patient.attrs.get(attr) for attr in self._hue_attrs], + *[patient_prediction[output_idx][attr].item() for attr in ordinal_attrs for output_idx in (2, 3)], + ): patient_prediction[0] .flatten() .cpu() .numpy() @@ -233,7 +242,12 @@ def _write_features_plots( features.values(), index=pd.MultiIndex.from_tuples( features.keys(), - names=["subset", "patient", *self._hue_attrs], + names=[ + "subset", + "patient", + *self._hue_attrs, + *[f"{attr}_unimodal_{pred_desc}" for attr in ordinal_attrs for pred_desc in ("param", "tau")], + ], ), ) @@ -283,15 +297,14 @@ def _write_prediction_scores( # Compute the loss on the predictions for all the patients of the subset subset_categorical_data, subset_numerical_data = [], [] - for (patient_id, patient), (_, patient_predictions) in zip(subset_patients.items(), subset_predictions): + for (patient_id, patient), patient_predictions in zip(subset_patients.items(), subset_predictions): + attr_predictions = patient_predictions[1] if target_categorical_attrs: patient_categorical_data = {"patient": patient_id} for attr in target_categorical_attrs: patient_categorical_data.update( { - f"{attr}_prediction": CLINICAL_CAT_ATTR_LABELS[attr][ - patient_predictions[attr].argmax() - ], + f"{attr}_prediction": CLINICAL_CAT_ATTR_LABELS[attr][attr_predictions[attr].argmax()], f"{attr}_target": patient.attrs.get(attr, np.nan), } ) @@ -302,7 +315,7 @@ def _write_prediction_scores( for attr in target_numerical_attrs: patient_numerical_data.update( { - f"{attr}_prediction": patient_predictions[attr].item(), + f"{attr}_prediction": attr_predictions[attr].item(), f"{attr}_target": patient.attrs.get(attr, np.nan), } ) diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index 5a7d8127..b3cee855 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -21,7 +21,7 @@ from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data -from didactic.models.layers import PositionalEncoding, SequentialPooling +from didactic.models.layers import PositionalEncoding, SequentialPooling, UnimodalLogitsHead CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute] @@ -36,6 +36,8 @@ def __init__( img_attrs: Sequence[ImageAttribute], views: Sequence[ViewEnum] = tuple(ViewEnum), predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None, + ordinal_mode: bool = True, + unimodal_head_kwargs: Dict[str, Any] | DictConfig = None, contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None, contrastive_loss_weight: float = 0, clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None, @@ -53,6 +55,10 @@ def __init__( embed_dim: Size of the tokens/embedding for all the modalities. predict_losses: Supervised criteria to measure the error between the predicted attributes and their real value. + ordinal_mode: Whether to consider applicable targets as ordinal variables, which means: + - Applying a constraint to enforce an unimodal softmax output from the prediction heads; + - Predicting a new output for each ordinal target, namely the parameter of the unimodal softmax. + unimodal_head_kwargs: Keyword arguments to forward to the initialization of the unimodal prediction heads. contrastive_loss: Self-supervised criterion to use as contrastive loss between pairs of (N, E) collections of feature vectors, in a contrastive learning step that follows the SCARF pretraining. (see ref: https://arxiv.org/abs/2106.15147) @@ -84,6 +90,10 @@ def __init__( if not isinstance(mtr_p, (int, float)): mtr_p = tuple(mtr_p) + # If kwargs are null, set them to empty dict + if unimodal_head_kwargs is None: + unimodal_head_kwargs = {} + if contrastive_loss is None and predict_losses is None: raise ValueError( "You should provide at least one of `contrastive_loss` or `predict_losses`. Providing only " @@ -299,22 +309,27 @@ def configure_model( # https://arxiv.org/pdf/2106.11959 prediction_heads = None if self.predict_losses: - prediction_heads = {} + prediction_heads = nn.ModuleDict() for target_clinical_attr in self.predict_losses: - if target_clinical_attr in ClinicalAttribute.categorical_attrs(): - if target_clinical_attr in ClinicalAttribute.binary_attrs(): - # Binary classification target - output_size = 1 - else: - # Multi-class classification target - output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr]) + if ( + target_clinical_attr in ClinicalAttribute.categorical_attrs() + and target_clinical_attr not in ClinicalAttribute.binary_attrs() + ): + # Multi-class classification target + output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr]) else: - # Regression target + # Binary classification or regression target output_size = 1 - prediction_heads[target_clinical_attr] = nn.Sequential( - nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size) - ) - prediction_heads = nn.ModuleDict(prediction_heads) + + if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs(): + # For ordinal targets, use a custom prediction head to constraint the distribution of logits + prediction_heads[target_clinical_attr] = UnimodalLogitsHead( + num_features, output_size, **self.hparams.unimodal_head_kwargs + ) + else: + prediction_heads[target_clinical_attr] = nn.Sequential( + nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size) + ) return encoder, contrastive_head, prediction_heads @@ -442,7 +457,7 @@ def forward( self, clinical_attrs: Dict[ClinicalAttribute, Tensor], img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor], - task: Literal["encode", "predict"] = "encode", + task: Literal["encode", "predict", "unimodal_param", "unimodal_tau"] = "encode", ) -> Tensor | Dict[ClinicalAttribute, Tensor]: """Performs a forward pass through i) the tokenizer, ii) the transformer encoder and iii) the prediction head. @@ -457,21 +472,50 @@ def forward( Returns: if `task` == 'encode': (N, E) | (N, S * E), Batch of features extracted by the encoder. + if `task` == 'unimodal_param`: + ? * (M), Parameter of the unimodal logits distribution for ordinal targets. + if `task` == 'unimodal_tau`: + ? * (M), Temperature used to control the sharpness of the unimodal logits distribution for ordinal + targets. if `task` == 'predict' (and the model includes prediction heads): ? * (N), Prediction for each target in `losses`. """ - in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) - out = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E) + if task != "encode" and not self.prediction_heads: + raise ValueError( + "You requested to perform a prediction task, but the model does not include any prediction heads." + ) + if task in ["unimodal_param", "unimodal_tau"] and not self.hparams.ordinal_mode: + raise ValueError( + "You requested to obtain some parameters of the unimodal softmax for ordinal attributes, but the model " + "is not configured to predict unimodal ordinal targets. Either set `ordinal_mode` to `True` or change " + "the requested inference task." + ) - if task in ["predict"]: - if not self.prediction_heads: - raise ValueError( - "You requested to perform a prediction task, but the model does not include any prediction heads." - ) + in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) + out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E) - out = {attr: prediction_head(out).squeeze(dim=1) for attr, prediction_head in self.prediction_heads.items()} + # Early return if requested task requires no prediction heads + if task == "encode": + return out_features - return out + # Forward pass through each target's prediction head + predictions = {attr: prediction_head(out_features) for attr, prediction_head in self.prediction_heads.items()} + + # Based on the requested task, extract and format the appropriate output of the prediction heads + match task: + case "predict": + if self.hparams.ordinal_mode: + predictions = {attr: pred[0] for attr, pred in predictions.items()} + case "unimodal_param": + predictions = {attr: pred[1] for attr, pred in predictions.items()} + case "unimodal_tau": + predictions = {attr: pred[2] for attr, pred in predictions.items()} + case _: + raise ValueError(f"Unknown task '{task}'.") + + # Squeeze out the singleton dimension from the predictions' features (only relevant for scalar predictions) + predictions = {attr: prediction.squeeze(dim=1) for attr, prediction in predictions.items()} + return predictions def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: # Extract clinical and image attributes from the batch @@ -498,10 +542,13 @@ def _prediction_shared_step( self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor ) -> Dict[str, Tensor]: # Forward pass through each target's prediction head - predictions = { - attr: prediction_head(out_features).squeeze(dim=1) - for attr, prediction_head in self.prediction_heads.items() - } + predictions = {} + for attr, prediction_head in self.prediction_heads.items(): + pred = prediction_head(out_features) + if self.hparams.ordinal_mode and attr in ClinicalAttribute.ordinal_attrs(): + # For ordinal targets, extract the logits from the multiple outputs of unimodal logits head + pred = pred[0] + predictions[attr] = pred.squeeze(dim=1) # Compute the loss/metrics for each target attribute, ignoring items for which targets are missing losses, metrics = {}, {} @@ -547,7 +594,12 @@ def _contrastive_shared_step( @torch.inference_mode() def predict_step( # noqa: D102 self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0 - ) -> Tuple[Tensor, Optional[Dict[ClinicalAttribute, Tensor]]]: + ) -> Tuple[ + Tensor, + Optional[Dict[ClinicalAttribute, Tensor]], + Optional[Dict[ClinicalAttribute, Tensor]], + Optional[Dict[ClinicalAttribute, Tensor]], + ]: # Extract clinical and image attributes from the patient and add batch dimension clinical_attrs = { attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs @@ -559,22 +611,27 @@ def predict_step( # noqa: D102 ).items() } - # Forward pass through the encoder - in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) - out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E) + # Encoder's output + out_features = self(clinical_attrs, img_attrs) - # If the network includes prediction heads, forward pass through them + # If the model has targets to predict, output the predictions predictions = None if self.prediction_heads: - predictions = { - attr: prediction_head(out_features).squeeze(dim=1) - for attr, prediction_head in self.prediction_heads.items() - } + predictions = self(clinical_attrs, img_attrs, task="predict") + + # If the model enforces unimodal constraint on ordinal targets, output the unimodal parametrization + unimodal_params, unimodal_taus = None, None + if self.hparams.ordinal_mode: + unimodal_params = self(clinical_attrs, img_attrs, task="unimodal_param") + unimodal_taus = self(clinical_attrs, img_attrs, task="unimodal_tau") - # Remove unnecessary batch dimension from the different outputs (after all potential downstream predictions have - # been performed) + # Remove unnecessary batch dimension from the different outputs + # (only do this once all downstream inferences have been performed) out_features = out_features.squeeze(dim=0) - if predictions: + if predictions is not None: predictions = {attr: prediction.squeeze(dim=0) for attr, prediction in predictions.items()} + if self.hparams.ordinal_mode: + unimodal_params = {attr: unimodal_param.squeeze(dim=0) for attr, unimodal_param in unimodal_params.items()} + unimodal_taus = {attr: unimodal_tau.squeeze(dim=0) for attr, unimodal_tau in unimodal_taus.items()} - return out_features, predictions + return out_features, predictions, unimodal_params, unimodal_taus