Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 24, 2020
1 parent 1f9ddd5 commit a73d146
Showing 1 changed file with 11 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def split_at_index(dim, index, t):
r = (*pre_slices, slice(index, None))
return t[l], t[r]

def queue_fifo(*args, length, dim=-2):
queue = torch.cat(args, dim=dim)
if length > 0:
return split_at_index(dim, -length, queue)

device = queue.device
shape = list(queue.shape)
shape[dim] = 0
return queue, torch.empty(shape, device = device)

def shift(x):
*_, i, j = x.shape
zero_pad = torch.zeros((*_, i, i), **to(x))
Expand Down Expand Up @@ -241,7 +251,7 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me

# calculate memory and compressed memory

old_mem, new_mem = split_at_index(1, -self.mem_len, torch.cat((mem, x), dim=1))
old_mem, new_mem = queue_fifo(mem, x, length = self.mem_len, dim = 1)
old_mem_padding = old_mem.shape[1] % self.cmem_ratio

if old_mem_padding != 0:
Expand Down

0 comments on commit a73d146

Please sign in to comment.