From 3bbb4dade629961de3c081d4c0064f6a776e20e3 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 19 Oct 2023 10:57:42 +0200 Subject: [PATCH] adapt tests --- .../generic/test_conditional_neural_problem.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/problems/generic/test_conditional_neural_problem.py b/tests/problems/generic/test_conditional_neural_problem.py index 4ef2dd0fc..18a982753 100644 --- a/tests/problems/generic/test_conditional_neural_problem.py +++ b/tests/problems/generic/test_conditional_neural_problem.py @@ -1,6 +1,5 @@ import pytest -from optax import adagrad - +import optax import jax.numpy as jnp import numpy as np @@ -90,9 +89,8 @@ def test_pass_custom_optimizers(self, adata_time: ad.AnnData): problem = ConditionalNeuralProblem(adata=adata_time) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))] problem = problem.prepare(key="time", joint_attr="X_pca") - custom_opt_f = adagrad(1e-4) - custom_opt_g = adagrad(1e-3) + custom_opt_f = optax.adagrad(1e-4) + custom_opt_g = optax.adagrad(1e-3) problem = problem.solve(iterations=2, opt_f=custom_opt_f, opt_g=custom_opt_g, cond_dim=1) - assert problem.optimizer_f == custom_opt_f - assert problem.optimizer_g == custom_opt_g + \ No newline at end of file