Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
ArinaDanilina authored Feb 20, 2024
2 parents b3e37f2 + c35430e commit 441339e
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 14 deletions.
10 changes: 5 additions & 5 deletions src/moscot/plotting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_create_col_colors,
_heatmap,
_input_to_adatas,
_plot_temporal,
_plot_scatter,
_sankey,
get_plotting_vars,
)
Expand Down Expand Up @@ -292,9 +292,9 @@ def push(
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])

fig = _plot_temporal(
fig = _plot_scatter(
adata=adata,
temporal_key=data["temporal_key"],
generic_key=data["key"],
key_stored=key,
source=data["source"],
target=data["target"],
Expand Down Expand Up @@ -400,9 +400,9 @@ def pull(
if data["data"] is not None and data["subset"] is not None and cmap is None:
cmap = _create_col_colors(adata, data["data"], data["subset"])

fig = _plot_temporal(
fig = _plot_scatter(
adata=adata,
temporal_key=data["temporal_key"],
generic_key=data["key"],
key_stored=key,
source=data["source"],
target=data["target"],
Expand Down
9 changes: 5 additions & 4 deletions src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def _input_to_adatas(
raise ValueError(f"Unable to interpret input of type `{type(inp)}`.")


def _plot_temporal(
def _plot_scatter(
adata: AnnData,
temporal_key: str,
generic_key: str,
key_stored: str,
source: float,
target: float,
Expand Down Expand Up @@ -430,7 +430,8 @@ def _plot_temporal(
titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))])
else:
titles = [
f"{categories if categories is not None else 'Cells'} at time {source if push else target} and {name}"
f"{'Push' if push else 'Pull'} {categories if categories is not None else 'cells'} "
+ f"from {source if push else target} to {target if push else source}"
]
for i, ax in enumerate(axs):
# we need to create adata_view because otherwise the view of the adata is copied in the next step i+1
Expand All @@ -446,7 +447,7 @@ def _plot_temporal(
adata_view = adata
else:
tmp = np.full(len(adata), constant_fill_value)
mask = adata.obs[temporal_key] == time_points[i]
mask = adata.obs[generic_key] == time_points[i]

tmp[mask] = adata[mask].obs[key_stored]
if scale:
Expand Down
10 changes: 9 additions & 1 deletion src/moscot/problems/generic/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ def push(
if TYPE_CHECKING:
assert isinstance(key_added, str)
plot_vars = {
"distribution_key": self.batch_key,
"source": source,
"target": target,
"key": self.batch_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
}
self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
set_plotting_vars(self.adata, _constants.PUSH, key=key_added, value=plot_vars)
Expand Down Expand Up @@ -232,6 +236,10 @@ def pull(
if key_added is not None:
plot_vars = {
"key": self.batch_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
"source": source,
"target": target,
}
self.adata.obs[key_added] = self._flatten(result, key=self.batch_key)
set_plotting_vars(self.adata, _constants.PULL, key=key_added, value=plot_vars)
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def push(
plot_vars = {
"source": source,
"target": target,
"temporal_key": self.temporal_key,
"key": self.temporal_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
}
Expand Down Expand Up @@ -526,7 +526,7 @@ def pull(

if key_added is not None:
plot_vars = {
"temporal_key": self.temporal_key,
"key": self.temporal_key,
"data": data if isinstance(data, str) else None,
"subset": subset,
"source": source,
Expand Down
4 changes: 2 additions & 2 deletions tests/plotting/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData:
@pytest.fixture()
def adata_pl_push(adata_time: AnnData) -> AnnData:
rng = np.random.RandomState(0)
plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
set_plotting_vars(adata_time, _constants.PUSH, key=_constants.PUSH, value=plot_vars)
Expand All @@ -60,7 +60,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData:
@pytest.fixture()
def adata_pl_pull(adata_time: AnnData) -> AnnData:
rng = np.random.RandomState(0)
plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1}
adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"]
adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category")
set_plotting_vars(adata_time, _constants.PULL, key=_constants.PULL, value=plot_vars)
Expand Down

0 comments on commit 441339e

Please sign in to comment.