Skip to content

Commit

Permalink
reduce code
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Oct 21, 2021
1 parent 18337eb commit cfbc71b
Showing 1 changed file with 0 additions and 25 deletions.
25 changes: 0 additions & 25 deletions CQL-SAC-discrete/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,6 @@ def calc_policy_loss(self, states, alpha):
actor_loss = (action_probs * (alpha.to(self.device) * log_pis - min_Q )).sum(1).mean()
log_action_pi = torch.sum(log_pis * action_probs, dim=1)
return actor_loss, log_action_pi

def _compute_policy_values(self, obs_pi, obs_q):
with torch.no_grad():
_, action_probs, log_pis = self.actor_local.evaluate(obs_pi)

qs1 = self.critic1(obs_q) * action_probs
qs2 = self.critic2(obs_q) * action_probs

return (qs1-log_pis).sum(1), (qs2-log_pis).sum(1)

def _compute_random_values(self, obs, random_action_probs, critic):
random_values = (critic(obs) * random_action_probs).sum(1)
random_log_probs = math.log(0.5 ** self.action_size)
return random_values - random_log_probs

def learn(self, step, experiences, gamma, d=1):
"""Updates actor, critics and entropy_alpha parameters using given batch of experience tuples.
Expand Down Expand Up @@ -138,17 +124,6 @@ def learn(self, step, experiences, gamma, d=1):
# ---------------------------- update critic ---------------------------- #
# Get predicted next-state actions and Q values from target models
with torch.no_grad():
# _, action_probs, log_pis = self.actor_local.evaluate(next_states)
# action_probs_ = action_probs.unsqueeze(1).repeat(1, 10, 1)
# print()
# temp_next_states = next_states.unsqueeze(1).repeat(1, 10, 1).view(next_states.shape[0] * 10, next_states.shape[1])
# Q_target1_next = (self.critic1_target(temp_next_states).view(states.shape[0], 10, self.action_size) * action_probs_).max(1)[0].view(-1, 1)
# Q_target2_next = (self.critic2_target(temp_next_states).view(states.shape[0], 10, self.action_size) * action_probs_).max(1)[0].view(-1, 1)
# Q_target_next = action_probs * (torch.min(Q_target1_next, Q_target2_next) - self.alpha.to(self.device) * log_pis)
# # Compute Q targets for current states (y_i)
# Q_targets = rewards + (gamma * (1 - dones) * Q_target_next.sum(dim=1).unsqueeze(-1))

#### old
_, action_probs, log_pis = self.actor_local.evaluate(next_states)
Q_target1_next = self.critic1_target(next_states)
Q_target2_next = self.critic2_target(next_states)
Expand Down

0 comments on commit cfbc71b

Please sign in to comment.