diff --git a/README.md b/README.md index c7b73fd..e66649e 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,34 @@ logits, _, aux_loss = model(segments[1], mask = masks[1], memories = memo # memories is a named tuple that contains the memory (mem) and the compressed memory (cmem) ``` +When training, you can use the `AutoregressiveWrapper` to have memory management across segments taken care of for you. As easy as it gets. + +```python +import torch +from compressive_transformer_pytorch import CompressiveTransformer +from compressive_transformer_pytorch import AutoregressiveWrapper + +model = CompressiveTransformer( + num_tokens = 20000, + dim = 512, + depth = 6, + seq_len = 1024, + mem_len = 1024, + cmem_len = 256, + cmem_ratio = 4, + memory_layers = [5,6] +).cuda() + +model = AutoregressiveWrapper(model) + +inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda() + +for loss, aux_loss in model(inputs, return_loss = True): + (loss + aux_loss).backward() + # optimizer step and zero grad +``` + + ## Citations ```bibtex diff --git a/compressive_transformer_pytorch/compressive_transformer_pytorch.py b/compressive_transformer_pytorch/compressive_transformer_pytorch.py index be16eee..819c33b 100644 --- a/compressive_transformer_pytorch/compressive_transformer_pytorch.py +++ b/compressive_transformer_pytorch/compressive_transformer_pytorch.py @@ -184,6 +184,7 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me mask_value = max_neg_value(dots) if pos_emb is not None: + pos_emb = pos_emb[:, -kv_len:] pos_dots = torch.einsum('bhid,hjd->bhij', q, pos_emb) * self.scale pos_dots = shift(pos_dots) dots = dots + pos_dots @@ -289,11 +290,13 @@ def forward(self, x, memories = None, mask = None): mem_iter = iterate_tensor(mem) cmem_iter = iterate_tensor(cmem) - for ind, (attn, ff, m, c) in enumerate(zip(self.attn_layers, self.ff_layers, mem, cmem)): + for ind, (attn, ff) in enumerate(zip(self.attn_layers, self.ff_layers)): layer_num = ind + 1 use_memory = layer_num in self.memory_layers - memories = (next(mem_iter), next(cmem_iter)) if use_memory else None + memories = None + if use_memory: + memories = (next(mem_iter), next(cmem_iter)) x, (mem_out, cmem_out), layer_aux_loss = attn(x, memories = memories, calc_memory = use_memory, input_mask = mask, pos_emb = pos_emb) x, = ff(x) diff --git a/setup.py b/setup.py index 221a682..f647822 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'compressive_transformer_pytorch', packages = find_packages(exclude=['examples']), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'Implementation of Compressive Transformer in Pytorch', author = 'Phil Wang',