Skip to content

Commit

Permalink
update examples to use gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 3, 2020
1 parent eb2df46 commit fceca51
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
18 changes: 11 additions & 7 deletions examples/enwik8_simple/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
BATCH_SIZE = 16
MAX_BATCH_SIZE = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY = 100

Expand Down Expand Up @@ -87,20 +88,23 @@ def __len__(self):
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
model.train()

for mlm_loss, aux_loss in model(next(train_loader), return_loss = True):
grad_accum_every = BATCH_SIZE / MAX_BATCH_SIZE

for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True):
loss = mlm_loss + aux_loss
loss.backward()
(loss / grad_accum_every).backward()

print(f'training loss: {mlm_loss.item():.4f} | aux_loss: {aux_loss.item():.4f}')

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if is_last:
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()

if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
for loss, aux_loss in model(next(val_loader), return_loss = True):
for loss, aux_loss, _ in model(next(val_loader), return_loss = True):
print(f'validation loss: {loss.item():.4f}')

if i % GENERATE_EVERY == 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'compressive-transformer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.3.4',
version = '0.3.5',
license='MIT',
description = 'Implementation of Compressive Transformer in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit fceca51

Please sign in to comment.