Skip to content

Commit

Permalink
fix, cleanup, make more efficient generation code
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 4, 2020
1 parent 4b3ce19 commit 86ce0a7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
39 changes: 20 additions & 19 deletions compressive_transformer_pytorch/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,49 +57,50 @@ def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., fi
self.net.eval()

out = start_tokens
inp = start_tokens

# take care of a primed sequence of any length

mem = None
*primes, inp = inp.split(self.seq_len, dim=1)

for segment in primes:
_, mem, _ = self.net(segment, memories = mem, **kwargs)

# take care of default masking

full_mask_like = lambda x: torch.full_like(x, True, dtype=torch.bool, device=x.device)

mask = kwargs.pop('mask', None)
if mask is None:
mask = full_mask_like(inp)
mask = full_mask_like(out)

# take care of a primed sequence of any length

mem = None
*primes, out = out.split(self.seq_len, dim=1)
*prime_masks, mask = mask.split(self.seq_len, dim=1)

for prime, prime_mask in zip(primes, prime_masks):
_, mem, _ = self.net(prime, memories = mem, mask = prime_mask, **kwargs)

# generate until hit sequence length

input_len = out.shape[1]

for _ in range(seq_len):
logits, mem, aux_loss = self.net(inp, memories = mem, **kwargs)
logits, mem, aux_loss = self.net(out[:, -input_len:], memories = mem, mask = mask[:, -input_len:], **kwargs)
logits = logits[:, -1, :]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)

# unlike most models, inputs start from sequence length of 1 once full sequence length is filled

if self.seq_len == inp.shape[1]:
inp = sample
mask = full_mask_like(inp)
else:
inp = torch.cat((inp, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)
out = torch.cat((out, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)

# append sample to accumulated output
# append sample to accumulated output

out = torch.cat((out, sample), dim=-1)
input_len = input_len % self.seq_len
input_len += 1

if eos_token is not None and (sample == eos_token).all():
break

out = out[:, t:]

if num_dims == 1:
out = out.squeeze(0)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'compressive-transformer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.3.11',
version = '0.3.12',
license='MIT',
description = 'Implementation of Compressive Transformer in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 86ce0a7

Please sign in to comment.