Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rely on Ruff for TYPE_CHECKING block mgmt #3248

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/scanpy/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import collections.abc as cabc
import os
import sys
from argparse import ArgumentParser, Namespace, _SubParsersAction
from collections.abc import MutableMapping
from functools import lru_cache, partial
from pathlib import Path
from shutil import which
Expand All @@ -27,7 +27,7 @@ def __init__(self, *args, _command: str, _runargs: dict[str, Any], **kwargs):
)


class _CommandDelegator(cabc.MutableMapping):
class _CommandDelegator(MutableMapping):
"""\
Provide the ability to delegate,
but don’t calculate the whole list until necessary
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/external/tl/_wishbone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import collections.abc as cabc
from collections.abc import Collection
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -11,7 +11,7 @@
from ..._utils._doctests import doctest_needs

if TYPE_CHECKING:
from collections.abc import Collection, Iterable
from collections.abc import Iterable

from anndata import AnnData

Expand Down Expand Up @@ -115,7 +115,7 @@
f"Start cell {start_cell} not found in data. "
"Please rerun with correct start cell."
)
if isinstance(num_waypoints, cabc.Collection):
if isinstance(num_waypoints, Collection):

Check warning on line 118 in src/scanpy/external/tl/_wishbone.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/external/tl/_wishbone.py#L118

Added line #L118 was not covered by tests
diff = np.setdiff1d(num_waypoints, adata.obs.index)
if diff.size > 0:
logging.warning(
Expand Down
10 changes: 5 additions & 5 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

import collections.abc as cabc
from collections import OrderedDict
from collections.abc import Collection, Mapping, Sequence
from itertools import product
from typing import TYPE_CHECKING, get_args

Expand Down Expand Up @@ -38,7 +38,7 @@
)

if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Mapping, Sequence
from collections.abc import Iterable
from typing import Literal, Union

from anndata import AnnData
Expand Down Expand Up @@ -220,7 +220,7 @@
isinstance(layers, str) and layers in adata.layers.keys()
):
layers = (layers, layers, layers)
elif isinstance(layers, cabc.Collection) and len(layers) == 3:
elif isinstance(layers, Collection) and len(layers) == 3:

Check warning on line 223 in src/scanpy/plotting/_anndata.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_anndata.py#L223

Added line #L223 was not covered by tests
layers = tuple(layers)
for layer in layers:
if layer not in adata.layers.keys() and layer not in ["X", None]:
Expand Down Expand Up @@ -299,7 +299,7 @@
palette_was_none = False
if palette is None:
palette_was_none = True
if isinstance(palette, cabc.Sequence) and not isinstance(palette, str):
if isinstance(palette, Sequence) and not isinstance(palette, str):
if not is_color_like(palette[0]):
palettes = palette
else:
Expand Down Expand Up @@ -2665,7 +2665,7 @@
var_names, var_group_labels, var_group_positions

"""
if isinstance(var_names, cabc.Mapping):
if isinstance(var_names, Mapping):
if var_group_labels is not None or var_group_positions is not None:
logg.warning(
"`var_names` is a dictionary. This will reset the current "
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

import collections.abc as cabc
from collections.abc import Mapping
from typing import TYPE_CHECKING, NamedTuple
from warnings import warn

Expand All @@ -17,7 +17,7 @@
from ._utils import check_colornorm, make_grid_spec

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Sequence
from typing import Literal, Self, Union

import pandas as pd
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def _update_var_groups(self) -> None:

updates var_names, var_group_labels, var_group_positions
"""
if isinstance(self.var_names, cabc.Mapping):
if isinstance(self.var_names, Mapping):
if self.has_var_groups:
logg.warning(
"`var_names` is a dictionary. This will reset the current "
Expand Down
11 changes: 3 additions & 8 deletions src/scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import collections.abc as cabc
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from copy import copy
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -38,7 +37,7 @@
from .scatterplots import _panel_grid, embedding, pca

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from typing import Literal

from anndata import AnnData
Expand Down Expand Up @@ -1611,11 +1610,7 @@ def embedding_density(

# if group is set, then plot it using multiple panels
# (even if only one group is set)
if (
group is not None
and not isinstance(group, str)
and isinstance(group, cabc.Sequence)
):
if group is not None and not isinstance(group, str) and isinstance(group, Sequence):
if ax is not None:
raise ValueError("Can only specify `ax` if no `group` sequence is given.")
fig, gs = _panel_grid(hspace, wspace, ncols, len(group))
Expand Down
17 changes: 7 additions & 10 deletions src/scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import collections.abc as cabc
import warnings
from collections.abc import Collection, Mapping, Sequence
from pathlib import Path
from types import MappingProxyType
from typing import TYPE_CHECKING
Expand All @@ -26,7 +26,6 @@
from .._utils import matrix

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Union

from anndata import AnnData
Expand Down Expand Up @@ -506,14 +505,12 @@ def paga(
groups_key = adata.uns["paga"]["groups"]

def is_flat(x):
has_one_per_category = isinstance(x, cabc.Collection) and len(x) == len(
has_one_per_category = isinstance(x, Collection) and len(x) == len(
adata.obs[groups_key].cat.categories
)
return has_one_per_category or x is None or isinstance(x, str)

if isinstance(colors, cabc.Mapping) and isinstance(
colors[next(iter(colors))], cabc.Mapping
):
if isinstance(colors, Mapping) and isinstance(colors[next(iter(colors))], Mapping):
# handle paga pie, remap string keys to integers
names_to_ixs = {
n: i for i, n in enumerate(adata.obs[groups_key].cat.categories)
Expand Down Expand Up @@ -554,7 +551,7 @@ def is_flat(x):
f"it needs to be one of {labels} not {root!r}."
)
root = list(labels).index(root)
if isinstance(root, cabc.Sequence) and root[0] in labels:
if isinstance(root, Sequence) and root[0] in labels:
root = [list(labels).index(r) for r in root]

# define the adjacency matrices
Expand Down Expand Up @@ -600,7 +597,7 @@ def is_flat(x):
sct = _paga_graph(
adata,
axs[icolor],
colors=colors if isinstance(colors, cabc.Mapping) else c,
colors=colors if isinstance(colors, Mapping) else c,
solid_edges=solid_edges,
dashed_edges=dashed_edges,
transitions=transitions,
Expand Down Expand Up @@ -935,7 +932,7 @@ def _paga_graph(
patheffects.withStroke(linewidth=fontoutline, foreground="w")
]
# usual scatter plot
if not isinstance(colors[0], cabc.Mapping):
if not isinstance(colors[0], Mapping):
n_groups = len(pos_array)
sct = ax.scatter(
pos_array[:, 0],
Expand All @@ -959,7 +956,7 @@ def _paga_graph(
# else pie chart plot
else:
for ix, (xx, yy) in enumerate(zip(pos_array[:, 0], pos_array[:, 1])):
if not isinstance(colors[ix], cabc.Mapping):
if not isinstance(colors[ix], Mapping):
raise ValueError(
f"{colors[ix]} is neither a dict of valid "
"matplotlib colors nor a valid matplotlib color."
Expand Down
15 changes: 6 additions & 9 deletions src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import collections.abc as cabc
import inspect
import sys
from collections.abc import Mapping, Sequence # noqa: TCH003
Expand Down Expand Up @@ -202,13 +201,13 @@ def embedding(
title = [title] if isinstance(title, str) else list(title)

# turn vmax and vmin into a sequence
if isinstance(vmax, str) or not isinstance(vmax, cabc.Sequence):
if isinstance(vmax, str) or not isinstance(vmax, Sequence):
vmax = [vmax]
if isinstance(vmin, str) or not isinstance(vmin, cabc.Sequence):
if isinstance(vmin, str) or not isinstance(vmin, Sequence):
vmin = [vmin]
if isinstance(vcenter, str) or not isinstance(vcenter, cabc.Sequence):
if isinstance(vcenter, str) or not isinstance(vcenter, Sequence):
vcenter = [vcenter]
if isinstance(norm, Normalize) or not isinstance(norm, cabc.Sequence):
if isinstance(norm, Normalize) or not isinstance(norm, Sequence):
norm = [norm]

# Size
Expand All @@ -219,7 +218,7 @@ def embedding(
# set as ndarray
if (
size is not None
and isinstance(size, (cabc.Sequence, pd.Series, np.ndarray))
and isinstance(size, (Sequence, pd.Series, np.ndarray))
and len(size) == adata.shape[0]
):
size = np.array(size, dtype=float)
Expand All @@ -245,9 +244,7 @@ def embedding(
# Eg. ['Gene1', 'louvain', 'Gene2'].
# component_list is a list of components [[0,1], [1,2]]
if (
not isinstance(color, str)
and isinstance(color, cabc.Sequence)
and len(color) > 1
not isinstance(color, str) and isinstance(color, Sequence) and len(color) > 1
) or len(dimensions) > 1:
if ax is not None:
raise ValueError(
Expand Down
17 changes: 8 additions & 9 deletions src/scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import collections.abc as cabc
import warnings
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Callable, Literal, TypedDict, Union, overload

import matplotlib as mpl
Expand All @@ -23,7 +22,7 @@
from . import palettes

if TYPE_CHECKING:
from collections.abc import Collection, Mapping
from collections.abc import Collection

from anndata import AnnData
from matplotlib.colors import Colormap
Expand Down Expand Up @@ -445,12 +444,12 @@
# this creates a palette from a colormap. E.g. 'Accent, Dark2, tab20'
cmap = plt.get_cmap(palette)
colors_list = [to_hex(x) for x in cmap(np.linspace(0, 1, len(categories)))]
elif isinstance(palette, cabc.Mapping):
elif isinstance(palette, Mapping):
colors_list = [to_hex(palette[k], keep_alpha=True) for k in categories]
else:
# check if palette is a list and convert it to a cycler, thus
# it doesnt matter if the list is shorter than the categories length:
if isinstance(palette, cabc.Sequence):
if isinstance(palette, Sequence):
if len(palette) < len(categories):
logg.warning(
"Length of palette colors is smaller than the number of "
Expand Down Expand Up @@ -551,7 +550,7 @@
def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=None):
import networkx as nx

if not isinstance(axs, cabc.Sequence):
if not isinstance(axs, Sequence):
axs = [axs]

if neighbors_key is None:
Expand All @@ -577,7 +576,7 @@


def plot_arrows(axs, adata, basis, arrows_kwds=None):
if not isinstance(axs, cabc.Sequence):
if not isinstance(axs, Sequence):

Check warning on line 579 in src/scanpy/plotting/_utils.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_utils.py#L579

Added line #L579 was not covered by tests
axs = [axs]
v_prefix = next(
(p for p in ["velocity", "Delta"] if f"{p}_{basis}" in adata.obsm), None
Expand Down Expand Up @@ -724,7 +723,7 @@
ax = plt.axes([left, bottom, width, height], projection="3d")
axs.append(ax)
else:
axs = ax if isinstance(ax, cabc.Sequence) else [ax]
axs = ax if isinstance(ax, Sequence) else [ax]

return axs, panel_pos, draw_region_width, figure_width

Expand Down Expand Up @@ -763,7 +762,7 @@
Depending on whether supplying a single array or a list of arrays,
return a single axis or a list of axes.
"""
if isinstance(highlights, cabc.Mapping):
if isinstance(highlights, Mapping):
highlights_indices = sorted(highlights)
highlights_labels = [highlights[i] for i in highlights_indices]
else:
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/queries/_queries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import collections.abc as cabc
from collections.abc import Iterable
from functools import singledispatch
from types import MappingProxyType
from typing import TYPE_CHECKING
Expand All @@ -12,7 +12,7 @@
from ..get import rank_genes_groups_df

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from collections.abc import Mapping
from typing import Any

import pandas as pd
Expand Down Expand Up @@ -60,7 +60,7 @@
"""
if isinstance(attrs, str):
attrs = [attrs]
elif isinstance(attrs, cabc.Iterable):
elif isinstance(attrs, Iterable):

Check warning on line 63 in src/scanpy/queries/_queries.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/queries/_queries.py#L63

Added line #L63 was not covered by tests
attrs = list(attrs)
else:
raise TypeError(f"attrs must be of type list or str, was {type(attrs)}.")
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/tools/_marker_gene_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

import collections.abc as cabc
from collections.abc import Set
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -190,7 +190,7 @@ def marker_gene_overlap(
if normalize is not None and method != "overlap_count":
raise ValueError("Can only normalize with method=`overlap_count`.")

if not all(isinstance(val, cabc.Set) for val in reference_markers.values()):
if not all(isinstance(val, Set) for val in reference_markers.values()):
try:
reference_markers = {
key: set(val) for key, val in reference_markers.items()
Expand Down
Loading