diff --git a/src/scanpy/cli.py b/src/scanpy/cli.py index 4a41c3a0d4..04b75c8b74 100644 --- a/src/scanpy/cli.py +++ b/src/scanpy/cli.py @@ -1,9 +1,9 @@ from __future__ import annotations -import collections.abc as cabc import os import sys from argparse import ArgumentParser, Namespace, _SubParsersAction +from collections.abc import MutableMapping from functools import lru_cache, partial from pathlib import Path from shutil import which @@ -27,7 +27,7 @@ def __init__(self, *args, _command: str, _runargs: dict[str, Any], **kwargs): ) -class _CommandDelegator(cabc.MutableMapping): +class _CommandDelegator(MutableMapping): """\ Provide the ability to delegate, but don’t calculate the whole list until necessary diff --git a/src/scanpy/external/tl/_wishbone.py b/src/scanpy/external/tl/_wishbone.py index 7e78d2eec6..e857226feb 100644 --- a/src/scanpy/external/tl/_wishbone.py +++ b/src/scanpy/external/tl/_wishbone.py @@ -1,6 +1,6 @@ from __future__ import annotations -import collections.abc as cabc +from collections.abc import Collection from typing import TYPE_CHECKING import numpy as np @@ -11,7 +11,7 @@ from ..._utils._doctests import doctest_needs if TYPE_CHECKING: - from collections.abc import Collection, Iterable + from collections.abc import Iterable from anndata import AnnData @@ -115,7 +115,7 @@ def wishbone( f"Start cell {start_cell} not found in data. " "Please rerun with correct start cell." ) - if isinstance(num_waypoints, cabc.Collection): + if isinstance(num_waypoints, Collection): diff = np.setdiff1d(num_waypoints, adata.obs.index) if diff.size > 0: logging.warning( diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index ddd639f62d..e3cbf1ae88 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -2,8 +2,8 @@ from __future__ import annotations -import collections.abc as cabc from collections import OrderedDict +from collections.abc import Collection, Mapping, Sequence from itertools import product from typing import TYPE_CHECKING, get_args @@ -38,7 +38,7 @@ ) if TYPE_CHECKING: - from collections.abc import Collection, Iterable, Mapping, Sequence + from collections.abc import Iterable from typing import Literal, Union from anndata import AnnData @@ -220,7 +220,7 @@ def _scatter_obs( isinstance(layers, str) and layers in adata.layers.keys() ): layers = (layers, layers, layers) - elif isinstance(layers, cabc.Collection) and len(layers) == 3: + elif isinstance(layers, Collection) and len(layers) == 3: layers = tuple(layers) for layer in layers: if layer not in adata.layers.keys() and layer not in ["X", None]: @@ -299,7 +299,7 @@ def _scatter_obs( palette_was_none = False if palette is None: palette_was_none = True - if isinstance(palette, cabc.Sequence) and not isinstance(palette, str): + if isinstance(palette, Sequence) and not isinstance(palette, str): if not is_color_like(palette[0]): palettes = palette else: @@ -2665,7 +2665,7 @@ def _check_var_names_type(var_names, var_group_labels, var_group_positions): var_names, var_group_labels, var_group_positions """ - if isinstance(var_names, cabc.Mapping): + if isinstance(var_names, Mapping): if var_group_labels is not None or var_group_positions is not None: logg.warning( "`var_names` is a dictionary. This will reset the current " diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index b3b6803c8c..928fc0057e 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -2,7 +2,7 @@ from __future__ import annotations -import collections.abc as cabc +from collections.abc import Mapping from typing import TYPE_CHECKING, NamedTuple from warnings import warn @@ -17,7 +17,7 @@ from ._utils import check_colornorm, make_grid_spec if TYPE_CHECKING: - from collections.abc import Iterable, Mapping, Sequence + from collections.abc import Iterable, Sequence from typing import Literal, Self, Union import pandas as pd @@ -1097,7 +1097,7 @@ def _update_var_groups(self) -> None: updates var_names, var_group_labels, var_group_positions """ - if isinstance(self.var_names, cabc.Mapping): + if isinstance(self.var_names, Mapping): if self.has_var_groups: logg.warning( "`var_names` is a dictionary. This will reset the current " diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index 47f392bd44..eec202d0a5 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations -import collections.abc as cabc -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from copy import copy from typing import TYPE_CHECKING @@ -38,7 +37,7 @@ from .scatterplots import _panel_grid, embedding, pca if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Iterable from typing import Literal from anndata import AnnData @@ -1611,11 +1610,7 @@ def embedding_density( # if group is set, then plot it using multiple panels # (even if only one group is set) - if ( - group is not None - and not isinstance(group, str) - and isinstance(group, cabc.Sequence) - ): + if group is not None and not isinstance(group, str) and isinstance(group, Sequence): if ax is not None: raise ValueError("Can only specify `ax` if no `group` sequence is given.") fig, gs = _panel_grid(hspace, wspace, ncols, len(group)) diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index 6ea2560b10..f0d45e9a80 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -1,7 +1,7 @@ from __future__ import annotations -import collections.abc as cabc import warnings +from collections.abc import Collection, Mapping, Sequence from pathlib import Path from types import MappingProxyType from typing import TYPE_CHECKING @@ -26,7 +26,6 @@ from .._utils import matrix if TYPE_CHECKING: - from collections.abc import Mapping, Sequence from typing import Any, Literal, Union from anndata import AnnData @@ -506,14 +505,12 @@ def paga( groups_key = adata.uns["paga"]["groups"] def is_flat(x): - has_one_per_category = isinstance(x, cabc.Collection) and len(x) == len( + has_one_per_category = isinstance(x, Collection) and len(x) == len( adata.obs[groups_key].cat.categories ) return has_one_per_category or x is None or isinstance(x, str) - if isinstance(colors, cabc.Mapping) and isinstance( - colors[next(iter(colors))], cabc.Mapping - ): + if isinstance(colors, Mapping) and isinstance(colors[next(iter(colors))], Mapping): # handle paga pie, remap string keys to integers names_to_ixs = { n: i for i, n in enumerate(adata.obs[groups_key].cat.categories) @@ -554,7 +551,7 @@ def is_flat(x): f"it needs to be one of {labels} not {root!r}." ) root = list(labels).index(root) - if isinstance(root, cabc.Sequence) and root[0] in labels: + if isinstance(root, Sequence) and root[0] in labels: root = [list(labels).index(r) for r in root] # define the adjacency matrices @@ -600,7 +597,7 @@ def is_flat(x): sct = _paga_graph( adata, axs[icolor], - colors=colors if isinstance(colors, cabc.Mapping) else c, + colors=colors if isinstance(colors, Mapping) else c, solid_edges=solid_edges, dashed_edges=dashed_edges, transitions=transitions, @@ -935,7 +932,7 @@ def _paga_graph( patheffects.withStroke(linewidth=fontoutline, foreground="w") ] # usual scatter plot - if not isinstance(colors[0], cabc.Mapping): + if not isinstance(colors[0], Mapping): n_groups = len(pos_array) sct = ax.scatter( pos_array[:, 0], @@ -959,7 +956,7 @@ def _paga_graph( # else pie chart plot else: for ix, (xx, yy) in enumerate(zip(pos_array[:, 0], pos_array[:, 1])): - if not isinstance(colors[ix], cabc.Mapping): + if not isinstance(colors[ix], Mapping): raise ValueError( f"{colors[ix]} is neither a dict of valid " "matplotlib colors nor a valid matplotlib color." diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 4f2b208ef1..93aff99552 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections.abc as cabc import inspect import sys from collections.abc import Mapping, Sequence # noqa: TCH003 @@ -202,13 +201,13 @@ def embedding( title = [title] if isinstance(title, str) else list(title) # turn vmax and vmin into a sequence - if isinstance(vmax, str) or not isinstance(vmax, cabc.Sequence): + if isinstance(vmax, str) or not isinstance(vmax, Sequence): vmax = [vmax] - if isinstance(vmin, str) or not isinstance(vmin, cabc.Sequence): + if isinstance(vmin, str) or not isinstance(vmin, Sequence): vmin = [vmin] - if isinstance(vcenter, str) or not isinstance(vcenter, cabc.Sequence): + if isinstance(vcenter, str) or not isinstance(vcenter, Sequence): vcenter = [vcenter] - if isinstance(norm, Normalize) or not isinstance(norm, cabc.Sequence): + if isinstance(norm, Normalize) or not isinstance(norm, Sequence): norm = [norm] # Size @@ -219,7 +218,7 @@ def embedding( # set as ndarray if ( size is not None - and isinstance(size, (cabc.Sequence, pd.Series, np.ndarray)) + and isinstance(size, (Sequence, pd.Series, np.ndarray)) and len(size) == adata.shape[0] ): size = np.array(size, dtype=float) @@ -245,9 +244,7 @@ def embedding( # Eg. ['Gene1', 'louvain', 'Gene2']. # component_list is a list of components [[0,1], [1,2]] if ( - not isinstance(color, str) - and isinstance(color, cabc.Sequence) - and len(color) > 1 + not isinstance(color, str) and isinstance(color, Sequence) and len(color) > 1 ) or len(dimensions) > 1: if ax is not None: raise ValueError( diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index b56649e568..a545f1e30c 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -1,8 +1,7 @@ from __future__ import annotations -import collections.abc as cabc import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Callable, Literal, TypedDict, Union, overload import matplotlib as mpl @@ -23,7 +22,7 @@ from . import palettes if TYPE_CHECKING: - from collections.abc import Collection, Mapping + from collections.abc import Collection from anndata import AnnData from matplotlib.colors import Colormap @@ -445,12 +444,12 @@ def _set_colors_for_categorical_obs( # this creates a palette from a colormap. E.g. 'Accent, Dark2, tab20' cmap = plt.get_cmap(palette) colors_list = [to_hex(x) for x in cmap(np.linspace(0, 1, len(categories)))] - elif isinstance(palette, cabc.Mapping): + elif isinstance(palette, Mapping): colors_list = [to_hex(palette[k], keep_alpha=True) for k in categories] else: # check if palette is a list and convert it to a cycler, thus # it doesnt matter if the list is shorter than the categories length: - if isinstance(palette, cabc.Sequence): + if isinstance(palette, Sequence): if len(palette) < len(categories): logg.warning( "Length of palette colors is smaller than the number of " @@ -551,7 +550,7 @@ def add_colors_for_categorical_sample_annotation( def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=None): import networkx as nx - if not isinstance(axs, cabc.Sequence): + if not isinstance(axs, Sequence): axs = [axs] if neighbors_key is None: @@ -577,7 +576,7 @@ def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=Non def plot_arrows(axs, adata, basis, arrows_kwds=None): - if not isinstance(axs, cabc.Sequence): + if not isinstance(axs, Sequence): axs = [axs] v_prefix = next( (p for p in ["velocity", "Delta"] if f"{p}_{basis}" in adata.obsm), None @@ -724,7 +723,7 @@ def setup_axes( ax = plt.axes([left, bottom, width, height], projection="3d") axs.append(ax) else: - axs = ax if isinstance(ax, cabc.Sequence) else [ax] + axs = ax if isinstance(ax, Sequence) else [ax] return axs, panel_pos, draw_region_width, figure_width @@ -763,7 +762,7 @@ def scatter_base( Depending on whether supplying a single array or a list of arrays, return a single axis or a list of axes. """ - if isinstance(highlights, cabc.Mapping): + if isinstance(highlights, Mapping): highlights_indices = sorted(highlights) highlights_labels = [highlights[i] for i in highlights_indices] else: diff --git a/src/scanpy/queries/_queries.py b/src/scanpy/queries/_queries.py index 24a2482449..8da90151ce 100644 --- a/src/scanpy/queries/_queries.py +++ b/src/scanpy/queries/_queries.py @@ -1,6 +1,6 @@ from __future__ import annotations -import collections.abc as cabc +from collections.abc import Iterable from functools import singledispatch from types import MappingProxyType from typing import TYPE_CHECKING @@ -12,7 +12,7 @@ from ..get import rank_genes_groups_df if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Mapping from typing import Any import pandas as pd @@ -60,7 +60,7 @@ def simple_query( """ if isinstance(attrs, str): attrs = [attrs] - elif isinstance(attrs, cabc.Iterable): + elif isinstance(attrs, Iterable): attrs = list(attrs) else: raise TypeError(f"attrs must be of type list or str, was {type(attrs)}.") diff --git a/src/scanpy/tools/_marker_gene_overlap.py b/src/scanpy/tools/_marker_gene_overlap.py index 534f5c3b33..1c286e9333 100644 --- a/src/scanpy/tools/_marker_gene_overlap.py +++ b/src/scanpy/tools/_marker_gene_overlap.py @@ -4,7 +4,7 @@ from __future__ import annotations -import collections.abc as cabc +from collections.abc import Set from typing import TYPE_CHECKING import numpy as np @@ -190,7 +190,7 @@ def marker_gene_overlap( if normalize is not None and method != "overlap_count": raise ValueError("Can only normalize with method=`overlap_count`.") - if not all(isinstance(val, cabc.Set) for val in reference_markers.values()): + if not all(isinstance(val, Set) for val in reference_markers.values()): try: reference_markers = { key: set(val) for key, val in reference_markers.items()