From 167fdaa403b806d69671894c338d683fb85da3fe Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Tue, 14 Nov 2023 18:40:51 +0100 Subject: [PATCH] Remove unused scripts to plot attrs w.r.t. clusters/neighbors after upstreaming generic script to `vital` --- didactic/data/cardinal/utils.py | 145 ------- .../describe_representation_clustering.py | 351 --------------- .../scripts/describe_representation_knn.py | 404 ------------------ 3 files changed, 900 deletions(-) delete mode 100644 didactic/data/cardinal/utils.py delete mode 100644 didactic/scripts/describe_representation_clustering.py delete mode 100644 didactic/scripts/describe_representation_knn.py diff --git a/didactic/data/cardinal/utils.py b/didactic/data/cardinal/utils.py deleted file mode 100644 index 1db71b7..0000000 --- a/didactic/data/cardinal/utils.py +++ /dev/null @@ -1,145 +0,0 @@ -from typing import Hashable, Iterable, Mapping, Tuple - -import numpy as np -import pandas as pd -from vital.data.cardinal.config import CardinalTag, TimeSeriesAttribute -from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.utils.attributes import build_attributes_dataframe -from vital.data.cardinal.utils.data_struct import Patient -from vital.data.cardinal.utils.itertools import Patients -from vital.data.transforms import Interp1d - - -def build_clusterings_dataframe( - patients: Patients, clusterings: Mapping[str, Mapping[Patient.Id, str]], cat_to_num: bool = False -) -> pd.DataFrame: - """Builds a dataframe to store the data of the patients in each cluster, repeating patients across clusterings. - - Args: - patients: Collection of patients to include in the dataframe. - clusterings: Instances of clustering of the patients population, represented as mappings between patient IDs - and cluster labels. - cat_to_num: Whether to convert the categorical attributes to numerical labels, based on the order of the - categories. - - Returns: - Dataframe, with a multi-index with levels `(model, cluster, patient_id)`, containing the data of the patients - from each cluster for each model/clustering, repeating patients as necessary. - """ - # Convert clusterings from mapping between item IDs and cluster IDs to lists of patient IDs by cluster - clusterings = { - clustering_id: { - cluster_label: sorted( - patient_id for patient_id, patient_cluster in clusters.items() if patient_cluster == cluster_label - ) - for cluster_label in sorted(set(clusters.values())) - } - for clustering_id, clusters in clusterings.items() - } - - data = patients.to_dataframe() - - if cat_to_num: - # Convert the categorical attributes to numerical labels - def _to_num(attr_data: pd.Series) -> pd.Series: - if attr_data.dtype == "category": - attr_data = attr_data.cat.codes - return attr_data - - data = data.apply(_to_num) - - # For each clustering, extract the data of the patients in each cluster - clusterings_data = pd.concat( - { - clustering_id: pd.concat( - { - cluster_label: data.loc[patient_ids_in_cluster] - for cluster_label, patient_ids_in_cluster in clusters.items() - } - ) - for clustering_id, clusters in clusterings.items() - }, - names=["model", "cluster", "patient_id"], - ) - - return clusterings_data - - -def build_knn_dataframe(patients: Patients, kneighbors: np.ndarray, cat_to_num: bool = False) -> pd.DataFrame: - """Builds a dataframe to store the data of the nearest neighbors of each patient, repeating patients as necessary. - - Args: - patients: Collection of patients to include in the dataframe. - kneighbors: Array (of `Patient.Id`s) of shape `(n_encodings, n_patients, n_neighbors)` containing the IDs of the - nearest neighbors of each patient for each encoding. - cat_to_num: Whether to convert the categorical attributes to numerical labels, based on the order of the - categories. - - Returns: - Dataframe, with a multi-index with levels `(model, patient_id, neighbor_id)`, containing the data of the nearest - neighbors of each patient for each model, repeating patients as necessary. - """ - data = patients.to_dataframe() - - if cat_to_num: - # Convert the categorical attributes to numerical labels - def _to_num(attr_data: pd.Series) -> pd.Series: - if attr_data.dtype == "category": - attr_data = attr_data.cat.codes - return attr_data - - data = data.apply(_to_num) - - # For each encoding, extract the data of the nearest neighbors of each patient - neigh_data = pd.concat( - { - f"{enc_idx}": pd.concat( - { - patient_id: data.loc[kneighbors_ids] - for patient_id, kneighbors_ids in zip(patients, enc_kneighbors_ids) - } - ) - for enc_idx, enc_kneighbors_ids in enumerate(kneighbors) - }, - names=["model", "patient_id", "neighbor_id"], - ) - - return neigh_data - - -def build_time_series_attr_by_patient_group_dataframe( - patients_groups: Mapping[Hashable, Iterable[Patient]], - attr: Tuple[ViewEnum, TimeSeriesAttribute], - group_desc: str = "group", - mask_tag: str = CardinalTag.mask, - resampling_rate: int = 128, -) -> pd.DataFrame: - """Builds a dataframe with the average curve of a time-series attribute by patient group. - - Args: - patients_groups: Mapping between group ID and the patients in that group. - attr: A pair of view and time-series attribute to compute the average curve of. - group_desc: Description of the semantic meaning of the groups. - mask_tag: Tag of the segmentation mask for which to extract the time-series attribute data. - resampling_rate: Number of points at which to resample the time-series attribute curves from each patient, so - that they can be easily compared and aggregated together. - - Returns: - Dataframe with the average curve of an time-series attribute by patient group, in long format. - """ - resampling_fn = Interp1d(resampling_rate) - - # For each group, compute the attribute's average curve - data = {} - for group, patients in patients_groups.items(): - # When stacking the attributes from all patients in a group, it's necessary to resample the attributes - # Otherwise the variable number of frames by seq would cause errors because of unequal array shapes - group_attr_data = np.vstack( - [resampling_fn(patient.get_mask_attributes(mask_tag)[attr[0]][attr[1]]) for patient in patients] - ) - # Add the unnecessary nested dict level to conform to the API expected by `build_attributes_dataframe` - data[group] = {attr: group_attr_data.mean(axis=0)} - - # Structure the attribute's average curve by group as a dataframe in long format - # to make it easier to use seaborn's plot - return build_attributes_dataframe(data, data_name=group_desc) diff --git a/didactic/scripts/describe_representation_clustering.py b/didactic/scripts/describe_representation_clustering.py deleted file mode 100644 index d31134e..0000000 --- a/didactic/scripts/describe_representation_clustering.py +++ /dev/null @@ -1,351 +0,0 @@ -import itertools -import logging -from pathlib import Path -from typing import Iterator, Mapping, Sequence, Tuple - -import numpy as np -import pandas as pd -import seaborn as sns -import seaborn.objects as so -from matplotlib.axes import Axes -from vital.data.cardinal.config import CardinalTag, TabularAttribute, TimeSeriesAttribute -from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.utils.attributes import CLINICAL_ATTR_UNITS, TABULAR_CAT_ATTR_LABELS, TIME_SERIES_ATTR_LABELS -from vital.data.cardinal.utils.data_struct import Patient -from vital.data.cardinal.utils.itertools import Patients - -from didactic.data.cardinal.utils import build_clusterings_dataframe, build_time_series_attr_by_patient_group_dataframe - -logger = logging.getLogger(__name__) - - -def plot_tabular_attrs_variability_figures( - patients: Patients, - clusterings: Mapping[str, Mapping[Patient.Id, str]], - tabular_attrs: Sequence[TabularAttribute] = None, - num_clustering_agg: str = None, - cat_plot_kwargs: dict = None, - num_plot_kwargs: dict = None, -) -> Iterator[Tuple[str, Axes]]: - """Plots the variability of cluster-aggregated tabular attributes across multiple clusterings w.r.t. clusters. - - Args: - patients: Collection of patients data from which to extract the attributes. - clusterings: Instances of clustering of the patients population, representation as mappings between patient IDs - and cluster labels. - tabular_attrs: Subset of tabular attributes on which to compile the results. If not provided, will default to - all available attributes. - num_clustering_agg: Aggregation function to use to aggregate the numerical attributes by clusters, before the - aggregation across clusterings. If not provided, the attributes are aggregated by clusters and across - clusterings at the same time, leading to a higher reported variability. - cat_plot_kwargs: Parameters to forward to the call to `seaborn.heatmap` for categorical attributes. - num_plot_kwargs: Parameters to forward to the call to `seaborn.boxplot` for numerical attributes. - - Returns: - Iterator over figures (and their corresponding titles) plotting the variability of cluster-aggregated tabular - attributes across multiple clusterings w.r.t. clusters. - """ - if cat_plot_kwargs is None: - cat_plot_kwargs = {} - if num_plot_kwargs is None: - num_plot_kwargs = {} - - # Gather the data of the patients in each cluster for each clustering - clusterings_data = build_clusterings_dataframe(patients, clusterings) - if tabular_attrs is not None: - clusterings_data = clusterings_data[tabular_attrs] - - # Ignore `matplotlib.category` logger 'INFO' level logs to avoid repeated logs about categorical units parsable - # as floats - logging.getLogger("matplotlib.category").setLevel(logging.WARNING) - - # For each attribute, plot the variability of the attribute w.r.t. clusters - for attr in clusterings_data.columns: - title = f"{attr}_wrt_clusters" - attr_data = clusterings_data[attr] - - # Based on whether the attribute is categorical or numerical, define different types of plots - if attr in TabularAttribute.categorical_attrs(): - # Compute the occurrence of each category for each cluster (including NA), across all clusterings - attr_stats = attr_data.groupby(["model", "cluster"]).value_counts(normalize=True, dropna=False) * 100 - # After the NA values have been taken into account for the count, drop them - attr_stats = attr_stats.dropna() - - # For unknown reasons, this plot is unable to pickup variables in the multi-index. As a workaround, we - # reset the index and to make the index levels into columns available to the plot - attr_stats = attr_stats.reset_index() - - # For boolean attributes, convert the values to string so that seaborn can properly pick up label names - # Avoids the following error: 'bool' object has no attribute 'startswith' - # At the same time, assign relevant labels/hues/etc. for either boolean or categorical attributes - if attr in TabularAttribute.boolean_attrs(): - attr_stats = attr_stats.astype({attr: str}) - ylabel = "(% true)" - hue_order = [str(val) for val in TABULAR_CAT_ATTR_LABELS[attr]] - else: - ylabel = "(% by label)" - hue_order = TABULAR_CAT_ATTR_LABELS[attr] - - # Use dodged barplots for categorical attributes - with sns.axes_style("darkgrid"): - plot = sns.barplot( - data=attr_stats, - x="cluster", - y="proportion", - hue=attr, - hue_order=hue_order, - estimator="median", - errorbar=lambda data: (np.quantile(data, 0.25), np.quantile(data, 0.75)), - **cat_plot_kwargs, - ) - - plot.set(title=title, ylabel=ylabel) - - else: # attr in TabularAttribute.numerical_attrs() - if num_clustering_agg is not None: - # Aggregate the numerical attributes by clusters, before the aggregation across clusterings - attr_data = attr_data.groupby(["model", "cluster"]).agg(num_clustering_agg) - - # Use boxplots for numerical attributes - with sns.axes_style("darkgrid"): - # Reset index on the data to make the index levels available as values to plot - plot = sns.boxplot(data=attr_data.reset_index(), x="cluster", y=attr, **num_plot_kwargs) - - plot.set(title=title, ylabel=CLINICAL_ATTR_UNITS[attr][0]) - - yield title, plot - - -def plot_time_series_attrs_variability_figures( - patients: Patients, - clusterings: Mapping[str, Mapping[Patient.Id, str]], - time_series_attrs: Sequence[Tuple[ViewEnum, TimeSeriesAttribute]], - mask_tag: str = CardinalTag.mask, -) -> Iterator[Tuple[str, Axes]]: - """Plots the variability of cluster-aggregated time-series attributes across multiple clusterings w.r.t. clusters. - - Args: - patients: Collection of patients data from which to extract the attributes. - clusterings: Instances of clustering of the patients population, representation as mappings between patient IDs - and cluster labels. - time_series_attrs: Subset of time-series attributes derived from segmentations (identified by view/attribute - pairs) for which to plot the variability between bins of the reference tabular attribute. - mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. - - Returns: - Iterator over figures (and their corresponding titles) plotting the variability of cluster-aggregated - time-series attrs across multiple clusterings w.r.t. clusters. - """ - # Convert clusterings from mapping between item IDs and cluster IDs to lists of patient IDs by cluster - clusterings = { - clustering_id: { - cluster_label: sorted( - patient_id for patient_id, patient_cluster in clusters.items() if patient_cluster == cluster_label - ) - for cluster_label in sorted(set(clusters.values())) - } - for clustering_id, clusters in clusterings.items() - } - - # Merge the lists of patients in each cluster for each clustering, to obtain a single list of patients per cluster - # (corresponding to the union of the patients in a specific cluster each clustering) - # At the same time, while we were up until now just working with patient IDs, we now fetch the patient data - any_clustering_label = list(clusterings.keys())[0] - cluster_labels = sorted(set(clusterings[any_clustering_label].keys())) - patients_by_cluster = { - cluster_label: list( - map( - patients.get, - itertools.chain.from_iterable(clusters[cluster_label] for clusters in clusterings.values()), - ) - ) - for cluster_label in cluster_labels - } - - # For each time-series attr, build the dataframe of the mean curve for each bin and plot the curves for each bin - for time_series_attr in time_series_attrs: - time_series_attr_data = build_time_series_attr_by_patient_group_dataframe( - patients_by_cluster, time_series_attr, group_desc="cluster", mask_tag=mask_tag - ) - - with sns.axes_style("darkgrid"): - plot = sns.lineplot( - data=time_series_attr_data, x="time", y="val", hue="cluster", hue_order=sorted(cluster_labels) - ) - title = f"{'/'.join(time_series_attr)}_wrt_clusters" - plot.set(title=title, ylabel=TIME_SERIES_ATTR_LABELS[time_series_attr[1]]) - - yield title, plot - - -def main(): - """Run the script.""" - from argparse import ArgumentParser - - from matplotlib import pyplot as plt - from tqdm import tqdm - from vital.utils.logging import configure_logging - from vital.utils.parsing import yaml_flow_collection - - # Ensure that matplotlib is using 'agg' backend in non-interactive case - plt.switch_backend("agg") - - configure_logging(log_to_console=True, console_level=logging.INFO) - parser = ArgumentParser() - parser.add_argument( - "clusterings", - nargs="+", - type=Path, - help="Path to files/folders describing the different clusterings of the patients for which to describe the " - "distribution of attributes", - ) - parser.add_argument( - "--clusterings_format", - type=str, - choices=["csv", "txt"], - default="csv", - help="Format in which the clusterings to be loaded are stored. `csv` is a single CSV file mapping a `patient` " - "column to a `cluster` column. `txt` is multiple text files inside the folder listing the IDs of the patient " - "in each cluster.", - ) - parser = Patients.add_args(parser) - parser.add_argument( - "--tabular_attrs", - type=TabularAttribute, - nargs="*", - choices=list(TabularAttribute), - help="Subset of tabular attributes on which to compile the results. If not provided, will default to all " - "available attributes", - ) - parser.add_argument( - "--time_series_attrs", - type=TimeSeriesAttribute, - choices=list(TimeSeriesAttribute), - nargs="*", - default=list(TimeSeriesAttribute), - help="Subset of time-series attributes derived from segmentations for which to plot the intra/inter-cluster " - "variability", - ) - parser.add_argument( - "--mask_tag", - type=str, - default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the time-series attributes", - ) - parser.add_argument( - "--num_clustering_agg", - type=str, - help="Aggregation function to use to aggregate the numerical attributes by clusters, before the aggregation " - "across clusterings. If not provided, the attributes are aggregated by clusters and across clusterings at the " - "same time, leading to a higher reported variability.", - ) - parser.add_argument( - "--tabular_cat_plot_kwargs", - type=yaml_flow_collection, - metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the call to `seaborn.heatmap` for categorical tabular attributes figures", - ) - parser.add_argument( - "--tabular_num_plot_kwargs", - type=yaml_flow_collection, - metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the call to `seaborn.boxplot` for categorical tabular attributes figures", - ) - parser.add_argument( - "--output_dir", - type=Path, - default=Path("compiled_representation_results"), - help="Root directory under which to save the compiled results for all of the methods", - ) - args = parser.parse_args() - kwargs = vars(args) - - ( - clustering_paths, - clusterings_fmt, - tabular_attrs, - time_series_attrs, - mask_tag, - num_clustering_agg, - cat_plot_kwargs, - num_plot_kwargs, - output_dir, - ) = list( - map( - kwargs.pop, - [ - "clusterings", - "clusterings_format", - "tabular_attrs", - "time_series_attrs", - "mask_tag", - "num_clustering_agg", - "tabular_cat_plot_kwargs", - "tabular_num_plot_kwargs", - "output_dir", - ], - ) - ) - time_series_attrs_keys = [ - (view, time_series_attr) for view, time_series_attr in itertools.product(args.views, time_series_attrs) - ] - - # Load the dataset - patients = Patients(**kwargs) - - # Load and interpret the clustering instances - match clusterings_fmt: - case "csv": - clusterings = { - str(idx): pd.read_csv(clustering_file, index_col=0, dtype={"patient": str, "cluster": str})[ - "cluster" - ].to_dict() - for idx, clustering_file in enumerate(clustering_paths) - } - case "txt": - clusterings = { - clustering_dir.stem: { - patient_id: cluster_file.stem - for cluster_file in clustering_dir.glob("*.txt") - for patient_id in cluster_file.read_text().split() - } - for clustering_dir in clustering_paths - } - case _: - raise ValueError(f"Unknown `clusterings_format`: {clusterings_fmt}") - - tabular_attrs_plots = plot_tabular_attrs_variability_figures( - patients, - clusterings, - tabular_attrs=tabular_attrs, - num_clustering_agg=num_clustering_agg, - cat_plot_kwargs=cat_plot_kwargs, - num_plot_kwargs=num_plot_kwargs, - ) - time_series_attrs_plots = plot_time_series_attrs_variability_figures( - patients, clusterings, time_series_attrs_keys, mask_tag=mask_tag - ) - - # Plot the variability of the tabular and time-series attributes - output_dir.mkdir(parents=True, exist_ok=True) # Prepare the output folder for the method - n_plots = (len(tabular_attrs) if tabular_attrs else len(TabularAttribute)) + len(time_series_attrs_keys) - for title, plot in tqdm( - itertools.chain(tabular_attrs_plots, time_series_attrs_plots), - desc="Plotting the variability of the attributes w.r.t. clusters", - unit="attr", - total=n_plots, - ): - title_pathified = title.lower().replace("/", "_").replace(" ", "_") - filepath = output_dir / f"{title_pathified}.svg" - - if isinstance(plot, so.Plot): - plot.save(filepath, bbox_inches="tight") - elif isinstance(plot, Axes): - plt.savefig(filepath) - plt.close() # Close the figure to avoid contamination between plots - else: - raise ValueError(f"Unable to save the figure for plot type: {type(plot)}.") - - -if __name__ == "__main__": - main() diff --git a/didactic/scripts/describe_representation_knn.py b/didactic/scripts/describe_representation_knn.py deleted file mode 100644 index f4d6627..0000000 --- a/didactic/scripts/describe_representation_knn.py +++ /dev/null @@ -1,404 +0,0 @@ -import itertools -import logging -from pathlib import Path -from typing import Iterator, Sequence, Tuple - -import numpy as np -import pandas as pd -import seaborn as sns -import seaborn.objects as so -from matplotlib.axes import Axes -from sklearn.neighbors import NearestNeighbors -from vital.data.cardinal.config import CardinalTag, TabularAttribute, TimeSeriesAttribute -from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.utils.attributes import CLINICAL_ATTR_UNITS, TIME_SERIES_ATTR_LABELS -from vital.data.cardinal.utils.data_struct import Patient -from vital.data.cardinal.utils.itertools import Patients - -from didactic.data.cardinal.utils import build_knn_dataframe, build_time_series_attr_by_patient_group_dataframe -from didactic.tasks.cardiac_multimodal_representation import CardiacMultimodalRepresentationTask - -logger = logging.getLogger(__name__) - - -def find_nearest_neighbors( - patients_encodings: Sequence[np.ndarray], patient_ids: Sequence[Patient.Id], **neigh_kwargs -) -> np.ndarray: - """Finds the nearest neighbors of each patient based on the encodings from each model. - - Args: - patients_encodings: Encodings of the patients corresponding to multiple representations of the population. - patient_ids: IDs of the patients to use when identifying the nearest neighbors. Should be of the same length as - every array in `patients_encodings`. - neigh_kwargs: Keyword arguments to forward to the `KNeighborsClassifier` constructor. - - Returns: - Array (of `Patient.Id`s) of shape `(n_encodings, n_patients, n_neighbors)` containing the IDs of the nearest - neighbors of each patient for each encoding. - """ - # Check that the number of patients is the same for each encoding, and matches the number of patient IDs - for idx, encoding in enumerate(patients_encodings): - if len(encoding) != len(patient_ids): - raise ValueError( - f"The number of patients in the encoding for model #{idx} does not match the number of patient IDs" - ) - - # For each encoding, fit a nearest neighbors model on the encoding and find the nearest neighbors of each patient - kneighbors_indices = np.stack( - [ - NearestNeighbors(**neigh_kwargs).fit(encoding).kneighbors(return_distance=False) - for encoding in patients_encodings - ] - ) - - # Convert the indices of the nearest neighbors to patient IDs - kneighbors_ids = np.array(patient_ids)[kneighbors_indices] - - return kneighbors_ids - - -def plot_tabular_attrs_variability_figures( - patients: Patients, - kneighbors_ids: np.ndarray, - var_attr: TabularAttribute, - tabular_attrs: Sequence[TabularAttribute] = None, - agg: str = "mean", - plot_kwargs: dict = None, - dots_layer_kwargs: dict = None, - polyfit_layer_kwargs: dict = None, -) -> Iterator[Tuple[str, so.Plot]]: - """Plots the variability of locally-aggregated tabular attrs across multiple encodings w.r.t. a ref. tabular attr. - - Args: - patients: Collection of patients data from which to extract the attributes. - kneighbors_ids: Array (of `Patient.Id`s) of shape `(n_encodings, n_patients, n_neighbors)` containing the IDs of - the nearest neighbors of each patient for each encoding. - var_attr: Reference tabular attribute w.r.t. which to plot the variability of the attributes. - tabular_attrs: Subset of tabular attributes on which to compile the results. If not provided, will default to - all available attributes. - agg: Aggregation function to apply to the neighborhood of each patient. - plot_kwargs: Parameters to forward to the call to `seaborn.object.Plot`. - dots_layer_kwargs: Parameters to forward to the call to `seaborn.object.Plot.add` for the scatter plot layer. - polyfit_layer_kwargs: Parameters to forward to the call to `seaborn.object.Plot.scale` for the polynomial fit - layer. - - Returns: - Iterator over figures (and their corresponding titles) plotting the variability of locally-aggregated tabular - attrs across multiple encodings w.r.t. a ref. tabular attr. - """ - if plot_kwargs is None: - plot_kwargs = {} - if dots_layer_kwargs is None: - dots_layer_kwargs = {} - if polyfit_layer_kwargs is None: - polyfit_layer_kwargs = {} - - # Gather the data of the nearest neighbors of each patient for each encoding - neigh_data = build_knn_dataframe(patients, kneighbors_ids, cat_to_num=True) - if tabular_attrs is not None: - neigh_data = neigh_data[tabular_attrs] - - # Compute the mean of the attributes over the nearest neighbors of each patient - neigh_agg_data = neigh_data.groupby(level=["model", "patient_id"]).agg(agg) - - # Ignore `matplotlib.category` logger 'INFO' level logs to avoid repeated logs about categorical units parsable - # as floats - logging.getLogger("matplotlib.category").setLevel(logging.WARNING) - - # For each attribute (regularized by aggregation across neighbors), - # generate a scatter plot of the attribute w.r.t. the reference attribute - for attr in neigh_agg_data.columns: - # Drop rows w/ NA values for the attribute to plot to avoid "Cannot cast ufunc 'lstsq_n'" error - # (the error happens when trying to fit least squares polynomial regression if NA values are present) - # This really only affects attributes with lots of missing values since agg over neighbors typically - # "fills in the gaps" for attributes with few missing values - attr_data = neigh_agg_data.dropna(subset=[attr]) - - plot = ( - so.Plot(data=attr_data, x=var_attr, y=attr, **plot_kwargs) - .add(so.Dots(), so.Jitter(0.3), **dots_layer_kwargs) - .add(so.Line(), so.PolyFit(), **polyfit_layer_kwargs) - ) - - title = f"{attr}_wrt_{var_attr}" - axis_labels = { - axis: "(ratio true/false)" if attr in TabularAttribute.boolean_attrs() else CLINICAL_ATTR_UNITS[attr][0] - for axis, attr in zip(["x", "y"], (var_attr, attr)) - } - plot = plot.label(title=title, **axis_labels) - - yield title, plot - - -def plot_time_series_attrs_variability_figures( - patients: Patients, - kneighbors_ids: np.ndarray, - var_attr: TabularAttribute, - time_series_attrs: Sequence[Tuple[ViewEnum, TimeSeriesAttribute]], - agg: str = "mean", - mask_tag: str = CardinalTag.mask, - n_bins: int = 5, -) -> Iterator[Tuple[str, Axes]]: - """Plots the variability of locally-aggregated time-series attrs across multiple encodings w.r.t. a ref. tab. attr. - - Args: - patients: Collection of patients data from which to extract the attributes. - kneighbors_ids: Array (of `Patient.Id`s) of shape `(n_encodings, n_patients, n_neighbors)` containing the IDs of - the nearest neighbors of each patient for each encoding. - var_attr: Reference tabular attribute w.r.t. which to plot the variability of the attributes. - time_series_attrs: Subset of time-series attributes derived from segmentations (identified by view/attribute - pairs) for which to plot the variability between bins of the reference tabular attribute. - agg: Aggregation function to apply to the neighborhood of each patient. - mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. - n_bins: Number of bins by which to divide the population and over which to compute the variability of the - time-series attributes. - - Returns: - Iterator over figures (and their corresponding titles) plotting the variability of locally-aggregated - time-series attrs across multiple encodings w.r.t. a ref. tabular attr. - """ - # Gather the reference attribute data of the nearest neighbors of each patient for each encoding - neigh_data = build_knn_dataframe(patients, kneighbors_ids, cat_to_num=True)[var_attr] - - # Compute the mean of the reference attribute over the nearest neighbors of each patient - neigh_agg_var = neigh_data.groupby(level=["model", "patient_id"]).agg(agg) - - # Divide the population into bins based on the reference attribute - bins = np.linspace(min(neigh_agg_var), max(neigh_agg_var), num=n_bins + 1) - bins[-1] += 1e-6 # Add epsilon to the last bin's upper bound since it's excluded by `np.digitize` - bin_labels = np.digitize(neigh_agg_var, bins) - 1 # Subtract 1 because bin indexing starts at 1 - # Since the bin labels are attributed based on the aggregation of neighbors, repeat the labels `n_neighbors` times - # to obtain a list of labels for each neighbor of each patient for each kneighbors model - bin_labels = pd.Series(np.repeat(bin_labels, kneighbors_ids.shape[-1], axis=0), index=neigh_data.index) - - # For each bin, flatten the list of neighbors of each patient in the bin - # (to obtain lists of patients, with duplicates, in each bin) - patient_ids_by_bin = { - bin_idx: bin_labels[bin_labels == bin_idx].index.get_level_values("neighbor_id").tolist() - for bin_idx in sorted(bin_labels.unique()) - } - patients_by_bin = { - bin_idx: list(Patients.from_dict({patient_id: patients[patient_id] for patient_id in patient_ids}).values()) - for bin_idx, patient_ids in patient_ids_by_bin.items() - } - - # For each time-series attr, build the dataframe of the mean curve for each bin and plot the curves for each bin - for time_series_attr in time_series_attrs: - neigh_agg_time_series_attr_data = build_time_series_attr_by_patient_group_dataframe( - patients_by_bin, time_series_attr, group_desc="bin", mask_tag=mask_tag - ) - - with sns.axes_style("darkgrid"): - plot = sns.lineplot( - data=neigh_agg_time_series_attr_data, - x="time", - y="val", - hue="bin", - hue_order=sorted(bin_labels.unique()), - ) - title = f"{'/'.join(time_series_attr)}_wrt_{var_attr}_bins" - plot.set(title=title, ylabel=TIME_SERIES_ATTR_LABELS[time_series_attr[1]]) - plot.legend(title=f"{var_attr} bin") - - yield title, plot - - -def main(): - """Run the script.""" - from argparse import ArgumentParser - - from matplotlib import pyplot as plt - from tqdm import tqdm - from vital.utils.logging import configure_logging - from vital.utils.parsing import yaml_flow_collection - from vital.utils.saving import load_from_checkpoint - - from didactic.tasks.utils import encode_patients - - # Ensure that matplotlib is using 'agg' backend in non-interactive case - plt.switch_backend("agg") - - configure_logging(log_to_console=True, console_level=logging.INFO) - parser = ArgumentParser() - parser.add_argument( - "models_ckpts", - nargs="+", - type=Path, - help="Checkpoints of models for which to compare and compile the local neighborhood", - ) - parser = Patients.add_args(parser) - parser.add_argument( - "--neigh_kwargs", - type=yaml_flow_collection, - metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the `KNeighborsClassifier` constructor", - ) - parser.add_argument( - "--reference_attr", - type=TabularAttribute, - default=TabularAttribute.ht_severity, - help="Reference tabular attribute w.r.t. which to plot the variability of the other attributes", - ) - parser.add_argument( - "--tabular_attrs", - type=TabularAttribute, - nargs="*", - choices=list(TabularAttribute), - help="Subset of tabular attributes on which to compile the results. If not provided, will default to all " - "available attributes", - ) - parser.add_argument( - "--time_series_attrs", - type=TimeSeriesAttribute, - choices=list(TimeSeriesAttribute), - nargs="*", - default=list(TimeSeriesAttribute), - help="Subset of time-series attributes derived from segmentations for which to plot the intra/inter-cluster " - "variability", - ) - parser.add_argument( - "--mask_tag", - type=str, - default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the time-series attributes", - ) - parser.add_argument( - "--agg", - type=str, - choices=["mean", "median"], - default="mean", - help="Aggregation function to apply to the neighborhood of each patient", - ) - parser.add_argument( - "--tabular_plot_kwargs", - type=yaml_flow_collection, - metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the call to `seaborn.object.Plot` for tabular attributes figures", - ) - parser.add_argument( - "--tabular_dots_kwargs", - type=yaml_flow_collection, - metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the call to `seaborn.object.Plot.add` for the scatter plot layer in tabular " - "attributes figures", - ) - parser.add_argument( - "--tabular_polyfit_kwargs", - type=yaml_flow_collection, - metavar="{ARG1:VAL1,ARG2:VAL2,...}", - help="Parameters to forward to the call to `seaborn.object.Plot.add` for the polynomial regression layer in " - "tabular attributes figures", - ) - parser.add_argument( - "--time_series_n_bins", - type=int, - default=5, - help="Number of bins by which to divide the population and over each of which to aggregate the time-series " - "attributes", - ) - parser.add_argument( - "--output_dir", - type=Path, - default=Path("compiled_representation_knn"), - help="Root directory under which to save the compiled results for all of the methods", - ) - args = parser.parse_args() - kwargs = vars(args) - - ( - models_ckpts, - neigh_kwargs, - ref_attr, - tabular_attrs, - time_series_attrs, - mask_tag, - agg, - plot_kwargs, - dots_kwargs, - polyfit_kwargs, - time_series_n_bins, - output_dir, - ) = list( - map( - kwargs.pop, - [ - "models_ckpts", - "neigh_kwargs", - "reference_attr", - "tabular_attrs", - "time_series_attrs", - "mask_tag", - "agg", - "tabular_plot_kwargs", - "tabular_dots_kwargs", - "tabular_polyfit_kwargs", - "time_series_n_bins", - "output_dir", - ], - ) - ) - time_series_attrs_keys = [ - (view, time_series_attr) for view, time_series_attr in itertools.product(args.views, time_series_attrs) - ] - if neigh_kwargs is None: - neigh_kwargs = {} - - # Load the dataset - patients = Patients(**kwargs) - - # Load the models and process the patients to generate the encodings - models = [ - load_from_checkpoint(ckpt, expected_checkpoint_type=CardiacMultimodalRepresentationTask) - for ckpt in models_ckpts - ] - encodings = [ - encode_patients(model, patients.values(), mask_tag=mask_tag) - for model in tqdm(models, desc="Encoding patients using each model", unit="model") - ] - - # Find the nearest neighbors of each patient for each encoding - kneighbors_ids = find_nearest_neighbors(encodings, list(patients), **neigh_kwargs) - - tabular_attrs_plots = plot_tabular_attrs_variability_figures( - patients, - kneighbors_ids, - ref_attr, - tabular_attrs=tabular_attrs, - agg=agg, - plot_kwargs=plot_kwargs, - dots_layer_kwargs=dots_kwargs, - polyfit_layer_kwargs=polyfit_kwargs, - ) - time_series_attrs_plots = plot_time_series_attrs_variability_figures( - patients, - kneighbors_ids, - ref_attr, - time_series_attrs_keys, - agg=agg, - mask_tag=mask_tag, - n_bins=time_series_n_bins, - ) - - # Plot the variability of the tabular and image attributes - output_dir.mkdir(parents=True, exist_ok=True) # Prepare the output folder for the method - n_plots = (len(tabular_attrs) if tabular_attrs else len(TabularAttribute)) + len(time_series_attrs_keys) - for title, plot in tqdm( - itertools.chain(tabular_attrs_plots, time_series_attrs_plots), - desc=f"Plotting the variability of the attributes w.r.t. {ref_attr}", - unit="attr", - total=n_plots, - ): - title_pathified = title.lower().replace("/", "_").replace(" ", "_") - filepath = output_dir / f"{title_pathified}.svg" - - if isinstance(plot, so.Plot): - plot.save(filepath, bbox_inches="tight") - elif isinstance(plot, Axes): - plt.savefig(filepath) - plt.close() # Close the figure to avoid contamination between plots - else: - raise ValueError(f"Unable to save the figure for plot type: {type(plot)}.") - - -if __name__ == "__main__": - main()