Skip to content

Commit

Permalink
fix update bug
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Sep 10, 2021
1 parent d9612af commit 6d2cc9d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
19 changes: 7 additions & 12 deletions CDQL-SAC/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import math
import copy

# inspired by: https://github.com/takuseno/d3rlpy/blob/fd273504c49580ecb11930a330504fe78aee6fd6/d3rlpy/algos/torch/cql_impl.py#L175
# and https://github.com/polixir/OfflineRL/blob/master/offlinerl/algo/modelfree/cql.py


class CQLSAC(nn.Module):
"""Interacts with and learns from the environment."""
Expand Down Expand Up @@ -187,16 +184,14 @@ def learn(self, step, experiences, gamma, d=1):

random_values1 = self._compute_random_values(temp_states, random_actions, self.critic1).reshape(states.shape[0], num_repeat, 1)
random_values2 = self._compute_random_values(temp_states, random_actions, self.critic2).reshape(states.shape[0], num_repeat, 1)

current_pi_values1 = current_pi_values1.reshape(states.shape[0], num_repeat, 1)
current_pi_values2 = current_pi_values2.reshape(states.shape[0], num_repeat, 1)
next_pi_values1 = next_pi_values1.reshape(states.shape[0], num_repeat, 1)
next_pi_values2 = next_pi_values2.reshape(states.shape[0], num_repeat, 1)

q1_current_action = self.critic1.get_qvalues(temp_states, current_pi_values1).reshape(states.shape[0], num_repeat, 1)
q2_current_action = self.critic2.get_qvalues(temp_states, current_pi_values2).reshape(states.shape[0], num_repeat, 1)

q1_next_action = self.critic1.get_qvalues(temp_next_states, next_pi_values1).reshape(states.shape[0], num_repeat, 1)
q2_next_action = self.critic2.get_qvalues(temp_next_states, next_pi_values2).reshape(states.shape[0], num_repeat, 1)


cat_q1 = torch.cat([random_values1, q1_current_action, q1_next_action], 1)
cat_q2 = torch.cat([random_values2, q2_current_action, q2_next_action], 1)
cat_q1 = torch.cat([random_values1, current_pi_values1, next_pi_values1], 1)
cat_q2 = torch.cat([random_values2, current_pi_values2, next_pi_values2], 1)

assert cat_q1.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q1 instead has shape: {cat_q1.shape}"
assert cat_q2.shape == (states.shape[0], 3 * num_repeat, 1), f"cat_q2 instead has shape: {cat_q2.shape}"
Expand Down
3 changes: 0 additions & 3 deletions CQL-SAC/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import math
import copy

# inspired by: https://github.com/takuseno/d3rlpy/blob/fd273504c49580ecb11930a330504fe78aee6fd6/d3rlpy/algos/torch/cql_impl.py#L175
# and https://github.com/polixir/OfflineRL/blob/master/offlinerl/algo/modelfree/cql.py


class CQLSAC(nn.Module):
"""Interacts with and learns from the environment."""
Expand Down

0 comments on commit 6d2cc9d

Please sign in to comment.