From 4cf6a3cf398cba98c44f528b2639c03ea0aea715 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Thu, 7 Jul 2022 19:36:31 +0200 Subject: [PATCH 1/8] add basic plotting function for metrics --- scib/__init__.py | 3 +- scib/plotting.py | 72 +++++++++++++++++++++++++++++++++++++++ tests/plots/test_plots.py | 43 +++++++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 scib/plotting.py create mode 100644 tests/plots/test_plots.py diff --git a/scib/__init__.py b/scib/__init__.py index f2ce75ce..1ea78958 100644 --- a/scib/__init__.py +++ b/scib/__init__.py @@ -5,7 +5,7 @@ __version__ = metadata.version("scib") -from . import integration, metrics, preprocessing, utils +from . import integration, metrics, plotting, preprocessing, utils from ._package_tools import rename_func from .metrics import clustering @@ -33,3 +33,4 @@ ig = integration me = metrics cl = clustering +pl = plotting diff --git a/scib/plotting.py b/scib/plotting.py new file mode 100644 index 00000000..aace760b --- /dev/null +++ b/scib/plotting.py @@ -0,0 +1,72 @@ +import numpy as np +import seaborn as sns +from matplotlib import pyplot as plt + + +def metrics(df, batch_metrics=None, bio_metrics=None, palette=None): + sns.set_context("paper") + + if palette is None: + palette = "viridis_r" + # sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True) + + if batch_metrics is None: + batch_metrics = ["ASW_batch", "PCR_batch", "graph_conn", "kBET", "iLISI"] + + if bio_metrics is None: + bio_metrics = [ + "NMI_cluster", + "ARI_cluster", + "ASW_label", + "cell_cycle_conservation", + "isolated_label_F1", + "isolated_label_silhouette", + "cLISI", + "hvg_overlap", + "trajectory", + ] + + df = df.melt(id_vars=["method"], var_name="metric", value_name="value") + + conditions = [(df["metric"].isin(batch_metrics)), (df["metric"].isin(bio_metrics))] + metric_type = ["Batch Correction", "Biological Conservation"] + + df["metric_type"] = np.select(conditions, metric_type) + df["metric"] = df["metric"].str.replace("_", " ") + df["rank"] = df.groupby("metric")["value"].rank(ascending=False) + + dims = df[["metric_type", "metric"]].drop_duplicates()["metric_type"].value_counts() + n_metrics = dims.sum() + n_methods = df["method"].nunique() + dim_x = (n_metrics + dims.shape[0]) * 0.48 + dim_y = np.max([2.5, n_methods]) + + # Build plot + fig, axs = plt.subplots( + nrows=1, + ncols=dims.shape[0], + figsize=(dim_x, dim_y), + sharey=True, + gridspec_kw=dict(width_ratios=list(dims)), + ) + + for i, metric_type in enumerate(dims.index): + sns.despine(bottom=True, left=True) + sns.scatterplot( + data=df.query(f'metric_type == "{metric_type}"'), + x="metric", + y="method", + hue="rank", + palette=palette, + size="value", + sizes=(0, 100), + ax=axs[i], + ) + axs[i].set(title=metric_type, xlabel=None, ylabel=None) + axs[i].tick_params(axis="x", rotation=90) + axs[i].legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) + + for t in axs[i].legend_.texts: + t.set_text(t.get_text()[:5]) + + fig.tight_layout() diff --git a/tests/plots/test_plots.py b/tests/plots/test_plots.py new file mode 100644 index 00000000..4292a830 --- /dev/null +++ b/tests/plots/test_plots.py @@ -0,0 +1,43 @@ +import numpy as np +import pandas as pd + +import scib + + +def test_plot(): + data = { + "ARI_cluster": {0: 0.951112722518898, 1: 0.262192519680191, 2: 0.1, 3: 0.2}, + "ASW_batch": {0: 0.9057019050549192, 1: 0.8448200803913499, 2: 0.7, 3: 0.3}, + "ASW_label": {0: 0.617242477834225, 1: 0.564448088407517, 2: 0.4, 3: 0.7}, + "NMI_cluster": {0: 0.9138665032024672, 1: 0.632615412598558, 2: 0.4, 3: 0.7}, + "PCR_batch": {0: 0.855878437307926, 1: 0.7125446098053699, 2: 0.6, 3: 0.5}, + "cLISI": {0: 1.0, 1: 0.9993835933509928, 2: 0.8, 3: 0.9}, + "cell_cycle_conservation": { + 0: 0.470498471863989, + 1: 0.741363581608263, + 2: 0.6, + 3: 0.8, + }, + "graph_conn": {0: 0.971955345243732, 1: 0.944989571511962, 2: 0.8, 3: 0.7}, + "hvg_overlap": {0: 0.4772209890553079, 1: 0.2025893518406739, 2: 0.1, 3: 0.2}, + "iLISI": {0: 0.07924053136125, 1: 0.004064867867098, 2: 0.1, 3: 0.2}, + "isolated_label_F1": { + 0: 0.107692307692308, + 1: 0.106870229007634, + 2: 0.1, + 3: 0.3, + }, + "isolated_label_silhouette": { + 0: 0.520902156829834, + 1: 0.550404392182827, + 2: 0.4, + 3: 0.6, + }, + "kBET": {0: 0.3197709591957574, 1: 0.2183332674192387, 2: 0.1, 3: 0.2}, + "method": {0: "method1", 1: "method2", 2: "method3", 3: "method4"}, + "trajectory": {0: np.nan, 1: np.nan, 2: np.nan, 3: np.nan}, + } + + df = pd.DataFrame(data) + scib.pl.metrics(df, palette="viridis") + scib.pl.metrics(df[0:2]) From 3291090422ebddff3b63ae84550b87bb96817b73 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Thu, 7 Jul 2022 21:25:16 +0200 Subject: [PATCH 2/8] annotate plotting function and adapted to edge cases --- scib/plotting.py | 73 +++++++++++++++++++++++++++------------ tests/plots/test_plots.py | 5 +-- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/scib/plotting.py b/scib/plotting.py index aace760b..39a9a5ee 100644 --- a/scib/plotting.py +++ b/scib/plotting.py @@ -3,8 +3,25 @@ from matplotlib import pyplot as plt -def metrics(df, batch_metrics=None, bio_metrics=None, palette=None): +def metrics( + metrics_df, + method_column="method", + metric_column="metric", + value_column="value", + batch_metrics=None, + bio_metrics=None, + palette=None, +): + """ + :param metrics_df: dataframe with columns for methods, metrics and metric values + :param metric_column: column in ``metrics_df`` of metrics + :param method_column: column in ``metrics_df`` of methods + :param batch_metrics: list of batch correction metrics in metrics column for annotating metric type + :param bio_metrics: list of biological conservation metrics in the metrics column for annotating metric type + :param palette: color map as input for ``seaborn.scatterplot`` + """ sns.set_context("paper") + sns.set_style("white") if palette is None: palette = "viridis_r" @@ -26,47 +43,57 @@ def metrics(df, batch_metrics=None, bio_metrics=None, palette=None): "trajectory", ] - df = df.melt(id_vars=["method"], var_name="metric", value_name="value") + df = metrics_df.copy() - conditions = [(df["metric"].isin(batch_metrics)), (df["metric"].isin(bio_metrics))] + conditions = [ + (df[metric_column].isin(batch_metrics)), + (df[metric_column].isin(bio_metrics)), + ] metric_type = ["Batch Correction", "Biological Conservation"] - df["metric_type"] = np.select(conditions, metric_type) - df["metric"] = df["metric"].str.replace("_", " ") - df["rank"] = df.groupby("metric")["value"].rank(ascending=False) + df[metric_column] = df[metric_column].str.replace("_", " ") + df["rank"] = df.groupby(metric_column)[value_column].rank(ascending=False) - dims = df[["metric_type", "metric"]].drop_duplicates()["metric_type"].value_counts() + dims = ( + df[["metric_type", metric_column]] + .drop_duplicates()["metric_type"] + .value_counts() + ) + n_metric_types = dims.shape[0] n_metrics = dims.sum() - n_methods = df["method"].nunique() - dim_x = (n_metrics + dims.shape[0]) * 0.48 - dim_y = np.max([2.5, n_methods]) + n_methods = df[method_column].nunique() + dim_x = np.max([4, (n_metrics + n_metric_types) * 0.4]) + dim_y = np.max([2.5, n_methods * 0.9]) # Build plot fig, axs = plt.subplots( nrows=1, - ncols=dims.shape[0], + ncols=n_metric_types, figsize=(dim_x, dim_y), sharey=True, gridspec_kw=dict(width_ratios=list(dims)), ) for i, metric_type in enumerate(dims.index): - sns.despine(bottom=True, left=True) + df_sub = df.query(f'metric_type == "{metric_type}"') + ax = axs if n_metric_types == 1 else axs[i] sns.scatterplot( - data=df.query(f'metric_type == "{metric_type}"'), - x="metric", - y="method", + data=df_sub, + x=metric_column, + y=method_column, hue="rank", palette=palette, - size="value", - sizes=(0, 100), - ax=axs[i], + size=value_column, + sizes=(df_sub["value"].min() * 100, df_sub["value"].max() * 100), + # sizes={x: int(x * 200) for x in df_sub['value'].dropna().unique()}, + legend="brief", + ax=ax, ) - axs[i].set(title=metric_type, xlabel=None, ylabel=None) - axs[i].tick_params(axis="x", rotation=90) - axs[i].legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) - - for t in axs[i].legend_.texts: + ax.set(title=metric_type, xlabel=None, ylabel=None) + ax.tick_params(axis="x", rotation=90) + ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) + for t in ax.legend_.texts: t.set_text(t.get_text()[:5]) + sns.despine(bottom=True, left=True) fig.tight_layout() diff --git a/tests/plots/test_plots.py b/tests/plots/test_plots.py index 4292a830..ebb2fc1e 100644 --- a/tests/plots/test_plots.py +++ b/tests/plots/test_plots.py @@ -39,5 +39,6 @@ def test_plot(): } df = pd.DataFrame(data) - scib.pl.metrics(df, palette="viridis") - scib.pl.metrics(df[0:2]) + df = df.melt(id_vars=["method"], var_name="metric", value_name="value") + scib.pl.metrics(df) + scib.pl.metrics(df[0:1]) From 4f43b4f5062e33f5e06eaeb30f731aea1002c21f Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Thu, 7 Jul 2022 22:07:32 +0200 Subject: [PATCH 3/8] include overall scores --- scib/plotting.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/scib/plotting.py b/scib/plotting.py index 39a9a5ee..635445cf 100644 --- a/scib/plotting.py +++ b/scib/plotting.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import seaborn as sns from matplotlib import pyplot as plt @@ -52,7 +53,26 @@ def metrics( metric_type = ["Batch Correction", "Biological Conservation"] df["metric_type"] = np.select(conditions, metric_type) df[metric_column] = df[metric_column].str.replace("_", " ") - df["rank"] = df.groupby(metric_column)[value_column].rank(ascending=False) + + # overall score + df = pd.concat( + [ + df, + df.groupby([method_column, "metric_type"])[value_column] + .mean() + .reset_index() + .assign(metric="Overall"), + df.groupby(method_column)[value_column] + .mean() + .reset_index() + .assign(metric_type="Overall", metric="Overall"), + ] + ) + + # rank + df["rank"] = df.groupby([metric_column, "metric_type"])[value_column].rank( + ascending=False + ) dims = ( df[["metric_type", metric_column]] @@ -75,6 +95,7 @@ def metrics( ) for i, metric_type in enumerate(dims.index): + legend = None if metric_type == "Overall" else "brief" df_sub = df.query(f'metric_type == "{metric_type}"') ax = axs if n_metric_types == 1 else axs[i] sns.scatterplot( @@ -86,14 +107,15 @@ def metrics( size=value_column, sizes=(df_sub["value"].min() * 100, df_sub["value"].max() * 100), # sizes={x: int(x * 200) for x in df_sub['value'].dropna().unique()}, - legend="brief", + legend=legend, ax=ax, ) ax.set(title=metric_type, xlabel=None, ylabel=None) ax.tick_params(axis="x", rotation=90) - ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) - for t in ax.legend_.texts: - t.set_text(t.get_text()[:5]) + if legend is not None: + ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) + for t in ax.legend_.texts: + t.set_text(t.get_text()[:5]) sns.despine(bottom=True, left=True) fig.tight_layout() From ffd3b043daea75feb1cf32c2348d6c2b823942c3 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Thu, 7 Jul 2022 22:59:09 +0200 Subject: [PATCH 4/8] fix ranking and make overall score optional --- scib/plotting.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/scib/plotting.py b/scib/plotting.py index 635445cf..cdbf4f90 100644 --- a/scib/plotting.py +++ b/scib/plotting.py @@ -12,6 +12,7 @@ def metrics( batch_metrics=None, bio_metrics=None, palette=None, + overall=True, ): """ :param metrics_df: dataframe with columns for methods, metrics and metric values @@ -20,6 +21,7 @@ def metrics( :param batch_metrics: list of batch correction metrics in metrics column for annotating metric type :param bio_metrics: list of biological conservation metrics in the metrics column for annotating metric type :param palette: color map as input for ``seaborn.scatterplot`` + :param overall: whether to include a column for the overall score """ sns.set_context("paper") sns.set_style("white") @@ -55,25 +57,35 @@ def metrics( df[metric_column] = df[metric_column].str.replace("_", " ") # overall score - df = pd.concat( - [ - df, - df.groupby([method_column, "metric_type"])[value_column] - .mean() - .reset_index() - .assign(metric="Overall"), + df_list = [ + df, + df.groupby([method_column, "metric_type"])[value_column] + .mean() + .reset_index() + .assign(metric="Overall"), + ] + if overall: + df_list.append( df.groupby(method_column)[value_column] .mean() .reset_index() - .assign(metric_type="Overall", metric="Overall"), - ] - ) + .assign(metric_type="Overall", metric="Overall") + ) + df = pd.concat(df_list) - # rank - df["rank"] = df.groupby([metric_column, "metric_type"])[value_column].rank( - ascending=False + # rank metrics + df["rank"] = ( + df.groupby([metric_column, "metric_type"])[value_column] + .rank( + method="min", + ascending=False, + na_option="bottom", + ) + .astype(int) ) + df = df.sort_values("rank") + # get plot dimensions dims = ( df[["metric_type", metric_column]] .drop_duplicates()["metric_type"] @@ -95,7 +107,7 @@ def metrics( ) for i, metric_type in enumerate(dims.index): - legend = None if metric_type == "Overall" else "brief" + legend = "brief" if i == 0 else None df_sub = df.query(f'metric_type == "{metric_type}"') ax = axs if n_metric_types == 1 else axs[i] sns.scatterplot( From 5d2bd36ef97b58deadc12ef1df1308f7be4da98f Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Sat, 9 Jul 2022 19:44:04 +0200 Subject: [PATCH 5/8] return figure --- scib/plotting.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scib/plotting.py b/scib/plotting.py index cdbf4f90..fceb0a63 100644 --- a/scib/plotting.py +++ b/scib/plotting.py @@ -13,6 +13,7 @@ def metrics( bio_metrics=None, palette=None, overall=True, + return_fig=False, ): """ :param metrics_df: dataframe with columns for methods, metrics and metric values @@ -22,6 +23,7 @@ def metrics( :param bio_metrics: list of biological conservation metrics in the metrics column for annotating metric type :param palette: color map as input for ``seaborn.scatterplot`` :param overall: whether to include a column for the overall score + :param return_fig: whether to return a fig object """ sns.set_context("paper") sns.set_style("white") @@ -131,3 +133,6 @@ def metrics( sns.despine(bottom=True, left=True) fig.tight_layout() + + if return_fig: + return fig From d0a5d96f0b6dd00337881b47e5ec0abe3d91f903 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Sun, 10 Jul 2022 16:04:01 +0200 Subject: [PATCH 6/8] fix scaling for longer method names --- scib/plotting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scib/plotting.py b/scib/plotting.py index fceb0a63..ce7d61d1 100644 --- a/scib/plotting.py +++ b/scib/plotting.py @@ -17,8 +17,9 @@ def metrics( ): """ :param metrics_df: dataframe with columns for methods, metrics and metric values - :param metric_column: column in ``metrics_df`` of metrics :param method_column: column in ``metrics_df`` of methods + :param metric_column: column in ``metrics_df`` of metrics + :param value_column: column in ``metrics_df`` with metric values :param batch_metrics: list of batch correction metrics in metrics column for annotating metric type :param bio_metrics: list of biological conservation metrics in the metrics column for annotating metric type :param palette: color map as input for ``seaborn.scatterplot`` @@ -96,7 +97,8 @@ def metrics( n_metric_types = dims.shape[0] n_metrics = dims.sum() n_methods = df[method_column].nunique() - dim_x = np.max([4, (n_metrics + n_metric_types) * 0.4]) + metric_len = df[metric_column].str.len().max() + dim_x = np.max([4, (n_metrics + n_metric_types + (metric_len / 10)) * 0.4]) dim_y = np.max([2.5, n_methods * 0.9]) # Build plot @@ -121,6 +123,7 @@ def metrics( size=value_column, sizes=(df_sub["value"].min() * 100, df_sub["value"].max() * 100), # sizes={x: int(x * 200) for x in df_sub['value'].dropna().unique()}, + edgecolor="black", legend=legend, ax=ax, ) From bebfdaeab3c08e7168419be65b6746561a13eb02 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Sun, 10 Jul 2022 22:55:44 +0200 Subject: [PATCH 7/8] update figure size --- scib/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scib/plotting.py b/scib/plotting.py index ce7d61d1..e453cec5 100644 --- a/scib/plotting.py +++ b/scib/plotting.py @@ -98,7 +98,7 @@ def metrics( n_metrics = dims.sum() n_methods = df[method_column].nunique() metric_len = df[metric_column].str.len().max() - dim_x = np.max([4, (n_metrics + n_metric_types + (metric_len / 10)) * 0.4]) + dim_x = np.max([4, (n_metrics + n_metric_types) * 0.4 + (metric_len / 10)]) dim_y = np.max([2.5, n_methods * 0.9]) # Build plot From 3f372972e56d4482d865a87f1922b3c8eeada2e4 Mon Sep 17 00:00:00 2001 From: Michaela Mueller Date: Tue, 12 Jul 2022 12:48:41 +0200 Subject: [PATCH 8/8] use overall ranking to sort methods --- scib/plotting.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/scib/plotting.py b/scib/plotting.py index e453cec5..586ddbec 100644 --- a/scib/plotting.py +++ b/scib/plotting.py @@ -66,14 +66,11 @@ def metrics( .mean() .reset_index() .assign(metric="Overall"), + df.groupby(method_column)[value_column] + .mean() + .reset_index() + .assign(metric_type="Overall", metric="Overall"), ] - if overall: - df_list.append( - df.groupby(method_column)[value_column] - .mean() - .reset_index() - .assign(metric_type="Overall", metric="Overall") - ) df = pd.concat(df_list) # rank metrics @@ -86,7 +83,10 @@ def metrics( ) .astype(int) ) - df = df.sort_values("rank") + method_rank = df.query('metric_type == "Overall"').sort_values( + "rank", ascending=True + )[method_column] + df[method_column] = pd.Categorical(df[method_column], categories=method_rank) # get plot dimensions dims = (