Skip to content

Commit

Permalink
Rename categories of attributes to be more accurate and generic
Browse files Browse the repository at this point in the history
Name changes:
- clinical -> tabular
- image - > time-series
  • Loading branch information
nathanpainchaud committed Nov 10, 2023
1 parent f0a4853 commit ff69c4e
Show file tree
Hide file tree
Showing 12 changed files with 381 additions and 363 deletions.
18 changes: 9 additions & 9 deletions didactic/data/cardinal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pandas as pd
from vital.data.cardinal.config import CardinalTag, ImageAttribute
from vital.data.cardinal.config import CardinalTag, TimeSeriesAttribute
from vital.data.cardinal.config import View as ViewEnum
from vital.data.cardinal.utils.attributes import build_attributes_dataframe
from vital.data.cardinal.utils.data_struct import Patient
Expand Down Expand Up @@ -107,25 +107,25 @@ def _to_num(attr_data: pd.Series) -> pd.Series:
return neigh_data


def build_img_attr_by_patient_group_dataframe(
def build_time_series_attr_by_patient_group_dataframe(
patients_groups: Mapping[Hashable, Iterable[Patient]],
attr: Tuple[ViewEnum, ImageAttribute],
attr: Tuple[ViewEnum, TimeSeriesAttribute],
group_desc: str = "group",
mask_tag: str = CardinalTag.mask,
resampling_rate: int = 128,
) -> pd.DataFrame:
"""Builds a dataframe with the average curve of an image attribute by patient group.
"""Builds a dataframe with the average curve of a time-series attribute by patient group.
Args:
patients_groups: Mapping between group ID and the patients in that group.
attr: A pair of view and image attribute to compute the average curve of.
attr: A pair of view and time-series attribute to compute the average curve of.
group_desc: Description of the semantic meaning of the groups.
mask_tag: Tag of the segmentation mask for which to extract the image attribute data.
resampling_rate: Number of points at which to resample the image attribute curves from each patient, so that
they can be easily compared and aggregated together.
mask_tag: Tag of the segmentation mask for which to extract the time-series attribute data.
resampling_rate: Number of points at which to resample the time-series attribute curves from each patient, so
that they can be easily compared and aggregated together.
Returns:
Dataframe with the average curve of an image attribute by patient group, in long format.
Dataframe with the average curve of an time-series attribute by patient group, in long format.
"""
resampling_fn = Interp1d(resampling_rate)

Expand Down
2 changes: 1 addition & 1 deletion didactic/scripts/analyse_clustering_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def main():
"hparams_files",
nargs="+",
type=Path,
help="Paths to YAML config file of the best clustering hyperparameters found for an encoder amodel",
help="Paths to YAML config file of the best clustering hyperparameters found for an encoder model",
)
parser.add_argument(
"--hparams",
Expand Down
2 changes: 1 addition & 1 deletion didactic/scripts/cardiac_multimodal_encoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main():
"--mask_tag",
type=str,
default=CardinalTag.mask,
help="Tag of the segmentation mask for which to extract the image attributes",
help="Tag of the segmentation mask for which to extract the time-series attributes",
)
parser.add_argument(
"--use_attention_rollout",
Expand Down
42 changes: 21 additions & 21 deletions didactic/scripts/cardiac_multimodal_representation_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from vital.data.cardinal.config import CardinalTag, ClinicalAttribute
from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS
from vital.data.cardinal.config import CardinalTag, TabularAttribute
from vital.data.cardinal.utils.attributes import TABULAR_CAT_ATTR_LABELS
from vital.data.cardinal.utils.data_struct import Patient
from vital.data.cardinal.utils.itertools import Patients
from vital.utils.plot import embedding_scatterplot
Expand All @@ -16,19 +16,19 @@
def plot_patients_embeddings(
model: CardiacMultimodalRepresentationTask,
patients: Patients,
plot_clinical_attrs: Sequence[ClinicalAttribute] = None,
plot_tabular_attrs: Sequence[TabularAttribute] = None,
categorical_attrs_lists: Dict[str, Dict[str, Sequence[Patient.Id]]] = None,
mask_tag: str = CardinalTag.mask,
progress_bar: bool = False,
**embedding_kwargs,
) -> Iterator[Tuple[ClinicalAttribute | str, Axes]]:
) -> Iterator[Tuple[TabularAttribute | str, Axes]]:
"""Generates 2D scatter plots of patients' encodings, labeled w.r.t. specific attributes.
Args:
model: Transformer encoder model to use for inference.
patients: (N) Patients to embed.
mask_tag: Tag of the segmentation mask for which to extract the image attributes.
plot_clinical_attrs: Patients' clinical attributes w.r.t. which to plot the embedding.
plot_tabular_attrs: Patients' tabular attributes w.r.t. which to plot the embedding.
categorical_attrs_lists: Nested mapping listing, for each additional categorical attribute, the patients
belonging to each of the attribute's labels.
progress_bar: If ``True``, enables progress bars detailing the progress of encoding patients.
Expand All @@ -38,18 +38,18 @@ def plot_patients_embeddings(
Returns:
An iterator over the attributes and associated scatter plots.
"""
if plot_clinical_attrs is None and categorical_attrs_lists is None:
if plot_tabular_attrs is None and categorical_attrs_lists is None:
raise ValueError(
"You have specified neither built-in attributes (`plot_attrs` is None) nor custom attributes "
"You have specified neither built-in attributes (`plot_tabular_attrs` is None) nor custom attributes "
"(`categorical_attrs_lists`) w.r.t. which to plot the embeddings. Specify at least one attribute of either "
"type to plot embeddings of the patients."
)

if plot_clinical_attrs is None:
plot_clinical_attrs = []
if plot_tabular_attrs is None:
plot_tabular_attrs = []
if categorical_attrs_lists is None:
categorical_attrs_lists = {}
plot_attrs = plot_clinical_attrs + list(categorical_attrs_lists)
plot_attrs = plot_tabular_attrs + list(categorical_attrs_lists)

# Encode the data using the model
patient_encodings = pd.DataFrame(
Expand All @@ -60,7 +60,7 @@ def plot_patients_embeddings(
patient_encodings = patient_encodings.join(
pd.DataFrame.from_dict(
{
patient.id: {attr: patient.attrs.get(attr) for attr in plot_clinical_attrs}
patient.id: {attr: patient.attrs.get(attr) for attr in plot_tabular_attrs}
for patient in patients.values()
},
orient="index",
Expand All @@ -84,9 +84,9 @@ def plot_patients_embeddings(
# Transfer the attributes data from the columns to the index, as required by the generic embedding function later
patient_encodings = patient_encodings.set_index(plot_attrs, append=True)

# Determine from the clinical attributes' predefined order or the natural ordering in the custom attributes the
# Determine from the tabular attributes' predefined order or the natural ordering in the custom attributes the
# hue order for the plots
plot_attrs_order = {attr: CLINICAL_CAT_ATTR_LABELS.get(attr) for attr in plot_clinical_attrs}
plot_attrs_order = {attr: TABULAR_CAT_ATTR_LABELS.get(attr) for attr in plot_tabular_attrs}
plot_attrs_order.update({attr: list(attr_lists) for attr, attr_lists in categorical_attrs_lists.items()})

# Plot data w.r.t. attributes
Expand Down Expand Up @@ -128,12 +128,12 @@ def main():
help="Tag of the segmentation mask for which to extract the image attributes",
)
parser.add_argument(
"--plot_clinical_attrs",
type=ClinicalAttribute,
"--plot_tabular_attrs",
type=TabularAttribute,
nargs="+",
choices=list(ClinicalAttribute),
default=list(ClinicalAttribute),
help="Patients' clinical attributes w.r.t. which to plot the embedding",
choices=list(TabularAttribute),
default=list(TabularAttribute),
help="Patients' tabular attributes w.r.t. which to plot the embedding",
)
parser.add_argument(
"--plot_categorical_attrs_dirs",
Expand All @@ -158,10 +158,10 @@ def main():
args = parser.parse_args()
kwargs = vars(args)

encoder_ckpt, mask_tag, plot_clinical_attrs, plot_categorical_attrs_dirs, embedding_kwargs, output_dir = (
encoder_ckpt, mask_tag, plot_tabular_attrs, plot_categorical_attrs_dirs, embedding_kwargs, output_dir = (
kwargs.pop("pretrained_encoder"),
kwargs.pop("mask_tag"),
kwargs.pop("plot_clinical_attrs"),
kwargs.pop("plot_tabular_attrs"),
kwargs.pop("plot_categorical_attrs_dirs"),
kwargs.pop("embedding_kwargs"),
kwargs.pop("output_dir"),
Expand Down Expand Up @@ -191,7 +191,7 @@ def main():
for attr, _ in plot_patients_embeddings(
encoder,
patients,
plot_clinical_attrs=plot_clinical_attrs,
plot_tabular_attrs=plot_tabular_attrs,
categorical_attrs_lists=categorical_attrs_lists,
mask_tag=mask_tag,
progress_bar=True,
Expand Down
86 changes: 43 additions & 43 deletions didactic/scripts/describe_patients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,75 @@

import numpy as np
import pandas as pd
from vital.data.cardinal.config import ClinicalAttribute
from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS
from vital.data.cardinal.config import TabularAttribute
from vital.data.cardinal.utils.attributes import TABULAR_CAT_ATTR_LABELS
from vital.data.cardinal.utils.itertools import Patients

NUM_CLINICAL_ATTR_STATS = defaultdict(lambda: "mean")
NUM_CLINICAL_ATTR_STATS.update(
_NUM_ATTR_STATS = defaultdict(lambda: "mean")
_NUM_ATTR_STATS.update(
{
ClinicalAttribute.ddd: "quartile",
ClinicalAttribute.creat: "quartile",
ClinicalAttribute.gfr: "quartile",
ClinicalAttribute.nt_probnp: "quartile",
ClinicalAttribute.mv_dt: "quartile",
TabularAttribute.ddd: "quartile",
TabularAttribute.creat: "quartile",
TabularAttribute.gfr: "quartile",
TabularAttribute.nt_probnp: "quartile",
TabularAttribute.mv_dt: "quartile",
}
)
"Numerical attributes' statistics are summarized using `mean ± std, unless other specified here."

NUM_CLINICAL_ATTR_DECIMALS = defaultdict(int)
NUM_CLINICAL_ATTR_DECIMALS.update(
_NUM_ATTR_DECIMALS = defaultdict(int)
_NUM_ATTR_DECIMALS.update(
{
ClinicalAttribute.bmi: 1,
ClinicalAttribute.ddd: 1,
ClinicalAttribute.gfr: 1,
ClinicalAttribute.e_velocity: 1,
ClinicalAttribute.a_velocity: 1,
ClinicalAttribute.e_e_prime_ratio: 1,
ClinicalAttribute.la_volume: 1,
ClinicalAttribute.la_area: 1,
ClinicalAttribute.vmax_tr: 1,
ClinicalAttribute.ivs_d: 1,
ClinicalAttribute.lvid_d: 1,
ClinicalAttribute.pw_d: 1,
ClinicalAttribute.tapse: 1,
ClinicalAttribute.s_prime: 1,
TabularAttribute.bmi: 1,
TabularAttribute.ddd: 1,
TabularAttribute.gfr: 1,
TabularAttribute.e_velocity: 1,
TabularAttribute.a_velocity: 1,
TabularAttribute.e_e_prime_ratio: 1,
TabularAttribute.la_volume: 1,
TabularAttribute.la_area: 1,
TabularAttribute.vmax_tr: 1,
TabularAttribute.ivs_d: 1,
TabularAttribute.lvid_d: 1,
TabularAttribute.pw_d: 1,
TabularAttribute.tapse: 1,
TabularAttribute.s_prime: 1,
}
)
"""Attributes' statistics are rounded to the nearest integer, unless specified otherwise here."""


def describe_patients(
patients: Patients, clinical_attrs: Sequence[ClinicalAttribute] = None, format_summary: bool = False
patients: Patients, tabular_attrs: Sequence[TabularAttribute] = None, format_summary: bool = False
) -> pd.DataFrame:
"""Computes statistics over patients' clinical attributes, adapting statistics to numerical/categorical attributes.
"""Computes statistics over patients' tabular attributes, adapting statistics to numerical/categorical attributes.
Args:
patients: Patients over which to compute the statistics.
clinical_attrs: Subset of clinical attributes over which to compute the statistics.
tabular_attrs: Subset of tabular attributes over which to compute the statistics.
format_summary: Whether to add a column where a subset of the stats (depending on the attribute) are selected
and formatted, as a summary of that attribute's statistics.
Returns:
Statistics describing patients' clinical attributes.
Statistics describing patients' tabular attributes.
"""
patients_attrs = patients.to_dataframe(clinical_attrs=clinical_attrs)
if clinical_attrs is None:
clinical_attrs = patients_attrs.columns.tolist()
patients_attrs = patients.to_dataframe(tabular_attrs=tabular_attrs)
if tabular_attrs is None:
tabular_attrs = patients_attrs.columns.tolist()

# Get the descriptions for the numerical attributes, with attributes as rows and descriptions as columns
num_stats = ["mean", "std", "50%", "25%", "75%"]
num_attrs_desc = patients_attrs.describe(include=np.number).loc[num_stats]

# Manually compute the occurrences of label for boolean/categorical attributes
cat_stats = ["count", "%"]
cat_attrs = [attr for attr in ClinicalAttribute.categorical_attrs() if attr in patients_attrs.columns]
cat_attrs = [attr for attr in TabularAttribute.categorical_attrs() if attr in patients_attrs.columns]
cat_attrs_desc = {}
for attr in cat_attrs:
attr_data = patients_attrs[attr]
attr_data = attr_data[attr_data.notna()] # Discard missing data

label_counts = {label: (attr_data == label).sum() for label in CLINICAL_CAT_ATTR_LABELS[attr]}
label_counts = {label: (attr_data == label).sum() for label in TABULAR_CAT_ATTR_LABELS[attr]}
label_percentages = {label: round(count * 100 / len(attr_data), 1) for label, count in label_counts.items()}
cat_attrs_desc[attr] = {"count": label_counts, "%": label_percentages}
# Structure the boolean/categorical description as dataframe, with attributes as rows and descriptions as columns
Expand All @@ -80,7 +80,7 @@ def describe_patients(
# Join descriptions of numerical and categorical attributes
patients_attrs_desc = num_attrs_desc.join(cat_attrs_desc, how="outer")
# Index w.r.t. attributes and sort the attributes and statistics
patients_attrs_desc = patients_attrs_desc.T.reindex(clinical_attrs)[num_stats + cat_stats]
patients_attrs_desc = patients_attrs_desc.T.reindex(tabular_attrs)[num_stats + cat_stats]

# Cast numerical stats to float, since the transpose leads all columns to be of generic 'object' type
patients_attrs_desc[num_stats] = patients_attrs_desc[num_stats].astype(float)
Expand All @@ -89,19 +89,19 @@ def describe_patients(
summaries = {}

for attr in patients_attrs_desc.index:
if attr in ClinicalAttribute.boolean_attrs():
if attr in TabularAttribute.boolean_attrs():
attr_summary = (
f"{patients_attrs_desc.loc[attr, 'count'][True]} "
f"({patients_attrs_desc.loc[attr, '%'][True]:.0f})"
)
elif attr in ClinicalAttribute.categorical_attrs():
elif attr in TabularAttribute.categorical_attrs():
attr_summary = "\n".join(
f"{cat_count} ({patients_attrs_desc.loc[attr, '%'][cat]:.0f})"
for cat, cat_count in patients_attrs_desc.loc[attr, "count"].items()
)
else: # attr in ClinicalAttribute.numerical_attrs():
dec = NUM_CLINICAL_ATTR_DECIMALS[attr]
match summary_stat := NUM_CLINICAL_ATTR_STATS[attr]:
else: # attr in TabularAttribute.numerical_attrs():
dec = _NUM_ATTR_DECIMALS[attr]
match summary_stat := _NUM_ATTR_STATS[attr]:
case "mean":
mean, std = patients_attrs_desc[["mean", "std"]].loc[attr]
attr_summary = f"{mean:.{dec}f} ± {std:.{dec}f}"
Expand Down Expand Up @@ -132,9 +132,9 @@ def main():
parser = Patients.add_args(parser)
parser.add_argument(
"--attributes",
type=ClinicalAttribute,
type=TabularAttribute,
nargs="+",
choices=list(ClinicalAttribute),
choices=list(TabularAttribute),
help="Attributes to describe",
)
parser.add_argument(
Expand Down Expand Up @@ -176,7 +176,7 @@ def main():
]

patients_attrs_desc_by_subset[subset] = describe_patients(
Patients(**kwargs, include_patients=subset_patients), clinical_attrs=clinical_attrs, format_summary=True
Patients(**kwargs, include_patients=subset_patients), tabular_attrs=clinical_attrs, format_summary=True
)

# Save the description for each subset
Expand Down
Loading

0 comments on commit ff69c4e

Please sign in to comment.