diff --git a/benchmarks/README.md b/benchmarks/README.md index fe22f89a5..65e091eec 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -38,6 +38,7 @@ The environments that were used for testing include: |**[Clipped PPO](clipped_ppo)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | | |**[DDPG](ddpg)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | | |**[SAC](sac)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | | +|**[TD3](td3)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | | |**[NEC](nec)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | | |**[HER](ddpg_her)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Fetch | | |**[DFP](dfp)** | ![#ceffad](https://placehold.it/15/ceffad/000000?text=+) |Doom | Doom Battle was not verified | diff --git a/benchmarks/ddpg/ant_ddpg.png b/benchmarks/ddpg/ant_ddpg.png index 61678c120..b1c391a1d 100644 Binary files a/benchmarks/ddpg/ant_ddpg.png and b/benchmarks/ddpg/ant_ddpg.png differ diff --git a/benchmarks/ddpg/half_cheetah_ddpg.png b/benchmarks/ddpg/half_cheetah_ddpg.png index 9b6689f76..ed42a2524 100644 Binary files a/benchmarks/ddpg/half_cheetah_ddpg.png and b/benchmarks/ddpg/half_cheetah_ddpg.png differ diff --git a/benchmarks/ddpg/hopper_ddpg.png b/benchmarks/ddpg/hopper_ddpg.png index 18061be51..6bafe539d 100644 Binary files a/benchmarks/ddpg/hopper_ddpg.png and b/benchmarks/ddpg/hopper_ddpg.png differ diff --git a/benchmarks/ddpg/humanoid_ddpg.png b/benchmarks/ddpg/humanoid_ddpg.png index ba73d2fa2..30b7a5946 100644 Binary files a/benchmarks/ddpg/humanoid_ddpg.png and b/benchmarks/ddpg/humanoid_ddpg.png differ diff --git a/benchmarks/ddpg/inverted_double_pendulum_ddpg.png b/benchmarks/ddpg/inverted_double_pendulum_ddpg.png index 519da9e1b..508e67bdd 100644 Binary files a/benchmarks/ddpg/inverted_double_pendulum_ddpg.png and b/benchmarks/ddpg/inverted_double_pendulum_ddpg.png differ diff --git a/benchmarks/ddpg/inverted_pendulum_ddpg.png b/benchmarks/ddpg/inverted_pendulum_ddpg.png index bd064a8ac..78ccd5923 100644 Binary files a/benchmarks/ddpg/inverted_pendulum_ddpg.png and b/benchmarks/ddpg/inverted_pendulum_ddpg.png differ diff --git a/benchmarks/ddpg/reacher_ddpg.png b/benchmarks/ddpg/reacher_ddpg.png index 114d9cd83..a409ab751 100644 Binary files a/benchmarks/ddpg/reacher_ddpg.png and b/benchmarks/ddpg/reacher_ddpg.png differ diff --git a/benchmarks/ddpg/swimmer_ddpg.png b/benchmarks/ddpg/swimmer_ddpg.png index 3e04fd717..7e41528f9 100644 Binary files a/benchmarks/ddpg/swimmer_ddpg.png and b/benchmarks/ddpg/swimmer_ddpg.png differ diff --git a/benchmarks/ddpg/walker2d_ddpg.png b/benchmarks/ddpg/walker2d_ddpg.png index 50efd3c14..2734dd79c 100644 Binary files a/benchmarks/ddpg/walker2d_ddpg.png and b/benchmarks/ddpg/walker2d_ddpg.png differ diff --git a/rl_coach/agents/ddpg_agent.py b/rl_coach/agents/ddpg_agent.py index be5831991..dbf3821ef 100644 --- a/rl_coach/agents/ddpg_agent.py +++ b/rl_coach/agents/ddpg_agent.py @@ -34,9 +34,9 @@ class DDPGCriticNetworkParameters(NetworkParameters): - def __init__(self): + def __init__(self, use_batchnorm=False): super().__init__() - self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True), + self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm), 'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)} self.middleware_parameters = FCMiddlewareParameters() self.heads_parameters = [DDPGVHeadParameters()] @@ -53,11 +53,11 @@ def __init__(self): class DDPGActorNetworkParameters(NetworkParameters): - def __init__(self): + def __init__(self, use_batchnorm=False): super().__init__() - self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)} - self.middleware_parameters = FCMiddlewareParameters(batchnorm=True) - self.heads_parameters = [DDPGActorHeadParameters()] + self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm)} + self.middleware_parameters = FCMiddlewareParameters(batchnorm=use_batchnorm) + self.heads_parameters = [DDPGActorHeadParameters(batchnorm=use_batchnorm)] self.optimizer_type = 'Adam' self.batch_size = 64 self.adam_optimizer_beta2 = 0.999 @@ -109,12 +109,12 @@ def __init__(self): class DDPGAgentParameters(AgentParameters): - def __init__(self): + def __init__(self, use_batchnorm=False): super().__init__(algorithm=DDPGAlgorithmParameters(), exploration=OUProcessParameters(), memory=EpisodicExperienceReplayParameters(), - networks=OrderedDict([("actor", DDPGActorNetworkParameters()), - ("critic", DDPGCriticNetworkParameters())])) + networks=OrderedDict([("actor", DDPGActorNetworkParameters(use_batchnorm=use_batchnorm)), + ("critic", DDPGCriticNetworkParameters(use_batchnorm=use_batchnorm))])) @property def path(self): @@ -170,7 +170,9 @@ def learn_from_batch(self, batch): # train the critic critic_inputs = copy.copy(batch.states(critic_keys)) critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1) - result = critic.train_and_sync_networks(critic_inputs, TD_targets) + + # also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work + result = critic.train_and_sync_networks(critic_inputs, TD_targets, use_inputs_for_apply_gradients=True) total_loss, losses, unclipped_grads = result[:3] # apply the gradients from the critic to the actor @@ -179,11 +181,12 @@ def learn_from_batch(self, batch): outputs=actor.online_network.weighted_gradients[0], initial_feed_dict=initial_feed_dict) + # also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work if actor.has_global: - actor.apply_gradients_to_global_network(gradients) + actor.apply_gradients_to_global_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys))) actor.update_online_network() else: - actor.apply_gradients_to_online_network(gradients) + actor.apply_gradients_to_online_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys))) return total_loss, losses, unclipped_grads diff --git a/rl_coach/architectures/head_parameters.py b/rl_coach/architectures/head_parameters.py index 981251b7f..1c64b63af 100644 --- a/rl_coach/architectures/head_parameters.py +++ b/rl_coach/architectures/head_parameters.py @@ -22,7 +22,7 @@ class HeadParameters(NetworkComponentParameters): def __init__(self, parameterized_class_name: str, activation_function: str = 'relu', name: str= 'head', num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0, - loss_weight: float=1.0, dense_layer=None): + loss_weight: float=1.0, dense_layer=None, is_training=False): super().__init__(dense_layer=dense_layer) self.activation_function = activation_function self.name = name @@ -30,6 +30,7 @@ def __init__(self, parameterized_class_name: str, activation_function: str = 're self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor self.loss_weight = loss_weight self.parameterized_class_name = parameterized_class_name + self.is_training = is_training @property def path(self): diff --git a/rl_coach/architectures/network_wrapper.py b/rl_coach/architectures/network_wrapper.py index 26f2920c4..dfefc4122 100644 --- a/rl_coach/architectures/network_wrapper.py +++ b/rl_coach/architectures/network_wrapper.py @@ -124,31 +124,37 @@ def update_online_network(self, rate=1.0): if self.global_network: self.online_network.set_weights(self.global_network.get_weights(), rate) - def apply_gradients_to_global_network(self, gradients=None): + def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None): """ Apply gradients from the online network on the global network :param gradients: optional gradients that will be used instead of teh accumulated gradients + :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's + update ops also requires the inputs) :return: """ if gradients is None: gradients = self.online_network.accumulated_gradients if self.network_parameters.shared_optimizer: - self.global_network.apply_gradients(gradients) + self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs) else: - self.online_network.apply_gradients(gradients) + self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs) - def apply_gradients_to_online_network(self, gradients=None): + def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None): """ Apply gradients from the online network on itself + :param gradients: optional gradients that will be used instead of teh accumulated gradients + :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's + update ops also requires the inputs) :return: """ if gradients is None: gradients = self.online_network.accumulated_gradients - self.online_network.apply_gradients(gradients) + self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs) - def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None): + def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None, + use_inputs_for_apply_gradients=False): """ A generic training function that enables multi-threading training using a global network if necessary. @@ -157,14 +163,20 @@ def train_and_sync_networks(self, inputs, targets, additional_fetches=[], import :param additional_fetches: Any additional tensor the user wants to fetch :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss error of this sample. If it is not given, the samples losses won't be scaled + :param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients + (e.g. for incorporating batchnorm update ops) :return: The loss of the training iteration """ result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches, importance_weights=importance_weights, no_accumulation=True) - self.apply_gradients_and_sync_networks(reset_gradients=False) + if use_inputs_for_apply_gradients: + self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs) + else: + self.apply_gradients_and_sync_networks(reset_gradients=False) + return result - def apply_gradients_and_sync_networks(self, reset_gradients=True): + def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None): """ Applies the gradients accumulated in the online network to the global network or to itself and syncs the networks if necessary @@ -173,17 +185,22 @@ def apply_gradients_and_sync_networks(self, reset_gradients=True): the network. this is useful when the accumulated gradients are overwritten instead if accumulated by the accumulate_gradients function. this allows reducing time complexity for this function by around 10% + :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's + update ops also requires the inputs) + """ if self.global_network: - self.apply_gradients_to_global_network() + self.apply_gradients_to_global_network(additional_inputs=additional_inputs) if reset_gradients: self.online_network.reset_accumulated_gradients() self.update_online_network() else: if reset_gradients: - self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients) + self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients, + additional_inputs=additional_inputs) else: - self.online_network.apply_gradients(self.online_network.accumulated_gradients) + self.online_network.apply_gradients(self.online_network.accumulated_gradients, + additional_inputs=additional_inputs) def parallel_prediction(self, network_input_tuples: List[Tuple]): """ diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 907593658..68420febb 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -270,8 +270,11 @@ def _create_gradient_applying_ops(self): elif self.network_is_trainable: # not any of the above but is trainable? -> create an operation for applying the gradients to # this network weights - self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( - zip(self.weights_placeholders, self.weights), global_step=self.global_step) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name) + + with tf.control_dependencies(update_ops): + self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( + zip(self.weights_placeholders, self.weights), global_step=self.global_step) def set_session(self, sess): self.sess = sess @@ -414,13 +417,16 @@ def create_feed_dict(self, inputs): return feed_dict - def apply_and_reset_gradients(self, gradients, scaler=1.): + def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None): """ Applies the given gradients to the network weights and resets the accumulation placeholder :param gradients: The gradients to use for the update :param scaler: A scaling factor that allows rescaling the gradients before applying them + :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's + update ops also requires the inputs) + """ - self.apply_gradients(gradients, scaler) + self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs) self.reset_accumulated_gradients() def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False): @@ -460,13 +466,16 @@ def wait_for_all_workers_barrier(self, include_only_training_workers: bool=False self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers) self.sess.run(self.release_init) - def apply_gradients(self, gradients, scaler=1.): + def apply_gradients(self, gradients, scaler=1., additional_inputs=None): """ Applies the given gradients to the network weights :param gradients: The gradients to use for the update :param scaler: A scaling factor that allows rescaling the gradients before applying them. The gradients will be MULTIPLIED by this factor + :param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's + update ops also requires the inputs) """ + if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters): if hasattr(self, 'global_step') and not self.network_is_local: self.sess.run(self.inc_step) @@ -503,6 +512,8 @@ def apply_gradients(self, gradients, scaler=1.): # async distributed training / distributed training with independent optimizer # / non-distributed training - just apply the gradients feed_dict = dict(zip(self.weights_placeholders, gradients)) + if additional_inputs is not None: + feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)} self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict) # release barrier diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index 01036590d..8821ac6cc 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -185,6 +185,7 @@ def get_input_embedder(self, input_name: str, embedder_params: InputEmbedderPara embedder_path = embedder_params.path(emb_type) embedder_params_copy = copy.copy(embedder_params) + embedder_params_copy.is_training = self.is_training embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function) embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type] embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type] @@ -204,6 +205,7 @@ def get_middleware(self, middleware_params: MiddlewareParameters): middleware_path = middleware_params.path middleware_params_copy = copy.copy(middleware_params) middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function) + middleware_params_copy.is_training = self.is_training module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path) return module @@ -218,6 +220,7 @@ def get_output_head(self, head_params: HeadParameters, head_idx: int): head_path = head_params.path head_params_copy = copy.copy(head_params) head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function) + head_params_copy.is_training = self.is_training return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={ 'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name, 'head_idx': head_idx, 'is_local': self.network_is_local}) @@ -339,7 +342,11 @@ def get_model(self) -> List: head_count += 1 # model weights - self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name) + if not self.distributed_training or self.network_is_global: + self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if + 'global_step' not in var.name] + else: + self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)] # Losses self.losses = tf.losses.get_losses(self.full_name) diff --git a/rl_coach/architectures/tensorflow_components/heads/ddpg_actor_head.py b/rl_coach/architectures/tensorflow_components/heads/ddpg_actor_head.py index 45545b45c..d17353c1a 100644 --- a/rl_coach/architectures/tensorflow_components/heads/ddpg_actor_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/ddpg_actor_head.py @@ -26,9 +26,9 @@ class DDPGActor(Head): def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh', - batchnorm: bool=True, dense_layer=Dense): + batchnorm: bool=True, dense_layer=Dense, is_training=False): super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, - dense_layer=dense_layer) + dense_layer=dense_layer, is_training=is_training) self.name = 'ddpg_actor_head' self.return_type = ActionProbabilities @@ -50,7 +50,7 @@ def _build_module(self, input_layer): batchnorm=self.batchnorm, activation_function=self.activation_function, dropout_rate=0, - is_training=False, + is_training=self.is_training, name="BatchnormActivationDropout_0")[-1] self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean') diff --git a/rl_coach/architectures/tensorflow_components/heads/head.py b/rl_coach/architectures/tensorflow_components/heads/head.py index d99744240..e971889e9 100644 --- a/rl_coach/architectures/tensorflow_components/heads/head.py +++ b/rl_coach/architectures/tensorflow_components/heads/head.py @@ -40,7 +40,7 @@ class Head(object): """ def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu', - dense_layer=Dense): + dense_layer=Dense, is_training=False): self.head_idx = head_idx self.network_name = network_name self.network_parameters = agent_parameters.network_wrappers[self.network_name] @@ -64,6 +64,7 @@ def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, self.dense_layer = Dense else: self.dense_layer = convert_layer_class(self.dense_layer) + self.is_training = is_training def __call__(self, input_layer): """ diff --git a/rl_coach/architectures/tensorflow_components/layers.py b/rl_coach/architectures/tensorflow_components/layers.py index 81a1992c7..eb6326234 100644 --- a/rl_coach/architectures/tensorflow_components/layers.py +++ b/rl_coach/architectures/tensorflow_components/layers.py @@ -26,6 +26,9 @@ def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name): layers = [input_layer] + # Rationale: passing a bool here will mean that batchnorm and or activation will never activate + assert not isinstance(is_training, bool) + # batchnorm if batchnorm: layers.append( diff --git a/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py b/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py index 61d340b76..4361e171d 100644 --- a/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py +++ b/rl_coach/architectures/tensorflow_components/middlewares/fc_middleware.py @@ -17,7 +17,7 @@ import tensorflow as tf -from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense +from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware from rl_coach.base_parameters import MiddlewareScheme from rl_coach.core_types import Middleware_FC_Embedding diff --git a/rl_coach/architectures/tensorflow_components/middlewares/lstm_middleware.py b/rl_coach/architectures/tensorflow_components/middlewares/lstm_middleware.py index 6b7f97d8e..6ca9cd7dd 100644 --- a/rl_coach/architectures/tensorflow_components/middlewares/lstm_middleware.py +++ b/rl_coach/architectures/tensorflow_components/middlewares/lstm_middleware.py @@ -18,7 +18,7 @@ import numpy as np import tensorflow as tf -from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense +from rl_coach.architectures.tensorflow_components.layers import Dense from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware from rl_coach.base_parameters import MiddlewareScheme from rl_coach.core_types import Middleware_LSTM_Embedding diff --git a/rl_coach/tests/architectures/tensorflow_components/embedders/test_image_embedder.py b/rl_coach/tests/architectures/tensorflow_components/embedders/test_image_embedder.py index c4fa08b25..65076d114 100644 --- a/rl_coach/tests/architectures/tensorflow_components/embedders/test_image_embedder.py +++ b/rl_coach/tests/architectures/tensorflow_components/embedders/test_image_embedder.py @@ -25,17 +25,20 @@ def test_embedder(reset): with pytest.raises(ValueError): embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test") + + is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) + pre_ops = len(tf.get_default_graph().get_operations()) # creating a simple image embedder - embedder = ImageEmbedder(np.array([100, 100, 10]), name="test") + embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", is_training=is_training) - # make sure the ops where not created yet - assert len(tf.get_default_graph().get_operations()) == 0 + # make sure the only the is_training op is creates + assert len(tf.get_default_graph().get_operations()) == pre_ops # call the embedder input_ph, output_ph = embedder() # make sure that now the ops were created - assert len(tf.get_default_graph().get_operations()) > 0 + assert len(tf.get_default_graph().get_operations()) > pre_ops # try feeding a batch of one example input = np.random.rand(1, 100, 100, 10) @@ -55,7 +58,9 @@ def test_embedder(reset): @pytest.mark.unit_test def test_complex_embedder(reset): # creating a deep vector embedder - embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep) + is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) + embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep, + is_training=is_training) # call the embedder embedder() @@ -71,8 +76,9 @@ def test_complex_embedder(reset): @pytest.mark.unit_test def test_activation_function(reset): # creating a deep image embedder with relu + is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep, - activation_function=tf.nn.relu) + activation_function=tf.nn.relu, is_training=is_training) # call the embedder embedder() @@ -86,7 +92,7 @@ def test_activation_function(reset): # creating a deep image embedder with tanh embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep, - activation_function=tf.nn.tanh) + activation_function=tf.nn.tanh, is_training=is_training) # call the embedder embedder_tanh() diff --git a/rl_coach/tests/architectures/tensorflow_components/embedders/test_vector_embedder.py b/rl_coach/tests/architectures/tensorflow_components/embedders/test_vector_embedder.py index 4ca436986..73482f918 100644 --- a/rl_coach/tests/architectures/tensorflow_components/embedders/test_vector_embedder.py +++ b/rl_coach/tests/architectures/tensorflow_components/embedders/test_vector_embedder.py @@ -22,16 +22,19 @@ def test_embedder(reset): embedder = VectorEmbedder(np.array([10, 10]), name="test") # creating a simple vector embedder - embedder = VectorEmbedder(np.array([10]), name="test") + is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) + pre_ops = len(tf.get_default_graph().get_operations()) + + embedder = VectorEmbedder(np.array([10]), name="test", is_training=is_training) # make sure the ops where not created yet - assert len(tf.get_default_graph().get_operations()) == 0 + assert len(tf.get_default_graph().get_operations()) == pre_ops # call the embedder input_ph, output_ph = embedder() # make sure that now the ops were created - assert len(tf.get_default_graph().get_operations()) > 0 + assert len(tf.get_default_graph().get_operations()) > pre_ops # try feeding a batch of one example input = np.random.rand(1, 10) @@ -51,7 +54,8 @@ def test_embedder(reset): @pytest.mark.unit_test def test_complex_embedder(reset): # creating a deep vector embedder - embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep) + is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) + embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep, is_training=is_training) # call the embedder embedder() @@ -67,8 +71,9 @@ def test_complex_embedder(reset): @pytest.mark.unit_test def test_activation_function(reset): # creating a deep vector embedder with relu + is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep, - activation_function=tf.nn.relu) + activation_function=tf.nn.relu, is_training=is_training) # call the embedder embedder() @@ -82,7 +87,7 @@ def test_activation_function(reset): # creating a deep vector embedder with tanh embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep, - activation_function=tf.nn.tanh) + activation_function=tf.nn.tanh, is_training=is_training) # call the embedder embedder_tanh()