Skip to content

Commit

Permalink
sym
Browse files Browse the repository at this point in the history
  • Loading branch information
symoon11 committed Jan 3, 2022
1 parent a3e4884 commit c21d8aa
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions lifelong_rl/trainers/q_learning/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,27 +190,20 @@ def train_from_torch(self, batch, indices):
qfs_loss_total = qfs_loss

if self.eta > 0:
qs_pred_grads = None
sample_size = min(qs_pred.size(0), actions.size(1))
indices = np.random.choice(qs_pred.size(0), size=sample_size, replace=False)
indices = torch.from_numpy(indices).long().to(ptu.device)

obs_tile = obs.unsqueeze(0).repeat(self.num_qs, 1, 1)
actions_tile = actions.unsqueeze(0).repeat(self.num_qs, 1, 1).requires_grad_(True)
qs_preds_tile = self.qfs(obs_tile, actions_tile)
qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True)
qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10)

qs_pred_grads = torch.index_select(qs_pred_grads, dim=0, index=indices).transpose(0, 1)
qs_pred_grads = qs_pred_grads.transpose(0, 1)

qs_pred_grads = torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
masks = torch.eye(sample_size, device=ptu.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1)
masks = torch.eye(self.num_qs, device=ptu.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1)
qs_pred_grads = (1 - masks) * qs_pred_grads
grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (sample_size - 1)
grad_loss = torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (self.num_qs - 1)

qfs_loss_total += self.eta * grad_loss


if self.use_automatic_entropy_tuning and not self.deterministic_backup:
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
Expand Down

0 comments on commit c21d8aa

Please sign in to comment.