Skip to content

Commit

Permalink
update readme, and fix bug with variable memory layers
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 3, 2020
1 parent 95edcd4 commit e82fe1f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ logits, _, aux_loss = model(segments[1], mask = masks[1], memories = memo
# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)
```

When training, you can use the `AutoregressiveWrapper` to have memory management across segments taken care of for you. As easy as it gets.

```python
import torch
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper

model = CompressiveTransformer(
num_tokens = 20000,
dim = 512,
depth = 6,
seq_len = 1024,
mem_len = 1024,
cmem_len = 256,
cmem_ratio = 4,
memory_layers = [5,6]
).cuda()

model = AutoregressiveWrapper(model)

inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()

for loss, aux_loss in model(inputs, return_loss = True):
(loss + aux_loss).backward()
# optimizer step and zero grad
```


## Citations

```bibtex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me
mask_value = max_neg_value(dots)

if pos_emb is not None:
pos_emb = pos_emb[:, -kv_len:]
pos_dots = torch.einsum('bhid,hjd->bhij', q, pos_emb) * self.scale
pos_dots = shift(pos_dots)
dots = dots + pos_dots
Expand Down Expand Up @@ -289,11 +290,13 @@ def forward(self, x, memories = None, mask = None):
mem_iter = iterate_tensor(mem)
cmem_iter = iterate_tensor(cmem)

for ind, (attn, ff, m, c) in enumerate(zip(self.attn_layers, self.ff_layers, mem, cmem)):
for ind, (attn, ff) in enumerate(zip(self.attn_layers, self.ff_layers)):
layer_num = ind + 1
use_memory = layer_num in self.memory_layers

memories = (next(mem_iter), next(cmem_iter)) if use_memory else None
memories = None
if use_memory:
memories = (next(mem_iter), next(cmem_iter))

x, (mem_out, cmem_out), layer_aux_loss = attn(x, memories = memories, calc_memory = use_memory, input_mask = mask, pos_emb = pos_emb)
x, = ff(x)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'compressive_transformer_pytorch',
packages = find_packages(exclude=['examples']),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'Implementation of Compressive Transformer in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e82fe1f

Please sign in to comment.