Skip to content

Commit

Permalink
adapt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MUCDK committed Oct 19, 2023
1 parent 21f3309 commit 3bbb4da
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tests/problems/generic/test_conditional_neural_problem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
from optax import adagrad

import optax
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -90,9 +89,8 @@ def test_pass_custom_optimizers(self, adata_time: ad.AnnData):
problem = ConditionalNeuralProblem(adata=adata_time)
adata_time = adata_time[adata_time.obs["time"].isin((0, 1))]
problem = problem.prepare(key="time", joint_attr="X_pca")
custom_opt_f = adagrad(1e-4)
custom_opt_g = adagrad(1e-3)
custom_opt_f = optax.adagrad(1e-4)
custom_opt_g = optax.adagrad(1e-3)

problem = problem.solve(iterations=2, opt_f=custom_opt_f, opt_g=custom_opt_g, cond_dim=1)
assert problem.optimizer_f == custom_opt_f
assert problem.optimizer_g == custom_opt_g

0 comments on commit 3bbb4da

Please sign in to comment.