From fceca516d0e4c77faabb4c6b1370780c9db17324 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 3 Jul 2020 15:13:56 -0700 Subject: [PATCH] update examples to use gradient accumulation --- examples/enwik8_simple/train.py | 18 +++++++++++------- setup.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/enwik8_simple/train.py b/examples/enwik8_simple/train.py index 65bdffd..ef95bdf 100644 --- a/examples/enwik8_simple/train.py +++ b/examples/enwik8_simple/train.py @@ -13,7 +13,8 @@ # constants NUM_BATCHES = int(1e5) -BATCH_SIZE = 4 +BATCH_SIZE = 16 +MAX_BATCH_SIZE = 4 LEARNING_RATE = 1e-4 VALIDATE_EVERY = 100 @@ -87,20 +88,23 @@ def __len__(self): for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): model.train() - for mlm_loss, aux_loss in model(next(train_loader), return_loss = True): + grad_accum_every = BATCH_SIZE / MAX_BATCH_SIZE + + for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True): loss = mlm_loss + aux_loss - loss.backward() + (loss / grad_accum_every).backward() print(f'training loss: {mlm_loss.item():.4f} | aux_loss: {aux_loss.item():.4f}') - torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) - optim.step() - optim.zero_grad() + if is_last: + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optim.step() + optim.zero_grad() if i % VALIDATE_EVERY == 0: model.eval() with torch.no_grad(): - for loss, aux_loss in model(next(val_loader), return_loss = True): + for loss, aux_loss, _ in model(next(val_loader), return_loss = True): print(f'validation loss: {loss.item():.4f}') if i % GENERATE_EVERY == 0: diff --git a/setup.py b/setup.py index ff53f53..5874530 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'compressive-transformer-pytorch', packages = find_packages(exclude=['examples']), - version = '0.3.4', + version = '0.3.5', license='MIT', description = 'Implementation of Compressive Transformer in Pytorch', author = 'Phil Wang',