Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 4, 2020
1 parent 57080b8 commit 41d810e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me

if input_mask is not None:
mask = input_mask[:, None, :, None] * input_mask[:, None, None, :]
mask = F.pad(mask, (mem_len + cmem_len, 0), value = False)
mask = F.pad(mask, (mem_len + cmem_len, 0), value = True)
dots.masked_fill_(~mask, mask_value)

mask = torch.ones(t, kv_len, **to(x)).triu_(diagonal = 1 + kv_len).bool()
total_mem_len = mem_len + cmem_len
mask = torch.ones(t, t + total_mem_len, **to(x)).triu_(diagonal = 1 + total_mem_len).bool()
dots.masked_fill_(mask[None, None, ...], mask_value)

attn = dots.softmax(dim=-1)
Expand Down Expand Up @@ -248,7 +249,6 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me
full_attn(q, cmem_k, cmem_v)
)


return logits, Memory(new_mem, new_cmem), aux_loss

# transformer
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'compressive-transformer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.3.8',
version = '0.3.10',
license='MIT',
description = 'Implementation of Compressive Transformer in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 41d810e

Please sign in to comment.