Skip to content

Commit

Permalink
adapt tests and valid loader conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
MUCDK committed Nov 5, 2023
1 parent 295eb6e commit 08fcec2
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def evaluate_a(self, cond: ArrayLike, x: ArrayLike) -> ArrayLike:
"""Conditional marginals of the source distribution."""
if self._model.mlp_xi is None:
raise ValueError("The source marginals have not been traced.")
if cond.n_dim != 2:
if cond.ndim != 2:
cond = cond[:, None]
input = jnp.concatenate((x, cond), axis=-1)
return self._model.state_eta.apply_fn(
Expand All @@ -815,7 +815,7 @@ def evaluate_b(self, cond: ArrayLike, x: ArrayLike) -> ArrayLike:
"""Conditional marginals of the target distribution."""
if self._model.mlp_eta is None:
raise ValueError("The target marginals have not been traced.")
if cond.n_dim != 2:
if cond.ndim != 2:
cond = cond[:, None]
input = jnp.concatenate((x, cond), axis=-1)
return self._model.state_xi.apply_fn({"params": self._model.state_xi.params}, input) # type:ignore[union-attr]
3 changes: 2 additions & 1 deletion src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _prepare( # type: ignore[override]
train_conditions = [d.conditions for d in distributions.values()]
train_a = [d.a for d in distributions.values()]
train_b = [d.b for d in distributions.values()]
valid_data, valid_a, valid_b = train_data, train_a, train_b
valid_data, valid_conditions, valid_a, valid_b = train_data, train_conditions, train_a, train_b
else:
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")
Expand Down Expand Up @@ -579,6 +579,7 @@ def _prepare( # type: ignore[override]
batch_size=batch_size,
**kwargs,
)

return (self._train_sampler, self._valid_sampler)

def _solve(self, data_samplers: Tuple[JaxSampler, JaxSampler]) -> CondNeuralDualOutput: # type: ignore[override]
Expand Down
14 changes: 0 additions & 14 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,20 +1037,6 @@ def solve(
input_dim = self.distributions[tmp].xy.shape[1]
cond_dim = self.distributions[tmp].conditions.shape[1]

# self._solver = backends.get_solver(
# problem_kind=self._problem_kind,
# input_dim=self.distributions[tmp].xy.shape[1], # type:ignore[union-attr, index]
# cond_dim=self.distributions[tmp].conditions.shape[1], # type:ignore[union-attr, index
# sample_pairs=self._sample_pairs,
# **kwargs,
# )

# self._solution = self._solver( # type: ignore[misc]
# distributions=self.distributions, # type: ignore[arg-type] #TODO: handle better
# sample_pairs=self._sample_pairs,
# device=device,
# )

solver_class = backends.get_solver(
self.problem_kind, solver_name=solver_name, backend=backend, return_class=True
)
Expand Down
4 changes: 1 addition & 3 deletions src/moscot/base/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ def __call__(
-------
The optimal transport solution.
"""
if is_conditional:
kwargs["distributions"] = xy
else:
if not is_conditional:
data = self._get_array_data(xy=xy, x=x, y=y, tags=tags)
kwargs = {**kwargs, **self._untag(data)}
res = super().__call__(**kwargs)
Expand Down
14 changes: 7 additions & 7 deletions tests/problems/generic/test_conditional_neural_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def test_learning_rescaling_factors(self, adata_time: ad.AnnData):
assert isinstance(problem.solution, BaseSolverOutput)
assert isinstance(problem.solution, NeuralDualOutput)

array = adata_time.obsm["X_pca"]
cond1 = np.array(jnp.ones_like(array))
cond2 = np.array(jnp.zeros_like(array))
learnt_eta_1 = problem.solution.evaluate_a(array, cond1)
learnt_xi_1 = problem.solution.evaluate_b(array, cond1)
learnt_eta_2 = problem.solution.evaluate_a(array, cond2)
learnt_xi_2 = problem.solution.evaluate_b(array, cond2)
array = np.asarray(adata_time.obsm["X_pca"].copy())
cond1 = jnp.ones((array.shape[0],))
cond2 = jnp.zeros((array.shape[0],))
learnt_eta_1 = problem.solution.evaluate_a(cond1, array)
learnt_xi_1 = problem.solution.evaluate_b(cond1, array)
learnt_eta_2 = problem.solution.evaluate_a(cond2, array)
learnt_xi_2 = problem.solution.evaluate_b(cond2, array)
assert learnt_eta_1.shape == (array.shape[0], 1)
assert learnt_xi_1.shape == (array.shape[0], 1)
assert learnt_eta_2.shape == (array.shape[0], 1)
Expand Down

0 comments on commit 08fcec2

Please sign in to comment.