Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
batchnorm fixes + disabling batchnorm in DDPG (#353)
Browse files Browse the repository at this point in the history
Co-authored-by: James Casbon <[email protected]>
  • Loading branch information
Gal Leibovich and jamescasbon authored Jun 23, 2019
1 parent 7b5d6a3 commit d6795bd
Show file tree
Hide file tree
Showing 22 changed files with 104 additions and 49 deletions.
1 change: 1 addition & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ The environments that were used for testing include:
|**[Clipped PPO](clipped_ppo)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[DDPG](ddpg)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[SAC](sac)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[TD3](td3)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[NEC](nec)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | |
|**[HER](ddpg_her)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Fetch | |
|**[DFP](dfp)** | ![#ceffad](https://placehold.it/15/ceffad/000000?text=+) |Doom | Doom Battle was not verified |
Expand Down
Binary file modified benchmarks/ddpg/ant_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/half_cheetah_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/hopper_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/humanoid_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/inverted_double_pendulum_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/inverted_pendulum_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/reacher_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/swimmer_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmarks/ddpg/walker2d_ddpg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 15 additions & 12 deletions rl_coach/agents/ddpg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@


class DDPGCriticNetworkParameters(NetworkParameters):
def __init__(self):
def __init__(self, use_batchnorm=False):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True),
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm),
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
self.middleware_parameters = FCMiddlewareParameters()
self.heads_parameters = [DDPGVHeadParameters()]
Expand All @@ -53,11 +53,11 @@ def __init__(self):


class DDPGActorNetworkParameters(NetworkParameters):
def __init__(self):
def __init__(self, use_batchnorm=False):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)}
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
self.heads_parameters = [DDPGActorHeadParameters()]
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm)}
self.middleware_parameters = FCMiddlewareParameters(batchnorm=use_batchnorm)
self.heads_parameters = [DDPGActorHeadParameters(batchnorm=use_batchnorm)]
self.optimizer_type = 'Adam'
self.batch_size = 64
self.adam_optimizer_beta2 = 0.999
Expand Down Expand Up @@ -109,12 +109,12 @@ def __init__(self):


class DDPGAgentParameters(AgentParameters):
def __init__(self):
def __init__(self, use_batchnorm=False):
super().__init__(algorithm=DDPGAlgorithmParameters(),
exploration=OUProcessParameters(),
memory=EpisodicExperienceReplayParameters(),
networks=OrderedDict([("actor", DDPGActorNetworkParameters()),
("critic", DDPGCriticNetworkParameters())]))
networks=OrderedDict([("actor", DDPGActorNetworkParameters(use_batchnorm=use_batchnorm)),
("critic", DDPGCriticNetworkParameters(use_batchnorm=use_batchnorm))]))

@property
def path(self):
Expand Down Expand Up @@ -170,7 +170,9 @@ def learn_from_batch(self, batch):
# train the critic
critic_inputs = copy.copy(batch.states(critic_keys))
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
result = critic.train_and_sync_networks(critic_inputs, TD_targets)

# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
result = critic.train_and_sync_networks(critic_inputs, TD_targets, use_inputs_for_apply_gradients=True)
total_loss, losses, unclipped_grads = result[:3]

# apply the gradients from the critic to the actor
Expand All @@ -179,11 +181,12 @@ def learn_from_batch(self, batch):
outputs=actor.online_network.weighted_gradients[0],
initial_feed_dict=initial_feed_dict)

# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
if actor.has_global:
actor.apply_gradients_to_global_network(gradients)
actor.apply_gradients_to_global_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
actor.update_online_network()
else:
actor.apply_gradients_to_online_network(gradients)
actor.apply_gradients_to_online_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))

return total_loss, losses, unclipped_grads

Expand Down
3 changes: 2 additions & 1 deletion rl_coach/architectures/head_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
class HeadParameters(NetworkComponentParameters):
def __init__(self, parameterized_class_name: str, activation_function: str = 'relu', name: str= 'head',
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
loss_weight: float=1.0, dense_layer=None):
loss_weight: float=1.0, dense_layer=None, is_training=False):
super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function
self.name = name
self.num_output_head_copies = num_output_head_copies
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
self.loss_weight = loss_weight
self.parameterized_class_name = parameterized_class_name
self.is_training = is_training

@property
def path(self):
Expand Down
39 changes: 28 additions & 11 deletions rl_coach/architectures/network_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,31 +124,37 @@ def update_online_network(self, rate=1.0):
if self.global_network:
self.online_network.set_weights(self.global_network.get_weights(), rate)

def apply_gradients_to_global_network(self, gradients=None):
def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None):
"""
Apply gradients from the online network on the global network
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
:return:
"""
if gradients is None:
gradients = self.online_network.accumulated_gradients
if self.network_parameters.shared_optimizer:
self.global_network.apply_gradients(gradients)
self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs)
else:
self.online_network.apply_gradients(gradients)
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)

def apply_gradients_to_online_network(self, gradients=None):
def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None):
"""
Apply gradients from the online network on itself
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
:return:
"""
if gradients is None:
gradients = self.online_network.accumulated_gradients
self.online_network.apply_gradients(gradients)
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)

def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None,
use_inputs_for_apply_gradients=False):
"""
A generic training function that enables multi-threading training using a global network if necessary.
Expand All @@ -157,14 +163,20 @@ def train_and_sync_networks(self, inputs, targets, additional_fetches=[], import
:param additional_fetches: Any additional tensor the user wants to fetch
:param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss
error of this sample. If it is not given, the samples losses won't be scaled
:param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients
(e.g. for incorporating batchnorm update ops)
:return: The loss of the training iteration
"""
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
importance_weights=importance_weights, no_accumulation=True)
self.apply_gradients_and_sync_networks(reset_gradients=False)
if use_inputs_for_apply_gradients:
self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs)
else:
self.apply_gradients_and_sync_networks(reset_gradients=False)

return result

def apply_gradients_and_sync_networks(self, reset_gradients=True):
def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None):
"""
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary
Expand All @@ -173,17 +185,22 @@ def apply_gradients_and_sync_networks(self, reset_gradients=True):
the network. this is useful when the accumulated gradients are overwritten instead
if accumulated by the accumulate_gradients function. this allows reducing time
complexity for this function by around 10%
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
"""
if self.global_network:
self.apply_gradients_to_global_network()
self.apply_gradients_to_global_network(additional_inputs=additional_inputs)
if reset_gradients:
self.online_network.reset_accumulated_gradients()
self.update_online_network()
else:
if reset_gradients:
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients)
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients,
additional_inputs=additional_inputs)
else:
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
self.online_network.apply_gradients(self.online_network.accumulated_gradients,
additional_inputs=additional_inputs)

def parallel_prediction(self, network_input_tuples: List[Tuple]):
"""
Expand Down
21 changes: 16 additions & 5 deletions rl_coach/architectures/tensorflow_components/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,11 @@ def _create_gradient_applying_ops(self):
elif self.network_is_trainable:
# not any of the above but is trainable? -> create an operation for applying the gradients to
# this network weights
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name)

with tf.control_dependencies(update_ops):
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
zip(self.weights_placeholders, self.weights), global_step=self.global_step)

def set_session(self, sess):
self.sess = sess
Expand Down Expand Up @@ -414,13 +417,16 @@ def create_feed_dict(self, inputs):

return feed_dict

def apply_and_reset_gradients(self, gradients, scaler=1.):
def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None):
"""
Applies the given gradients to the network weights and resets the accumulation placeholder
:param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
"""
self.apply_gradients(gradients, scaler)
self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs)
self.reset_accumulated_gradients()

def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
Expand Down Expand Up @@ -460,13 +466,16 @@ def wait_for_all_workers_barrier(self, include_only_training_workers: bool=False
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
self.sess.run(self.release_init)

def apply_gradients(self, gradients, scaler=1.):
def apply_gradients(self, gradients, scaler=1., additional_inputs=None):
"""
Applies the given gradients to the network weights
:param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
The gradients will be MULTIPLIED by this factor
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
"""

if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
if hasattr(self, 'global_step') and not self.network_is_local:
self.sess.run(self.inc_step)
Expand Down Expand Up @@ -503,6 +512,8 @@ def apply_gradients(self, gradients, scaler=1.):
# async distributed training / distributed training with independent optimizer
# / non-distributed training - just apply the gradients
feed_dict = dict(zip(self.weights_placeholders, gradients))
if additional_inputs is not None:
feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)}
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)

# release barrier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def get_input_embedder(self, input_name: str, embedder_params: InputEmbedderPara

embedder_path = embedder_params.path(emb_type)
embedder_params_copy = copy.copy(embedder_params)
embedder_params_copy.is_training = self.is_training
embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function)
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
Expand All @@ -204,6 +205,7 @@ def get_middleware(self, middleware_params: MiddlewareParameters):
middleware_path = middleware_params.path
middleware_params_copy = copy.copy(middleware_params)
middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function)
middleware_params_copy.is_training = self.is_training
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
return module

Expand All @@ -218,6 +220,7 @@ def get_output_head(self, head_params: HeadParameters, head_idx: int):
head_path = head_params.path
head_params_copy = copy.copy(head_params)
head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function)
head_params_copy.is_training = self.is_training
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
'head_idx': head_idx, 'is_local': self.network_is_local})
Expand Down Expand Up @@ -339,7 +342,11 @@ def get_model(self) -> List:
head_count += 1

# model weights
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
if not self.distributed_training or self.network_is_global:
self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if
'global_step' not in var.name]
else:
self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)]

# Losses
self.losses = tf.losses.get_losses(self.full_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
class DDPGActor(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
batchnorm: bool=True, dense_layer=Dense):
batchnorm: bool=True, dense_layer=Dense, is_training=False):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer)
dense_layer=dense_layer, is_training=is_training)
self.name = 'ddpg_actor_head'
self.return_type = ActionProbabilities

Expand All @@ -50,7 +50,7 @@ def _build_module(self, input_layer):
batchnorm=self.batchnorm,
activation_function=self.activation_function,
dropout_rate=0,
is_training=False,
is_training=self.is_training,
name="BatchnormActivationDropout_0")[-1]
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')

Expand Down
3 changes: 2 additions & 1 deletion rl_coach/architectures/tensorflow_components/heads/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Head(object):
"""
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
dense_layer=Dense):
dense_layer=Dense, is_training=False):
self.head_idx = head_idx
self.network_name = network_name
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
Expand All @@ -64,6 +64,7 @@ def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition,
self.dense_layer = Dense
else:
self.dense_layer = convert_layer_class(self.dense_layer)
self.is_training = is_training

def __call__(self, input_layer):
"""
Expand Down
3 changes: 3 additions & 0 deletions rl_coach/architectures/tensorflow_components/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name):
layers = [input_layer]

# Rationale: passing a bool here will mean that batchnorm and or activation will never activate
assert not isinstance(is_training, bool)

# batchnorm
if batchnorm:
layers.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import tensorflow as tf

from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.core_types import Middleware_FC_Embedding
Expand Down
Loading

0 comments on commit d6795bd

Please sign in to comment.