Skip to content

Commit

Permalink
Fix conversion of logits to softmax for the computation of AUROC in `…
Browse files Browse the repository at this point in the history
…PredictionWriter`
  • Loading branch information
nathanpainchaud committed Jul 2, 2024
1 parent a918f45 commit 1360ef0
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions didactic/data/cardinal/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,14 @@ def _write_prediction_scores(
# Compute metrics on the predictions for all the patients of the subset +
# collect and structure necessary predictions to compute these metrics
subset_categorical_data, subset_numerical_data = [], []
classification_logits = {attr: [] for attr in target_categorical_attrs}
classification_out = {attr: [] for attr in target_categorical_attrs}
for (patient_id, patient), patient_predictions in zip(subset_patients.items(), subset_predictions):
attr_predictions = patient_predictions[1]
if target_categorical_attrs:
patient_categorical_data = {"patient": patient_id}
for attr in target_categorical_attrs:
# Collect the classification logits
classification_logits.setdefault(attr, []).append(attr_predictions[attr].detach().cpu().numpy())
# Collect the classification logits/probabilities
classification_out.setdefault(attr, []).append(attr_predictions[attr].detach().cpu().numpy())
# Add the hard prediction and target labels
patient_categorical_data.update(
{
Expand All @@ -327,8 +327,20 @@ def _write_prediction_scores(
)
subset_numerical_data.append(patient_numerical_data)

# Convert the classification logits to numpy arrays
classification_logits = {attr: np.array(probs) for attr, probs in classification_logits.items()}
# Convert the classification logits/probabilities to numpy arrays
for attr, attr_pred in classification_out.items():
attr_pred = np.array(attr_pred)
if (attr_pred < 0).any() or (attr_pred > 1).any():
# If output were logits, compute probabilities from logits
attr_pred = softmax(attr_pred, axis=1)
# Rescale the output of the softmax to make sure that it sums to 1
# In theory this should be handled by the softmax function itself, but scipy's implementation
# returned values summing to 0.9995, which was not close enough to 1 for some downstream cases
# Also, wrap the scaling inside a while loop to handle numerical instability, since a one-time
# scaling was not always enough to get close enough to a sum of 1
while not np.allclose(1, attr_pred.sum(axis=1)):
attr_pred /= attr_pred.sum(axis=1, keepdims=True)
classification_out[attr] = attr_pred

if subset_categorical_data:
subset_categorical_df = pd.DataFrame.from_records(subset_categorical_data, index="patient")
Expand All @@ -345,7 +357,7 @@ def _write_prediction_scores(
subset_categorical_stats.loc["auroc"] = {
f"{attr}_prediction": roc_auc_score(
subset_categorical_df[f"{attr}_target"][notna_mask[f"{attr}_target"]],
softmax(classification_logits[attr][notna_mask[f"{attr}_target"]], axis=1),
classification_out[attr][notna_mask[f"{attr}_target"]],
multi_class="ovr",
)
for attr in target_categorical_attrs
Expand Down

0 comments on commit 1360ef0

Please sign in to comment.