diff --git a/agent/proto.py b/agent/proto.py index c59087d..8598169 100644 --- a/agent/proto.py +++ b/agent/proto.py @@ -109,7 +109,7 @@ def compute_intr_reward(self, obs, step): # enqueue candidates ptr = self.queue_ptr self.queue[ptr:ptr + self.num_protos] = z[candidates] - queue_ptr = (ptr + self.num_protos) % self.queue.shape[0] + self.queue_ptr = (ptr + self.num_protos) % self.queue.shape[0] # compute distances between the batch and the queue of candidates z_to_q = torch.norm(z[:, None, :] - self.queue[None, :, :], dim=2, p=2)