Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 10, 2023
1 parent 4152baf commit 06e55c0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 26 deletions.
7 changes: 5 additions & 2 deletions src/moscot/backends/ott/_jax_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
raise ValueError("If `policy_pairs` contains more than 1 value, `sample_to_idx` is required.")
sample_to_idx = {self.policy_pairs[0][0]: 0, self.policy_pairs[0][1]: 1}
self._sample_to_idx = sample_to_idx

@partial(jax.jit, static_argnames=["index"])
def _sample_source(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, None]:
"""Jitted sample function."""
Expand All @@ -57,7 +58,7 @@ def _sample_target(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.nd
"""Jitted sample function."""
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(b[index]))
return samples, None

@partial(jax.jit, static_argnames=["index"])
def _sample_target_conditional(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Jitted sample function."""
Expand Down Expand Up @@ -112,7 +113,9 @@ def __call__(
self.distributions[self.sample_to_idx[policy_pair[0]]]
), None if self.conditions is None else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]])
if sample == "target":
return jnp.asarray(self.distributions[self.sample_to_idx[policy_pair[1]]]), None if self.conditions is None else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]])
return jnp.asarray(
self.distributions[self.sample_to_idx[policy_pair[1]]]
), None if self.conditions is None else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]])
if sample == "both":
return (
jnp.asarray(self.distributions[self.sample_to_idx[policy_pair[0]]]),
Expand Down
43 changes: 23 additions & 20 deletions src/moscot/backends/ott/_neuraldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,19 +529,23 @@ def train_neuraldual(

if not self.is_balanced:
# resample source with unbalanced marginals
batch["source"], batch["condition"] = trainloader.unbalanced_resample(source_key, (source, condition), a) # type: ignore[misc] # noqa: E501
batch["source"], batch["condition"] = trainloader.unbalanced_resample(source_key, (source, condition), a) # type: ignore[misc] # noqa: E501
# train step for potential g directly updating the train state
self.state_f, loss, W2_dist, loss_f, loss_g, penalty = self.train_step_f(self.state_f, self.state_g, batch)
self.state_f, loss, W2_dist, loss_f, loss_g, penalty = self.train_step_f(
self.state_f, self.state_g, batch
)
average_meters["train_loss_f"].update(loss_f)
#logs = self._update_logs(logs, None, loss_f, None, is_train_set=True)
# logs = self._update_logs(logs, None, loss_f, None, is_train_set=True)
logs = self._update_logs_new(logs, loss, W2_dist, loss_f, loss_g, penalty, is_train_set=True)
# resample target batch with unbalanced marginals
if not self.is_balanced:
target_key, self.key = jax.random.split(self.key, 2)
batch["target"], batch["condition"] = trainloader.unbalanced_resample(target_key, (target, condition), b)
batch["target"], batch["condition"] = trainloader.unbalanced_resample(
target_key, (target, condition), b
)
# train step for potential f directly updating the train state
self.state_g, loss, W2_dist, loss_f, loss_g, penalty = self.train_step_g(self.state_f, self.state_g, batch)
#logs = self._update_logs(logs, loss_g, None, W2_dist, is_train_set=True)
# logs = self._update_logs(logs, loss_g, None, W2_dist, is_train_set=True)
logs = self._update_logs_new(logs, loss, W2_dist, loss_f, loss_g, penalty, is_train_set=True)
# clip weights of g
if not self.pos_weights:
Expand All @@ -558,10 +562,14 @@ def train_neuraldual(
source_key, policy_pair, sample="source"
)
valid_batch["target"], batch["condition"] = validloader(source_key, policy_pair, sample="target")
valid_loss, valid_W2_dist, valid_loss_f, valid_loss_g, valid_penalty = self.valid_step_f(self.state_f, self.state_g, valid_batch)
#valid_loss_g, valid_w_dist = self.valid_step_g(self.state_f, self.state_g, valid_batch)
#logs = self._update_logs(logs, valid_loss_f, valid_loss_g, valid_W2_dist, is_train_set=False)
logs = self._update_logs_new(logs, valid_loss, valid_W2_dist, valid_loss_f, valid_loss_g, valid_penalty, is_train_set=False)
valid_loss, valid_W2_dist, valid_loss_f, valid_loss_g, valid_penalty = self.valid_step_f(
self.state_f, self.state_g, valid_batch
)
# valid_loss_g, valid_w_dist = self.valid_step_g(self.state_f, self.state_g, valid_batch)
# logs = self._update_logs(logs, valid_loss_f, valid_loss_g, valid_W2_dist, is_train_set=False)
logs = self._update_logs_new(
logs, valid_loss, valid_W2_dist, valid_loss_f, valid_loss_g, valid_penalty, is_train_set=False
)
a, b = validloader.compute_unbalanced_marginals(valid_batch["source"], valid_batch["target"])
_, _, _, _, loss_eta, loss_xi = self.unbalancedness_step_fn(
valid_batch["source"],
Expand All @@ -577,19 +585,19 @@ def train_neuraldual(
)

# update best model and patience as necessary
#try:
# try:
# total_loss = logs[self.patience_metric][-1]
#except ValueError:
# except ValueError:
# f"Unknown metric: {self.patience_metric}."
#if total_loss < best_loss:
# if total_loss < best_loss:
# best_loss = total_loss
# best_iter_distance = valid_average_meters["valid_neural_dual_dist"].avg
# best_params_f = self.state_f.params
# best_params_g = self.state_g.params
# curr_patience = 0
#else:
# else:
# curr_patience += 1
#if curr_patience >= self.patience:
# if curr_patience >= self.patience:
# break
if self.best_model_selection:
self.state_f = self.state_f.replace(params=best_params_f)
Expand All @@ -605,7 +613,6 @@ def get_step_fn(self, train: bool, to_optimize: Literal["f", "g"]):

def loss_fn(params_f, params_g, f_value, g_value, f_gradient, batch):
"""Loss function for both potentials."""

source, target, condition = batch["source"], batch["target"], batch["condition"]
g_target = g_value(params_g)(target, condition)
grad_f_s = f_gradient(params_f)(source, condition)
Expand Down Expand Up @@ -634,7 +641,6 @@ def loss_fn(params_f, params_g, f_value, g_value, f_gradient, batch):
return loss_g, (loss, W2_dist, loss_f, loss_g, penalty)
raise NotImplementedError


@jax.jit
def step_fn(state_f, state_g, batch):
"""Step function of either training or validation."""
Expand Down Expand Up @@ -729,13 +735,11 @@ def _update_logs(
logs["valid_w_dist"].append(float(w_dist))
return logs



@staticmethod
def _update_logs_new(
logs: Dict[str, List[float]],
loss: Optional[jnp.ndarray],
W2_dist: Optional[jnp.ndarray],
W2_dist: Optional[jnp.ndarray],
loss_f: Optional[jnp.ndarray],
loss_g: Optional[jnp.ndarray],
penalty: Optional[jnp.ndarray],
Expand Down Expand Up @@ -765,4 +769,3 @@ def _update_logs_new(
if penalty is not None:
logs["valid_penalty"].append(float(penalty))
return logs

8 changes: 4 additions & 4 deletions src/moscot/backends/ott/nets/_icnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,19 @@ def potential_value_fn(
return lambda x, c: self.apply({"params": params}, x, c=c) # type: ignore[misc]
raise ValueError("`is_potential` must be `True`.")

#assert other_potential_value_fn is not None, (
# assert other_potential_value_fn is not None, (
# "The value of the gradient-based potential depends " "on the value of the other potential."
#)
# )

#def value_fn(x: jnp.ndarray) -> jnp.ndarray:
# def value_fn(x: jnp.ndarray) -> jnp.ndarray:
# squeeze = x.ndim == 1
# if squeeze:
# x = jnp.expand_dims(x, 0)
# grad_g_x = jax.lax.stop_gradient(self.apply({"params": params}, x))
# value = -other_potential_value_fn(grad_g_x) + jax.vmap(jnp.dot)(grad_g_x, x)
# return value.squeeze(0) if squeeze else value

#return value_fn
# return value_fn

def potential_gradient_fn(
self,
Expand Down

0 comments on commit 06e55c0

Please sign in to comment.