Skip to content

Commit

Permalink
WIP: Add ordinal_mode option to enforce unimodal distribution for p…
Browse files Browse the repository at this point in the history
…redictions on ordinal classes
  • Loading branch information
nathanpainchaud committed Oct 30, 2023
1 parent 6b57cf6 commit 1d1dc1b
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 43 deletions.
20 changes: 17 additions & 3 deletions didactic/data/cardinal/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,17 @@ def _write_features_plots(
predictions: Sequences of encoder output features and predicted clinical attribute for each patient. There
is one sublist for each prediction dataloader provided.
"""
prediction_example = predictions[0][0] # 1st: subset, 2nd: batch
# Pre-compute the list of attributes for which we have a unimodal parameter, since this output might be None
# and we don't want to access it in that case
attrs_w_unimodal_param = list(prediction_example[2]) if prediction_example[2] else []
features = {
(subset, patient.id, *[patient.attrs.get(attr) for attr in self._hue_attrs]): patient_prediction[0]
(
subset,
patient.id,
*[patient.attrs.get(attr) for attr in self._hue_attrs],
*[patient_prediction[2][attr].item() for attr in attrs_w_unimodal_param],
): patient_prediction[0]
.flatten()
.cpu()
.numpy()
Expand All @@ -233,7 +242,12 @@ def _write_features_plots(
features.values(),
index=pd.MultiIndex.from_tuples(
features.keys(),
names=["subset", "patient", *self._hue_attrs],
names=[
"subset",
"patient",
*self._hue_attrs,
*[f"{attr}_unimodal_param" for attr in attrs_w_unimodal_param],
],
),
)

Expand Down Expand Up @@ -283,7 +297,7 @@ def _write_prediction_scores(

# Compute the loss on the predictions for all the patients of the subset
subset_categorical_data, subset_numerical_data = [], []
for (patient_id, patient), (_, patient_predictions) in zip(subset_patients.items(), subset_predictions):
for (patient_id, patient), (_, patient_predictions, _) in zip(subset_patients.items(), subset_predictions):
if target_categorical_attrs:
patient_categorical_data = {"patient": patient_id}
for attr in target_categorical_attrs:
Expand Down
140 changes: 100 additions & 40 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes
from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS
from vital.models.classification.mlp import MLP
from vital.models.layers import UnimodalLogits
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data

Expand All @@ -37,6 +38,7 @@ def __init__(
img_attrs: Sequence[ImageAttribute],
views: Sequence[ViewEnum] = tuple(ViewEnum),
predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None,
ordinal_mode: bool = True,
contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None,
contrastive_loss_weight: float = 0,
clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None,
Expand All @@ -55,6 +57,9 @@ def __init__(
embed_dim: Size of the tokens/embedding for all the modalities.
predict_losses: Supervised criteria to measure the error between the predicted attributes and their real
value.
ordinal_mode: Whether to consider applicable targets as ordinal variables, which means:
- Applying a constraint to enforce an unimodal softmax output from the prediction heads;
- Predicting a new output for each ordinal target, namely the parameter of the unimodal softmax.
contrastive_loss: Self-supervised criterion to use as contrastive loss between pairs of (N, E) collections
of feature vectors, in a contrastive learning step that follows the SCARF pretraining.
(see ref: https://arxiv.org/abs/2106.15147)
Expand Down Expand Up @@ -205,7 +210,12 @@ def __init__(
)

# Initialize transformer encoder and self-supervised + prediction heads
self.encoder, self.contrastive_head, self.prediction_heads = self.configure_model()
(
self.encoder,
self.contrastive_head,
self.prediction_heads,
self.unimodal_parametrization_heads,
) = self.configure_model()

# Check compatibility between config options and the encoder's architecture
if (self.train_attrs_dropout or self.test_attrs_dropout) and not isinstance(
Expand Down Expand Up @@ -300,7 +310,7 @@ def example_input_array(

def configure_model(
self,
) -> Tuple[nn.Module, Optional[nn.Module], Optional[nn.ModuleDict]]:
) -> Tuple[nn.Module, Optional[nn.Module], Optional[nn.ModuleDict], Optional[nn.ModuleDict]]:
"""Build the model, which must return a transformer encoder, and self-supervised or prediction heads."""
# Build the transformer encoder
encoder = hydra.utils.instantiate(self.hparams.model.get("encoder"))
Expand All @@ -319,26 +329,34 @@ def configure_model(

# Build the prediction heads (one by clinical attribute to predict) following the architecture proposed in
# https://arxiv.org/pdf/2106.11959
prediction_heads = None
prediction_heads, unimodal_parametrization_heads = None, None
if self.predict_losses:
prediction_heads = {}
prediction_heads = nn.ModuleDict()
if self.hparams.ordinal_mode:
unimodal_parametrization_heads = nn.ModuleDict()
for target_clinical_attr in self.predict_losses:
if target_clinical_attr in ClinicalAttribute.categorical_attrs():
if target_clinical_attr in ClinicalAttribute.binary_attrs():
# Binary classification target
output_size = 1
else:
# Multi-class classification target
output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr])
if (
target_clinical_attr in ClinicalAttribute.categorical_attrs()
and target_clinical_attr not in ClinicalAttribute.binary_attrs()
):
# Multi-class classification target
output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr])
else:
# Regression target
# Binary classification or regression target
output_size = 1
prediction_heads[target_clinical_attr] = nn.Sequential(
nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size)
)
prediction_heads = nn.ModuleDict(prediction_heads)

return encoder, contrastive_head, prediction_heads
if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs():
unimodal_parametrization_heads[target_clinical_attr] = nn.Sequential(
nn.Linear(num_features, 1), nn.Softplus()
)
prediction_heads[target_clinical_attr] = UnimodalLogits(output_size)

else:
prediction_heads[target_clinical_attr] = nn.Sequential(
nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size)
)

return encoder, contrastive_head, prediction_heads, unimodal_parametrization_heads

def configure_optimizers(self) -> Dict[Literal["optimizer", "lr_scheduler"], Any]:
"""Configure optimizer to ignore parameters that should remain frozen (e.g. image tokenizer)."""
Expand Down Expand Up @@ -492,7 +510,7 @@ def forward(
self,
clinical_attrs: Dict[ClinicalAttribute, Tensor],
img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor],
task: Literal["encode", "predict"] = "encode",
task: Literal["encode", "unimodal_param", "predict"] = "encode",
) -> Tensor | Dict[ClinicalAttribute, Tensor]:
"""Performs a forward pass through i) the tokenizer, ii) the transformer encoder and iii) the prediction head.
Expand All @@ -507,21 +525,49 @@ def forward(
Returns:
if `task` == 'encode':
(N, E) | (N, S * E), Batch of features extracted by the encoder.
if `task` == 'unimodal_param`:
? * (M), Positive scalar that parameterizes the unimodal softmax for ordinal targets.
if `task` == 'predict' (and the model includes prediction heads):
? * (N), Prediction for each target in `losses`.
"""
if task in ["encode", "unimodal_param"] and not self.prediction_heads:
raise ValueError(
"You requested to perform a prediction task, but the model does not include any prediction heads."
)
if task == "unimodal_param" and not self.unimodal_parametrization_heads:
raise ValueError(
"You requested to obtain the parametrization of the unimodal softmax for ordinal attributes, but the "
"model is not configured to predict unimodal ordinal targets. Either set `ordinal_mode` to `True` or "
"change the requested inference task."
)

in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S)
out = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E)
out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E)

if task in ["predict"]:
if not self.prediction_heads:
raise ValueError(
"You requested to perform a prediction task, but the model does not include any prediction heads."
)
# Early return if requested task requires no prediction heads
if task == "encode":
return out_features

out = {attr: prediction_head(out).squeeze(dim=1) for attr, prediction_head in self.prediction_heads.items()}
predictions = {}
if self.unimodal_parametrization_heads:
# For ordinal targets, the unimodal parametrization heads are used to obtain the input of the prediction
# heads, i.e. the parametrization of the unimodal softmax
predictions = {
attr: unimodal_param_head(out_features)
for attr, unimodal_param_head in self.unimodal_parametrization_heads.items()
}

return out
if task == "predict":
# If the task is to obtain the predictions, forward pass through the prediction heads
# Depending on the target, the input is either the encoder's output or the unimodal parameterization
predictions = {
attr: prediction_head(predictions.get(attr, out_features))
for attr, prediction_head in self.prediction_heads.items()
}

# Squeeze out the singleton dimension for the predictions' features (since the predictions are always scalar)
predictions = {attr: prediction.squeeze(dim=1) for attr, prediction in predictions.items()}
return predictions

def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]:
# Extract clinical and image attributes from the batch
Expand All @@ -547,9 +593,20 @@ def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]:
def _prediction_shared_step(
self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor
) -> Dict[str, Tensor]:
predictions = {}

if self.unimodal_parametrization_heads:
# For ordinal targets, the unimodal parametrization heads are used to obtain the input of the prediction
# heads, i.e. the parametrization of the unimodal softmax
predictions = {
attr: unimodal_param_head(out_features)
for attr, unimodal_param_head in self.unimodal_parametrization_heads.items()
}

# Forward pass through each target's prediction head
# Depending on the target, the input is either the encoder's output or the unimodal parameterization
predictions = {
attr: prediction_head(out_features).squeeze(dim=1)
attr: prediction_head(predictions.get(attr, out_features)).squeeze(dim=1)
for attr, prediction_head in self.prediction_heads.items()
}

Expand Down Expand Up @@ -597,7 +654,7 @@ def _contrastive_shared_step(
@torch.inference_mode()
def predict_step( # noqa: D102
self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0
) -> Tuple[Tensor, Optional[Dict[ClinicalAttribute, Tensor]]]:
) -> Tuple[Tensor, Optional[Dict[ClinicalAttribute, Tensor]], Optional[Dict[ClinicalAttribute, Tensor]]]:
# Extract clinical and image attributes from the patient and add batch dimension
clinical_attrs = {
attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs
Expand All @@ -609,22 +666,25 @@ def predict_step( # noqa: D102
).items()
}

# Forward pass through the encoder
in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S)
out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E)
# Encoder's output
out_features = self(clinical_attrs, img_attrs)

# If the network includes prediction heads, forward pass through them
# If the model has targets to predict, output the predictions
predictions = None
if self.prediction_heads:
predictions = {
attr: prediction_head(out_features).squeeze(dim=1)
for attr, prediction_head in self.prediction_heads.items()
}
predictions = self(clinical_attrs, img_attrs, task="predict")

# If the model enforces unimodal constraint on ordinal targets, output the unimodal parametrization
unimodal_params = None
if self.hparams.ordinal_mode:
unimodal_params = self(clinical_attrs, img_attrs, task="unimodal_param")

# Remove unnecessary batch dimension from the different outputs (after all potential downstream predictions have
# been performed)
# Remove unnecessary batch dimension from the different outputs
# (only do this once all downstream inferences have been performed)
out_features = out_features.squeeze(dim=0)
if predictions:
if predictions is not None:
predictions = {attr: prediction.squeeze(dim=0) for attr, prediction in predictions.items()}
if unimodal_params is not None:
unimodal_params = {attr: unimodal_param.squeeze(dim=0) for attr, unimodal_param in unimodal_params.items()}

return out_features, predictions
return out_features, predictions, unimodal_params

0 comments on commit 1d1dc1b

Please sign in to comment.