Skip to content

Commit

Permalink
Clean up dendrogram typing
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Sep 20, 2024
1 parent b0597a9 commit 9254e0c
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 50 deletions.
94 changes: 49 additions & 45 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

import collections.abc as cabc
from collections import OrderedDict
from collections.abc import Collection, Mapping, Sequence
from itertools import product
from typing import TYPE_CHECKING, get_args

Expand Down Expand Up @@ -38,7 +38,7 @@
)

if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Mapping, Sequence
from collections.abc import Iterable
from typing import Literal, Union

from anndata import AnnData
Expand Down Expand Up @@ -220,7 +220,7 @@ def _scatter_obs(
isinstance(layers, str) and layers in adata.layers.keys()
):
layers = (layers, layers, layers)
elif isinstance(layers, cabc.Collection) and len(layers) == 3:
elif isinstance(layers, Collection) and len(layers) == 3:
layers = tuple(layers)
for layer in layers:
if layer not in adata.layers.keys() and layer not in ["X", None]:
Expand Down Expand Up @@ -299,7 +299,7 @@ def _scatter_obs(
palette_was_none = False
if palette is None:
palette_was_none = True
if isinstance(palette, cabc.Sequence) and not isinstance(palette, str):
if isinstance(palette, Sequence) and not isinstance(palette, str):
if not is_color_like(palette[0]):
palettes = palette
else:
Expand Down Expand Up @@ -1189,7 +1189,7 @@ def heatmap(
dendro_data = _reorder_categories_after_dendrogram(
adata,
groupby,
dendrogram,
dendrogram_key=None if dendrogram is True else dendrogram,
var_names=var_names,
var_group_labels=var_group_labels,
var_group_positions=var_group_positions,
Expand Down Expand Up @@ -1579,7 +1579,7 @@ def tracksplot(
dendro_data = _reorder_categories_after_dendrogram(
adata,
groupby,
dendrogram,
dendrogram_key=None if dendrogram is True else dendrogram,
var_names=var_names,
var_group_labels=var_group_labels,
var_group_positions=var_group_positions,
Expand Down Expand Up @@ -2247,13 +2247,13 @@ def _plot_gene_groups_brackets(

def _reorder_categories_after_dendrogram(
adata: AnnData,
groupby,
dendrogram,
groupby: str | Sequence[str],
*,
var_names=None,
var_group_labels=None,
var_group_positions=None,
categories=None,
dendrogram_key: str | None,
var_names: Sequence[str],
var_group_labels: Sequence[str] | None,
var_group_positions: Sequence[tuple[int, int]] | None,
categories: Sequence[str],
):
"""\
Function used by plotting functions that need to reorder the the groupby
Expand All @@ -2273,7 +2273,7 @@ def _reorder_categories_after_dendrogram(
'var_group_labels', and 'var_group_positions'
"""

key = _get_dendrogram_key(adata, dendrogram, groupby)
key = _get_dendrogram_key(adata, dendrogram_key, groupby)

if isinstance(groupby, str):
groupby = [groupby]
Expand Down Expand Up @@ -2305,36 +2305,35 @@ def _reorder_categories_after_dendrogram(
)

# reorder var_groups (if any)
if var_names is not None:
var_names_idx_ordered = list(range(len(var_names)))

if var_group_positions:
if set(var_group_labels) == set(categories):
positions_ordered = []
labels_ordered = []
position_start = 0
var_names_idx_ordered = []
for cat_name in categories_ordered:
idx = var_group_labels.index(cat_name)
position = var_group_positions[idx]
_var_names = var_names[position[0] : position[1] + 1]
var_names_idx_ordered.extend(range(position[0], position[1] + 1))
positions_ordered.append(
(position_start, position_start + len(_var_names) - 1)
)
position_start += len(_var_names)
labels_ordered.append(var_group_labels[idx])
var_group_labels = labels_ordered
var_group_positions = positions_ordered
else:
logg.warning(
"Groups are not reordered because the `groupby` categories "
"and the `var_group_labels` are different.\n"
f"categories: {_format_first_three_categories(categories)}\n"
f"var_group_labels: {_format_first_three_categories(var_group_labels)}"
if var_group_positions is None or var_group_labels is None:
assert var_group_positions is None
assert var_group_labels is None
var_names_idx_ordered = None
elif set(var_group_labels) == set(categories):
positions_ordered = []
labels_ordered = []
position_start = 0
var_names_idx_ordered = []
for cat_name in categories_ordered:
idx = var_group_labels.index(cat_name)
position = var_group_positions[idx]
_var_names = var_names[position[0] : position[1] + 1]
var_names_idx_ordered.extend(range(position[0], position[1] + 1))
positions_ordered.append(
(position_start, position_start + len(_var_names) - 1)
)
position_start += len(_var_names)
labels_ordered.append(var_group_labels[idx])
var_group_labels = labels_ordered
var_group_positions = positions_ordered
else:
var_names_idx_ordered = None
logg.warning(
"Groups are not reordered because the `groupby` categories "
"and the `var_group_labels` are different.\n"
f"categories: {_format_first_three_categories(categories)}\n"
f"var_group_labels: {_format_first_three_categories(var_group_labels)}"
)
var_names_idx_ordered = list(range(len(var_names)))

if var_names_idx_ordered is not None:
var_names_ordered = [var_names[x] for x in var_names_idx_ordered]
Expand All @@ -2358,14 +2357,19 @@ def _format_first_three_categories(categories):
return ", ".join(categories)


def _get_dendrogram_key(adata, dendrogram_key, groupby):
def _get_dendrogram_key(
adata: AnnData, dendrogram_key: str | None, groupby: str | Sequence[str]
) -> str:
# the `dendrogram_key` can be a bool an NoneType or the name of the
# dendrogram key. By default the name of the dendrogram key is 'dendrogram'
if not isinstance(dendrogram_key, str):
if dendrogram_key is None:
if isinstance(groupby, str):
dendrogram_key = f"dendrogram_{groupby}"
elif isinstance(groupby, list):
elif isinstance(groupby, Sequence):
dendrogram_key = f'dendrogram_{"_".join(groupby)}'
else:
msg = f"groupby has wrong type: {type(groupby).__name__}."
raise AssertionError(msg)

if dendrogram_key not in adata.uns:
from ..tools._dendrogram import dendrogram
Expand Down Expand Up @@ -2665,7 +2669,7 @@ def _check_var_names_type(var_names, var_group_labels, var_group_positions):
var_names, var_group_labels, var_group_positions
"""
if isinstance(var_names, cabc.Mapping):
if isinstance(var_names, Mapping):
if var_group_labels is not None or var_group_positions is not None:
logg.warning(
"`var_names` is a dictionary. This will reset the current "
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def savefig(self, filename: str, bbox_inches: str | None = "tight", **kwargs):
self.make_figure()
plt.savefig(filename, bbox_inches=bbox_inches, **kwargs)

def _reorder_categories_after_dendrogram(self, dendrogram) -> None:
def _reorder_categories_after_dendrogram(self, dendrogram_key: str | None) -> None:
"""\
Function used by plotting functions that need to reorder the the groupby
observations based on the dendrogram results.
Expand All @@ -900,7 +900,7 @@ def _format_first_three_categories(_categories):
_categories = _categories[:3] + ["etc."]
return ", ".join(_categories)

key = _get_dendrogram_key(self.adata, dendrogram, self.groupby)
key = _get_dendrogram_key(self.adata, dendrogram_key, self.groupby)

dendro_info = self.adata.uns[key]
if self.groupby != dendro_info["groupby"]:
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/plotting/_dotplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def dotplot(
)

if dendrogram:
dp.add_dendrogram(dendrogram_key=dendrogram)
dp.add_dendrogram(dendrogram_key=None if dendrogram is True else dendrogram)
if swap_axes:
dp.swap_axes()

Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/plotting/_matrixplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def matrixplot(
)

if dendrogram:
mp.add_dendrogram(dendrogram_key=dendrogram)
mp.add_dendrogram(dendrogram_key=None if dendrogram is True else dendrogram)
if swap_axes:
mp.swap_axes()

Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/plotting/_stacked_violin.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def stacked_violin(
)

if dendrogram:
vp.add_dendrogram(dendrogram_key=dendrogram)
vp.add_dendrogram(dendrogram_key=None if dendrogram is True else dendrogram)
if swap_axes:
vp.swap_axes()
vp = vp.style(
Expand Down

0 comments on commit 9254e0c

Please sign in to comment.