Skip to content

Commit

Permalink
fix bug in SAC
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokarev-TT-33 authored May 19, 2020
1 parent 0840a82 commit 3fb5210
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions rlzoo/algorithms/sac/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def classic_control(env, default_seed=True):
with tf.name_scope('Policy'):
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
hidden_dim_list=num_hidden_layer * [hidden_dim],
output_activation=None,
state_conditioned=True)
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
alg_params['net_list'] = net_list
Expand Down Expand Up @@ -110,6 +111,7 @@ def box2d(env, default_seed=True):
with tf.name_scope('Policy'):
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
hidden_dim_list=num_hidden_layer * [hidden_dim],
output_activation=None,
state_conditioned=True)
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
alg_params['net_list'] = net_list
Expand Down Expand Up @@ -165,6 +167,7 @@ def mujoco(env, default_seed=True):
with tf.name_scope('Policy'):
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
hidden_dim_list=num_hidden_layer * [hidden_dim],
output_activation=None,
state_conditioned=True)
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
alg_params['net_list'] = net_list
Expand Down Expand Up @@ -220,6 +223,7 @@ def robotics(env, default_seed=True):
with tf.name_scope('Policy'):
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
hidden_dim_list=num_hidden_layer * [hidden_dim],
output_activation=None,
state_conditioned=True)
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
alg_params['net_list'] = net_list
Expand Down Expand Up @@ -275,6 +279,7 @@ def dm_control(env, default_seed=True):
with tf.name_scope('Policy'):
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
hidden_dim_list=num_hidden_layer * [hidden_dim],
output_activation=None,
state_conditioned=True)
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
alg_params['net_list'] = net_list
Expand Down Expand Up @@ -330,6 +335,7 @@ def rlbench(env, default_seed=True):
with tf.name_scope('Policy'):
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
hidden_dim_list=num_hidden_layer * [hidden_dim],
output_activation=None,
state_conditioned=True)
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
alg_params['net_list'] = net_list
Expand Down

0 comments on commit 3fb5210

Please sign in to comment.