diff --git a/imagen_pytorch/trainer.py b/imagen_pytorch/trainer.py new file mode 100644 index 0000000..9601e29 --- /dev/null +++ b/imagen_pytorch/trainer.py @@ -0,0 +1,399 @@ +import time +import copy +from pathlib import Path +from math import ceil +from functools import partial, wraps +from collections.abc import Iterable + +import torch +from torch import nn +from torch.optim import Adam +from torch.cuda.amp import autocast, GradScaler + +from imagen_pytorch.imagen_pytorch import Imagen + +import numpy as np + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def cast_tuple(val, length = 1): + return val if isinstance(val, tuple) else ((val,) * length) + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + +def group_dict_by_key(cond, d): + return_val = [dict(),dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + +def string_begins_with(prefix, str): + return str.startswith(prefix) + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def get_pkg_version(): + from pkg_resources import get_distribution + return get_distribution('dalle2_pytorch').version + +# decorators + +def cast_torch_tensor(fn): + @wraps(fn) + def inner(model, *args, **kwargs): + device = kwargs.pop('_device', next(model.parameters()).device) + cast_device = kwargs.pop('_cast_device', True) + + kwargs_keys = kwargs.keys() + all_args = (*args, *kwargs.values()) + split_kwargs_index = len(all_args) - len(kwargs_keys) + all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) + + if cast_device: + all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) + + args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] + kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) + + out = fn(model, *args, **kwargs) + return out + return inner + +# gradient accumulation functions + +def split_iterable(it, split_size): + accum = [] + for ind in range(ceil(len(it) / split_size)): + start_index = ind * split_size + accum.append(it[start_index: (start_index + split_size)]) + return accum + +def split(t, split_size = None): + if not exists(split_size): + return t + + if isinstance(t, torch.Tensor): + return t.split(split_size, dim = 0) + + if isinstance(t, Iterable): + return split_iterable(t, split_size) + + return TypeError + +def find_first(cond, arr): + for el in arr: + if cond(el): + return el + return None + +def split_args_and_kwargs(*args, split_size = None, **kwargs): + all_args = (*args, *kwargs.values()) + len_all_args = len(all_args) + first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) + assert exists(first_tensor) + + batch_size = len(first_tensor) + split_size = default(split_size, batch_size) + num_chunks = ceil(batch_size / split_size) + + dict_len = len(kwargs) + dict_keys = kwargs.keys() + split_kwargs_index = len_all_args - dict_len + + split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] + chunk_sizes = tuple(map(len, split_all_args[0])) + + for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): + chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] + chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) + chunk_size_frac = chunk_size / batch_size + yield chunk_size_frac, (chunked_args, chunked_kwargs) + +# exponential moving average wrapper + +class EMA(nn.Module): + def __init__( + self, + model, + beta = 0.9999, + update_after_step = 1000, + update_every = 10, + ): + super().__init__() + self.beta = beta + self.online_model = model + self.ema_model = copy.deepcopy(model) + + self.update_every = update_every + self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 + + self.register_buffer('initted', torch.Tensor([False])) + self.register_buffer('step', torch.tensor([0])) + + def restore_ema_model_device(self): + device = self.initted.device + self.ema_model.to(device) + + def copy_params_from_model_to_ema(self): + self.ema_model.state_dict(self.online_model.state_dict()) + + def update(self): + self.step += 1 + + if (self.step % self.update_every) != 0: + return + + if self.step <= self.update_after_step: + self.copy_params_from_model_to_ema() + return + + if not self.initted: + self.copy_params_from_model_to_ema() + self.initted.data.copy_(torch.Tensor([True])) + + self.update_moving_average(self.ema_model, self.online_model) + + def update_moving_average(self, ma_model, current_model): + def calculate_ema(beta, old, new): + if not exists(old): + return new + return old * beta + (1 - beta) * new + + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = calculate_ema(self.beta, old_weight, up_weight) + + for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): + new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) + ma_buffer.copy_(new_buffer_value) + + def __call__(self, *args, **kwargs): + return self.ema_model(*args, **kwargs) + +# imagen trainer + +def imagen_sample_in_chunks(fn): + @wraps(fn) + def inner(self, *args, max_batch_size = None, **kwargs): + if not exists(max_batch_size): + return fn(self, *args, **kwargs) + + if self.imagen.unconditional: + batch_size = kwargs.get('batch_size') + batch_sizes = num_to_groups(batch_size, max_batch_size) + outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] + else: + outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] + + return torch.cat(outputs, dim = 0) + return inner + +class ImagenTrainer(nn.Module): + def __init__( + self, + imagen, + use_ema = True, + lr = 1e-4, + wd = 1e-2, + eps = 1e-8, + max_grad_norm = 0.5, + amp = False, + group_wd_params = True, + **kwargs + ): + super().__init__() + assert isinstance(imagen, Imagen) + ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + + self.imagen = imagen + self.num_unets = len(self.imagen.unets) + + self.use_ema = use_ema + self.ema_unets = nn.ModuleList([]) + + self.amp = amp + + # be able to finely customize learning rate, weight decay + # per unet + + lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps)) + + for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.imagen.unets, lr, wd, eps)): + optimizer = Adam( + unet.parameters(), + lr = unet_lr, + eps = unet_eps, + **kwargs + ) + + setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers + + if self.use_ema: + self.ema_unets.append(EMA(unet, **ema_kwargs)) + + scaler = GradScaler(enabled = amp) + setattr(self, f'scaler{ind}', scaler) + + # gradient clipping if needed + + self.max_grad_norm = max_grad_norm + + self.register_buffer('step', torch.tensor([0.])) + + def save(self, path, overwrite = True, **kwargs): + path = Path(path) + assert not (path.exists() and not overwrite) + path.parent.mkdir(parents = True, exist_ok = True) + + save_obj = dict( + model = self.imagen.state_dict(), + version = get_pkg_version(), + step = self.step.item(), + **kwargs + ) + + for ind in range(0, self.num_unets): + scaler_key = f'scaler{ind}' + optimizer_key = f'scaler{ind}' + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} + + if self.use_ema: + save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} + + torch.save(save_obj, str(path)) + + def load(self, path, only_model = False, strict = True): + path = Path(path) + assert path.exists() + + loaded_obj = torch.load(str(path)) + + if get_pkg_version() != loaded_obj['version']: + print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') + + self.imagen.load_state_dict(loaded_obj['model'], strict = strict) + self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) + + if only_model: + return loaded_obj + + for ind in range(0, self.num_unets): + scaler_key = f'scaler{ind}' + optimizer_key = f'scaler{ind}' + scaler = getattr(self, scaler_key) + optimizer = getattr(self, optimizer_key) + + scaler.load_state_dict(loaded_obj[scaler_key]) + optimizer.load_state_dict(loaded_obj[optimizer_key]) + + if self.use_ema: + assert 'ema' in loaded_obj + self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) + + return loaded_obj + + @property + def unets(self): + return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) + + def scale(self, loss, *, unet_number): + assert 1 <= unet_number <= self.num_unets + index = unet_number - 1 + scaler = getattr(self, f'scaler{index}') + return scaler.scale(loss) + + def update(self, unet_number = None): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + assert exists(unet_number) and 1 <= unet_number <= self.num_unets + index = unet_number - 1 + unet = self.imagen.unets[index] + + optimizer = getattr(self, f'optim{index}') + scaler = getattr(self, f'scaler{index}') + + if exists(self.max_grad_norm): + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if self.use_ema: + ema_unet = self.ema_unets[index] + ema_unet.update() + + self.step += 1 + + @torch.no_grad() + @cast_torch_tensor + @imagen_sample_in_chunks + def sample(self, *args, **kwargs): + if kwargs.pop('use_non_ema', False) or not self.use_ema: + return self.imagen.sample(*args, **kwargs) + + trainable_unets = self.imagen.unets + self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling + + output = self.imagen.sample(*args, **kwargs) + + self.imagen.unets = trainable_unets # restore original training unets + + # cast the ema_model unets back to original device + for ema in self.ema_unets: + ema.restore_ema_model_device() + + return output + + @cast_torch_tensor + def forward( + self, + *args, + unet_number = None, + max_batch_size = None, + **kwargs + ): + if self.num_unets == 1: + unet_number = default(unet_number, 1) + + total_loss = 0. + + for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): + with autocast(enabled = self.amp): + loss = self.imagen(*chunked_args, unet_number = unet_number, **chunked_kwargs) + loss = loss * chunk_size_frac + + total_loss += loss.item() + + if self.training: + self.scale(loss, unet_number = unet_number).backward() + + return total_loss