From 1d1dc1b79b067f004e8f4ef32f8ce0c4cb1f296a Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Fri, 27 Oct 2023 23:59:59 +0200 Subject: [PATCH] WIP: Add `ordinal_mode` option to enforce unimodal distribution for predictions on ordinal classes --- didactic/data/cardinal/predict.py | 20 ++- .../cardiac_multimodal_representation.py | 140 +++++++++++++----- 2 files changed, 117 insertions(+), 43 deletions(-) diff --git a/didactic/data/cardinal/predict.py b/didactic/data/cardinal/predict.py index bd9aded5..585ff003 100644 --- a/didactic/data/cardinal/predict.py +++ b/didactic/data/cardinal/predict.py @@ -217,8 +217,17 @@ def _write_features_plots( predictions: Sequences of encoder output features and predicted clinical attribute for each patient. There is one sublist for each prediction dataloader provided. """ + prediction_example = predictions[0][0] # 1st: subset, 2nd: batch + # Pre-compute the list of attributes for which we have a unimodal parameter, since this output might be None + # and we don't want to access it in that case + attrs_w_unimodal_param = list(prediction_example[2]) if prediction_example[2] else [] features = { - (subset, patient.id, *[patient.attrs.get(attr) for attr in self._hue_attrs]): patient_prediction[0] + ( + subset, + patient.id, + *[patient.attrs.get(attr) for attr in self._hue_attrs], + *[patient_prediction[2][attr].item() for attr in attrs_w_unimodal_param], + ): patient_prediction[0] .flatten() .cpu() .numpy() @@ -233,7 +242,12 @@ def _write_features_plots( features.values(), index=pd.MultiIndex.from_tuples( features.keys(), - names=["subset", "patient", *self._hue_attrs], + names=[ + "subset", + "patient", + *self._hue_attrs, + *[f"{attr}_unimodal_param" for attr in attrs_w_unimodal_param], + ], ), ) @@ -283,7 +297,7 @@ def _write_prediction_scores( # Compute the loss on the predictions for all the patients of the subset subset_categorical_data, subset_numerical_data = [], [] - for (patient_id, patient), (_, patient_predictions) in zip(subset_patients.items(), subset_predictions): + for (patient_id, patient), (_, patient_predictions, _) in zip(subset_patients.items(), subset_predictions): if target_categorical_attrs: patient_categorical_data = {"patient": patient_id} for attr in target_categorical_attrs: diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index 788c7c12..2f9c71f3 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -19,6 +19,7 @@ from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS from vital.models.classification.mlp import MLP +from vital.models.layers import UnimodalLogits from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data @@ -37,6 +38,7 @@ def __init__( img_attrs: Sequence[ImageAttribute], views: Sequence[ViewEnum] = tuple(ViewEnum), predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None, + ordinal_mode: bool = True, contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None, contrastive_loss_weight: float = 0, clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None, @@ -55,6 +57,9 @@ def __init__( embed_dim: Size of the tokens/embedding for all the modalities. predict_losses: Supervised criteria to measure the error between the predicted attributes and their real value. + ordinal_mode: Whether to consider applicable targets as ordinal variables, which means: + - Applying a constraint to enforce an unimodal softmax output from the prediction heads; + - Predicting a new output for each ordinal target, namely the parameter of the unimodal softmax. contrastive_loss: Self-supervised criterion to use as contrastive loss between pairs of (N, E) collections of feature vectors, in a contrastive learning step that follows the SCARF pretraining. (see ref: https://arxiv.org/abs/2106.15147) @@ -205,7 +210,12 @@ def __init__( ) # Initialize transformer encoder and self-supervised + prediction heads - self.encoder, self.contrastive_head, self.prediction_heads = self.configure_model() + ( + self.encoder, + self.contrastive_head, + self.prediction_heads, + self.unimodal_parametrization_heads, + ) = self.configure_model() # Check compatibility between config options and the encoder's architecture if (self.train_attrs_dropout or self.test_attrs_dropout) and not isinstance( @@ -300,7 +310,7 @@ def example_input_array( def configure_model( self, - ) -> Tuple[nn.Module, Optional[nn.Module], Optional[nn.ModuleDict]]: + ) -> Tuple[nn.Module, Optional[nn.Module], Optional[nn.ModuleDict], Optional[nn.ModuleDict]]: """Build the model, which must return a transformer encoder, and self-supervised or prediction heads.""" # Build the transformer encoder encoder = hydra.utils.instantiate(self.hparams.model.get("encoder")) @@ -319,26 +329,34 @@ def configure_model( # Build the prediction heads (one by clinical attribute to predict) following the architecture proposed in # https://arxiv.org/pdf/2106.11959 - prediction_heads = None + prediction_heads, unimodal_parametrization_heads = None, None if self.predict_losses: - prediction_heads = {} + prediction_heads = nn.ModuleDict() + if self.hparams.ordinal_mode: + unimodal_parametrization_heads = nn.ModuleDict() for target_clinical_attr in self.predict_losses: - if target_clinical_attr in ClinicalAttribute.categorical_attrs(): - if target_clinical_attr in ClinicalAttribute.binary_attrs(): - # Binary classification target - output_size = 1 - else: - # Multi-class classification target - output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr]) + if ( + target_clinical_attr in ClinicalAttribute.categorical_attrs() + and target_clinical_attr not in ClinicalAttribute.binary_attrs() + ): + # Multi-class classification target + output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr]) else: - # Regression target + # Binary classification or regression target output_size = 1 - prediction_heads[target_clinical_attr] = nn.Sequential( - nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size) - ) - prediction_heads = nn.ModuleDict(prediction_heads) - return encoder, contrastive_head, prediction_heads + if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs(): + unimodal_parametrization_heads[target_clinical_attr] = nn.Sequential( + nn.Linear(num_features, 1), nn.Softplus() + ) + prediction_heads[target_clinical_attr] = UnimodalLogits(output_size) + + else: + prediction_heads[target_clinical_attr] = nn.Sequential( + nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size) + ) + + return encoder, contrastive_head, prediction_heads, unimodal_parametrization_heads def configure_optimizers(self) -> Dict[Literal["optimizer", "lr_scheduler"], Any]: """Configure optimizer to ignore parameters that should remain frozen (e.g. image tokenizer).""" @@ -492,7 +510,7 @@ def forward( self, clinical_attrs: Dict[ClinicalAttribute, Tensor], img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor], - task: Literal["encode", "predict"] = "encode", + task: Literal["encode", "unimodal_param", "predict"] = "encode", ) -> Tensor | Dict[ClinicalAttribute, Tensor]: """Performs a forward pass through i) the tokenizer, ii) the transformer encoder and iii) the prediction head. @@ -507,21 +525,49 @@ def forward( Returns: if `task` == 'encode': (N, E) | (N, S * E), Batch of features extracted by the encoder. + if `task` == 'unimodal_param`: + ? * (M), Positive scalar that parameterizes the unimodal softmax for ordinal targets. if `task` == 'predict' (and the model includes prediction heads): ? * (N), Prediction for each target in `losses`. """ + if task in ["encode", "unimodal_param"] and not self.prediction_heads: + raise ValueError( + "You requested to perform a prediction task, but the model does not include any prediction heads." + ) + if task == "unimodal_param" and not self.unimodal_parametrization_heads: + raise ValueError( + "You requested to obtain the parametrization of the unimodal softmax for ordinal attributes, but the " + "model is not configured to predict unimodal ordinal targets. Either set `ordinal_mode` to `True` or " + "change the requested inference task." + ) + in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) - out = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E) + out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E) - if task in ["predict"]: - if not self.prediction_heads: - raise ValueError( - "You requested to perform a prediction task, but the model does not include any prediction heads." - ) + # Early return if requested task requires no prediction heads + if task == "encode": + return out_features - out = {attr: prediction_head(out).squeeze(dim=1) for attr, prediction_head in self.prediction_heads.items()} + predictions = {} + if self.unimodal_parametrization_heads: + # For ordinal targets, the unimodal parametrization heads are used to obtain the input of the prediction + # heads, i.e. the parametrization of the unimodal softmax + predictions = { + attr: unimodal_param_head(out_features) + for attr, unimodal_param_head in self.unimodal_parametrization_heads.items() + } - return out + if task == "predict": + # If the task is to obtain the predictions, forward pass through the prediction heads + # Depending on the target, the input is either the encoder's output or the unimodal parameterization + predictions = { + attr: prediction_head(predictions.get(attr, out_features)) + for attr, prediction_head in self.prediction_heads.items() + } + + # Squeeze out the singleton dimension for the predictions' features (since the predictions are always scalar) + predictions = {attr: prediction.squeeze(dim=1) for attr, prediction in predictions.items()} + return predictions def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: # Extract clinical and image attributes from the batch @@ -547,9 +593,20 @@ def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: def _prediction_shared_step( self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor ) -> Dict[str, Tensor]: + predictions = {} + + if self.unimodal_parametrization_heads: + # For ordinal targets, the unimodal parametrization heads are used to obtain the input of the prediction + # heads, i.e. the parametrization of the unimodal softmax + predictions = { + attr: unimodal_param_head(out_features) + for attr, unimodal_param_head in self.unimodal_parametrization_heads.items() + } + # Forward pass through each target's prediction head + # Depending on the target, the input is either the encoder's output or the unimodal parameterization predictions = { - attr: prediction_head(out_features).squeeze(dim=1) + attr: prediction_head(predictions.get(attr, out_features)).squeeze(dim=1) for attr, prediction_head in self.prediction_heads.items() } @@ -597,7 +654,7 @@ def _contrastive_shared_step( @torch.inference_mode() def predict_step( # noqa: D102 self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0 - ) -> Tuple[Tensor, Optional[Dict[ClinicalAttribute, Tensor]]]: + ) -> Tuple[Tensor, Optional[Dict[ClinicalAttribute, Tensor]], Optional[Dict[ClinicalAttribute, Tensor]]]: # Extract clinical and image attributes from the patient and add batch dimension clinical_attrs = { attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs @@ -609,22 +666,25 @@ def predict_step( # noqa: D102 ).items() } - # Forward pass through the encoder - in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) - out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E) + # Encoder's output + out_features = self(clinical_attrs, img_attrs) - # If the network includes prediction heads, forward pass through them + # If the model has targets to predict, output the predictions predictions = None if self.prediction_heads: - predictions = { - attr: prediction_head(out_features).squeeze(dim=1) - for attr, prediction_head in self.prediction_heads.items() - } + predictions = self(clinical_attrs, img_attrs, task="predict") + + # If the model enforces unimodal constraint on ordinal targets, output the unimodal parametrization + unimodal_params = None + if self.hparams.ordinal_mode: + unimodal_params = self(clinical_attrs, img_attrs, task="unimodal_param") - # Remove unnecessary batch dimension from the different outputs (after all potential downstream predictions have - # been performed) + # Remove unnecessary batch dimension from the different outputs + # (only do this once all downstream inferences have been performed) out_features = out_features.squeeze(dim=0) - if predictions: + if predictions is not None: predictions = {attr: prediction.squeeze(dim=0) for attr, prediction in predictions.items()} + if unimodal_params is not None: + unimodal_params = {attr: unimodal_param.squeeze(dim=0) for attr, unimodal_param in unimodal_params.items()} - return out_features, predictions + return out_features, predictions, unimodal_params