From ce36aab6528b26a699f5f1cefd330fdaf23a5d72 Mon Sep 17 00:00:00 2001 From: Pablo Samuel Castro Date: Thu, 29 Jun 2023 10:09:12 -0400 Subject: [PATCH] Project import generated by Copybara. PiperOrigin-RevId: 544346317 --- dopamine/jax/agents/dqn/dqn_agent.py | 3 ++- dopamine/jax/agents/rainbow/rainbow_agent.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dopamine/jax/agents/dqn/dqn_agent.py b/dopamine/jax/agents/dqn/dqn_agent.py index c8a6623a..34398b15 100644 --- a/dopamine/jax/agents/dqn/dqn_agent.py +++ b/dopamine/jax/agents/dqn/dqn_agent.py @@ -342,7 +342,8 @@ def __init__(self, def _build_networks_and_optimizer(self): self._rng, rng = jax.random.split(self._rng) - self.online_params = self.network_def.init(rng, x=self.state) + state = self.preprocess_fn(self.state) + self.online_params = self.network_def.init(rng, x=state) self.optimizer = create_optimizer(self._optimizer_name) self.optimizer_state = self.optimizer.init(self.online_params) self.target_network_params = self.online_params diff --git a/dopamine/jax/agents/rainbow/rainbow_agent.py b/dopamine/jax/agents/rainbow/rainbow_agent.py index 63f2223e..42d689ae 100644 --- a/dopamine/jax/agents/rainbow/rainbow_agent.py +++ b/dopamine/jax/agents/rainbow/rainbow_agent.py @@ -277,7 +277,8 @@ def __init__(self, def _build_networks_and_optimizer(self): self._rng, rng = jax.random.split(self._rng) - self.online_params = self.network_def.init(rng, x=self.state, + state = self.preprocess_fn(self.state) + self.online_params = self.network_def.init(rng, x=state, support=self._support) self.optimizer = dqn_agent.create_optimizer(self._optimizer_name) self.optimizer_state = self.optimizer.init(self.online_params) @@ -316,7 +317,7 @@ def begin_episode(self, observation): self._rng, self.action = select_action(self.network_def, self.online_params, - self.state, + self.preprocess_fn(self.state), self._rng, self.num_actions, self.eval_mode,