Skip to content

Commit

Permalink
adapt distributioncontainer
Browse files Browse the repository at this point in the history
  • Loading branch information
MUCDK committed Nov 3, 2023
1 parent 3bd674f commit 61e3a01
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/moscot/backends/ott/_jax_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _sample_policy_pair(key: jax.random.KeyArray) -> Tuple[Tuple[Any, Any]]:
"""Sample a policy pair."""
index = jax.random.randint(key, shape=[], minval=0, maxval=len(self.policy_pairs))
return self.policy_pairs[index]

self._sample_source = _sample_source if self.conditions is None else _sample_source_conditional
self._sample_target = _sample_target
self.sample_policy_pair = _sample_policy_pair
Expand Down
1 change: 0 additions & 1 deletion src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
from functools import partial
from typing import (
Any,
Expand Down
4 changes: 1 addition & 3 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,4 @@ def evaluate_b(self, cond: ArrayLike, x: ArrayLike) -> ArrayLike:
if cond.n_dim != 2:
cond = cond[:, None]
input = jnp.concatenate((x, cond), axis=-1)
return self._model.state_xi.apply_fn(
{"params": self._model.state_xi.params}, input
) # type:ignore[union-attr]
return self._model.state_xi.apply_fn({"params": self._model.state_xi.params}, input) # type:ignore[union-attr]
6 changes: 1 addition & 5 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@
)
from moscot.backends.ott._jax_data import JaxSampler
from moscot.backends.ott._neuraldual import OTTNeuralDualSolver
from moscot.backends.ott._utils import (
alpha_to_fused_penalty,
check_shapes,
ensure_2d,
)
from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d
from moscot.backends.ott.output import CondNeuralDualOutput, NeuralDualOutput, OTTOutput
from moscot.base.solver import OTSolver
from moscot.costs import get_cost
Expand Down
3 changes: 1 addition & 2 deletions src/moscot/utils/tagged_array.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import enum
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Hashable, Literal, Optional, Tuple, TypeVar, Union

Expand Down Expand Up @@ -337,7 +336,7 @@ def from_adata(
return cls(xy=xy_data, xx=xx_data, a=a, b=b, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn)


class DistributionCollection(OrderedDict[K, DistributionContainer]):
class DistributionCollection(Dict[K, DistributionContainer]):
"""Collection of distributions."""

def __repr__(self) -> str:
Expand Down

0 comments on commit 61e3a01

Please sign in to comment.