diff --git a/stable_baselines/common/policies.py b/stable_baselines/common/policies.py index a2f3cc7f1c..93bdd57120 100644 --- a/stable_baselines/common/policies.py +++ b/stable_baselines/common/policies.py @@ -435,11 +435,13 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256 else: # Use the new net_arch parameter if layers is not None: warnings.warn("The new net_arch parameter overrides the deprecated layers parameter.") - if feature_extraction == "cnn": - raise NotImplementedError() with tf.variable_scope("model", reuse=reuse): - latent = tf.layers.flatten(self.processed_obs) + if feature_extraction == "cnn": + latent = cnn_extractor(self.processed_obs, **kwargs) + latent = tf.layers.flatten(latent) + else: + latent = tf.layers.flatten(self.processed_obs) policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network value_only_layers = [] # Layer sizes of the network that only belongs to the value network diff --git a/tests/test_cnn_lstm_policy.py b/tests/test_cnn_lstm_policy.py new file mode 100644 index 0000000000..58cfef3ca3 --- /dev/null +++ b/tests/test_cnn_lstm_policy.py @@ -0,0 +1,71 @@ +import os + +import numpy as np +import tensorflow as tf +import pytest +from gym import make +from gym.wrappers.time_limit import TimeLimit + +from stable_baselines.ppo2 import PPO2 +from stable_baselines.common.policies import CnnLstmPolicy +from stable_baselines.common.evaluation import evaluate_policy +from stable_baselines.common.tf_layers import conv, linear, conv_to_fc + + +def custom_cnn_extractor(input_images): + activ = tf.nn.relu + layer_1 = activ(conv(input_images, 'c1', n_filters=8, filter_size=3, stride=1, init_scale=np.sqrt(2))) + layer_2 = activ(conv(layer_1, 'c2', n_filters=8, filter_size=3, stride=1, init_scale=np.sqrt(2))) + layer_2 = conv_to_fc(layer_2) + return activ(linear(layer_2, 'fc1', n_hidden=256, init_scale=np.sqrt(2))) + + +class CustomCnnLstmPolicy1(CnnLstmPolicy): + def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=32, reuse=False, **_kwargs): + super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse, + cnn_extractor=custom_cnn_extractor, **_kwargs) + + +class CustomCnnLstmPolicy2(CnnLstmPolicy): + def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=32, reuse=False, **_kwargs): + super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse, + net_arch=['lstm', 8], cnn_extractor=custom_cnn_extractor, **_kwargs) + + +class CustomCnnLstmPolicy3(CnnLstmPolicy): + def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=32, reuse=False, **_kwargs): + super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse, + net_arch=[8, 'lstm', dict(vf=[5, 10], pi=[10])], + cnn_extractor=custom_cnn_extractor, **_kwargs) + + +POLICIES = [CnnLstmPolicy, CustomCnnLstmPolicy1, CustomCnnLstmPolicy2, CustomCnnLstmPolicy3] + + +def make_env(i): + env = make("Breakout-v0") + env = TimeLimit(env, max_episode_steps=20) + env.seed(i) + return env + + +@pytest.mark.parametrize("policy", POLICIES) +@pytest.mark.expensive +def test_cnn_lstm_policy(request, policy): + model_fname = './test_model_{}.zip'.format(request.node.name) + + try: + env = make_env(0) + model = PPO2(policy, env, nminibatches=1) + model.learn(total_timesteps=15) + env = model.get_env() + evaluate_policy(model, env, n_eval_episodes=5) + # saving + model.save(model_fname) + del model, env + # loading + _ = PPO2.load(model_fname, policy=policy) + + finally: + if os.path.exists(model_fname): + os.remove(model_fname)