Skip to content

Commit

Permalink
Rely on Ruff for TYPE_CHECKING block mgmt
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Sep 20, 2024
1 parent b0597a9 commit 3cce3f2
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 54 deletions.
4 changes: 2 additions & 2 deletions src/scanpy/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import collections.abc as cabc
import os
import sys
from argparse import ArgumentParser, Namespace, _SubParsersAction
from collections.abc import MutableMapping
from functools import lru_cache, partial
from pathlib import Path
from shutil import which
Expand All @@ -27,7 +27,7 @@ def __init__(self, *args, _command: str, _runargs: dict[str, Any], **kwargs):
)


class _CommandDelegator(cabc.MutableMapping):
class _CommandDelegator(MutableMapping):
"""\
Provide the ability to delegate,
but don’t calculate the whole list until necessary
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/external/tl/_wishbone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import collections.abc as cabc
from collections.abc import Collection
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -11,7 +11,7 @@
from ..._utils._doctests import doctest_needs

if TYPE_CHECKING:
from collections.abc import Collection, Iterable
from collections.abc import Iterable

from anndata import AnnData

Expand Down Expand Up @@ -115,7 +115,7 @@ def wishbone(
f"Start cell {start_cell} not found in data. "
"Please rerun with correct start cell."
)
if isinstance(num_waypoints, cabc.Collection):
if isinstance(num_waypoints, Collection):

Check warning on line 118 in src/scanpy/external/tl/_wishbone.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/external/tl/_wishbone.py#L118

Added line #L118 was not covered by tests
diff = np.setdiff1d(num_waypoints, adata.obs.index)
if diff.size > 0:
logging.warning(
Expand Down
10 changes: 5 additions & 5 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

import collections.abc as cabc
from collections import OrderedDict
from collections.abc import Collection, Mapping, Sequence
from itertools import product
from typing import TYPE_CHECKING, get_args

Expand Down Expand Up @@ -38,7 +38,7 @@
)

if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Mapping, Sequence
from collections.abc import Iterable
from typing import Literal, Union

from anndata import AnnData
Expand Down Expand Up @@ -220,7 +220,7 @@ def _scatter_obs(
isinstance(layers, str) and layers in adata.layers.keys()
):
layers = (layers, layers, layers)
elif isinstance(layers, cabc.Collection) and len(layers) == 3:
elif isinstance(layers, Collection) and len(layers) == 3:

Check warning on line 223 in src/scanpy/plotting/_anndata.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_anndata.py#L223

Added line #L223 was not covered by tests
layers = tuple(layers)
for layer in layers:
if layer not in adata.layers.keys() and layer not in ["X", None]:
Expand Down Expand Up @@ -299,7 +299,7 @@ def _scatter_obs(
palette_was_none = False
if palette is None:
palette_was_none = True
if isinstance(palette, cabc.Sequence) and not isinstance(palette, str):
if isinstance(palette, Sequence) and not isinstance(palette, str):
if not is_color_like(palette[0]):
palettes = palette
else:
Expand Down Expand Up @@ -2665,7 +2665,7 @@ def _check_var_names_type(var_names, var_group_labels, var_group_positions):
var_names, var_group_labels, var_group_positions
"""
if isinstance(var_names, cabc.Mapping):
if isinstance(var_names, Mapping):
if var_group_labels is not None or var_group_positions is not None:
logg.warning(
"`var_names` is a dictionary. This will reset the current "
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

import collections.abc as cabc
from collections.abc import Mapping
from typing import TYPE_CHECKING, NamedTuple
from warnings import warn

Expand All @@ -17,7 +17,7 @@
from ._utils import check_colornorm, make_grid_spec

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Sequence
from typing import Literal, Self, Union

import pandas as pd
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def _update_var_groups(self) -> None:
updates var_names, var_group_labels, var_group_positions
"""
if isinstance(self.var_names, cabc.Mapping):
if isinstance(self.var_names, Mapping):
if self.has_var_groups:
logg.warning(
"`var_names` is a dictionary. This will reset the current "
Expand Down
11 changes: 3 additions & 8 deletions src/scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import collections.abc as cabc
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from copy import copy
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -38,7 +37,7 @@
from .scatterplots import _panel_grid, embedding, pca

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from typing import Literal

from anndata import AnnData
Expand Down Expand Up @@ -1611,11 +1610,7 @@ def embedding_density(

# if group is set, then plot it using multiple panels
# (even if only one group is set)
if (
group is not None
and not isinstance(group, str)
and isinstance(group, cabc.Sequence)
):
if group is not None and not isinstance(group, str) and isinstance(group, Sequence):
if ax is not None:
raise ValueError("Can only specify `ax` if no `group` sequence is given.")
fig, gs = _panel_grid(hspace, wspace, ncols, len(group))
Expand Down
17 changes: 7 additions & 10 deletions src/scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import collections.abc as cabc
import warnings
from collections.abc import Collection, Mapping, Sequence
from pathlib import Path
from types import MappingProxyType
from typing import TYPE_CHECKING
Expand All @@ -26,7 +26,6 @@
from .._utils import matrix

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Union

from anndata import AnnData
Expand Down Expand Up @@ -506,14 +505,12 @@ def paga(
groups_key = adata.uns["paga"]["groups"]

def is_flat(x):
has_one_per_category = isinstance(x, cabc.Collection) and len(x) == len(
has_one_per_category = isinstance(x, Collection) and len(x) == len(
adata.obs[groups_key].cat.categories
)
return has_one_per_category or x is None or isinstance(x, str)

if isinstance(colors, cabc.Mapping) and isinstance(
colors[next(iter(colors))], cabc.Mapping
):
if isinstance(colors, Mapping) and isinstance(colors[next(iter(colors))], Mapping):
# handle paga pie, remap string keys to integers
names_to_ixs = {
n: i for i, n in enumerate(adata.obs[groups_key].cat.categories)
Expand Down Expand Up @@ -554,7 +551,7 @@ def is_flat(x):
f"it needs to be one of {labels} not {root!r}."
)
root = list(labels).index(root)
if isinstance(root, cabc.Sequence) and root[0] in labels:
if isinstance(root, Sequence) and root[0] in labels:
root = [list(labels).index(r) for r in root]

# define the adjacency matrices
Expand Down Expand Up @@ -600,7 +597,7 @@ def is_flat(x):
sct = _paga_graph(
adata,
axs[icolor],
colors=colors if isinstance(colors, cabc.Mapping) else c,
colors=colors if isinstance(colors, Mapping) else c,
solid_edges=solid_edges,
dashed_edges=dashed_edges,
transitions=transitions,
Expand Down Expand Up @@ -935,7 +932,7 @@ def _paga_graph(
patheffects.withStroke(linewidth=fontoutline, foreground="w")
]
# usual scatter plot
if not isinstance(colors[0], cabc.Mapping):
if not isinstance(colors[0], Mapping):
n_groups = len(pos_array)
sct = ax.scatter(
pos_array[:, 0],
Expand All @@ -959,7 +956,7 @@ def _paga_graph(
# else pie chart plot
else:
for ix, (xx, yy) in enumerate(zip(pos_array[:, 0], pos_array[:, 1])):
if not isinstance(colors[ix], cabc.Mapping):
if not isinstance(colors[ix], Mapping):
raise ValueError(
f"{colors[ix]} is neither a dict of valid "
"matplotlib colors nor a valid matplotlib color."
Expand Down
15 changes: 6 additions & 9 deletions src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import collections.abc as cabc
import inspect
import sys
from collections.abc import Mapping, Sequence # noqa: TCH003
Expand Down Expand Up @@ -202,13 +201,13 @@ def embedding(
title = [title] if isinstance(title, str) else list(title)

# turn vmax and vmin into a sequence
if isinstance(vmax, str) or not isinstance(vmax, cabc.Sequence):
if isinstance(vmax, str) or not isinstance(vmax, Sequence):
vmax = [vmax]
if isinstance(vmin, str) or not isinstance(vmin, cabc.Sequence):
if isinstance(vmin, str) or not isinstance(vmin, Sequence):
vmin = [vmin]
if isinstance(vcenter, str) or not isinstance(vcenter, cabc.Sequence):
if isinstance(vcenter, str) or not isinstance(vcenter, Sequence):
vcenter = [vcenter]
if isinstance(norm, Normalize) or not isinstance(norm, cabc.Sequence):
if isinstance(norm, Normalize) or not isinstance(norm, Sequence):
norm = [norm]

# Size
Expand All @@ -219,7 +218,7 @@ def embedding(
# set as ndarray
if (
size is not None
and isinstance(size, (cabc.Sequence, pd.Series, np.ndarray))
and isinstance(size, (Sequence, pd.Series, np.ndarray))
and len(size) == adata.shape[0]
):
size = np.array(size, dtype=float)
Expand All @@ -245,9 +244,7 @@ def embedding(
# Eg. ['Gene1', 'louvain', 'Gene2'].
# component_list is a list of components [[0,1], [1,2]]
if (
not isinstance(color, str)
and isinstance(color, cabc.Sequence)
and len(color) > 1
not isinstance(color, str) and isinstance(color, Sequence) and len(color) > 1
) or len(dimensions) > 1:
if ax is not None:
raise ValueError(
Expand Down
17 changes: 8 additions & 9 deletions src/scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import collections.abc as cabc
import warnings
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Callable, Literal, TypedDict, Union, overload

import matplotlib as mpl
Expand All @@ -23,7 +22,7 @@
from . import palettes

if TYPE_CHECKING:
from collections.abc import Collection, Mapping
from collections.abc import Collection

from anndata import AnnData
from matplotlib.colors import Colormap
Expand Down Expand Up @@ -445,12 +444,12 @@ def _set_colors_for_categorical_obs(
# this creates a palette from a colormap. E.g. 'Accent, Dark2, tab20'
cmap = plt.get_cmap(palette)
colors_list = [to_hex(x) for x in cmap(np.linspace(0, 1, len(categories)))]
elif isinstance(palette, cabc.Mapping):
elif isinstance(palette, Mapping):
colors_list = [to_hex(palette[k], keep_alpha=True) for k in categories]
else:
# check if palette is a list and convert it to a cycler, thus
# it doesnt matter if the list is shorter than the categories length:
if isinstance(palette, cabc.Sequence):
if isinstance(palette, Sequence):
if len(palette) < len(categories):
logg.warning(
"Length of palette colors is smaller than the number of "
Expand Down Expand Up @@ -551,7 +550,7 @@ def add_colors_for_categorical_sample_annotation(
def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=None):
import networkx as nx

if not isinstance(axs, cabc.Sequence):
if not isinstance(axs, Sequence):
axs = [axs]

if neighbors_key is None:
Expand All @@ -577,7 +576,7 @@ def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=Non


def plot_arrows(axs, adata, basis, arrows_kwds=None):
if not isinstance(axs, cabc.Sequence):
if not isinstance(axs, Sequence):

Check warning on line 579 in src/scanpy/plotting/_utils.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_utils.py#L579

Added line #L579 was not covered by tests
axs = [axs]
v_prefix = next(
(p for p in ["velocity", "Delta"] if f"{p}_{basis}" in adata.obsm), None
Expand Down Expand Up @@ -724,7 +723,7 @@ def setup_axes(
ax = plt.axes([left, bottom, width, height], projection="3d")
axs.append(ax)
else:
axs = ax if isinstance(ax, cabc.Sequence) else [ax]
axs = ax if isinstance(ax, Sequence) else [ax]

return axs, panel_pos, draw_region_width, figure_width

Expand Down Expand Up @@ -763,7 +762,7 @@ def scatter_base(
Depending on whether supplying a single array or a list of arrays,
return a single axis or a list of axes.
"""
if isinstance(highlights, cabc.Mapping):
if isinstance(highlights, Mapping):
highlights_indices = sorted(highlights)
highlights_labels = [highlights[i] for i in highlights_indices]
else:
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/queries/_queries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import collections.abc as cabc
from collections.abc import Iterable
from functools import singledispatch
from types import MappingProxyType
from typing import TYPE_CHECKING
Expand All @@ -12,7 +12,7 @@
from ..get import rank_genes_groups_df

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from collections.abc import Mapping
from typing import Any

import pandas as pd
Expand Down Expand Up @@ -60,7 +60,7 @@ def simple_query(
"""
if isinstance(attrs, str):
attrs = [attrs]
elif isinstance(attrs, cabc.Iterable):
elif isinstance(attrs, Iterable):

Check warning on line 63 in src/scanpy/queries/_queries.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/queries/_queries.py#L63

Added line #L63 was not covered by tests
attrs = list(attrs)
else:
raise TypeError(f"attrs must be of type list or str, was {type(attrs)}.")
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/tools/_marker_gene_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

import collections.abc as cabc
from collections.abc import Set
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -190,7 +190,7 @@ def marker_gene_overlap(
if normalize is not None and method != "overlap_count":
raise ValueError("Can only normalize with method=`overlap_count`.")

if not all(isinstance(val, cabc.Set) for val in reference_markers.values()):
if not all(isinstance(val, Set) for val in reference_markers.values()):
try:
reference_markers = {
key: set(val) for key, val in reference_markers.items()
Expand Down

0 comments on commit 3cce3f2

Please sign in to comment.