Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save pretrained for TPUs #105

Merged
merged 2 commits into from
May 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions aitextgen/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import ProgressBarBase
from tqdm.auto import tqdm
import os
import shutil
import subprocess
import sys

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
import os
import shutil
import subprocess

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.utilities import _TPU_AVAILABLE


class ATGTransformer(pl.LightningModule):
Expand All @@ -18,12 +21,12 @@ class ATGTransformer(pl.LightningModule):

def __init__(self, model, dataset, hparams, tokenizer):
super(ATGTransformer, self).__init__()
self.model, self.dataset, self.hparams, self.tokenizer = (
self.model, self.dataset, self.tokenizer = (
model,
dataset,
hparams,
tokenizer,
)
self.save_hyperparameters(hparams)

def forward(self, inputs):
return self.model(**inputs, return_dict=False)
Expand Down Expand Up @@ -112,6 +115,10 @@ def __init__(
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.train_transformers_only = train_transformers_only
self.num_layers_freeze = num_layers_freeze

@property
def save_every_check(self):
return self.save_every > 0 and self.steps % self.save_every == 0

def enabled(self):
self.enabled = True
Expand Down Expand Up @@ -172,10 +179,19 @@ def on_batch_end(self, trainer, pl_module):
desc += f" — GPU Mem: {gpu_memory} MB"
self.main_progress_bar.update(self.progress_bar_refresh_rate)
self.main_progress_bar.set_description(desc)


if _TPU_AVAILABLE and self.save_every_check:
Copy link
Contributor

@SeanNaren SeanNaren Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get some more information around this and why it's necessary? I think we can iterate on this to get something cleaner; the idea is that the boilerplate can live within lightning, outside of the user code.

did_unfreeze = False
if self.enabled:
self.unfreeze_layers(pl_module)
did_unfreeze = True
self.save_pytorch_model(trainer, pl_module, tpu=True)
if did_unfreeze:
self.freeze_layers(pl_module)

if self.enabled:
did_unfreeze = False
if self.save_every > 0 and self.steps % self.save_every == 0:
if not _TPU_AVAILABLE and self.save_every_check:
self.unfreeze_layers(pl_module)
self.save_pytorch_model(trainer, pl_module)
did_unfreeze = True
Expand Down Expand Up @@ -228,13 +244,19 @@ def generate_sample_text(self, trainer, pl_module):

self.main_progress_bar.write("=" * 10)

def save_pytorch_model(self, trainer, pl_module):
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m"
)
pl_module.model.save_pretrained(self.output_dir)
def save_pytorch_model(self, trainer, pl_module, tpu=False):

if self.enabled:
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m"
)
if tpu:
import torch_xla.core.xla_model as xm
pl_module.model.save_pretrained(self.output_dir, save_function=xm.save)
else:
pl_module.model.save_pretrained(self.output_dir)

if self.save_gdrive:
if self.enabled and self.save_gdrive:
for pt_file in ["pytorch_model.bin", "config.json"]:
shutil.copyfile(
os.path.join(self.output_dir, pt_file),
Expand Down