Skip to content

Commit

Permalink
Add optional custom prediction head for ordinal targets to enforce un…
Browse files Browse the repository at this point in the history
…imodal logits distribution

Custom prediction heads also output parameter of the unimodal distribution and softmax temperature, on top of the logits.

Thus, adapted downstream prediction logging pipeline to also log these parameters.
  • Loading branch information
nathanpainchaud committed Nov 1, 2023
1 parent 563235b commit 3fa7854
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 48 deletions.
27 changes: 20 additions & 7 deletions didactic/data/cardinal/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,17 @@ def _write_features_plots(
predictions: Sequences of encoder output features and predicted clinical attribute for each patient. There
is one sublist for each prediction dataloader provided.
"""
prediction_example = predictions[0][0] # 1st: subset, 2nd: batch
# Pre-compute the list of attributes for which we have a unimodal parameter, since this output might be None
# and we don't want to access it in that case
ordinal_attrs = list(prediction_example[2]) if prediction_example[2] else []
features = {
(subset, patient.id, *[patient.attrs.get(attr) for attr in self._hue_attrs]): patient_prediction[0]
(
subset,
patient.id,
*[patient.attrs.get(attr) for attr in self._hue_attrs],
*[patient_prediction[output_idx][attr].item() for attr in ordinal_attrs for output_idx in (2, 3)],
): patient_prediction[0]
.flatten()
.cpu()
.numpy()
Expand All @@ -233,7 +242,12 @@ def _write_features_plots(
features.values(),
index=pd.MultiIndex.from_tuples(
features.keys(),
names=["subset", "patient", *self._hue_attrs],
names=[
"subset",
"patient",
*self._hue_attrs,
*[f"{attr}_unimodal_{pred_desc}" for attr in ordinal_attrs for pred_desc in ("param", "tau")],
],
),
)

Expand Down Expand Up @@ -283,15 +297,14 @@ def _write_prediction_scores(

# Compute the loss on the predictions for all the patients of the subset
subset_categorical_data, subset_numerical_data = [], []
for (patient_id, patient), (_, patient_predictions) in zip(subset_patients.items(), subset_predictions):
for (patient_id, patient), patient_predictions in zip(subset_patients.items(), subset_predictions):
attr_predictions = patient_predictions[1]
if target_categorical_attrs:
patient_categorical_data = {"patient": patient_id}
for attr in target_categorical_attrs:
patient_categorical_data.update(
{
f"{attr}_prediction": CLINICAL_CAT_ATTR_LABELS[attr][
patient_predictions[attr].argmax()
],
f"{attr}_prediction": CLINICAL_CAT_ATTR_LABELS[attr][attr_predictions[attr].argmax()],
f"{attr}_target": patient.attrs.get(attr, np.nan),
}
)
Expand All @@ -302,7 +315,7 @@ def _write_prediction_scores(
for attr in target_numerical_attrs:
patient_numerical_data.update(
{
f"{attr}_prediction": patient_predictions[attr].item(),
f"{attr}_prediction": attr_predictions[attr].item(),
f"{attr}_target": patient.attrs.get(attr, np.nan),
}
)
Expand Down
139 changes: 98 additions & 41 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data

from didactic.models.layers import PositionalEncoding, SequentialPooling
from didactic.models.layers import PositionalEncoding, SequentialPooling, UnimodalLogitsHead

CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute]

Expand All @@ -36,6 +36,8 @@ def __init__(
img_attrs: Sequence[ImageAttribute],
views: Sequence[ViewEnum] = tuple(ViewEnum),
predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None,
ordinal_mode: bool = True,
unimodal_head_kwargs: Dict[str, Any] | DictConfig = None,
contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None,
contrastive_loss_weight: float = 0,
clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None,
Expand All @@ -53,6 +55,10 @@ def __init__(
embed_dim: Size of the tokens/embedding for all the modalities.
predict_losses: Supervised criteria to measure the error between the predicted attributes and their real
value.
ordinal_mode: Whether to consider applicable targets as ordinal variables, which means:
- Applying a constraint to enforce an unimodal softmax output from the prediction heads;
- Predicting a new output for each ordinal target, namely the parameter of the unimodal softmax.
unimodal_head_kwargs: Keyword arguments to forward to the initialization of the unimodal prediction heads.
contrastive_loss: Self-supervised criterion to use as contrastive loss between pairs of (N, E) collections
of feature vectors, in a contrastive learning step that follows the SCARF pretraining.
(see ref: https://arxiv.org/abs/2106.15147)
Expand Down Expand Up @@ -84,6 +90,10 @@ def __init__(
if not isinstance(mtr_p, (int, float)):
mtr_p = tuple(mtr_p)

# If kwargs are null, set them to empty dict
if unimodal_head_kwargs is None:
unimodal_head_kwargs = {}

if contrastive_loss is None and predict_losses is None:
raise ValueError(
"You should provide at least one of `contrastive_loss` or `predict_losses`. Providing only "
Expand Down Expand Up @@ -299,22 +309,27 @@ def configure_model(
# https://arxiv.org/pdf/2106.11959
prediction_heads = None
if self.predict_losses:
prediction_heads = {}
prediction_heads = nn.ModuleDict()
for target_clinical_attr in self.predict_losses:
if target_clinical_attr in ClinicalAttribute.categorical_attrs():
if target_clinical_attr in ClinicalAttribute.binary_attrs():
# Binary classification target
output_size = 1
else:
# Multi-class classification target
output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr])
if (
target_clinical_attr in ClinicalAttribute.categorical_attrs()
and target_clinical_attr not in ClinicalAttribute.binary_attrs()
):
# Multi-class classification target
output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr])
else:
# Regression target
# Binary classification or regression target
output_size = 1
prediction_heads[target_clinical_attr] = nn.Sequential(
nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size)
)
prediction_heads = nn.ModuleDict(prediction_heads)

if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs():
# For ordinal targets, use a custom prediction head to constraint the distribution of logits
prediction_heads[target_clinical_attr] = UnimodalLogitsHead(
num_features, output_size, **self.hparams.unimodal_head_kwargs
)
else:
prediction_heads[target_clinical_attr] = nn.Sequential(
nn.LayerNorm(num_features), nn.ReLU(), nn.Linear(num_features, output_size)
)

return encoder, contrastive_head, prediction_heads

Expand Down Expand Up @@ -442,7 +457,7 @@ def forward(
self,
clinical_attrs: Dict[ClinicalAttribute, Tensor],
img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor],
task: Literal["encode", "predict"] = "encode",
task: Literal["encode", "predict", "unimodal_param", "unimodal_tau"] = "encode",
) -> Tensor | Dict[ClinicalAttribute, Tensor]:
"""Performs a forward pass through i) the tokenizer, ii) the transformer encoder and iii) the prediction head.
Expand All @@ -457,21 +472,50 @@ def forward(
Returns:
if `task` == 'encode':
(N, E) | (N, S * E), Batch of features extracted by the encoder.
if `task` == 'unimodal_param`:
? * (M), Parameter of the unimodal logits distribution for ordinal targets.
if `task` == 'unimodal_tau`:
? * (M), Temperature used to control the sharpness of the unimodal logits distribution for ordinal
targets.
if `task` == 'predict' (and the model includes prediction heads):
? * (N), Prediction for each target in `losses`.
"""
in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S)
out = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E)
if task != "encode" and not self.prediction_heads:
raise ValueError(
"You requested to perform a prediction task, but the model does not include any prediction heads."
)
if task in ["unimodal_param", "unimodal_tau"] and not self.hparams.ordinal_mode:
raise ValueError(
"You requested to obtain some parameters of the unimodal softmax for ordinal attributes, but the model "
"is not configured to predict unimodal ordinal targets. Either set `ordinal_mode` to `True` or change "
"the requested inference task."
)

if task in ["predict"]:
if not self.prediction_heads:
raise ValueError(
"You requested to perform a prediction task, but the model does not include any prediction heads."
)
in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S)
out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E)

out = {attr: prediction_head(out).squeeze(dim=1) for attr, prediction_head in self.prediction_heads.items()}
# Early return if requested task requires no prediction heads
if task == "encode":
return out_features

return out
# Forward pass through each target's prediction head
predictions = {attr: prediction_head(out_features) for attr, prediction_head in self.prediction_heads.items()}

# Based on the requested task, extract and format the appropriate output of the prediction heads
match task:
case "predict":
if self.hparams.ordinal_mode:
predictions = {attr: pred[0] for attr, pred in predictions.items()}
case "unimodal_param":
predictions = {attr: pred[1] for attr, pred in predictions.items()}
case "unimodal_tau":
predictions = {attr: pred[2] for attr, pred in predictions.items()}
case _:
raise ValueError(f"Unknown task '{task}'.")

# Squeeze out the singleton dimension from the predictions' features (only relevant for scalar predictions)
predictions = {attr: prediction.squeeze(dim=1) for attr, prediction in predictions.items()}
return predictions

def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]:
# Extract clinical and image attributes from the batch
Expand All @@ -498,10 +542,13 @@ def _prediction_shared_step(
self, batch: PatientData, batch_idx: int, in_tokens: Tensor, avail_mask: Tensor, out_features: Tensor
) -> Dict[str, Tensor]:
# Forward pass through each target's prediction head
predictions = {
attr: prediction_head(out_features).squeeze(dim=1)
for attr, prediction_head in self.prediction_heads.items()
}
predictions = {}
for attr, prediction_head in self.prediction_heads.items():
pred = prediction_head(out_features)
if self.hparams.ordinal_mode and attr in ClinicalAttribute.ordinal_attrs():
# For ordinal targets, extract the logits from the multiple outputs of unimodal logits head
pred = pred[0]
predictions[attr] = pred.squeeze(dim=1)

# Compute the loss/metrics for each target attribute, ignoring items for which targets are missing
losses, metrics = {}, {}
Expand Down Expand Up @@ -547,7 +594,12 @@ def _contrastive_shared_step(
@torch.inference_mode()
def predict_step( # noqa: D102
self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0
) -> Tuple[Tensor, Optional[Dict[ClinicalAttribute, Tensor]]]:
) -> Tuple[
Tensor,
Optional[Dict[ClinicalAttribute, Tensor]],
Optional[Dict[ClinicalAttribute, Tensor]],
Optional[Dict[ClinicalAttribute, Tensor]],
]:
# Extract clinical and image attributes from the patient and add batch dimension
clinical_attrs = {
attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs
Expand All @@ -559,22 +611,27 @@ def predict_step( # noqa: D102
).items()
}

# Forward pass through the encoder
in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S)
out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) | (N, S * E)
# Encoder's output
out_features = self(clinical_attrs, img_attrs)

# If the network includes prediction heads, forward pass through them
# If the model has targets to predict, output the predictions
predictions = None
if self.prediction_heads:
predictions = {
attr: prediction_head(out_features).squeeze(dim=1)
for attr, prediction_head in self.prediction_heads.items()
}
predictions = self(clinical_attrs, img_attrs, task="predict")

# If the model enforces unimodal constraint on ordinal targets, output the unimodal parametrization
unimodal_params, unimodal_taus = None, None
if self.hparams.ordinal_mode:
unimodal_params = self(clinical_attrs, img_attrs, task="unimodal_param")
unimodal_taus = self(clinical_attrs, img_attrs, task="unimodal_tau")

# Remove unnecessary batch dimension from the different outputs (after all potential downstream predictions have
# been performed)
# Remove unnecessary batch dimension from the different outputs
# (only do this once all downstream inferences have been performed)
out_features = out_features.squeeze(dim=0)
if predictions:
if predictions is not None:
predictions = {attr: prediction.squeeze(dim=0) for attr, prediction in predictions.items()}
if self.hparams.ordinal_mode:
unimodal_params = {attr: unimodal_param.squeeze(dim=0) for attr, unimodal_param in unimodal_params.items()}
unimodal_taus = {attr: unimodal_tau.squeeze(dim=0) for attr, unimodal_tau in unimodal_taus.items()}

return out_features, predictions
return out_features, predictions, unimodal_params, unimodal_taus

0 comments on commit 3fa7854

Please sign in to comment.