Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to work on more recent versions of jax #9

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
pyyaml==5.3.1
tqdm==4.48.2
numpy==1.19.2
pandas==1.1.2
dm-haiku==0.0.2
gym[atari]==0.17.2
dm_control==0.0.322773188
tensorboardX==2.1
dm-haiku>=0.0.2
dm_control>=0.0.322773188
gym[atari]>=0.17.2
jax>=0.2.7
numpy>=1.19.2
optax>=0.0.2
pandas>=1.1.2
pyyaml>=5.3.1
tensorboardX>=2.1
tqdm>=4.48.2
6 changes: 3 additions & 3 deletions rljax/algorithm/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.algorithm.base_class import OffPolicyActorCritic
from rljax.network import ContinuousQFunction, DeterministicPolicy
Expand Down Expand Up @@ -81,13 +81,13 @@ def fn_actor(s):
# Critic.
self.critic = hk.without_apply_rng(hk.transform(fn_critic))
self.params_critic = self.params_critic_target = self.critic.init(next(self.rng), *self.fake_args_critic)
opt_init, self.opt_critic = optix.adam(lr_critic)
opt_init, self.opt_critic = optax.adam(lr_critic)
self.opt_state_critic = opt_init(self.params_critic)

# Actor.
self.actor = hk.without_apply_rng(hk.transform(fn_actor))
self.params_actor = self.params_actor_target = self.actor.init(next(self.rng), *self.fake_args_actor)
opt_init, self.opt_actor = optix.adam(lr_actor)
opt_init, self.opt_actor = optax.adam(lr_actor)
self.opt_state_actor = opt_init(self.params_actor)

# Other parameters.
Expand Down
4 changes: 2 additions & 2 deletions rljax/algorithm/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.algorithm.base_class import QLearning
from rljax.network import DiscreteQFunction
Expand Down Expand Up @@ -74,7 +74,7 @@ def fn(s):

self.net = hk.without_apply_rng(hk.transform(fn))
self.params = self.params_target = self.net.init(next(self.rng), *self.fake_args)
opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size)
opt_init, self.opt = optax.adam(lr, eps=0.01 / batch_size)
self.opt_state = opt_init(self.params)

@partial(jax.jit, static_argnums=0)
Expand Down
6 changes: 3 additions & 3 deletions rljax/algorithm/fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.algorithm.qrdqn import QRDQN
from rljax.network import CumProbNetwork, DiscreteImplicitQuantileFunction, make_quantile_nerwork
Expand Down Expand Up @@ -81,13 +81,13 @@ def fn(s, cum_p):

self.net, self.params, fake_feature = make_quantile_nerwork(self.rng, state_space, action_space, fn, num_quantiles)
self.params_target = self.params
opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size)
opt_init, self.opt = optax.adam(lr, eps=0.01 / batch_size)
self.opt_state = opt_init(self.params)

# Fraction proposal network.
self.cum_p_net = hk.without_apply_rng(hk.transform(lambda s: CumProbNetwork(num_quantiles=num_quantiles)(s)))
self.params_cum_p = self.cum_p_net.init(next(self.rng), fake_feature)
opt_init, self.opt_cum_p = optix.rmsprop(lr_cum_p, decay=0.95, eps=1e-5, centered=True)
opt_init, self.opt_cum_p = optax.rmsprop(lr_cum_p, decay=0.95, eps=1e-5, centered=True)
self.opt_state_cum_p = opt_init(self.params_cum_p)

def forward(self, state):
Expand Down
4 changes: 2 additions & 2 deletions rljax/algorithm/misc/discor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.network import ContinuousQFunction
from rljax.util import load_params, save_params
Expand Down Expand Up @@ -38,7 +38,7 @@ def fn_error(s, a):
# Error model.
self.error = hk.without_apply_rng(hk.transform(fn_error))
self.params_error = self.params_error_target = self.error.init(next(self.rng), *self.fake_args_critic)
opt_init, self.opt_error = optix.adam(lr_error)
opt_init, self.opt_error = optax.adam(lr_error)
self.opt_state_error = opt_init(self.params_error)
# Running mean of error.
self.rm_error_list = [jnp.array(init_error, dtype=jnp.float32) for _ in range(num_critics)]
Expand Down
6 changes: 3 additions & 3 deletions rljax/algorithm/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.algorithm.base_class import OnPolicyActorCritic
from rljax.network import ContinuousVFunction, StateIndependentGaussianPolicy
Expand Down Expand Up @@ -65,13 +65,13 @@ def fn_actor(s):
# Critic.
self.critic = hk.without_apply_rng(hk.transform(fn_critic))
self.params_critic = self.params_critic_target = self.critic.init(next(self.rng), *self.fake_args_critic)
opt_init, self.opt_critic = optix.adam(lr_critic)
opt_init, self.opt_critic = optax.adam(lr_critic)
self.opt_state_critic = opt_init(self.params_critic)

# Actor.
self.actor = hk.without_apply_rng(hk.transform(fn_actor))
self.params_actor = self.params_actor_target = self.actor.init(next(self.rng), *self.fake_args_actor)
opt_init, self.opt_actor = optix.adam(lr_actor)
opt_init, self.opt_actor = optax.adam(lr_actor)
self.opt_state_actor = opt_init(self.params_actor)

# Other parameters.
Expand Down
8 changes: 4 additions & 4 deletions rljax/algorithm/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.algorithm.base_class import OffPolicyActorCritic
from rljax.network import ContinuousQFunction, StateDependentGaussianPolicy
Expand Down Expand Up @@ -95,18 +95,18 @@ def fn_actor(s):
# Critic.
self.critic = hk.without_apply_rng(hk.transform(fn_critic))
self.params_critic = self.params_critic_target = self.critic.init(next(self.rng), *self.fake_args_critic)
opt_init, self.opt_critic = optix.adam(lr_critic)
opt_init, self.opt_critic = optax.adam(lr_critic)
self.opt_state_critic = opt_init(self.params_critic)
# Actor.
self.actor = hk.without_apply_rng(hk.transform(fn_actor))
self.params_actor = self.actor.init(next(self.rng), *self.fake_args_actor)
opt_init, self.opt_actor = optix.adam(lr_actor)
opt_init, self.opt_actor = optax.adam(lr_actor)
self.opt_state_actor = opt_init(self.params_actor)
# Entropy coefficient.
if not hasattr(self, "target_entropy"):
self.target_entropy = -float(self.action_space.shape[0])
self.log_alpha = jnp.array(np.log(init_alpha), dtype=jnp.float32)
opt_init, self.opt_alpha = optix.adam(lr_alpha, b1=adam_b1_alpha)
opt_init, self.opt_alpha = optax.adam(lr_alpha, b1=adam_b1_alpha)
self.opt_state_alpha = opt_init(self.log_alpha)

@partial(jax.jit, static_argnums=0)
Expand Down
6 changes: 3 additions & 3 deletions rljax/algorithm/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.algorithm.sac import SAC
from rljax.network import ContinuousQFunction, SACDecoder, SACEncoder, SACLinear, StateDependentGaussianPolicy
Expand Down Expand Up @@ -123,11 +123,11 @@ def fn_actor(x):
# Decoder.
self.decoder = hk.without_apply_rng(hk.transform(lambda x: SACDecoder(state_space, num_filters=32, num_layers=4)(x)))
self.params_decoder = self.decoder.init(next(self.rng), fake_feature)
opt_init, self.opt_ae = optix.adam(lr_ae)
opt_init, self.opt_ae = optax.adam(lr_ae)
self.opt_state_ae = opt_init(self.params_ae)

# Re-define the optimizer for critic.
opt_init, self.opt_critic = optix.adam(lr_critic)
opt_init, self.opt_critic = optax.adam(lr_critic)
self.opt_state_critic = opt_init(self.params_entire_critic)

# Other parameters.
Expand Down
4 changes: 2 additions & 2 deletions rljax/algorithm/slac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import optix
import optax

from rljax.algorithm.misc import SlacMixIn
from rljax.algorithm.sac import SAC
Expand Down Expand Up @@ -123,7 +123,7 @@ def fn_actor(x):
z2_dim=z2_dim,
feature_dim=feature_dim,
)
opt_init, self.opt_model = optix.adam(lr_model)
opt_init, self.opt_model = optax.adam(lr_model)
self.opt_state_model = opt_init(self.params_model)

@partial(jax.jit, static_argnums=0)
Expand Down
4 changes: 2 additions & 2 deletions rljax/util/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import haiku as hk
import jax
import jax.numpy as jnp
from jax.experimental import optix
import optax
from jax.tree_util import tree_flatten


Expand All @@ -26,7 +26,7 @@ def optimize(
if max_grad_norm is not None:
grad = clip_gradient_norm(grad, max_grad_norm)
update, opt_state = opt(grad, opt_state)
params_to_update = optix.apply_updates(params_to_update, update)
params_to_update = optax.apply_updates(params_to_update, update)
return opt_state, params_to_update, loss, aux


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def run(self):
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
],
Expand Down
4 changes: 2 additions & 2 deletions tests/util/test_optim.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import haiku as hk
import jax.numpy as jnp
import numpy as np
import optax
import pytest
from jax.experimental import optix

from rljax.util.optim import clip_gradient, clip_gradient_norm, optimize, soft_update, weight_decay

Expand All @@ -11,7 +11,7 @@
def test_optimize(lr, w, x):
net = hk.without_apply_rng(hk.transform(lambda x: hk.Linear(1, with_bias=False, w_init=hk.initializers.Constant(w))(x)))
params = net.init(next(hk.PRNGSequence(0)), jnp.zeros((1, 1)))
opt_init, opt = optix.sgd(lr)
opt_init, opt = optax.sgd(lr)
opt_state = opt_init(params)

def _loss(params, x):
Expand Down