Skip to content

Commit

Permalink
Switch flax.optim for optax in mlp_agent. Apply pyink formatting.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 496415945
  • Loading branch information
Joshua Greaves committed Dec 19, 2022
1 parent e47eecc commit 72082fe
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 114 deletions.
99 changes: 56 additions & 43 deletions balloon_learning_environment/agents/mlp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyformat: mode=pyink
r"""A simple multi-layer perceptron (MLP) agent.
This agent learns a mapping from states to Q-values using simple SARSA updates.
Expand All @@ -34,93 +35,105 @@
from absl import logging
from balloon_learning_environment.agents import agent
from balloon_learning_environment.agents import networks
from flax import optim
import gin
import jax
import jax.numpy as jnp
import numpy as np
import optax


@gin.configurable
def create_optimizer(learning_rate: float = 0.001):
def create_optimizer(
learning_rate: float = 0.001,
) -> optax.GradientTransformation:
"""Create an SGD optimizer for training."""
return optim.GradientDescent(learning_rate=learning_rate)
return optax.sgd(learning_rate=learning_rate)


@functools.partial(jax.jit, static_argnums=0)
def select_action(network_def: Any,
network_params: np.ndarray,
state: Any) -> int:
def select_action(
network_def: Any, network_params: np.ndarray, state: Any
) -> int:
"""Select an action greedily from network."""
return jnp.argmax(network_def.apply(network_params, state))


@functools.partial(jax.jit, static_argnums=(0, 7))
def train(network_def: Any,
optimizer: Any,
state: Any,
action: int,
reward: float,
next_state: Any,
next_action: int,
gamma: float):
@functools.partial(jax.jit, static_argnames=('network_def', 'optimizer'))
def train(
network_def: Any,
network_params: Any,
optimizer: optax.GradientTransformation,
optimizer_state: optax.OptState,
state: Any,
action: int,
reward: float,
next_state: Any,
next_action: int,
gamma: float,
):
"""Run a single SARSA-update."""

def loss_fn(params):
q_val = network_def.apply(params, state)[action]
next_action_val = network_def.apply(params, next_state)[next_action]
target = reward + gamma * next_action_val
return (q_val - target)**2
return (q_val - target) ** 2

grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return loss, optimizer
loss, grad = grad_fn(network_params)
updates, optimizer_state = optimizer.update(
grad, optimizer_state, network_params
)
network_params = optax.apply_updates(network_params, updates)
return loss, network_params, optimizer_state


@gin.configurable
class MLPAgent(agent.Agent):
"""An agent using a simple MLP network."""

def __init__(self, num_actions: int, observation_shape: Sequence[int],
gamma: float = 0.9,
seed: Union[int, None] = None):
def __init__(
self,
num_actions: int,
observation_shape: Sequence[int],
gamma: float = 0.9,
seed: Union[int, None] = None,
):
super().__init__(num_actions, observation_shape)
self._gamma = gamma
seed = int(time.time() * 1e6) if seed is None else seed
rng = jax.random.PRNGKey(seed)
self.network_def = networks.MLPNetwork(num_actions=(self._num_actions))
example_state = jnp.zeros(observation_shape)
network_params = self.network_def.init(rng, example_state)
optimizer_def = create_optimizer()
self.optimizer = optimizer_def.create(network_params)

self._mode = agent.AgentMode('train')
self.network_params = self.network_def.init(rng, example_state)
self.optimizer = create_optimizer()
self.optimizer_state = self.optimizer.init(self.network_params)

@property
def network(self):
return self.optimizer.target
self._mode = agent.AgentMode('train')

def begin_episode(self, observation: np.ndarray) -> int:
action = select_action(self.network_def,
self.network,
observation)
action = select_action(self.network_def, self.network_params, observation)
self.last_state = observation
self.last_action = action
return action

def step(self, reward: float, observation: np.ndarray) -> int:
action = select_action(self.network_def,
self.network,
observation)
action = select_action(self.network_def, self.network_params, observation)

if self._mode == agent.AgentMode.TRAIN:
loss, self.optimizer = train(self.network_def,
self.optimizer,
self.last_state,
self.last_action,
reward,
observation,
action,
self._gamma)
loss, self.network_params, self.optimizer_state = train(
self.network_def,
self.network_params,
self.optimizer,
self.optimizer_state,
self.last_state,
self.last_action,
reward,
observation,
action,
self._gamma,
)
logging.info('Loss: %f', loss)

self.last_state = observation
Expand Down
152 changes: 82 additions & 70 deletions balloon_learning_environment/agents/mlp_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pyformat: mode=pyink
"""Tests for balloon_learning_environment.agents.mlp_agent."""

from absl.testing import absltest
from balloon_learning_environment.agents import agent as base_agent
from balloon_learning_environment.agents import agent_registry
from balloon_learning_environment.agents import mlp_agent
from balloon_learning_environment.agents import networks
import flax
import gin
import jax
import jax.numpy as jnp
import optax


class MLPAgentTest(absltest.TestCase):
Expand All @@ -40,88 +41,97 @@ def _create_network(self):

def test_select_action(self):
self._create_network()
network_params = self._network_def.init(jax.random.PRNGKey(0),
self._example_state)
network_params = self._network_def.init(
jax.random.PRNGKey(0), self._example_state
)
# A state of all zeros will produce all-zero Q-values, which will result in
# the argmax always selecting action 0.
zeros_state = jnp.zeros_like(self._example_state)
self.assertEqual(0, mlp_agent.select_action(self._network_def,
network_params,
zeros_state))
self.assertEqual(
0,
mlp_agent.select_action(self._network_def, network_params, zeros_state),
)
# Because we are using a fixed seed we can deterministically guarantee that
# a state of all ones will pick action 2.
ones_state = jnp.ones_like(self._example_state)
self.assertEqual(2, mlp_agent.select_action(self._network_def,
network_params,
ones_state))
self.assertEqual(
2,
mlp_agent.select_action(self._network_def, network_params, ones_state),
)

def test_create_optimizer(self):
optim = mlp_agent.create_optimizer()
self.assertIsInstance(optim, flax.optim.sgd.GradientDescent)
self.assertEqual(0.001, optim.hyper_params.learning_rate)
lr = 0.5
gin.bind_parameter('mlp_agent.create_optimizer.learning_rate', lr)
optim = mlp_agent.create_optimizer()
self.assertEqual(lr, optim.hyper_params.learning_rate)
self.assertIsInstance(optim, optax.GradientTransformation)

def test_train(self):
self._create_network()
network_params = self._network_def.init(jax.random.PRNGKey(0),
self._example_state)
optim_def = mlp_agent.create_optimizer()
optim = optim_def.create(network_params)
network_params_before = self._network_def.init(
jax.random.PRNGKey(0), self._example_state
)
optim = mlp_agent.create_optimizer()
optim_state_before = optim.init(network_params_before)

# An all-zeros state will produce all-zeros Q values.
state = jnp.zeros_like(self._example_state)
# An all-ones next_state will produce non-zero Q-values, leading to non-zero
# temporal difference and a non-zero gradient. This will in turn change the
# optimizer's target.
next_state = jnp.ones_like(self._example_state)
loss, new_optim = mlp_agent.train(self._network_def,
optim,
state,
0, # action
0., # reward
next_state,
0, # next action
0.9) # gamma
self.assertGreater(loss, 0.) # Non-zero loss.
loss, network_params_after, _ = mlp_agent.train(
self._network_def,
network_params_before,
optim,
optim_state_before,
state,
0, # action
0.0, # reward
next_state,
0, # next action
0.9,
) # gamma

self.assertGreater(loss, 0.0) # Non-zero loss.
# Optimizer target will have changed.
self.assertFalse(jnp.array_equal(
optim.target['params']['Dense_0']['kernel'],
new_optim.target['params']['Dense_0']['kernel']))

def test_agent_defaults(self):
agent = mlp_agent.MLPAgent(self._num_actions, self._observation_shape)
self.assertEqual(0.9, agent._gamma)
# Create an optimizer with a fixed seed.
self._create_network()
network_params = self._network_def.init(jax.random.PRNGKey(0),
self._example_state)
optim_def = mlp_agent.create_optimizer()
optim = optim_def.create(network_params)
self.assertFalse(
jnp.array_equal(
network_params_before['params']['Dense_0']['kernel'],
network_params_after['params']['Dense_0']['kernel'],
)
)

def test_agent_initialized_parameters_randomly_if_no_seed_specified(self):
agent1 = mlp_agent.MLPAgent(self._num_actions, self._observation_shape)
agent2 = mlp_agent.MLPAgent(self._num_actions, self._observation_shape)

# Because we did not specify a seed to the agent, it will use one based on
# time, which will not match the optimizer we created with a fixed seed.
self.assertFalse(jnp.array_equal(
optim.target['params']['Dense_0']['kernel'],
agent.optimizer.target['params']['Dense_0']['kernel']))

def test_agent_network(self):
agent = mlp_agent.MLPAgent(self._num_actions, self._observation_shape,
gamma=0.99, seed=0)
self._create_network()
network_params = self._network_def.init(jax.random.PRNGKey(0),
self._example_state)
optim_def = mlp_agent.create_optimizer()
optim = optim_def.create(network_params)
# Because we specified a seed to the agent, it will match the one we created
# here.
self.assertTrue(jnp.array_equal(
optim.target['params']['Dense_0']['kernel'],
agent.network['params']['Dense_0']['kernel']))
self.assertFalse(
jnp.array_equal(
agent1.network_params['params']['Dense_0']['kernel'],
agent2.network_params['params']['Dense_0']['kernel'],
)
)

def test_agent_generates_parameters_deterministically_if_seeded(self):
agent1 = mlp_agent.MLPAgent(
self._num_actions, self._observation_shape, gamma=0.99, seed=0
)
agent2 = mlp_agent.MLPAgent(
self._num_actions, self._observation_shape, gamma=0.99, seed=0
)

# Because we specified a seed to the agent, the parameters should match.
self.assertTrue(
jnp.array_equal(
agent1.network_params['params']['Dense_0']['kernel'],
agent2.network_params['params']['Dense_0']['kernel'],
)
)

def test_begin_episode(self):
agent = mlp_agent.MLPAgent(self._num_actions, self._observation_shape,
gamma=0.99, seed=0)
agent = mlp_agent.MLPAgent(
self._num_actions, self._observation_shape, gamma=0.99, seed=0
)
# An all-zeros state will produce all-zeros Q values, which will result in
# action 0 selected by the argmax.
action = agent.begin_episode(jnp.zeros_like(self._example_state))
Expand All @@ -132,34 +142,36 @@ def test_begin_episode(self):
self.assertEqual(2, action)

def test_step_and_end_episode(self):
agent = mlp_agent.MLPAgent(self._num_actions, self._observation_shape,
gamma=0.99, seed=0)
agent = mlp_agent.MLPAgent(
self._num_actions, self._observation_shape, gamma=0.99, seed=0
)
# Calling step before begin_episode raises an error.
with self.assertRaises(AttributeError):
_ = agent.step(0., jnp.zeros_like(self._example_state))
_ = agent.step(0.0, jnp.zeros_like(self._example_state))
# Call begin_episode to avoid errors.
_ = agent.begin_episode(jnp.zeros_like(self._example_state))
# An all-zeros state will produce all-zeros Q values, which will result in
# action 0 selected by the argmax.
action = agent.step(0., jnp.zeros_like(self._example_state))
action = agent.step(0.0, jnp.zeros_like(self._example_state))
self.assertEqual(0, action)
# Because we are using a fixed seed we can deterministically guarantee that
# a state of all ones will pick action 2.
action = agent.step(0., jnp.ones_like(self._example_state))
action = agent.step(0.0, jnp.ones_like(self._example_state))
self.assertEqual(2, action)
# end_episode doesn't do anything (it exists to conform to the Agent
# interface). This next line just checks that it runs without problems.
agent.end_episode(0., True)
agent.end_episode(0.0, True)

def test_agent_does_not_train_in_eval_mode(self):
agent = mlp_agent.MLPAgent(self._num_actions, self._observation_shape,
gamma=0.99, seed=0)
agent = mlp_agent.MLPAgent(
self._num_actions, self._observation_shape, gamma=0.99, seed=0
)
agent.set_mode(base_agent.AgentMode.EVAL)

params_before = agent.optimizer.target
params_before = agent.network_params
agent.begin_episode(jnp.ones(self._observation_shape, dtype=jnp.float32))
agent.step(1.0, jnp.ones(self._observation_shape, dtype=jnp.float32))
params_after = agent.optimizer.target
params_after = agent.network_params

self.assertEqual(params_before, params_after)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from balloon_learning_environment.generative import dataset_wind_field_reservoir
from balloon_learning_environment.generative import vae
from balloon_learning_environment.utils import wind
from flax import optim
from flax import optim # TODO(joshgreaves): Switch to optax.
from flax.metrics import tensorboard
from flax.training import checkpoints
import gin
Expand Down

0 comments on commit 72082fe

Please sign in to comment.