Skip to content

Commit

Permalink
Don't assert for index/column names when comparing series (e.g. in `s…
Browse files Browse the repository at this point in the history
…et_x`, `set_y`, and `set_x`) (#669)

* don't assert for index/column names

* create and use util func

* use the utils everywhere
  • Loading branch information
selmanozleyen authored Apr 2, 2024
1 parent 25c6a0e commit 3151da7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
11 changes: 11 additions & 0 deletions src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ def _validate_args_cell_transition(
raise TypeError(f"Expected argument to be either `str` or `dict`, found `{type(arg)}`.")


def _assert_series_match(a: pd.Series, b: pd.Series) -> None:
"""Assert that two series are equal ignoring the names."""
pd.testing.assert_series_equal(a, b, check_names=False)


def _assert_columns_and_index_match(a: pd.Series, b: pd.DataFrame) -> None:
"""Assert that a series and a dataframe's index and columns are matching."""
_assert_series_match(a, b.index.to_series())
_assert_series_match(a, b.columns.to_series())


def _get_cell_indices(
adata: AnnData,
key: Optional[str] = None,
Expand Down
33 changes: 15 additions & 18 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from moscot.base.output import BaseSolverOutput, MatrixSolverOutput
from moscot.base.problems._utils import (
TimeScalesHeatKernel,
_assert_columns_and_index_match,
_assert_series_match,
require_solution,
wrap_prepare,
wrap_solve,
Expand Down Expand Up @@ -530,8 +532,8 @@ def set_solution(
raise ValueError(f"`{self}` already contains a solution, use `overwrite=True` to overwrite it.")

if isinstance(solution, pd.DataFrame):
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), solution.index.to_series())
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), solution.columns.to_series())
_assert_series_match(self.adata_src.obs_names.to_series(), solution.index.to_series())
_assert_series_match(self.adata_tgt.obs_names.to_series(), solution.columns.to_series())
solution = solution.to_numpy()
if not isinstance(solution, BaseSolverOutput):
solution = MatrixSolverOutput(solution, **kwargs)
Expand Down Expand Up @@ -729,13 +731,12 @@ def set_graph_xy(
"""
expected_series = pd.concat([self.adata_src.obs_names.to_series(), self.adata_tgt.obs_names.to_series()])
if isinstance(data, pd.DataFrame):
pd.testing.assert_series_equal(expected_series, data.index.to_series())
pd.testing.assert_series_equal(expected_series, data.columns.to_series())
_assert_columns_and_index_match(expected_series, data)
data_src = data.to_numpy()
elif isinstance(data, tuple):
data_src, index_src, index_tgt = data
pd.testing.assert_series_equal(expected_series, index_src)
pd.testing.assert_series_equal(expected_series, index_tgt)
_assert_series_match(expected_series, index_src)
_assert_series_match(expected_series, index_tgt)
else:
raise ValueError(
"Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`, `pd.Series`), "
Expand Down Expand Up @@ -780,12 +781,11 @@ def set_graph_x(
"""
expected_series = self.adata_src.obs_names.to_series()
if isinstance(data, pd.DataFrame):
pd.testing.assert_series_equal(expected_series, data.index.to_series())
pd.testing.assert_series_equal(expected_series, data.columns.to_series())
_assert_columns_and_index_match(expected_series, data)
data_src = data.to_numpy()
elif isinstance(data, tuple):
data_src, index_src = data
pd.testing.assert_series_equal(expected_series, index_src)
_assert_series_match(expected_series, index_src)
else:
raise ValueError(
"Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`), "
Expand Down Expand Up @@ -830,12 +830,11 @@ def set_graph_y(
"""
expected_series = self.adata_tgt.obs_names.to_series()
if isinstance(data, pd.DataFrame):
pd.testing.assert_series_equal(expected_series, data.index.to_series())
pd.testing.assert_series_equal(expected_series, data.columns.to_series())
_assert_columns_and_index_match(expected_series, data)
data_src = data.to_numpy()
elif isinstance(data, tuple):
data_src, index_src = data
pd.testing.assert_series_equal(expected_series, index_src)
_assert_series_match(expected_series, index_src)
else:
raise ValueError(
"Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`), "
Expand Down Expand Up @@ -869,8 +868,8 @@ def set_xy(
- :attr:`xy` - the :term:`linear term`.
- :attr:`stage` - set to ``'prepared'``.
"""
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.index.to_series())
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.columns.to_series())
_assert_series_match(self.adata_src.obs_names.to_series(), data.index.to_series())
_assert_series_match(self.adata_tgt.obs_names.to_series(), data.columns.to_series())

self._xy = TaggedArray(data_src=data.to_numpy(), data_tgt=None, tag=Tag(tag), cost="cost")
self._stage = "prepared"
Expand All @@ -893,8 +892,7 @@ def set_x(self, data: pd.DataFrame, tag: Literal["cost_matrix", "kernel"]) -> No
- :attr:`x` - the source :term:`quadratic term`.
- :attr:`stage` - set to ``'prepared'``.
"""
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.index.to_series())
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.columns.to_series())
_assert_columns_and_index_match(self.adata_src.obs_names.to_series(), data)

if self.problem_kind == "linear":
logger.info(f"Changing the problem type from {self.problem_kind!r} to 'quadratic (fused)'.")
Expand All @@ -920,8 +918,7 @@ def set_y(self, data: pd.DataFrame, tag: Literal["cost_matrix", "kernel"]) -> No
- :attr:`y` - the target :term:`quadratic term`.
- :attr:`stage` - set to ``'prepared'``.
"""
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.index.to_series())
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.columns.to_series())
_assert_columns_and_index_match(self.adata_tgt.obs_names.to_series(), data)

if self.problem_kind == "linear":
logger.info(f"Changing the problem type from {self.problem_kind!r} to 'quadratic (fused)'.")
Expand Down

0 comments on commit 3151da7

Please sign in to comment.