From 8182f6ce0027d088a03cf04924765567a9ec77fc Mon Sep 17 00:00:00 2001 From: SarahOuologuem Date: Mon, 11 Mar 2024 09:21:36 +0000 Subject: [PATCH 1/4] force color to be a string --- muon/_core/plot.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/muon/_core/plot.py b/muon/_core/plot.py index da8cab3..7d72f3e 100644 --- a/muon/_core/plot.py +++ b/muon/_core/plot.py @@ -22,7 +22,7 @@ def scatter( data: Union[AnnData, MuData], x: Optional[str] = None, y: Optional[str] = None, - color: Optional[Union[str, Sequence[str]]] = None, + color: Optional[str] = None, use_raw: Optional[bool] = None, layers: Optional[Union[str, Sequence[str]]] = None, **kwargs, @@ -42,8 +42,8 @@ def scatter( x coordinate y : Optional[str] y coordinate - color : Optional[Union[str, Sequence[str]]], optional (default: None) - Keys for variables or annotations of observations (.obs columns), + color : Optional[str], optional (default: None) + Key for variables or annotations of observations (.obs columns), or a hex colour specification. use_raw : Optional[bool], optional (default: None) Use `.raw` attribute of the modality where a feature (from `color`) is derived from. @@ -72,10 +72,10 @@ def scatter( if isinstance(color, str): color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) color_obs = pd.DataFrame({color: color_obs}) - color = [color] else: - # scanpy#311 / scanpy#1497 has to be fixed for this to work - color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) + raise TypeError("Expected color to be a string.") + + color_obs.index = data.obs_names obs = pd.concat([obs, color_obs], axis=1, ignore_index=False) @@ -86,14 +86,14 @@ def scatter( # and are now stored in .obs retval = sc.pl.scatter(ad, x=x, y=y, color=color, **kwargs) if color is not None: - for col in color: - try: - data.uns[f"{col}_colors"] = ad.uns[f"{col}_colors"] - except KeyError: - pass + try: + data.uns[f"{color}_colors"] = ad.uns[f"{color}_colors"] + except KeyError: + pass return retval + # # Embedding # From 52f9e1d8d53cb1cc0d01edb66f45fb7b38ecb18b Mon Sep 17 00:00:00 2001 From: Danila Date: Wed, 16 Oct 2024 16:36:56 -0700 Subject: [PATCH 2/4] Use sc.pl.scatter in mu.pl.scatter --- muon/_core/plot.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/muon/_core/plot.py b/muon/_core/plot.py index 7d72f3e..752518c 100644 --- a/muon/_core/plot.py +++ b/muon/_core/plot.py @@ -53,9 +53,7 @@ def scatter( No layer is used by default. A single layer value will be expanded to [layer, layer, layer]. """ if isinstance(data, AnnData): - return sc.pl.embedding( - data, x=x, y=y, color=color, use_raw=use_raw, layers=layers, **kwargs - ) + return sc.pl.scatter(data, x=x, y=y, color=color, use_raw=use_raw, layers=layers, **kwargs) if isinstance(layers, str) or layers is None: layers = [layers, layers, layers] @@ -74,7 +72,6 @@ def scatter( color_obs = pd.DataFrame({color: color_obs}) else: raise TypeError("Expected color to be a string.") - color_obs.index = data.obs_names obs = pd.concat([obs, color_obs], axis=1, ignore_index=False) @@ -93,7 +90,6 @@ def scatter( return retval - # # Embedding # From 34c9de2b8af15bf735613566abf758f2e240d669 Mon Sep 17 00:00:00 2001 From: Danila Date: Wed, 16 Oct 2024 16:51:36 -0700 Subject: [PATCH 3/4] Add a simple test for mu.pl.scatter() --- tests/test_muon_plot.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/test_muon_plot.py diff --git a/tests/test_muon_plot.py b/tests/test_muon_plot.py new file mode 100644 index 0000000..d1642cf --- /dev/null +++ b/tests/test_muon_plot.py @@ -0,0 +1,31 @@ +import pytest + +import numpy as np +from scipy import sparse +import pandas as pd +from anndata import AnnData +import muon as mu +from muon import MuData +import matplotlib + +matplotlib.use("Agg") + + +@pytest.fixture() +def mdata(): + mdata = MuData( + { + "mod1": AnnData(np.arange(0, 100, 0.1).reshape(-1, 10)), + "mod2": AnnData(np.arange(101, 2101, 1).reshape(-1, 20)), + } + ) + mdata.var_names_make_unique() + yield mdata + + +class TestScatter: + def test_pl_scatter(self, mdata): + mdata = mdata.copy() + np.random.seed(42) + mdata.obs["condition"] = np.random.choice(["a", "b"], mdata.n_obs) + mu.pl.scatter(mdata, x="mod1:0", y="mod2:0", color="condition") From 9b98f5192458605d0bd2d86e3a7a029c0bbbf484 Mon Sep 17 00:00:00 2001 From: Danila Date: Wed, 16 Oct 2024 17:10:28 -0700 Subject: [PATCH 4/4] Propagate error to scanpy accoding to its promised functionality --- muon/_core/plot.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/muon/_core/plot.py b/muon/_core/plot.py index 752518c..52be000 100644 --- a/muon/_core/plot.py +++ b/muon/_core/plot.py @@ -1,4 +1,4 @@ -from typing import Union, List, Optional, Iterable, Sequence, Dict +from typing import Dict, Iterable, List, Optional, Sequence, Union import warnings from matplotlib.axes import Axes @@ -22,7 +22,7 @@ def scatter( data: Union[AnnData, MuData], x: Optional[str] = None, y: Optional[str] = None, - color: Optional[str] = None, + color: Optional[Union[str, Sequence[str]]] = None, use_raw: Optional[bool] = None, layers: Optional[Union[str, Sequence[str]]] = None, **kwargs, @@ -42,8 +42,8 @@ def scatter( x coordinate y : Optional[str] y coordinate - color : Optional[str], optional (default: None) - Key for variables or annotations of observations (.obs columns), + color : Optional[Union[str, Sequence[str]]], optional (default: None) + Keys or a single key for variables or annotations of observations (.obs columns), or a hex colour specification. use_raw : Optional[bool], optional (default: None) Use `.raw` attribute of the modality where a feature (from `color`) is derived from. @@ -71,7 +71,7 @@ def scatter( color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) color_obs = pd.DataFrame({color: color_obs}) else: - raise TypeError("Expected color to be a string.") + color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) color_obs.index = data.obs_names obs = pd.concat([obs, color_obs], axis=1, ignore_index=False)