Skip to content

Commit

Permalink
add trainer class
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 25, 2022
1 parent fd6763e commit e2c1895
Showing 1 changed file with 399 additions and 0 deletions.
399 changes: 399 additions & 0 deletions imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,399 @@
import time
import copy
from pathlib import Path
from math import ceil
from functools import partial, wraps
from collections.abc import Iterable

import torch
from torch import nn
from torch.optim import Adam
from torch.cuda.amp import autocast, GradScaler

from imagen_pytorch.imagen_pytorch import Imagen

import numpy as np

# helper functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)

def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))

def group_dict_by_key(cond, d):
return_val = [dict(),dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)

def string_begins_with(prefix, str):
return str.startswith(prefix)

def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)

def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs

def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr

def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version

# decorators

def cast_torch_tensor(fn):
@wraps(fn)
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)

kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
split_kwargs_index = len(all_args) - len(kwargs_keys)
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))

if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))

args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))

out = fn(model, *args, **kwargs)
return out
return inner

# gradient accumulation functions

def split_iterable(it, split_size):
accum = []
for ind in range(ceil(len(it) / split_size)):
start_index = ind * split_size
accum.append(it[start_index: (start_index + split_size)])
return accum

def split(t, split_size = None):
if not exists(split_size):
return t

if isinstance(t, torch.Tensor):
return t.split(split_size, dim = 0)

if isinstance(t, Iterable):
return split_iterable(t, split_size)

return TypeError

def find_first(cond, arr):
for el in arr:
if cond(el):
return el
return None

def split_args_and_kwargs(*args, split_size = None, **kwargs):
all_args = (*args, *kwargs.values())
len_all_args = len(all_args)
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
assert exists(first_tensor)

batch_size = len(first_tensor)
split_size = default(split_size, batch_size)
num_chunks = ceil(batch_size / split_size)

dict_len = len(kwargs)
dict_keys = kwargs.keys()
split_kwargs_index = len_all_args - dict_len

split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))

for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
chunk_size_frac = chunk_size / batch_size
yield chunk_size_frac, (chunked_args, chunked_kwargs)

# exponential moving average wrapper

class EMA(nn.Module):
def __init__(
self,
model,
beta = 0.9999,
update_after_step = 1000,
update_every = 10,
):
super().__init__()
self.beta = beta
self.online_model = model
self.ema_model = copy.deepcopy(model)

self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0

self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0]))

def restore_ema_model_device(self):
device = self.initted.device
self.ema_model.to(device)

def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict())

def update(self):
self.step += 1

if (self.step % self.update_every) != 0:
return

if self.step <= self.update_after_step:
self.copy_params_from_model_to_ema()
return

if not self.initted:
self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True]))

self.update_moving_average(self.ema_model, self.online_model)

def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new):
if not exists(old):
return new
return old * beta + (1 - beta) * new

for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)

for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
ma_buffer.copy_(new_buffer_value)

def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)

# imagen trainer

def imagen_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)

if self.imagen.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]

return torch.cat(outputs, dim = 0)
return inner

class ImagenTrainer(nn.Module):
def __init__(
self,
imagen,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
eps = 1e-8,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
**kwargs
):
super().__init__()
assert isinstance(imagen, Imagen)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)

self.imagen = imagen
self.num_unets = len(self.imagen.unets)

self.use_ema = use_ema
self.ema_unets = nn.ModuleList([])

self.amp = amp

# be able to finely customize learning rate, weight decay
# per unet

lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))

for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.imagen.unets, lr, wd, eps)):
optimizer = Adam(
unet.parameters(),
lr = unet_lr,
eps = unet_eps,
**kwargs
)

setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers

if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))

scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)

# gradient clipping if needed

self.max_grad_norm = max_grad_norm

self.register_buffer('step', torch.tensor([0.]))

def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
path.parent.mkdir(parents = True, exist_ok = True)

save_obj = dict(
model = self.imagen.state_dict(),
version = get_pkg_version(),
step = self.step.item(),
**kwargs
)

for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}

if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}

torch.save(save_obj, str(path))

def load(self, path, only_model = False, strict = True):
path = Path(path)
assert path.exists()

loaded_obj = torch.load(str(path))

if get_pkg_version() != loaded_obj['version']:
print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')

self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])

if only_model:
return loaded_obj

for ind in range(0, self.num_unets):
scaler_key = f'scaler{ind}'
optimizer_key = f'scaler{ind}'
scaler = getattr(self, scaler_key)
optimizer = getattr(self, optimizer_key)

scaler.load_state_dict(loaded_obj[scaler_key])
optimizer.load_state_dict(loaded_obj[optimizer_key])

if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)

return loaded_obj

@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])

def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)

def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)

assert exists(unet_number) and 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.imagen.unets[index]

optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')

if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)

scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()

self.step += 1

@torch.no_grad()
@cast_torch_tensor
@imagen_sample_in_chunks
def sample(self, *args, **kwargs):
if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.imagen.sample(*args, **kwargs)

trainable_unets = self.imagen.unets
self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling

output = self.imagen.sample(*args, **kwargs)

self.imagen.unets = trainable_unets # restore original training unets

# cast the ema_model unets back to original device
for ema in self.ema_unets:
ema.restore_ema_model_device()

return output

@cast_torch_tensor
def forward(
self,
*args,
unet_number = None,
max_batch_size = None,
**kwargs
):
if self.num_unets == 1:
unet_number = default(unet_number, 1)

total_loss = 0.

for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.imagen(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac

total_loss += loss.item()

if self.training:
self.scale(loss, unet_number = unet_number).backward()

return total_loss

0 comments on commit e2c1895

Please sign in to comment.