From ff69c4e33cb4c3872d61e9e4c3b262a922a770c1 Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Fri, 10 Nov 2023 16:21:45 +0100 Subject: [PATCH] Rename categories of attributes to be more accurate and generic Name changes: - clinical -> tabular - image - > time-series --- didactic/data/cardinal/utils.py | 18 +- .../scripts/analyse_clustering_hparams.py | 2 +- .../cardiac_multimodal_encoder_attn.py | 2 +- .../cardiac_multimodal_representation_plot.py | 42 +-- didactic/scripts/describe_patients.py | 86 ++--- .../describe_representation_clustering.py | 120 +++---- .../scripts/list_patients_wrt_attributes.py | 8 +- .../cardiac_multimodal_representation.py | 313 +++++++++--------- .../cardiac_representation_clustering.py | 10 +- didactic/tasks/cardiac_sequence_attrs_ae.py | 88 ++--- didactic/tasks/cardiac_sequence_attrs_pca.py | 21 +- didactic/tasks/utils.py | 34 +- 12 files changed, 381 insertions(+), 363 deletions(-) diff --git a/didactic/data/cardinal/utils.py b/didactic/data/cardinal/utils.py index 67c95b01..1db71b70 100644 --- a/didactic/data/cardinal/utils.py +++ b/didactic/data/cardinal/utils.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd -from vital.data.cardinal.config import CardinalTag, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum from vital.data.cardinal.utils.attributes import build_attributes_dataframe from vital.data.cardinal.utils.data_struct import Patient @@ -107,25 +107,25 @@ def _to_num(attr_data: pd.Series) -> pd.Series: return neigh_data -def build_img_attr_by_patient_group_dataframe( +def build_time_series_attr_by_patient_group_dataframe( patients_groups: Mapping[Hashable, Iterable[Patient]], - attr: Tuple[ViewEnum, ImageAttribute], + attr: Tuple[ViewEnum, TimeSeriesAttribute], group_desc: str = "group", mask_tag: str = CardinalTag.mask, resampling_rate: int = 128, ) -> pd.DataFrame: - """Builds a dataframe with the average curve of an image attribute by patient group. + """Builds a dataframe with the average curve of a time-series attribute by patient group. Args: patients_groups: Mapping between group ID and the patients in that group. - attr: A pair of view and image attribute to compute the average curve of. + attr: A pair of view and time-series attribute to compute the average curve of. group_desc: Description of the semantic meaning of the groups. - mask_tag: Tag of the segmentation mask for which to extract the image attribute data. - resampling_rate: Number of points at which to resample the image attribute curves from each patient, so that - they can be easily compared and aggregated together. + mask_tag: Tag of the segmentation mask for which to extract the time-series attribute data. + resampling_rate: Number of points at which to resample the time-series attribute curves from each patient, so + that they can be easily compared and aggregated together. Returns: - Dataframe with the average curve of an image attribute by patient group, in long format. + Dataframe with the average curve of an time-series attribute by patient group, in long format. """ resampling_fn = Interp1d(resampling_rate) diff --git a/didactic/scripts/analyse_clustering_hparams.py b/didactic/scripts/analyse_clustering_hparams.py index 979c3616..039c71f7 100644 --- a/didactic/scripts/analyse_clustering_hparams.py +++ b/didactic/scripts/analyse_clustering_hparams.py @@ -20,7 +20,7 @@ def main(): "hparams_files", nargs="+", type=Path, - help="Paths to YAML config file of the best clustering hyperparameters found for an encoder amodel", + help="Paths to YAML config file of the best clustering hyperparameters found for an encoder model", ) parser.add_argument( "--hparams", diff --git a/didactic/scripts/cardiac_multimodal_encoder_attn.py b/didactic/scripts/cardiac_multimodal_encoder_attn.py index e87d1148..f6bee188 100644 --- a/didactic/scripts/cardiac_multimodal_encoder_attn.py +++ b/didactic/scripts/cardiac_multimodal_encoder_attn.py @@ -53,7 +53,7 @@ def main(): "--mask_tag", type=str, default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the image attributes", + help="Tag of the segmentation mask for which to extract the time-series attributes", ) parser.add_argument( "--use_attention_rollout", diff --git a/didactic/scripts/cardiac_multimodal_representation_plot.py b/didactic/scripts/cardiac_multimodal_representation_plot.py index 8d32f7c7..7d57c7d6 100644 --- a/didactic/scripts/cardiac_multimodal_representation_plot.py +++ b/didactic/scripts/cardiac_multimodal_representation_plot.py @@ -3,8 +3,8 @@ import pandas as pd from matplotlib import pyplot as plt from matplotlib.axes import Axes -from vital.data.cardinal.config import CardinalTag, ClinicalAttribute -from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS +from vital.data.cardinal.config import CardinalTag, TabularAttribute +from vital.data.cardinal.utils.attributes import TABULAR_CAT_ATTR_LABELS from vital.data.cardinal.utils.data_struct import Patient from vital.data.cardinal.utils.itertools import Patients from vital.utils.plot import embedding_scatterplot @@ -16,19 +16,19 @@ def plot_patients_embeddings( model: CardiacMultimodalRepresentationTask, patients: Patients, - plot_clinical_attrs: Sequence[ClinicalAttribute] = None, + plot_tabular_attrs: Sequence[TabularAttribute] = None, categorical_attrs_lists: Dict[str, Dict[str, Sequence[Patient.Id]]] = None, mask_tag: str = CardinalTag.mask, progress_bar: bool = False, **embedding_kwargs, -) -> Iterator[Tuple[ClinicalAttribute | str, Axes]]: +) -> Iterator[Tuple[TabularAttribute | str, Axes]]: """Generates 2D scatter plots of patients' encodings, labeled w.r.t. specific attributes. Args: model: Transformer encoder model to use for inference. patients: (N) Patients to embed. mask_tag: Tag of the segmentation mask for which to extract the image attributes. - plot_clinical_attrs: Patients' clinical attributes w.r.t. which to plot the embedding. + plot_tabular_attrs: Patients' tabular attributes w.r.t. which to plot the embedding. categorical_attrs_lists: Nested mapping listing, for each additional categorical attribute, the patients belonging to each of the attribute's labels. progress_bar: If ``True``, enables progress bars detailing the progress of encoding patients. @@ -38,18 +38,18 @@ def plot_patients_embeddings( Returns: An iterator over the attributes and associated scatter plots. """ - if plot_clinical_attrs is None and categorical_attrs_lists is None: + if plot_tabular_attrs is None and categorical_attrs_lists is None: raise ValueError( - "You have specified neither built-in attributes (`plot_attrs` is None) nor custom attributes " + "You have specified neither built-in attributes (`plot_tabular_attrs` is None) nor custom attributes " "(`categorical_attrs_lists`) w.r.t. which to plot the embeddings. Specify at least one attribute of either " "type to plot embeddings of the patients." ) - if plot_clinical_attrs is None: - plot_clinical_attrs = [] + if plot_tabular_attrs is None: + plot_tabular_attrs = [] if categorical_attrs_lists is None: categorical_attrs_lists = {} - plot_attrs = plot_clinical_attrs + list(categorical_attrs_lists) + plot_attrs = plot_tabular_attrs + list(categorical_attrs_lists) # Encode the data using the model patient_encodings = pd.DataFrame( @@ -60,7 +60,7 @@ def plot_patients_embeddings( patient_encodings = patient_encodings.join( pd.DataFrame.from_dict( { - patient.id: {attr: patient.attrs.get(attr) for attr in plot_clinical_attrs} + patient.id: {attr: patient.attrs.get(attr) for attr in plot_tabular_attrs} for patient in patients.values() }, orient="index", @@ -84,9 +84,9 @@ def plot_patients_embeddings( # Transfer the attributes data from the columns to the index, as required by the generic embedding function later patient_encodings = patient_encodings.set_index(plot_attrs, append=True) - # Determine from the clinical attributes' predefined order or the natural ordering in the custom attributes the + # Determine from the tabular attributes' predefined order or the natural ordering in the custom attributes the # hue order for the plots - plot_attrs_order = {attr: CLINICAL_CAT_ATTR_LABELS.get(attr) for attr in plot_clinical_attrs} + plot_attrs_order = {attr: TABULAR_CAT_ATTR_LABELS.get(attr) for attr in plot_tabular_attrs} plot_attrs_order.update({attr: list(attr_lists) for attr, attr_lists in categorical_attrs_lists.items()}) # Plot data w.r.t. attributes @@ -128,12 +128,12 @@ def main(): help="Tag of the segmentation mask for which to extract the image attributes", ) parser.add_argument( - "--plot_clinical_attrs", - type=ClinicalAttribute, + "--plot_tabular_attrs", + type=TabularAttribute, nargs="+", - choices=list(ClinicalAttribute), - default=list(ClinicalAttribute), - help="Patients' clinical attributes w.r.t. which to plot the embedding", + choices=list(TabularAttribute), + default=list(TabularAttribute), + help="Patients' tabular attributes w.r.t. which to plot the embedding", ) parser.add_argument( "--plot_categorical_attrs_dirs", @@ -158,10 +158,10 @@ def main(): args = parser.parse_args() kwargs = vars(args) - encoder_ckpt, mask_tag, plot_clinical_attrs, plot_categorical_attrs_dirs, embedding_kwargs, output_dir = ( + encoder_ckpt, mask_tag, plot_tabular_attrs, plot_categorical_attrs_dirs, embedding_kwargs, output_dir = ( kwargs.pop("pretrained_encoder"), kwargs.pop("mask_tag"), - kwargs.pop("plot_clinical_attrs"), + kwargs.pop("plot_tabular_attrs"), kwargs.pop("plot_categorical_attrs_dirs"), kwargs.pop("embedding_kwargs"), kwargs.pop("output_dir"), @@ -191,7 +191,7 @@ def main(): for attr, _ in plot_patients_embeddings( encoder, patients, - plot_clinical_attrs=plot_clinical_attrs, + plot_tabular_attrs=plot_tabular_attrs, categorical_attrs_lists=categorical_attrs_lists, mask_tag=mask_tag, progress_bar=True, diff --git a/didactic/scripts/describe_patients.py b/didactic/scripts/describe_patients.py index a6060860..03d1e47e 100644 --- a/didactic/scripts/describe_patients.py +++ b/didactic/scripts/describe_patients.py @@ -3,61 +3,61 @@ import numpy as np import pandas as pd -from vital.data.cardinal.config import ClinicalAttribute -from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS +from vital.data.cardinal.config import TabularAttribute +from vital.data.cardinal.utils.attributes import TABULAR_CAT_ATTR_LABELS from vital.data.cardinal.utils.itertools import Patients -NUM_CLINICAL_ATTR_STATS = defaultdict(lambda: "mean") -NUM_CLINICAL_ATTR_STATS.update( +_NUM_ATTR_STATS = defaultdict(lambda: "mean") +_NUM_ATTR_STATS.update( { - ClinicalAttribute.ddd: "quartile", - ClinicalAttribute.creat: "quartile", - ClinicalAttribute.gfr: "quartile", - ClinicalAttribute.nt_probnp: "quartile", - ClinicalAttribute.mv_dt: "quartile", + TabularAttribute.ddd: "quartile", + TabularAttribute.creat: "quartile", + TabularAttribute.gfr: "quartile", + TabularAttribute.nt_probnp: "quartile", + TabularAttribute.mv_dt: "quartile", } ) "Numerical attributes' statistics are summarized using `mean ± std, unless other specified here." -NUM_CLINICAL_ATTR_DECIMALS = defaultdict(int) -NUM_CLINICAL_ATTR_DECIMALS.update( +_NUM_ATTR_DECIMALS = defaultdict(int) +_NUM_ATTR_DECIMALS.update( { - ClinicalAttribute.bmi: 1, - ClinicalAttribute.ddd: 1, - ClinicalAttribute.gfr: 1, - ClinicalAttribute.e_velocity: 1, - ClinicalAttribute.a_velocity: 1, - ClinicalAttribute.e_e_prime_ratio: 1, - ClinicalAttribute.la_volume: 1, - ClinicalAttribute.la_area: 1, - ClinicalAttribute.vmax_tr: 1, - ClinicalAttribute.ivs_d: 1, - ClinicalAttribute.lvid_d: 1, - ClinicalAttribute.pw_d: 1, - ClinicalAttribute.tapse: 1, - ClinicalAttribute.s_prime: 1, + TabularAttribute.bmi: 1, + TabularAttribute.ddd: 1, + TabularAttribute.gfr: 1, + TabularAttribute.e_velocity: 1, + TabularAttribute.a_velocity: 1, + TabularAttribute.e_e_prime_ratio: 1, + TabularAttribute.la_volume: 1, + TabularAttribute.la_area: 1, + TabularAttribute.vmax_tr: 1, + TabularAttribute.ivs_d: 1, + TabularAttribute.lvid_d: 1, + TabularAttribute.pw_d: 1, + TabularAttribute.tapse: 1, + TabularAttribute.s_prime: 1, } ) """Attributes' statistics are rounded to the nearest integer, unless specified otherwise here.""" def describe_patients( - patients: Patients, clinical_attrs: Sequence[ClinicalAttribute] = None, format_summary: bool = False + patients: Patients, tabular_attrs: Sequence[TabularAttribute] = None, format_summary: bool = False ) -> pd.DataFrame: - """Computes statistics over patients' clinical attributes, adapting statistics to numerical/categorical attributes. + """Computes statistics over patients' tabular attributes, adapting statistics to numerical/categorical attributes. Args: patients: Patients over which to compute the statistics. - clinical_attrs: Subset of clinical attributes over which to compute the statistics. + tabular_attrs: Subset of tabular attributes over which to compute the statistics. format_summary: Whether to add a column where a subset of the stats (depending on the attribute) are selected and formatted, as a summary of that attribute's statistics. Returns: - Statistics describing patients' clinical attributes. + Statistics describing patients' tabular attributes. """ - patients_attrs = patients.to_dataframe(clinical_attrs=clinical_attrs) - if clinical_attrs is None: - clinical_attrs = patients_attrs.columns.tolist() + patients_attrs = patients.to_dataframe(tabular_attrs=tabular_attrs) + if tabular_attrs is None: + tabular_attrs = patients_attrs.columns.tolist() # Get the descriptions for the numerical attributes, with attributes as rows and descriptions as columns num_stats = ["mean", "std", "50%", "25%", "75%"] @@ -65,13 +65,13 @@ def describe_patients( # Manually compute the occurrences of label for boolean/categorical attributes cat_stats = ["count", "%"] - cat_attrs = [attr for attr in ClinicalAttribute.categorical_attrs() if attr in patients_attrs.columns] + cat_attrs = [attr for attr in TabularAttribute.categorical_attrs() if attr in patients_attrs.columns] cat_attrs_desc = {} for attr in cat_attrs: attr_data = patients_attrs[attr] attr_data = attr_data[attr_data.notna()] # Discard missing data - label_counts = {label: (attr_data == label).sum() for label in CLINICAL_CAT_ATTR_LABELS[attr]} + label_counts = {label: (attr_data == label).sum() for label in TABULAR_CAT_ATTR_LABELS[attr]} label_percentages = {label: round(count * 100 / len(attr_data), 1) for label, count in label_counts.items()} cat_attrs_desc[attr] = {"count": label_counts, "%": label_percentages} # Structure the boolean/categorical description as dataframe, with attributes as rows and descriptions as columns @@ -80,7 +80,7 @@ def describe_patients( # Join descriptions of numerical and categorical attributes patients_attrs_desc = num_attrs_desc.join(cat_attrs_desc, how="outer") # Index w.r.t. attributes and sort the attributes and statistics - patients_attrs_desc = patients_attrs_desc.T.reindex(clinical_attrs)[num_stats + cat_stats] + patients_attrs_desc = patients_attrs_desc.T.reindex(tabular_attrs)[num_stats + cat_stats] # Cast numerical stats to float, since the transpose leads all columns to be of generic 'object' type patients_attrs_desc[num_stats] = patients_attrs_desc[num_stats].astype(float) @@ -89,19 +89,19 @@ def describe_patients( summaries = {} for attr in patients_attrs_desc.index: - if attr in ClinicalAttribute.boolean_attrs(): + if attr in TabularAttribute.boolean_attrs(): attr_summary = ( f"{patients_attrs_desc.loc[attr, 'count'][True]} " f"({patients_attrs_desc.loc[attr, '%'][True]:.0f})" ) - elif attr in ClinicalAttribute.categorical_attrs(): + elif attr in TabularAttribute.categorical_attrs(): attr_summary = "\n".join( f"{cat_count} ({patients_attrs_desc.loc[attr, '%'][cat]:.0f})" for cat, cat_count in patients_attrs_desc.loc[attr, "count"].items() ) - else: # attr in ClinicalAttribute.numerical_attrs(): - dec = NUM_CLINICAL_ATTR_DECIMALS[attr] - match summary_stat := NUM_CLINICAL_ATTR_STATS[attr]: + else: # attr in TabularAttribute.numerical_attrs(): + dec = _NUM_ATTR_DECIMALS[attr] + match summary_stat := _NUM_ATTR_STATS[attr]: case "mean": mean, std = patients_attrs_desc[["mean", "std"]].loc[attr] attr_summary = f"{mean:.{dec}f} ± {std:.{dec}f}" @@ -132,9 +132,9 @@ def main(): parser = Patients.add_args(parser) parser.add_argument( "--attributes", - type=ClinicalAttribute, + type=TabularAttribute, nargs="+", - choices=list(ClinicalAttribute), + choices=list(TabularAttribute), help="Attributes to describe", ) parser.add_argument( @@ -176,7 +176,7 @@ def main(): ] patients_attrs_desc_by_subset[subset] = describe_patients( - Patients(**kwargs, include_patients=subset_patients), clinical_attrs=clinical_attrs, format_summary=True + Patients(**kwargs, include_patients=subset_patients), tabular_attrs=clinical_attrs, format_summary=True ) # Save the description for each subset diff --git a/didactic/scripts/describe_representation_clustering.py b/didactic/scripts/describe_representation_clustering.py index decbfc9e..d31134e2 100644 --- a/didactic/scripts/describe_representation_clustering.py +++ b/didactic/scripts/describe_representation_clustering.py @@ -8,32 +8,32 @@ import seaborn as sns import seaborn.objects as so from matplotlib.axes import Axes -from vital.data.cardinal.config import CardinalTag, ClinicalAttribute, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TabularAttribute, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.utils.attributes import CLINICAL_ATTR_UNITS, CLINICAL_CAT_ATTR_LABELS, IMAGE_ATTR_LABELS +from vital.data.cardinal.utils.attributes import CLINICAL_ATTR_UNITS, TABULAR_CAT_ATTR_LABELS, TIME_SERIES_ATTR_LABELS from vital.data.cardinal.utils.data_struct import Patient from vital.data.cardinal.utils.itertools import Patients -from didactic.data.cardinal.utils import build_clusterings_dataframe, build_img_attr_by_patient_group_dataframe +from didactic.data.cardinal.utils import build_clusterings_dataframe, build_time_series_attr_by_patient_group_dataframe logger = logging.getLogger(__name__) -def plot_clinical_attrs_variability_figures( +def plot_tabular_attrs_variability_figures( patients: Patients, clusterings: Mapping[str, Mapping[Patient.Id, str]], - clinical_attrs: Sequence[ClinicalAttribute] = None, + tabular_attrs: Sequence[TabularAttribute] = None, num_clustering_agg: str = None, cat_plot_kwargs: dict = None, num_plot_kwargs: dict = None, ) -> Iterator[Tuple[str, Axes]]: - """Plots the variability of cluster-aggregated clinical attrs across multiple clusterings w.r.t. clusters. + """Plots the variability of cluster-aggregated tabular attributes across multiple clusterings w.r.t. clusters. Args: patients: Collection of patients data from which to extract the attributes. clusterings: Instances of clustering of the patients population, representation as mappings between patient IDs and cluster labels. - clinical_attrs: Subset of clinical attributes on which to compile the results. If not provided, will default to + tabular_attrs: Subset of tabular attributes on which to compile the results. If not provided, will default to all available attributes. num_clustering_agg: Aggregation function to use to aggregate the numerical attributes by clusters, before the aggregation across clusterings. If not provided, the attributes are aggregated by clusters and across @@ -42,8 +42,8 @@ def plot_clinical_attrs_variability_figures( num_plot_kwargs: Parameters to forward to the call to `seaborn.boxplot` for numerical attributes. Returns: - Iterator over figures (and their corresponding titles) plotting the variability of cluster-aggregated clinical - attrs across multiple clusterings w.r.t. clusters. + Iterator over figures (and their corresponding titles) plotting the variability of cluster-aggregated tabular + attributes across multiple clusterings w.r.t. clusters. """ if cat_plot_kwargs is None: cat_plot_kwargs = {} @@ -52,8 +52,8 @@ def plot_clinical_attrs_variability_figures( # Gather the data of the patients in each cluster for each clustering clusterings_data = build_clusterings_dataframe(patients, clusterings) - if clinical_attrs is not None: - clusterings_data = clusterings_data[clinical_attrs] + if tabular_attrs is not None: + clusterings_data = clusterings_data[tabular_attrs] # Ignore `matplotlib.category` logger 'INFO' level logs to avoid repeated logs about categorical units parsable # as floats @@ -65,7 +65,7 @@ def plot_clinical_attrs_variability_figures( attr_data = clusterings_data[attr] # Based on whether the attribute is categorical or numerical, define different types of plots - if attr in ClinicalAttribute.categorical_attrs(): + if attr in TabularAttribute.categorical_attrs(): # Compute the occurrence of each category for each cluster (including NA), across all clusterings attr_stats = attr_data.groupby(["model", "cluster"]).value_counts(normalize=True, dropna=False) * 100 # After the NA values have been taken into account for the count, drop them @@ -78,13 +78,13 @@ def plot_clinical_attrs_variability_figures( # For boolean attributes, convert the values to string so that seaborn can properly pick up label names # Avoids the following error: 'bool' object has no attribute 'startswith' # At the same time, assign relevant labels/hues/etc. for either boolean or categorical attributes - if attr in ClinicalAttribute.boolean_attrs(): + if attr in TabularAttribute.boolean_attrs(): attr_stats = attr_stats.astype({attr: str}) ylabel = "(% true)" - hue_order = [str(val) for val in CLINICAL_CAT_ATTR_LABELS[attr]] + hue_order = [str(val) for val in TABULAR_CAT_ATTR_LABELS[attr]] else: ylabel = "(% by label)" - hue_order = CLINICAL_CAT_ATTR_LABELS[attr] + hue_order = TABULAR_CAT_ATTR_LABELS[attr] # Use dodged barplots for categorical attributes with sns.axes_style("darkgrid"): @@ -101,7 +101,7 @@ def plot_clinical_attrs_variability_figures( plot.set(title=title, ylabel=ylabel) - else: # attr in ClinicalAttribute.numerical_attrs() + else: # attr in TabularAttribute.numerical_attrs() if num_clustering_agg is not None: # Aggregate the numerical attributes by clusters, before the aggregation across clusterings attr_data = attr_data.groupby(["model", "cluster"]).agg(num_clustering_agg) @@ -116,25 +116,25 @@ def plot_clinical_attrs_variability_figures( yield title, plot -def plot_img_attrs_variability_figures( +def plot_time_series_attrs_variability_figures( patients: Patients, clusterings: Mapping[str, Mapping[Patient.Id, str]], - image_attrs: Sequence[Tuple[ViewEnum, ImageAttribute]], + time_series_attrs: Sequence[Tuple[ViewEnum, TimeSeriesAttribute]], mask_tag: str = CardinalTag.mask, ) -> Iterator[Tuple[str, Axes]]: - """Plots the variability of cluster-aggregated image attrs across multiple clusterings w.r.t. clusters. + """Plots the variability of cluster-aggregated time-series attributes across multiple clusterings w.r.t. clusters. Args: patients: Collection of patients data from which to extract the attributes. clusterings: Instances of clustering of the patients population, representation as mappings between patient IDs and cluster labels. - image_attrs: Subset of image-based attributes derived from segmentations (identified by view/attribute pairs) - for which to plot the variability between bins of the reference clinical attribute. - mask_tag: Tag of the segmentation mask for which to extract the image attributes. + time_series_attrs: Subset of time-series attributes derived from segmentations (identified by view/attribute + pairs) for which to plot the variability between bins of the reference tabular attribute. + mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. Returns: - Iterator over figures (and their corresponding titles) plotting the variability of cluster-aggregated image - attrs across multiple clusterings w.r.t. clusters. + Iterator over figures (and their corresponding titles) plotting the variability of cluster-aggregated + time-series attrs across multiple clusterings w.r.t. clusters. """ # Convert clusterings from mapping between item IDs and cluster IDs to lists of patient IDs by cluster clusterings = { @@ -162,18 +162,18 @@ def plot_img_attrs_variability_figures( for cluster_label in cluster_labels } - # For each image attribute, build the dataframe of the mean curve for each bin and plot the curves for each bin - for img_attr in image_attrs: - neigh_agg_img_attrs_data = build_img_attr_by_patient_group_dataframe( - patients_by_cluster, img_attr, group_desc="cluster", mask_tag=mask_tag + # For each time-series attr, build the dataframe of the mean curve for each bin and plot the curves for each bin + for time_series_attr in time_series_attrs: + time_series_attr_data = build_time_series_attr_by_patient_group_dataframe( + patients_by_cluster, time_series_attr, group_desc="cluster", mask_tag=mask_tag ) with sns.axes_style("darkgrid"): plot = sns.lineplot( - data=neigh_agg_img_attrs_data, x="time", y="val", hue="cluster", hue_order=sorted(cluster_labels) + data=time_series_attr_data, x="time", y="val", hue="cluster", hue_order=sorted(cluster_labels) ) - title = f"{'/'.join(img_attr)}_wrt_clusters" - plot.set(title=title, ylabel=IMAGE_ATTR_LABELS[img_attr[1]]) + title = f"{'/'.join(time_series_attr)}_wrt_clusters" + plot.set(title=title, ylabel=TIME_SERIES_ATTR_LABELS[time_series_attr[1]]) yield title, plot @@ -210,27 +210,27 @@ def main(): ) parser = Patients.add_args(parser) parser.add_argument( - "--clinical_attrs", - type=ClinicalAttribute, + "--tabular_attrs", + type=TabularAttribute, nargs="*", - choices=list(ClinicalAttribute), - help="Subset of clinical attributes on which to compile the results. If not provided, will default to all " + choices=list(TabularAttribute), + help="Subset of tabular attributes on which to compile the results. If not provided, will default to all " "available attributes", ) parser.add_argument( - "--image_attrs", - type=ImageAttribute, - choices=list(ImageAttribute), + "--time_series_attrs", + type=TimeSeriesAttribute, + choices=list(TimeSeriesAttribute), nargs="*", - default=list(ImageAttribute), - help="Subset of image-based attributes derived from segmentations for which to plot the intra/inter-cluster " + default=list(TimeSeriesAttribute), + help="Subset of time-series attributes derived from segmentations for which to plot the intra/inter-cluster " "variability", ) parser.add_argument( "--mask_tag", type=str, default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the image attributes", + help="Tag of the segmentation mask for which to extract the time-series attributes", ) parser.add_argument( "--num_clustering_agg", @@ -240,16 +240,16 @@ def main(): "same time, leading to a higher reported variability.", ) parser.add_argument( - "--clinical_cat_plot_kwargs", + "--tabular_cat_plot_kwargs", type=yaml_flow_collection, metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the call to `seaborn.heatmap` for categorical clinical attributes figures", + help="Parameters to forward to the call to `seaborn.heatmap` for categorical tabular attributes figures", ) parser.add_argument( - "--clinical_num_plot_kwargs", + "--tabular_num_plot_kwargs", type=yaml_flow_collection, metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the call to `seaborn.boxplot` for categorical clinical attributes figures", + help="Parameters to forward to the call to `seaborn.boxplot` for categorical tabular attributes figures", ) parser.add_argument( "--output_dir", @@ -263,8 +263,8 @@ def main(): ( clustering_paths, clusterings_fmt, - clinical_attrs, - image_attrs, + tabular_attrs, + time_series_attrs, mask_tag, num_clustering_agg, cat_plot_kwargs, @@ -276,17 +276,19 @@ def main(): [ "clusterings", "clusterings_format", - "clinical_attrs", - "image_attrs", + "tabular_attrs", + "time_series_attrs", "mask_tag", "num_clustering_agg", - "clinical_cat_plot_kwargs", - "clinical_num_plot_kwargs", + "tabular_cat_plot_kwargs", + "tabular_num_plot_kwargs", "output_dir", ], ) ) - image_attrs_keys = [(view, image_attr) for view, image_attr in itertools.product(args.views, image_attrs)] + time_series_attrs_keys = [ + (view, time_series_attr) for view, time_series_attr in itertools.product(args.views, time_series_attrs) + ] # Load the dataset patients = Patients(**kwargs) @@ -312,21 +314,23 @@ def main(): case _: raise ValueError(f"Unknown `clusterings_format`: {clusterings_fmt}") - clinical_attrs_plots = plot_clinical_attrs_variability_figures( + tabular_attrs_plots = plot_tabular_attrs_variability_figures( patients, clusterings, - clinical_attrs=clinical_attrs, + tabular_attrs=tabular_attrs, num_clustering_agg=num_clustering_agg, cat_plot_kwargs=cat_plot_kwargs, num_plot_kwargs=num_plot_kwargs, ) - image_attrs_plots = plot_img_attrs_variability_figures(patients, clusterings, image_attrs_keys, mask_tag=mask_tag) + time_series_attrs_plots = plot_time_series_attrs_variability_figures( + patients, clusterings, time_series_attrs_keys, mask_tag=mask_tag + ) - # Plot the variability of the clinical and image attributes + # Plot the variability of the tabular and time-series attributes output_dir.mkdir(parents=True, exist_ok=True) # Prepare the output folder for the method - n_plots = (len(clinical_attrs) if clinical_attrs else len(ClinicalAttribute)) + len(image_attrs_keys) + n_plots = (len(tabular_attrs) if tabular_attrs else len(TabularAttribute)) + len(time_series_attrs_keys) for title, plot in tqdm( - itertools.chain(clinical_attrs_plots, image_attrs_plots), + itertools.chain(tabular_attrs_plots, time_series_attrs_plots), desc="Plotting the variability of the attributes w.r.t. clusters", unit="attr", total=n_plots, diff --git a/didactic/scripts/list_patients_wrt_attributes.py b/didactic/scripts/list_patients_wrt_attributes.py index 26e9df68..0143d80f 100644 --- a/didactic/scripts/list_patients_wrt_attributes.py +++ b/didactic/scripts/list_patients_wrt_attributes.py @@ -5,17 +5,17 @@ def main(): import pandas as pd from tqdm.auto import tqdm - from vital.data.cardinal.config import ClinicalAttribute + from vital.data.cardinal.config import TabularAttribute from vital.data.cardinal.utils.itertools import Patients parser = ArgumentParser() parser = Patients.add_args(parser) parser.add_argument( "--attributes", - type=ClinicalAttribute, + type=TabularAttribute, nargs="+", - choices=ClinicalAttribute.categorical_attrs(), - default=ClinicalAttribute.categorical_attrs(), + choices=TabularAttribute.categorical_attrs(), + default=TabularAttribute.categorical_attrs(), help="Attributes w.r.t. which to list the patients belonging to each class", ) parser.add_argument( diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index fa929167..7fdb53e2 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -14,10 +14,10 @@ from torch.nn import Parameter, ParameterDict, init from torchmetrics.functional import accuracy, mean_absolute_error from vital.data.augmentation.base import mask_tokens, random_masking -from vital.data.cardinal.config import CardinalTag, ClinicalAttribute, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TabularAttribute, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes -from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS +from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_time_series_attributes +from vital.data.cardinal.utils.attributes import TABULAR_CAT_ATTR_LABELS from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data @@ -25,7 +25,7 @@ from didactic.models.time_series import TimeSeriesEmbedding logger = logging.getLogger(__name__) -CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute] +CardiacAttribute = TabularAttribute | Tuple[ViewEnum, TimeSeriesAttribute] class CardiacMultimodalRepresentationTask(SharedStepsTask): @@ -34,15 +34,15 @@ class CardiacMultimodalRepresentationTask(SharedStepsTask): def __init__( self, embed_dim: int, - clinical_attrs: Sequence[ClinicalAttribute | str], - img_attrs: Sequence[ImageAttribute], + tabular_attrs: Sequence[TabularAttribute | str], + time_series_attrs: Sequence[TimeSeriesAttribute], views: Sequence[ViewEnum] = tuple(ViewEnum), - predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None, + predict_losses: Dict[TabularAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None, ordinal_mode: bool = True, contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None, contrastive_loss_weight: float = 0, - clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None, - img_tokenizer: Optional[TimeSeriesEmbedding | DictConfig] = None, + tabular_tokenizer: Optional[FeatureTokenizer | DictConfig] = None, + time_series_tokenizer: Optional[TimeSeriesEmbedding | DictConfig] = None, cls_token: bool = True, sequential_pooling: bool = False, mtr_p: float | Tuple[float, float] = 0, @@ -54,6 +54,9 @@ def __init__( Args: embed_dim: Size of the tokens/embedding for all the modalities. + tabular_attrs: Tabular attributes to provide to the model. + time_series_attrs: Time-series attributes to provide to the model. + views: Views from which to include time-series attributes. predict_losses: Supervised criteria to measure the error between the predicted attributes and their real value. ordinal_mode: Whether to consider applicable targets as ordinal variables, which means: @@ -63,11 +66,8 @@ def __init__( of feature vectors, in a contrastive learning step that follows the SCARF pretraining. (see ref: https://arxiv.org/abs/2106.15147) contrastive_loss_weight: Factor by which to weight the `contrastive_loss` in the overall loss. - clinical_attrs: Clinical attributes to provide to the model. - img_attrs: Image attributes to provide to the model. - views: Views from which to include image attributes. - clinical_tokenizer: Tokenizer that can process clinical, i.e. patient records, data. - img_tokenizer: Tokenizer that can process imaging data. + tabular_tokenizer: Tokenizer that can process tabular, i.e. patient records, data. + time_series_tokenizer: Tokenizer that can process time-series data. cls_token: Whether to add a CLS token to use as the encoder's output token. sequential_pooling: Whether to perform sequential pooling on the encoder's output tokens. Otherwise, the full sequence of tokens is concatenated before being fed to the prediction head. @@ -82,9 +82,9 @@ def __init__( """ # Ensure string tags are converted to their appropriate enum types # And to it before call to the parent's `init` so that the converted values are saved in `hparams` - clinical_attrs = tuple(ClinicalAttribute[e] for e in clinical_attrs) + tabular_attrs = tuple(TabularAttribute[e] for e in tabular_attrs) views = tuple(ViewEnum[e] for e in views) - img_attrs = tuple(ImageAttribute[e] for e in img_attrs) + time_series_attrs = tuple(TimeSeriesAttribute[e] for e in time_series_attrs) # If dropout/masking are not single numbers, make sure they are tuples (and not another container type) if not isinstance(mtr_p, (int, float)): @@ -104,66 +104,71 @@ def __init__( "dimensionality of the encoder's output from a sequence of tokens to only one token." ) - if not clinical_tokenizer and clinical_attrs: + if not tabular_tokenizer and tabular_attrs: raise ValueError( - f"You have requested the following attributes derived from clinical data: " - f"{[str(attr) for attr in clinical_attrs]}, but have not configured a tokenizer for clinical-based " - f"attributes. Either provide this tokenizer (through the `clinical_tokenizer` parameter) or remove any " - f"clinical-based attributes (by setting the `clinical_attrs` to be an empty list)." + f"You have requested the following tabular attributes: " + f"{[str(attr) for attr in tabular_attrs]}, but have not configured a tokenizer for tabular attributes. " + f"Either provide this tokenizer (through the `tabular_tokenizer` parameter) or remove any tabular " + f"attributes (by setting the `tabular_attrs` to be an empty list)." ) - if img_attrs: - if not img_tokenizer: + if time_series_attrs: + if not time_series_tokenizer: raise ValueError( - f"You have requested the following attributes derived from imaging data: " - f"{[str(attr) for attr in img_attrs]}, but have not configured a tokenizer for image-based " - f"attributes. Either provide this tokenizer (through the `img_tokenizer` parameter) or remove any " - f"image-based attributes (by setting the `img_attrs` to be an empty list)." + f"You have requested the following time-series attributes: " + f"{[str(attr) for attr in time_series_attrs]}, but have not configured a tokenizer for time-series " + f"attributes. Either provide this tokenizer (through the `time_series_tokenizer` parameter) or " + f"remove any time-series attributes (by setting the `time_series_attrs` to be an empty list)." ) if ( - img_tokenizer.model if isinstance(img_tokenizer, TimeSeriesEmbedding) else img_tokenizer.get("model") + time_series_tokenizer.model + if isinstance(time_series_tokenizer, TimeSeriesEmbedding) + else time_series_tokenizer.get("model") ) is None: logger.warning( - f"You have requested the following attributes derived from imaging data: " - f"{[str(attr) for attr in img_attrs]}, but have not configured a model for the tokenizer for " - f"image-based attributes. The tokenizer's model is optional, but highly recommended, so this is " - f"likely an oversight. You can provide this model through the `img_tokenizer.model` parameter." + f"You have requested the following time-series attributes: " + f"{[str(attr) for attr in time_series_attrs]}, but have not configured a model for the tokenizer " + f"for time-series attributes. The tokenizer's model is optional, but highly recommended, so this " + f"is likely an oversight. You can provide this model through the `time_series_tokenizer.model` " + f"parameter." ) - if not (clinical_attrs or img_attrs): + if not (tabular_attrs or time_series_attrs): raise ValueError( - "You configured neither clinical attributes nor image attributes as input variables to the model, but " - "the model requires at least one input. Set non-empty values for either or both `clinical_attrs` and " - "`img_attrs`." + "You configured neither tabular attributes nor time-series attributes as input variables to the model, " + "but the model requires at least one input. Set non-empty values for either or both `tabular_attrs` " + "and `time_series_attrs`." ) super().__init__(*args, **kwargs) - # TOFIX: Hack to log image tokenizer model's hparams when it's a config for a `torch.nn.Sequential` object + # TOFIX: Hack to log time-series tokenizer model's hparams when it's a config for a `torch.nn.Sequential` object # In that case, we have to use a `ListConfig` for the reserved `_args_` key. However, the automatic # serialization of `_args_` fails (w/ a '`DictConfig' not JSON serializable' error). Therefore, we fix it by # manually unpacking and logging the first and only element in the `_args_` `ListConfig` - if isinstance(img_tokenizer, DictConfig): - if img_tokenizer.get("model", {}).get("_target_") == "torch.nn.Sequential": - self.save_hyperparameters({"img_tokenizer/model/_args_/0": img_tokenizer.model._args_[0]}) + if isinstance(time_series_tokenizer, DictConfig): + if time_series_tokenizer.get("model", {}).get("_target_") == "torch.nn.Sequential": + self.save_hyperparameters( + {"time_series_tokenizer/model/_args_/0": time_series_tokenizer.model._args_[0]} + ) # Add shortcut to lr to work with Lightning's learning rate finder self.hparams.lr = None # Add shortcut to token labels to avoid downstream applications having to determine them from hyperparameters - self.token_tags = clinical_attrs + tuple( - "/".join([view, attr]) for view, attr in itertools.product(views, img_attrs) + self.token_tags = tabular_attrs + tuple( + "/".join([view, attr]) for view, attr in itertools.product(views, time_series_attrs) ) if cls_token: self.token_tags = self.token_tags + ("LAT",) - # Categorise the clinical attributes (tabular data) in terms of their type (numerical vs categorical) - self.clinical_num_attrs = [ - attr for attr in self.hparams.clinical_attrs if attr in ClinicalAttribute.numerical_attrs() + # Categorise the tabular attributes in terms of their type (numerical vs categorical) + self.tabular_num_attrs = [ + attr for attr in self.hparams.tabular_attrs if attr in TabularAttribute.numerical_attrs() ] - self.clinical_cat_attrs = [ - attr for attr in self.hparams.clinical_attrs if attr in ClinicalAttribute.categorical_attrs() + self.tabular_cat_attrs = [ + attr for attr in self.hparams.tabular_attrs if attr in TabularAttribute.categorical_attrs() ] - self.clinical_cat_attrs_cardinalities = [ - len(CLINICAL_CAT_ATTR_LABELS[cat_attr]) for cat_attr in self.clinical_cat_attrs + self.tabular_cat_attrs_cardinalities = [ + len(TABULAR_CAT_ATTR_LABELS[cat_attr]) for cat_attr in self.tabular_cat_attrs ] # Extract train/test masking probabilities from their configs @@ -180,23 +185,23 @@ def __init__( self.predict_losses = {} if predict_losses: self.predict_losses = { - ClinicalAttribute[attr]: hydra.utils.instantiate(attr_loss) + TabularAttribute[attr]: hydra.utils.instantiate(attr_loss) if isinstance(attr_loss, DictConfig) else attr_loss for attr, attr_loss in predict_losses.items() } - self.hparams.target_clinical_attrs = tuple( + self.hparams.target_tabular_attrs = tuple( self.predict_losses ) # Hyperparameter to easily access target attributes for attr in self.predict_losses: - if attr in ClinicalAttribute.numerical_attrs(): + if attr in TabularAttribute.numerical_attrs(): self.metrics[attr] = {"mae": mean_absolute_error} - elif attr in ClinicalAttribute.binary_attrs(): + elif attr in TabularAttribute.binary_attrs(): self.metrics[attr] = {"acc": functools.partial(accuracy, task="binary")} - else: # attr in ClinicalAttribute.categorical_attrs() + else: # attr in TabularAttribute.categorical_attrs() self.metrics[attr] = { "acc": functools.partial( - accuracy, task="multiclass", num_classes=len(CLINICAL_CAT_ATTR_LABELS[attr]) + accuracy, task="multiclass", num_classes=len(TABULAR_CAT_ATTR_LABELS[attr]) ) } @@ -211,8 +216,8 @@ def __init__( # Compute shapes relevant for defining the models' architectures self.sequence_length = ( - len(self.hparams.clinical_attrs) - + (len(self.hparams.img_attrs) * len(self.hparams.views)) + len(self.hparams.tabular_attrs) + + (len(self.hparams.time_series_attrs) * len(self.hparams.views)) + self.hparams.cls_token ) @@ -232,25 +237,25 @@ def __init__( f"of the number of attention heads for your configuration above this warning." ) - if clinical_attrs: - if isinstance(clinical_tokenizer, DictConfig): - clinical_tokenizer = hydra.utils.instantiate( - clinical_tokenizer, - n_num_features=len(self.clinical_num_attrs), - cat_cardinalities=self.clinical_cat_attrs_cardinalities, + if tabular_attrs: + if isinstance(tabular_tokenizer, DictConfig): + tabular_tokenizer = hydra.utils.instantiate( + tabular_tokenizer, + n_num_features=len(self.tabular_num_attrs), + cat_cardinalities=self.tabular_cat_attrs_cardinalities, ) else: # Set tokenizer to `None` if it's not going to be used - clinical_tokenizer = None - self.clinical_tokenizer = clinical_tokenizer + tabular_tokenizer = None + self.tabular_tokenizer = tabular_tokenizer - if img_attrs: - if isinstance(img_tokenizer, DictConfig): - img_tokenizer = hydra.utils.instantiate(img_tokenizer) + if time_series_attrs: + if isinstance(time_series_tokenizer, DictConfig): + time_series_tokenizer = hydra.utils.instantiate(time_series_tokenizer) else: # Set tokenizer to `None` if it's not going to be used - img_tokenizer = None - self.img_tokenizer = img_tokenizer + time_series_tokenizer = None + self.time_series_tokenizer = time_series_tokenizer # Initialize modules/parameters dependent on the encoder's configuration @@ -287,18 +292,18 @@ def _init_mask_token() -> Parameter: @property def example_input_array( self, - ) -> Tuple[Dict[ClinicalAttribute, Tensor], Dict[Tuple[ViewEnum, ImageAttribute], Tensor]]: + ) -> Tuple[Dict[TabularAttribute, Tensor], Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]]: """Redefine example input array based on the cardiac attributes provided to the model.""" # 2 is the size of the batch in the example - clinical_attrs = {attr: torch.randn(2) for attr in self.clinical_num_attrs} + tab_attrs = {attr: torch.randn(2) for attr in self.tabular_num_attrs} # Only generate 0/1 labels, to avoid generating labels bigger than the number of classes, which would lead to # an index out of range error when looking up the embedding of the class in the categorical feature tokenizer - clinical_attrs.update({attr: torch.randint(2, (2,)) for attr in self.clinical_cat_attrs}) - img_attrs = { - (view, attr): torch.randn(2, self.hparams.data_params.in_shape[CardinalTag.image_attrs][1]) - for view, attr in itertools.product(self.hparams.views, self.hparams.img_attrs) + tab_attrs.update({attr: torch.randint(2, (2,)) for attr in self.tabular_cat_attrs}) + time_series_attrs = { + (view, attr): torch.randn(2, self.hparams.data_params.in_shape[CardinalTag.time_series_attrs][1]) + for view, attr in itertools.product(self.hparams.views, self.hparams.time_series_attrs) } - return clinical_attrs, img_attrs + return tab_attrs, time_series_attrs def configure_model( self, @@ -312,50 +317,52 @@ def configure_model( if self.contrastive_loss: contrastive_head = hydra.utils.instantiate(self.hparams.model.contrastive_head) - # Build the prediction heads (one by clinical attribute to predict) following the architecture proposed in + # Build the prediction heads (one by tabular attribute to predict) following the architecture proposed in # https://arxiv.org/pdf/2106.11959 prediction_heads = None if self.predict_losses: prediction_heads = nn.ModuleDict() - for target_clinical_attr in self.predict_losses: + for target_tab_attr in self.predict_losses: if ( - target_clinical_attr in ClinicalAttribute.categorical_attrs() - and target_clinical_attr not in ClinicalAttribute.binary_attrs() + target_tab_attr in TabularAttribute.categorical_attrs() + and target_tab_attr not in TabularAttribute.binary_attrs() ): # Multi-class classification target - output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr]) + output_size = len(TABULAR_CAT_ATTR_LABELS[target_tab_attr]) else: # Binary classification or regression target output_size = 1 - if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs(): + if self.hparams.ordinal_mode and target_tab_attr in TabularAttribute.ordinal_attrs(): # For ordinal targets, use a separate prediction head config - prediction_heads[target_clinical_attr] = hydra.utils.instantiate( + prediction_heads[target_tab_attr] = hydra.utils.instantiate( self.hparams.model.ordinal_head, num_logits=output_size ) else: - prediction_heads[target_clinical_attr] = hydra.utils.instantiate( + prediction_heads[target_tab_attr] = hydra.utils.instantiate( self.hparams.model.prediction_head, out_features=output_size ) return encoder, contrastive_head, prediction_heads def configure_optimizers(self) -> Dict[Literal["optimizer", "lr_scheduler"], Any]: - """Configure optimizer to ignore parameters that should remain frozen (e.g. image tokenizer).""" + """Configure optimizer to ignore parameters that should remain frozen (e.g. tokenizers).""" return super().configure_optimizers(params=filter(lambda p: p.requires_grad, self.parameters())) @auto_move_data def tokenize( - self, clinical_attrs: Dict[ClinicalAttribute, Tensor], img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor] + self, + tabular_attrs: Dict[TabularAttribute, Tensor], + time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor], ) -> Tuple[Tensor, Tensor]: - """Tokenizes the input clinical and image attributes, providing a mask of non-missing attributes. + """Tokenizes the input tabular and time-series attributes, providing a mask of non-missing attributes. Args: - clinical_attrs: (K: S, V: N), Sequence of batches of clinical attributes. To indicate an item is missing an + tabular_attrs: (K: S, V: N), Sequence of batches of tabular attributes. To indicate an item is missing an attribute, the flags `MISSING_NUM_ATTR`/`MISSING_CAT_ATTR` can be used for numerical and categorical attributes, respectively. - img_attrs: (K: S, V: (N, ?)), Sequence of batches of image attributes, where the dimensionality of each - attribute can vary. + time_series_attrs: (K: S, V: (N, ?)), Sequence of batches of time-series attributes, where the + dimensionality of each attribute can vary. Returns: Batch of i) (N, S, E) tokens for each attribute, and ii) (N, S) mask of non-missing attributes. @@ -364,17 +371,17 @@ def tokenize( tokens, notna_mask = [], [] # Tokenize the attributes - if clinical_attrs: - clinical_num_attrs, clinical_cat_attrs = None, None - if self.clinical_num_attrs: - # Group the numerical attributes from the `clinical_attrs` input in a single tensor - clinical_num_attrs = torch.hstack( - [clinical_attrs[attr].unsqueeze(1) for attr in self.clinical_num_attrs] + if tabular_attrs: + num_attrs, cat_attrs = None, None + if self.tabular_num_attrs: + # Group the numerical attributes from the `tabular_attrs` input in a single tensor + num_attrs = torch.hstack( + [tabular_attrs[attr].unsqueeze(1) for attr in self.tabular_num_attrs] ) # (N, S_num) - if self.clinical_cat_attrs: - # Group the categorical attributes from the `clinical_attrs` input in a single tensor - clinical_cat_attrs = torch.hstack( - [clinical_attrs[attr].unsqueeze(1) for attr in self.clinical_cat_attrs] + if self.tabular_cat_attrs: + # Group the categorical attributes from the `tabular_attrs` input in a single tensor + cat_attrs = torch.hstack( + [tabular_attrs[attr].unsqueeze(1) for attr in self.tabular_cat_attrs] ) # (N, S_cat) # Use "sanitized" version of the inputs, where invalid values are replaced by null/default values, for the # tokenization process. This is done to avoid propagating NaNs to available/valid values. @@ -383,30 +390,32 @@ def tokenize( # instead of their current null/default values. # 1) Convert missing numerical attributes (NaNs) to numbers to avoid propagating NaNs # 2) Clip categorical labels to convert indicators of missing data (-1) into valid indices (0) - clinical_attrs_tokens = self.clinical_tokenizer( - x_num=torch.nan_to_num(clinical_num_attrs) if clinical_num_attrs is not None else None, - x_cat=clinical_cat_attrs.clip(0) if clinical_cat_attrs is not None else None, - ) # (N, S_clinical, E) - tokens.append(clinical_attrs_tokens) - - # Identify missing data in clinical attributes - if self.clinical_num_attrs: - notna_mask.append(~(clinical_num_attrs.isnan())) - if self.clinical_cat_attrs: - notna_mask.append(clinical_cat_attrs != MISSING_CAT_ATTR) - - if img_attrs: - img_attrs_tokens = self.img_tokenizer(img_attrs) # S * (N, ?) -> (N, S_img, E) - tokens.append(img_attrs_tokens) - - # Indicate that, when image tokens are requested, they are always available - image_notna_mask = torch.full(img_attrs_tokens.shape[:2], True, device=img_attrs_tokens.device) - notna_mask.append(image_notna_mask) + tab_attrs_tokens = self.tabular_tokenizer( + x_num=torch.nan_to_num(num_attrs) if num_attrs is not None else None, + x_cat=cat_attrs.clip(0) if cat_attrs is not None else None, + ) # (N, S_tab, E) + tokens.append(tab_attrs_tokens) + + # Identify missing data in tabular attributes + if self.tabular_num_attrs: + notna_mask.append(~(num_attrs.isnan())) + if self.tabular_cat_attrs: + notna_mask.append(cat_attrs != MISSING_CAT_ATTR) + + if time_series_attrs: + time_series_attrs_tokens = self.time_series_tokenizer(time_series_attrs) # S * (N, ?) -> (N, S_ts, E) + tokens.append(time_series_attrs_tokens) + + # Indicate that, when time-series tokens are requested, they are always available + time_series_notna_mask = torch.full( + time_series_attrs_tokens.shape[:2], True, device=time_series_attrs_tokens.device + ) + notna_mask.append(time_series_notna_mask) # Cast to float to make sure tokens are not represented using double - tokens = torch.cat(tokens, dim=1).float() # (N, S_clinical + S_img, E) + tokens = torch.cat(tokens, dim=1).float() # (N, S_tab + S_ts, E) # Cast to bool to make sure attention mask is represented by bool - notna_mask = torch.cat(notna_mask, dim=1).bool() # (N, S_clinical + S_img) + notna_mask = torch.cat(notna_mask, dim=1).bool() # (N, S_tab + S_ts) return tokens, notna_mask @@ -464,18 +473,18 @@ def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = Fa @auto_move_data def forward( self, - clinical_attrs: Dict[ClinicalAttribute, Tensor], - img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor], + tabular_attrs: Dict[TabularAttribute, Tensor], + time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor], task: Literal["encode", "predict", "unimodal_param", "unimodal_tau"] = "encode", - ) -> Tensor | Dict[ClinicalAttribute, Tensor]: + ) -> Tensor | Dict[TabularAttribute, Tensor]: """Performs a forward pass through i) the tokenizer, ii) the transformer encoder and iii) the prediction head. Args: - clinical_attrs: (K: S, V: N) Sequence of batches of clinical attributes. To indicate an item is missing an + tabular_attrs: (K: S, V: N) Sequence of batches of tabular attributes. To indicate an item is missing an attribute, the flags `MISSING_NUM_ATTR`/`MISSING_CAT_ATTR` can be used for numerical and categorical attributes, respectively. - img_attrs: (K: S, V: (N, ?)), Sequence of batches of image attributes, where the dimensionality of each - attribute can vary. + time_series_attrs: (K: S, V: (N, ?)), Sequence of batches of time-series attributes, where the + dimensionality of each attribute can vary. task: Flag indicating which type of inference task to perform. Returns: @@ -500,7 +509,7 @@ def forward( "the requested inference task." ) - in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) + in_tokens, avail_mask = self.tokenize(tabular_attrs, time_series_attrs) # (N, S, E), (N, S) out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) # Early return if requested task requires no prediction heads @@ -515,7 +524,7 @@ def forward( case "predict": if self.hparams.ordinal_mode: predictions = { - attr: pred[0] if attr in ClinicalAttribute.ordinal_attrs() else pred + attr: pred[0] if attr in TabularAttribute.ordinal_attrs() else pred for attr, pred in predictions.items() } case "unimodal_param": @@ -530,11 +539,13 @@ def forward( return predictions def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: - # Extract clinical and image attributes from the batch - clinical_attrs = {attr: attr_data for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs} - img_attrs = filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.img_attrs) + # Extract tabular and time-series attributes from the batch + tabular_attrs = {attr: attr_data for attr, attr_data in batch.items() if attr in self.hparams.tabular_attrs} + time_series_attrs = filter_time_series_attributes( + batch, views=self.hparams.views, attributes=self.hparams.time_series_attrs + ) - in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) + in_tokens, avail_mask = self.tokenize(tabular_attrs, time_series_attrs) # (N, S, E), (N, S) out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) metrics = {} @@ -557,7 +568,7 @@ def _prediction_shared_step( predictions = {} for attr, prediction_head in self.prediction_heads.items(): pred = prediction_head(out_features) - if self.hparams.ordinal_mode and attr in ClinicalAttribute.ordinal_attrs(): + if self.hparams.ordinal_mode and attr in TabularAttribute.ordinal_attrs(): # For ordinal targets, extract the logits from the multiple outputs of unimodal logits head pred = pred[0] predictions[attr] = pred.squeeze(dim=1) @@ -567,16 +578,16 @@ def _prediction_shared_step( for attr, loss in self.predict_losses.items(): target, y_hat = batch[attr], predictions[attr] - if attr in ClinicalAttribute.categorical_attrs(): + if attr in TabularAttribute.categorical_attrs(): notna_mask = target != MISSING_CAT_ATTR - else: # attr in ClinicalAttribute.numerical_attrs(): + else: # attr in TabularAttribute.numerical_attrs(): notna_mask = ~target.isnan() losses[f"{loss.__class__.__name__.lower().replace('loss', '')}/{attr}"] = loss( y_hat[notna_mask], # For BCE losses (e.g. `BCELoss`, BCEWithLogitsLoss`, etc.), the targets have to be floats, # so convert them from long to float - target[notna_mask] if attr not in ClinicalAttribute.binary_attrs() else target[notna_mask].float(), + target[notna_mask] if attr not in TabularAttribute.binary_attrs() else target[notna_mask].float(), ) for metric_tag, metric in self.metrics[attr].items(): @@ -608,34 +619,34 @@ def predict_step( # noqa: D102 self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0 ) -> Tuple[ Tensor, - Optional[Dict[ClinicalAttribute, Tensor]], - Optional[Dict[ClinicalAttribute, Tensor]], - Optional[Dict[ClinicalAttribute, Tensor]], + Optional[Dict[TabularAttribute, Tensor]], + Optional[Dict[TabularAttribute, Tensor]], + Optional[Dict[TabularAttribute, Tensor]], ]: - # Extract clinical and image attributes from the patient and add batch dimension - clinical_attrs = { - attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs + # Extract tabular and time-series attributes from the patient and add batch dimension + tabular_attrs = { + attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.tabular_attrs } - img_attrs = { + time_series_attrs = { attr: attr_data[None, ...] - for attr, attr_data in filter_image_attributes( - batch, views=self.hparams.views, attributes=self.hparams.img_attrs + for attr, attr_data in filter_time_series_attributes( + batch, views=self.hparams.views, attributes=self.hparams.time_series_attrs ).items() } # Encoder's output - out_features = self(clinical_attrs, img_attrs) + out_features = self(tabular_attrs, time_series_attrs) # If the model has targets to predict, output the predictions predictions = None if self.prediction_heads: - predictions = self(clinical_attrs, img_attrs, task="predict") + predictions = self(tabular_attrs, time_series_attrs, task="predict") # If the model enforces unimodal constraint on ordinal targets, output the unimodal parametrization unimodal_params, unimodal_taus = None, None if self.hparams.ordinal_mode: - unimodal_params = self(clinical_attrs, img_attrs, task="unimodal_param") - unimodal_taus = self(clinical_attrs, img_attrs, task="unimodal_tau") + unimodal_params = self(tabular_attrs, time_series_attrs, task="unimodal_param") + unimodal_taus = self(tabular_attrs, time_series_attrs, task="unimodal_tau") # Remove unnecessary batch dimension from the different outputs # (only do this once all downstream inferences have been performed) diff --git a/didactic/tasks/cardiac_representation_clustering.py b/didactic/tasks/cardiac_representation_clustering.py index d9f052eb..e6e55091 100644 --- a/didactic/tasks/cardiac_representation_clustering.py +++ b/didactic/tasks/cardiac_representation_clustering.py @@ -12,7 +12,7 @@ from sklearn.mixture import GaussianMixture from sklearn.model_selection import GridSearchCV from tqdm.auto import tqdm -from vital.data.cardinal.config import ClinicalAttribute +from vital.data.cardinal.config import TabularAttribute class GridSearchEnsembleClustering: @@ -266,7 +266,7 @@ def main(): "--mask_tag", type=str, default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the image attributes", + help="Tag of the segmentation mask for which to extract the time-series attributes", ) parser.add_argument( "--output_dir", @@ -311,9 +311,9 @@ def main(): ) parser.add_argument( "--order_clusters_by", - type=ClinicalAttribute, - choices=list(ClinicalAttribute), - default=ClinicalAttribute.ht_grade, + type=TabularAttribute, + choices=list(TabularAttribute), + default=TabularAttribute.ht_grade, help="Attribute used to order the clusters by ascending mean value by cluster", ) args = parser.parse_args() diff --git a/didactic/tasks/cardiac_sequence_attrs_ae.py b/didactic/tasks/cardiac_sequence_attrs_ae.py index 828ecead..7b2c1347 100644 --- a/didactic/tasks/cardiac_sequence_attrs_ae.py +++ b/didactic/tasks/cardiac_sequence_attrs_ae.py @@ -10,9 +10,9 @@ from strenum import SnakeCaseStrEnum from torch import Tensor, nn from torch.nn import functional as F -from vital.data.cardinal.config import CardinalTag, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.datapipes import PatientData, filter_image_attributes +from vital.data.cardinal.datapipes import PatientData, filter_time_series_attributes from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data from vital.utils.norm import minmax_scaling, scale @@ -23,29 +23,29 @@ @unique class _AttributeNormalization(SnakeCaseStrEnum): - """Names of the available strategies for normalizing image attributes values.""" + """Names of the available strategies for normalizing time-series attributes values.""" data = auto() """Normalize the attributes' values w.r.t. statistics computed on the whole training dataset.""" metrics = auto() - """Normalize the metrics computed on the image attributes w.r.t. statistics computed on each mini-batch.""" + """Normalize the metrics computed on the time-series attributes w.r.t. statistics computed on each mini-batch.""" @unique class _AttributeStatistic(SnakeCaseStrEnum): - """Statistics about the image attributes that are computed on the dataset and stored inside the model.""" + """Statistics about the time-series attributes that are computed on the dataset and stored inside the model.""" min = auto() max = auto() class CardiacSequenceAttributesAutoencoder(SharedStepsTask): - """Autoencoder pipeline specialized for cardiac sequences image attributes.""" + """Autoencoder pipeline specialized for cardiac sequences time-series attributes.""" def __init__( self, views: Sequence[ViewEnum] = tuple(ViewEnum), - attrs: Sequence[ImageAttribute] = tuple(ImageAttribute), + attrs: Sequence[TimeSeriesAttribute] = tuple(TimeSeriesAttribute), normalization: _AttributeNormalization = _AttributeNormalization.data, reconstruction_loss: nn.Module | DictConfig = nn.L1Loss(), *args, @@ -55,10 +55,10 @@ def __init__( Args: views: Views to train the model on. - attrs: Image attributes to train the model on. - normalization: Strategy to use to normalize image attributes values. - reconstruction_loss: Criterion to measure the reconstruction error on the image attribute curves, or Hydra - config object describing how to instantiate such criterion. + attrs: Time-series attributes to train the model on. + normalization: Strategy to use to normalize time-series attributes values. + reconstruction_loss: Criterion to measure the reconstruction error on the time-series attribute curves, or + Hydra config object describing how to instantiate such criterion. *args: Positional arguments to pass to the parent's constructor. **kwargs: Keyword arguments to pass to the parent's constructor. """ @@ -76,7 +76,7 @@ def __init__( self.reconstruction_loss = reconstruction_loss self._reconstruction_loss_name = self.reconstruction_loss.__class__.__name__.lower().replace("loss", "") - # Register buffers for image attributes (needs to be in `__init__`) + # Register buffers for time-series attributes (needs to be in `__init__`) attrs_stats_defaults = { _AttributeStatistic.min: torch.finfo().max, _AttributeStatistic.max: torch.finfo().min, @@ -88,9 +88,9 @@ def __init__( @property def example_input_array(self) -> Tensor: - """Redefine example input array based only on the image attributes modality.""" - img_attrs_shape = self.hparams.data_params.in_shape[CardinalTag.image_attrs] - return torch.randn((2, 1, img_attrs_shape[1])) + """Redefine example input array based only on the time-series attributes modality.""" + attrs_shape = self.hparams.data_params.in_shape[CardinalTag.time_series_attrs] + return torch.randn((2, 1, attrs_shape[1])) @property def latent_dim(self) -> int: @@ -117,31 +117,33 @@ def reconstruction_loss_scale(self) -> int: ) return reconstruction_loss_scale - def _get_attr_bounds(self, attr: Tuple[ViewEnum, ImageAttribute]) -> Tuple[Tensor, Tensor]: - """Access the stored min/max bounds related to an image attribute. + def _get_attr_bounds(self, attr: Tuple[ViewEnum, TimeSeriesAttribute]) -> Tuple[Tensor, Tensor]: + """Access the stored min/max bounds related to a time-series attribute. Args: attr: Key identifying the attribute for which to look up the bounds. Returns: - Min/max bounds for the requested image attribute. + Min/max bounds for the requested time-series attribute. """ return self._get_attr_stat(attr, _AttributeStatistic.min), self._get_attr_stat(attr, _AttributeStatistic.max) - def _get_attr_stat(self, attr: Tuple[ViewEnum, ImageAttribute], stat: _AttributeStatistic) -> Tensor: - """Access a statistic related to an image attribute, saved as a torch buffer inside the model. + def _get_attr_stat(self, attr: Tuple[ViewEnum, TimeSeriesAttribute], stat: _AttributeStatistic) -> Tensor: + """Access a statistic related to a time-series attribute, saved as a torch buffer inside the model. Args: attr: Key identifying the attribute for which to look up the statistic. stat: Statistic to look up. Returns: - Statistic for the requested image attribute. + Statistic for the requested time-series attribute. """ return getattr(self, "_".join((*attr, stat))) - def _set_attr_stat(self, attr: Tuple[ViewEnum, ImageAttribute], stat: _AttributeStatistic, val: Tensor) -> None: - """Sets the value of a statistic related to an image attribute, saved as a torch buffer inside the model. + def _set_attr_stat( + self, attr: Tuple[ViewEnum, TimeSeriesAttribute], stat: _AttributeStatistic, val: Tensor + ) -> None: + """Sets the value of a statistic related to a time-series attribute, saved as a torch buffer inside the model. Args: attr: Key identifying the attribute for which to set the statistic. @@ -152,12 +154,12 @@ def _set_attr_stat(self, attr: Tuple[ViewEnum, ImageAttribute], stat: _Attribute def configure_model(self) -> nn.Module: """Configure the network architecture used by the system.""" - attrs_shape = self.hparams.data_params.in_shape[CardinalTag.image_attrs] + attrs_shape = self.hparams.data_params.in_shape[CardinalTag.time_series_attrs] model = hydra.utils.instantiate(self.hparams.model, input_shape=(1, attrs_shape[-1])) return model def on_fit_start(self) -> None: - """Computes global statistics for the image attributes on the training subset. + """Computes global statistics for the time-series attributes on the training subset. These stats will be used during training and inference to normalize attributes values or metrics """ @@ -169,7 +171,7 @@ def on_fit_start(self) -> None: for batch in train_dl: for (stat, update_fn), (attr, attr_data) in itertools.product( attrs_stats_update_fn.items(), - filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs).items(), + filter_time_series_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs).items(), ): self._set_attr_stat(attr, stat, update_fn(self._get_attr_stat(attr, stat), attr_data)) @@ -222,7 +224,7 @@ def forward( self, x: Tensor, task: Literal["encode", "decode", "reconstruct"] = "reconstruct", - attr: Tuple[ViewEnum, ImageAttribute] = None, + attr: Tuple[ViewEnum, TimeSeriesAttribute] = None, out_shape: Tuple[int, ...] = None, ) -> Tensor: """Performs test-time inference on the input. @@ -280,12 +282,12 @@ def forward( return x def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 - attrs = filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) + attrs = filter_time_series_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) if self.hparams.normalization == _AttributeNormalization.data: attrs = self._normalize_attrs(attrs) - # Forward on image attributes + # Forward on time-series attributes attrs_x_hat, attrs_z = {}, {} for attr_key, attr_data in attrs.items(): attrs_z[attr_key] = self.encoder(attr_data.unsqueeze(1)) @@ -312,8 +314,8 @@ def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: return metrics def _normalize_attrs( - self, attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor] - ) -> Dict[Tuple[ViewEnum, ImageAttribute], Tensor]: + self, attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor] + ) -> Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]: """Normalizes attributes with different range of values. Args: @@ -328,8 +330,8 @@ def _normalize_attrs( } def _normalize_attrs_reconstruction_metrics( - self, attrs_metrics: Dict[Tuple[ViewEnum, ImageAttribute], Tensor] - ) -> Dict[Tuple[ViewEnum, ImageAttribute], Tensor]: + self, attrs_metrics: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor] + ) -> Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]: """Normalizes reconstruction metrics computed on attributes with different range of values. Args: @@ -353,9 +355,9 @@ def _normalize_attrs_reconstruction_metrics( @torch.inference_mode() def predict_step( # noqa: D102 self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0 - ) -> Dict[Tuple[ViewEnum, ImageAttribute], Tuple[Tensor, Tensor]]: - # Reconstruct the image attributes - attrs = filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) + ) -> Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tuple[Tensor, Tensor]]: + # Reconstruct the time-series attributes + attrs = filter_time_series_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) prediction = {} for attr_key, attr_data in attrs.items(): if not (is_batch := attr_data.ndim == 2): @@ -383,12 +385,12 @@ def __init__(self, cardiac_sequence_attrs_model: str | Path | CardiacSequenceAtt """Initializes class instance. Args: - cardiac_sequence_attrs_model: Pretrained image attributes autoencoder model used to compress the attributes - into tokens. Mutually exclusive parameter with `embed_dim`. + cardiac_sequence_attrs_model: Pretrained time-series attributes autoencoder model used to compress the + attributes into tokens. Mutually exclusive parameter with `embed_dim`. """ super().__init__() - # If the image attributes encoder is a checkpoint rather than an instantiated network, load the model from + # If the time-series attributes encoder is a checkpoint rather than an instantiated network, load the model from # the checkpoint if isinstance(cardiac_sequence_attrs_model, (str, Path)): cardiac_sequence_attrs_model = load_from_checkpoint(cardiac_sequence_attrs_model) @@ -396,11 +398,11 @@ def __init__(self, cardiac_sequence_attrs_model: str | Path | CardiacSequenceAtt # Also, the backend model needs to be saved as a class member even if it's not necessary so that the # tokenizer as a whole can behave as expected of a module # (e.g. moving it across devices is applied recursively to the backend model, etc.) - self.img_attrs_ae = cardiac_sequence_attrs_model.eval().requires_grad_(False) + self.autoencoder = cardiac_sequence_attrs_model.eval().requires_grad_(False) @torch.inference_mode() - def forward(self, attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor]) -> Tensor: - """Encodes image attributes using the autoencoder. + def forward(self, attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]) -> Tensor: + """Encodes time-series attributes using the autoencoder. Args: attrs: (K: S, V: (N, ?)): Attributes to tokenize, where the dimensionality of each attribute can vary. @@ -408,4 +410,4 @@ def forward(self, attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor]) -> Tenso Returns: (N, S, E), Tokenized version of the attributes. """ - return torch.stack([self.img_attrs_ae(x, task="encode", attr=attr) for attr, x in attrs.items()], dim=1) + return torch.stack([self.autoencoder(x, task="encode", attr=attr) for attr, x in attrs.items()], dim=1) diff --git a/didactic/tasks/cardiac_sequence_attrs_pca.py b/didactic/tasks/cardiac_sequence_attrs_pca.py index 8d0a6159..76c34c58 100644 --- a/didactic/tasks/cardiac_sequence_attrs_pca.py +++ b/didactic/tasks/cardiac_sequence_attrs_pca.py @@ -5,7 +5,7 @@ import numpy as np from sklearn.decomposition import PCA -from vital.data.cardinal.config import ImageAttribute +from vital.data.cardinal.config import TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum from vital.data.transforms import Interp1d from vital.utils.decorators import auto_cast_data @@ -13,7 +13,7 @@ class CardiacSequenceAttributesPCA: - """PCA model specialized for cardiac sequences image attributes, where attributes have different range of values.""" + """PCA model specialized for cardiac sequences time-series attrs, where attrs have different range of values.""" def __init__( self, @@ -26,7 +26,7 @@ def __init__( Args: n_components: Dimensionality of the PCA model's latent space. pca_kwargs: Parameters that will be passed along to the `PCA`'s init. - strategy: Strategies available for handling multi-domain image attributes values. + strategy: Strategies available for handling multi-domain time-series attributes values. 'global_pca': one global PCA model is trained on the normalized attributes. 'attr_pca': one PCA model is fitted to each attribute, side-stepping the need for normalization. """ @@ -117,7 +117,7 @@ def __call__( self, x: np.ndarray, task: Literal["encode", "decode", "reconstruct"] = "reconstruct", - attr: Tuple[ViewEnum, ImageAttribute] = None, + attr: Tuple[ViewEnum, TimeSeriesAttribute] = None, out_shape: Tuple[int, ...] = None, ) -> np.ndarray: """Performs test-time inference on the input. @@ -170,11 +170,12 @@ def __call__( x = x.squeeze(axis=1) return x - def fit(self, samples: Dict[Tuple[ViewEnum, ImageAttribute], np.ndarray]) -> "CardiacSequenceAttributesPCA": + def fit(self, samples: Dict[Tuple[ViewEnum, TimeSeriesAttribute], np.ndarray]) -> "CardiacSequenceAttributesPCA": """Fits one or multiple PCA model(s) to the attributes samples, depending on the chosen strategy. Args: - samples: Mapping between attributes and their samples, of shape + samples: Mapping between attributes and their samples, of shape (N, L), where N is the number of samples and + L is the length of the time-series. The samples are divided by attributes to allow the model to learn models/statistics on each attribute independently. @@ -219,7 +220,7 @@ def main(): from matplotlib import pyplot as plt from tqdm.auto import tqdm from vital.data.cardinal.config import CardinalTag - from vital.data.cardinal.utils.attributes import IMAGE_ATTR_LABELS + from vital.data.cardinal.utils.attributes import TIME_SERIES_ATTR_LABELS from vital.data.cardinal.utils.itertools import Patients from vital.utils.logging import configure_logging from vital.utils.parsing import yaml_flow_collection @@ -239,7 +240,7 @@ def main(): "--mask_tag", type=str, default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the image attributes", + help="Tag of the segmentation mask for which to extract the time-series attributes", ) parser.add_argument( "--n_features", @@ -260,7 +261,7 @@ def main(): type=str, choices=["global_pca", "attr_pca"], default="global_pca", - help="Strategies available for handling multi-domain image attributes values. \n" + help="Strategies available for handling multi-domain time-series attributes values. \n" "'global_pca': one global PCA model is trained on the normalized attributes. \n" "'attr_pca': one PCA model is fitted to each attribute, side-stepping the need for normalization.", ) @@ -357,7 +358,7 @@ def _save_cur_fig(title: str, folder: Path) -> None: lambda x: model(x, task="decode", attr=attr_key), attr_samples_embedding, sweep_coeffs, - plots_kwargs={"ylabel": IMAGE_ATTR_LABELS[attr_key[1]]}, + plots_kwargs={"ylabel": TIME_SERIES_ATTR_LABELS[attr_key[1]]}, ): _save_cur_fig(title, output_dir / "sweep" / "_".join(attr_key)) diff --git a/didactic/tasks/utils.py b/didactic/tasks/utils.py index 22321c52..c3d6dd06 100644 --- a/didactic/tasks/utils.py +++ b/didactic/tasks/utils.py @@ -5,7 +5,7 @@ import numpy as np import torch from tqdm.auto import tqdm -from vital.data.cardinal.config import CardinalTag, ClinicalAttribute, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TabularAttribute, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum from vital.data.cardinal.datapipes import process_patient from vital.data.cardinal.utils.data_dis import check_subsets @@ -31,7 +31,7 @@ def encode_patients( Args: model: Transformer encoder model to use for inference. patients: (N) Patients to encode. - mask_tag: Tag of the segmentation mask for which to extract the image attributes. + mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. progress_bar: If ``True``, enables progress bars detailing the progress of the processing and encoding patients data. **forward_kwargs: Keyword arguments to pass along to the encoder's inference method. @@ -39,16 +39,16 @@ def encode_patients( Returns: (N, E), encodings of the patients. """ - clinical_attrs, img_attrs = model.hparams.clinical_attrs, model.hparams.img_attrs + tab_attrs, time_series_attrs = model.hparams.tabular_attrs, model.hparams.time_series_attrs if progress_bar: patients = tqdm(patients, desc="Processing patients' data to prepare it for inference", unit="patient") patients_attrs = [ - process_patient(patient, clinical_attributes=clinical_attrs, image_attributes=img_attrs, mask_tag=mask_tag) + process_patient(patient, tabular_attrs=tab_attrs, time_series_attrs=time_series_attrs, mask_tag=mask_tag) for patient in patients ] # Run inference on one patient at a time, instead of in batches, to avoid a few possible issues: - # i) having to resample image attributes to be of the same constant shape to be able to stack them + # i) having to resample time-series attributes to be of the same constant shape to be able to stack them # ii) out of memory errors, in case of very large collections of patients if progress_bar: patients_attrs = tqdm( @@ -60,8 +60,8 @@ def encode_patients( [ encode_patients_attrs( model, - {attr: patient_attrs[attr] for attr in clinical_attrs}, - {(view, attr): patient_attrs[view][attr] for view in model.hparams.views for attr in img_attrs}, + {attr: patient_attrs[attr] for attr in tab_attrs}, + {(view, attr): patient_attrs[view][attr] for view in model.hparams.views for attr in time_series_attrs}, **forward_kwargs, ) for patient_attrs in patients_attrs @@ -73,30 +73,30 @@ def encode_patients( def encode_patients_attrs( model: CardiacMultimodalRepresentationTask, - clinical_attrs: Dict[ClinicalAttribute, np.ndarray], - img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], np.ndarray], + tabular_attrs: Dict[TabularAttribute, np.ndarray], + time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], np.ndarray], **forward_kwargs: Dict[str, Any], ) -> np.ndarray: """Wrapper around encoder inference to handle boilerplate code (e.g. numpy to torch, batching/unbatching, etc.). Args: model: Transformer encoder model to use for inference. - clinical_attrs: (K: S, V: [N]) Sequence of (batch of) clinical attributes. - img_attrs: (K: S, V: ([N,] L)), Sequence of (batch of) image attributes, where L is the dimensionality of each - attribute. + tabular_attrs: (K: S, V: [N]) Sequence of (batch of) tabular attributes. + time_series_attrs: (K: S, V: ([N,] L)), Sequence of (batch of) time-series attributes, where L is the + dimensionality of each attribute. **forward_kwargs: Keyword arguments to pass along to the encoder's inference method. Returns: ([N,], E), encoding(s) of the patient/batch of patients. """ - is_batch = list(clinical_attrs.values())[0].ndim == 1 + is_batch = list(tabular_attrs.values())[0].ndim == 1 # If the input isn't a batch of data, add the batch dimension - clinical_attrs = {k: v if is_batch else v[None, ...] for k, v in clinical_attrs.items()} - img_attrs = {k: v if is_batch else v[None, ...] for k, v in img_attrs.items()} + tabular_attrs = {k: v if is_batch else v[None, ...] for k, v in tabular_attrs.items()} + time_series_attrs = {k: v if is_batch else v[None, ...] for k, v in time_series_attrs.items()} with torch.inference_mode(): - out_features = model(numpy_to_torch(clinical_attrs), numpy_to_torch(img_attrs), **forward_kwargs) + out_features = model(numpy_to_torch(tabular_attrs), numpy_to_torch(time_series_attrs), **forward_kwargs) # Squeeze to remove batch dimension, if it wasn't there in the input if not is_batch: @@ -119,7 +119,7 @@ def summarize_patient_attn( Args: model: Transformer encoder model for which we want to analyze the attention. patient: Patient for which to summarize the model's attention. - mask_tag: Tag of the segmentation mask for which to extract the image attributes. + mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. use_attention_rollout: Whether to use attention rollout to compute the summary of the attention. attention_rollout_kwargs: When using attention rollout (`use_attention_rollout` is True), parameters to forward to `didactic.models.explain.attention_rollout`.