Skip to content

Commit

Permalink
make sure auxiliary loss actually optimizes the compression net
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 31, 2021
1 parent a73d146 commit d34aab1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,22 +260,24 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me
if old_mem.shape[1] == 0 or self.cmem_len <= 0:
return logits, Memory(new_mem, new_cmem), aux_loss

compressed_mem = self.compress_mem_fn(old_mem)
compressed_mem = self.compress_mem_fn(old_mem.detach())
old_cmem, new_cmem = split_at_index(1, -self.cmem_len, torch.cat((cmem, compressed_mem), dim=1))

if not self.training:
return logits, Memory(new_mem, new_cmem), aux_loss

# calculate compressed memory auxiliary loss if training

self.to_kv.weight.detach_()

cmem_k, cmem_v = self.to_kv(compressed_mem).chunk(2, dim=-1)
cmem_k, cmem_v = map(merge_heads, (cmem_k, cmem_v))
cmem_k, cmem_v = map(lambda x: x.expand(-1, h, -1, -1), (cmem_k, cmem_v))

old_mem_range = slice(- min(mem_len, self.mem_len) - self.seq_len, -self.seq_len)
old_mem_k, old_mem_v = map(lambda x: x[:, :, old_mem_range].clone(), (k, v))

q, old_mem_k, old_mem_v, cmem_k, cmem_v = map(torch.detach, (q, old_mem_k, old_mem_v, cmem_k, cmem_v))
q, old_mem_k, old_mem_v = map(torch.detach, (q, old_mem_k, old_mem_v))

attn_fn = partial(full_attn, dropout_fn = self.reconstruction_attn_dropout)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'compressive-transformer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.3.20',
version = '0.3.21',
license='MIT',
description = 'Implementation of Compressive Transformer in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d34aab1

Please sign in to comment.