From 9254e0cd52b859cbd9317409139e28cac96bc2fc Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Sep 2024 13:13:24 +0200 Subject: [PATCH] Clean up dendrogram typing --- src/scanpy/plotting/_anndata.py | 94 ++++++++++++++------------ src/scanpy/plotting/_baseplot_class.py | 4 +- src/scanpy/plotting/_dotplot.py | 2 +- src/scanpy/plotting/_matrixplot.py | 2 +- src/scanpy/plotting/_stacked_violin.py | 2 +- 5 files changed, 54 insertions(+), 50 deletions(-) diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index ddd639f62..612f47e99 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -2,8 +2,8 @@ from __future__ import annotations -import collections.abc as cabc from collections import OrderedDict +from collections.abc import Collection, Mapping, Sequence from itertools import product from typing import TYPE_CHECKING, get_args @@ -38,7 +38,7 @@ ) if TYPE_CHECKING: - from collections.abc import Collection, Iterable, Mapping, Sequence + from collections.abc import Iterable from typing import Literal, Union from anndata import AnnData @@ -220,7 +220,7 @@ def _scatter_obs( isinstance(layers, str) and layers in adata.layers.keys() ): layers = (layers, layers, layers) - elif isinstance(layers, cabc.Collection) and len(layers) == 3: + elif isinstance(layers, Collection) and len(layers) == 3: layers = tuple(layers) for layer in layers: if layer not in adata.layers.keys() and layer not in ["X", None]: @@ -299,7 +299,7 @@ def _scatter_obs( palette_was_none = False if palette is None: palette_was_none = True - if isinstance(palette, cabc.Sequence) and not isinstance(palette, str): + if isinstance(palette, Sequence) and not isinstance(palette, str): if not is_color_like(palette[0]): palettes = palette else: @@ -1189,7 +1189,7 @@ def heatmap( dendro_data = _reorder_categories_after_dendrogram( adata, groupby, - dendrogram, + dendrogram_key=None if dendrogram is True else dendrogram, var_names=var_names, var_group_labels=var_group_labels, var_group_positions=var_group_positions, @@ -1579,7 +1579,7 @@ def tracksplot( dendro_data = _reorder_categories_after_dendrogram( adata, groupby, - dendrogram, + dendrogram_key=None if dendrogram is True else dendrogram, var_names=var_names, var_group_labels=var_group_labels, var_group_positions=var_group_positions, @@ -2247,13 +2247,13 @@ def _plot_gene_groups_brackets( def _reorder_categories_after_dendrogram( adata: AnnData, - groupby, - dendrogram, + groupby: str | Sequence[str], *, - var_names=None, - var_group_labels=None, - var_group_positions=None, - categories=None, + dendrogram_key: str | None, + var_names: Sequence[str], + var_group_labels: Sequence[str] | None, + var_group_positions: Sequence[tuple[int, int]] | None, + categories: Sequence[str], ): """\ Function used by plotting functions that need to reorder the the groupby @@ -2273,7 +2273,7 @@ def _reorder_categories_after_dendrogram( 'var_group_labels', and 'var_group_positions' """ - key = _get_dendrogram_key(adata, dendrogram, groupby) + key = _get_dendrogram_key(adata, dendrogram_key, groupby) if isinstance(groupby, str): groupby = [groupby] @@ -2305,36 +2305,35 @@ def _reorder_categories_after_dendrogram( ) # reorder var_groups (if any) - if var_names is not None: - var_names_idx_ordered = list(range(len(var_names))) - - if var_group_positions: - if set(var_group_labels) == set(categories): - positions_ordered = [] - labels_ordered = [] - position_start = 0 - var_names_idx_ordered = [] - for cat_name in categories_ordered: - idx = var_group_labels.index(cat_name) - position = var_group_positions[idx] - _var_names = var_names[position[0] : position[1] + 1] - var_names_idx_ordered.extend(range(position[0], position[1] + 1)) - positions_ordered.append( - (position_start, position_start + len(_var_names) - 1) - ) - position_start += len(_var_names) - labels_ordered.append(var_group_labels[idx]) - var_group_labels = labels_ordered - var_group_positions = positions_ordered - else: - logg.warning( - "Groups are not reordered because the `groupby` categories " - "and the `var_group_labels` are different.\n" - f"categories: {_format_first_three_categories(categories)}\n" - f"var_group_labels: {_format_first_three_categories(var_group_labels)}" + if var_group_positions is None or var_group_labels is None: + assert var_group_positions is None + assert var_group_labels is None + var_names_idx_ordered = None + elif set(var_group_labels) == set(categories): + positions_ordered = [] + labels_ordered = [] + position_start = 0 + var_names_idx_ordered = [] + for cat_name in categories_ordered: + idx = var_group_labels.index(cat_name) + position = var_group_positions[idx] + _var_names = var_names[position[0] : position[1] + 1] + var_names_idx_ordered.extend(range(position[0], position[1] + 1)) + positions_ordered.append( + (position_start, position_start + len(_var_names) - 1) ) + position_start += len(_var_names) + labels_ordered.append(var_group_labels[idx]) + var_group_labels = labels_ordered + var_group_positions = positions_ordered else: - var_names_idx_ordered = None + logg.warning( + "Groups are not reordered because the `groupby` categories " + "and the `var_group_labels` are different.\n" + f"categories: {_format_first_three_categories(categories)}\n" + f"var_group_labels: {_format_first_three_categories(var_group_labels)}" + ) + var_names_idx_ordered = list(range(len(var_names))) if var_names_idx_ordered is not None: var_names_ordered = [var_names[x] for x in var_names_idx_ordered] @@ -2358,14 +2357,19 @@ def _format_first_three_categories(categories): return ", ".join(categories) -def _get_dendrogram_key(adata, dendrogram_key, groupby): +def _get_dendrogram_key( + adata: AnnData, dendrogram_key: str | None, groupby: str | Sequence[str] +) -> str: # the `dendrogram_key` can be a bool an NoneType or the name of the # dendrogram key. By default the name of the dendrogram key is 'dendrogram' - if not isinstance(dendrogram_key, str): + if dendrogram_key is None: if isinstance(groupby, str): dendrogram_key = f"dendrogram_{groupby}" - elif isinstance(groupby, list): + elif isinstance(groupby, Sequence): dendrogram_key = f'dendrogram_{"_".join(groupby)}' + else: + msg = f"groupby has wrong type: {type(groupby).__name__}." + raise AssertionError(msg) if dendrogram_key not in adata.uns: from ..tools._dendrogram import dendrogram @@ -2665,7 +2669,7 @@ def _check_var_names_type(var_names, var_group_labels, var_group_positions): var_names, var_group_labels, var_group_positions """ - if isinstance(var_names, cabc.Mapping): + if isinstance(var_names, Mapping): if var_group_labels is not None or var_group_positions is not None: logg.warning( "`var_names` is a dictionary. This will reset the current " diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index b3b6803c8..f0f42fd4d 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -874,7 +874,7 @@ def savefig(self, filename: str, bbox_inches: str | None = "tight", **kwargs): self.make_figure() plt.savefig(filename, bbox_inches=bbox_inches, **kwargs) - def _reorder_categories_after_dendrogram(self, dendrogram) -> None: + def _reorder_categories_after_dendrogram(self, dendrogram_key: str | None) -> None: """\ Function used by plotting functions that need to reorder the the groupby observations based on the dendrogram results. @@ -900,7 +900,7 @@ def _format_first_three_categories(_categories): _categories = _categories[:3] + ["etc."] return ", ".join(_categories) - key = _get_dendrogram_key(self.adata, dendrogram, self.groupby) + key = _get_dendrogram_key(self.adata, dendrogram_key, self.groupby) dendro_info = self.adata.uns[key] if self.groupby != dendro_info["groupby"]: diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index 2048cd0e8..17c676fb4 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -1043,7 +1043,7 @@ def dotplot( ) if dendrogram: - dp.add_dendrogram(dendrogram_key=dendrogram) + dp.add_dendrogram(dendrogram_key=None if dendrogram is True else dendrogram) if swap_axes: dp.swap_axes() diff --git a/src/scanpy/plotting/_matrixplot.py b/src/scanpy/plotting/_matrixplot.py index f5fc18a72..11024bb95 100644 --- a/src/scanpy/plotting/_matrixplot.py +++ b/src/scanpy/plotting/_matrixplot.py @@ -454,7 +454,7 @@ def matrixplot( ) if dendrogram: - mp.add_dendrogram(dendrogram_key=dendrogram) + mp.add_dendrogram(dendrogram_key=None if dendrogram is True else dendrogram) if swap_axes: mp.swap_axes() diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 3dcbbf067..ba43613e1 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -833,7 +833,7 @@ def stacked_violin( ) if dendrogram: - vp.add_dendrogram(dendrogram_key=dendrogram) + vp.add_dendrogram(dendrogram_key=None if dendrogram is True else dendrogram) if swap_axes: vp.swap_axes() vp = vp.style(