From d34aab11398a399896e1f9014ffb5db431eb16b8 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 31 Jan 2021 08:41:09 -0800 Subject: [PATCH] make sure auxiliary loss actually optimizes the compression net --- .../compressive_transformer_pytorch.py | 6 ++++-- setup.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/compressive_transformer_pytorch/compressive_transformer_pytorch.py b/compressive_transformer_pytorch/compressive_transformer_pytorch.py index 8fb9fb5..33c4caf 100644 --- a/compressive_transformer_pytorch/compressive_transformer_pytorch.py +++ b/compressive_transformer_pytorch/compressive_transformer_pytorch.py @@ -260,7 +260,7 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me if old_mem.shape[1] == 0 or self.cmem_len <= 0: return logits, Memory(new_mem, new_cmem), aux_loss - compressed_mem = self.compress_mem_fn(old_mem) + compressed_mem = self.compress_mem_fn(old_mem.detach()) old_cmem, new_cmem = split_at_index(1, -self.cmem_len, torch.cat((cmem, compressed_mem), dim=1)) if not self.training: @@ -268,6 +268,8 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me # calculate compressed memory auxiliary loss if training + self.to_kv.weight.detach_() + cmem_k, cmem_v = self.to_kv(compressed_mem).chunk(2, dim=-1) cmem_k, cmem_v = map(merge_heads, (cmem_k, cmem_v)) cmem_k, cmem_v = map(lambda x: x.expand(-1, h, -1, -1), (cmem_k, cmem_v)) @@ -275,7 +277,7 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me old_mem_range = slice(- min(mem_len, self.mem_len) - self.seq_len, -self.seq_len) old_mem_k, old_mem_v = map(lambda x: x[:, :, old_mem_range].clone(), (k, v)) - q, old_mem_k, old_mem_v, cmem_k, cmem_v = map(torch.detach, (q, old_mem_k, old_mem_v, cmem_k, cmem_v)) + q, old_mem_k, old_mem_v = map(torch.detach, (q, old_mem_k, old_mem_v)) attn_fn = partial(full_attn, dropout_fn = self.reconstruction_attn_dropout) diff --git a/setup.py b/setup.py index 97f6ccf..e0f1d9d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'compressive-transformer-pytorch', packages = find_packages(exclude=['examples']), - version = '0.3.20', + version = '0.3.21', license='MIT', description = 'Implementation of Compressive Transformer in Pytorch', author = 'Phil Wang',