Skip to content

Commit

Permalink
added design matrix and plotting function (scverse#591)
Browse files Browse the repository at this point in the history
Co-authored-by: lehnerl <[email protected]>
Co-authored-by: giovp <[email protected]>
Co-authored-by: Giovanni Palla <[email protected]>
Co-authored-by: Zwaveligerspeeler <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Apr 6, 2023
1 parent 0cd835d commit b644428
Show file tree
Hide file tree
Showing 20 changed files with 676 additions and 15 deletions.
9 changes: 0 additions & 9 deletions .github/workflows/news.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@ jobs:
python -m pip install --upgrade pip
pip install tox
- name: Update dev release notes
env:
PR_NUMBER: ${{ github.event.number }}
run: |
tox -e news -- "$PR_NUMBER" --add-author -v
tox -e update-dev-notes
git add "docs/source/release/changelog/$PR_NUMBER*.rst"
git status -s
- name: Check generated docs
run: |
tox -e check-docs
Expand Down
12 changes: 12 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Plotting
pl.ripley
pl.co_occurrence
pl.extract
pl.var_by_distance

Reading
~~~~~~~
Expand All @@ -69,6 +70,17 @@ Reading
read.vizgen
read.nanostring

Tools
~~~~~~~~

.. module:: squidpy.tl
.. currentmodule:: squidpy

.. autosummary::
:toctree: api

tl.var_by_distance

Datasets
~~~~~~~~

Expand Down
13 changes: 13 additions & 0 deletions docs/source/release/notes-dev.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,16 @@ Miscellaneous
or :attr:`anndata.AnnData.obsm.`
`@michalk8 <https://github.com/michalk8>`__
`#672 <https://github.com/scverse/squidpy/pull/672>`__


Squidpy dev (2023-04-02)
========================

Features
--------

- Add :func:`squidpy.tl.var_by_distance` to calculate distances to user-defined anchor points
and stores the resulting design matrix in :attr:`adata.obsm`.
- Add :func:`squidpy.pl.var_by_distance` to visualize a variable such as expression by distance to an anchor points.
`@LLehner <https://github.com/LLehner>`__
`#591 <https://github.com/scverse/squidpy/pull/591>`__
4 changes: 4 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ cond
conda
connectivities
convolutional
covariate
covariates
CP
Crofton
csv
Expand Down Expand Up @@ -113,6 +115,7 @@ seqFISH
seqV
Sfrp
spaceranger
squidpy
Squidpy
StarDist
stromal
Expand All @@ -121,6 +124,7 @@ Tangram
Tensorflow
th
thresholded
tl
tori
transcriptomics
uncommenting
Expand Down
2 changes: 1 addition & 1 deletion squidpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from packaging.version import parse

from squidpy import datasets, gr, im, pl, read
from squidpy import datasets, gr, im, pl, read, tl

__author__ = __maintainer__ = "Theislab"
__email__ = ", ".join(
Expand Down
1 change: 1 addition & 0 deletions squidpy/gr/_sepal.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _score_helper(

score, sparse = [], issparse(vals)
for i in ixs:
conc = vals[:, i].A.flatten() if sparse else vals[:, i].copy() # type: ignore[union-attr]
conc = vals[:, i].A.flatten() if sparse else vals[:, i].copy() # type: ignore[union-attr]
time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh)
score.append(dt * time_iter)
Expand Down
1 change: 1 addition & 0 deletions squidpy/pl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from squidpy.pl._ligrec import ligrec
from squidpy.pl._spatial import spatial_scatter, spatial_segment
from squidpy.pl._utils import extract
from squidpy.pl._var_by_distance import var_by_distance
10 changes: 8 additions & 2 deletions squidpy/pl/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,16 @@
__all__ = ["centrality_scores", "interaction_matrix", "nhood_enrichment", "ripley", "co_occurrence"]


def _get_data(adata: AnnData, cluster_key: str, func_name: str, **kwargs: Any) -> Any:
def _get_data(adata: AnnData, cluster_key: str, func_name: str, attr: str = "uns", **kwargs: Any) -> Any:
key = getattr(Key.uns, func_name)(cluster_key, **kwargs)

try:
return adata.uns[key]
if attr == "uns":
return adata.uns[key]
elif attr == "obsm":
return adata.obsm[key]
else:
raise ValueError(f"attr must be either 'uns' or 'obsm', got {attr}")
except KeyError:
raise KeyError(
f"Unable to get the data from `adata.uns[{key!r}]`. "
Expand Down
180 changes: 180 additions & 0 deletions squidpy/pl/_var_by_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

from pathlib import Path
from types import MappingProxyType
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from anndata import AnnData
from cycler import Cycler
from matplotlib import rcParams
from matplotlib.axes import Axes
from scanpy.plotting._tools.scatterplots import _panel_grid
from scanpy.plotting._utils import _set_default_colors_for_categorical_obs
from scipy.sparse import issparse

from squidpy._docs import d
from squidpy.pl._utils import save_fig

__all__ = ["var_by_distance"]


@d.dedent
def var_by_distance(
adata: AnnData,
var: str | List[str],
anchor_key: str | List[str],
design_matrix_key: str = "design_matrix",
color: str | None = None,
covariate: str | None = None,
order: int = 5,
show_scatter: bool = True,
line_palette: Union[str, Sequence[str], Cycler, None] = None,
scatter_palette: Union[str, Sequence[str], Cycler, None] = "viridis",
dpi: int | None = None,
figsize: Tuple[int, int] | None = None,
save: str | Path | None = None,
title: str | None = None,
axis_label: str | None = None,
return_ax: Optional[bool] = None,
regplot_kwargs: Mapping[str, Any] = MappingProxyType({}),
scatterplot_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> Union[Axes, None]:
"""
Plot a variable using a smooth regression line with increasing distance to an anchor point.
Parameters
----------
%(adata)s
design_matrix_key
Name of the design matrix, previously computed with :func:`squidpy.tl.var_by_distance`, to use.
var
Variables to plot on y-axis.
anchor_key
Anchor point column from which distances are taken.
color
Variables to plot on color palette.
covariate
A covariate for which separate regression lines are plotted for each category.
order
Order of the polynomial fit for :func:`seaborn.regplot`.
show_scatter
Whether to show a scatter plot underlying the regression line.
line_palette
Categorical color palette used in case a covariate is specified.
scatter_palette
Color palette for the scatter plot underlying the `sns.regplot`
%(plotting_save)s
title
Panel titles.
axis_label
Panel axis labels.
regplot_kwargs
Kwargs for `sns.regplot`
scatterplot_kwargs
Kwargs for `sns.scatter`
Returns
-------
%(plotting_returns)s
"""
dpi = rcParams["figure.dpi"] if dpi is None else dpi
regplot_kwargs = dict(regplot_kwargs)
scatterplot_kwargs = dict(regplot_kwargs)

df = adata.obsm[design_matrix_key] # get design matrix
df[var] = np.array(adata[:, var].X.A) if issparse(adata[:, var].X) else np.array(adata[:, var].X) # add var column

# if several variables are plotted, make a panel grid
if isinstance(var, List):
fig, grid = _panel_grid(
hspace=0.25, wspace=0.75 / rcParams["figure.figsize"][0] + 0.02, ncols=4, num_panels=len(var)
)
axs = []
else:
var = [var]

# iterate over the variables to plot
for i, v in enumerate(var):
if len(var) > 1:
ax = plt.subplot(grid[i])
axs.append(ax)
else:
# if a single variable and no grid, then one ax object suffices
fig, ax = plt.subplots(1, 1, figsize=figsize)

# if no covariate is specified, 'sns.regplot' will take the values of all observations
if covariate is None:
sns.regplot(
data=df,
x=anchor_key,
y=v,
order=order,
color=line_palette,
scatter=show_scatter,
ax=ax,
line_kws=regplot_kwargs,
)
else:
# make a categorical color palette if none was specified and there are several regplots to be plotted
if isinstance(line_palette, str) or line_palette is None:
_set_default_colors_for_categorical_obs(adata, covariate)
line_palette = adata.uns[covariate + "_colors"]
covariate_instances = df[covariate].unique()

# iterate over all covariate values and make 'sns.regplot' for each
for i, co in enumerate(covariate_instances):
sns.regplot(
data=df.loc[df[covariate] == co],
x=anchor_key,
y=v,
order=order,
color=line_palette[i],
scatter=show_scatter,
ax=ax,
label=co,
line_kws=regplot_kwargs,
)
label_colors, _ = ax.get_legend_handles_labels()
ax.legend(label_colors, covariate_instances)
# add scatter plot if specified
if show_scatter:
if color is None:
plt.scatter(data=df, x=anchor_key, y=v, color="grey", **scatterplot_kwargs)
# if variable to plot on color palette is categorical, make categorical color palette
elif df[color].dtype.name == "category":
unique_colors = df[color].unique()
cNorm = colors.Normalize(vmin=0, vmax=len(unique_colors))
scalarMap = cm.ScalarMappable(norm=cNorm, cmap=scatter_palette)
for i in range(len(unique_colors)):
plt.scatter(
data=df.loc[df[color] == unique_colors[i]],
x=anchor_key,
y=v,
color=scalarMap.to_rgba(i),
**scatterplot_kwargs,
)
# if variable to plot on color palette is not categorical
else:
plt.scatter(data=df, x=anchor_key, y=v, c=color, cmap=scatter_palette, **scatterplot_kwargs)
if title is not None:
ax.set(title=title)
if axis_label is None:
ax.set(xlabel=f"distance to {anchor_key}")
else:
ax.set(xlabel=axis_label)

# remove line palette if it was made earlier in function
if f"{covariate}_colors" in adata.uns:
del line_palette

axs = axs if len(var) > 1 else ax

if save is not None:
save_fig(fig, path=save, transparent=False, dpi=dpi)
if return_ax:
return axs
2 changes: 2 additions & 0 deletions squidpy/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""The design matrix module."""
from squidpy.tl._var_by_distance import var_by_distance
Loading

0 comments on commit b644428

Please sign in to comment.