-
-
Notifications
You must be signed in to change notification settings - Fork 773
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fd6763e
commit e2c1895
Showing
1 changed file
with
399 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,399 @@ | ||
import time | ||
import copy | ||
from pathlib import Path | ||
from math import ceil | ||
from functools import partial, wraps | ||
from collections.abc import Iterable | ||
|
||
import torch | ||
from torch import nn | ||
from torch.optim import Adam | ||
from torch.cuda.amp import autocast, GradScaler | ||
|
||
from imagen_pytorch.imagen_pytorch import Imagen | ||
|
||
import numpy as np | ||
|
||
# helper functions | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def default(val, d): | ||
return val if exists(val) else d | ||
|
||
def cast_tuple(val, length = 1): | ||
return val if isinstance(val, tuple) else ((val,) * length) | ||
|
||
def pick_and_pop(keys, d): | ||
values = list(map(lambda key: d.pop(key), keys)) | ||
return dict(zip(keys, values)) | ||
|
||
def group_dict_by_key(cond, d): | ||
return_val = [dict(),dict()] | ||
for key in d.keys(): | ||
match = bool(cond(key)) | ||
ind = int(not match) | ||
return_val[ind][key] = d[key] | ||
return (*return_val,) | ||
|
||
def string_begins_with(prefix, str): | ||
return str.startswith(prefix) | ||
|
||
def group_by_key_prefix(prefix, d): | ||
return group_dict_by_key(partial(string_begins_with, prefix), d) | ||
|
||
def groupby_prefix_and_trim(prefix, d): | ||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) | ||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) | ||
return kwargs_without_prefix, kwargs | ||
|
||
def num_to_groups(num, divisor): | ||
groups = num // divisor | ||
remainder = num % divisor | ||
arr = [divisor] * groups | ||
if remainder > 0: | ||
arr.append(remainder) | ||
return arr | ||
|
||
def get_pkg_version(): | ||
from pkg_resources import get_distribution | ||
return get_distribution('dalle2_pytorch').version | ||
|
||
# decorators | ||
|
||
def cast_torch_tensor(fn): | ||
@wraps(fn) | ||
def inner(model, *args, **kwargs): | ||
device = kwargs.pop('_device', next(model.parameters()).device) | ||
cast_device = kwargs.pop('_cast_device', True) | ||
|
||
kwargs_keys = kwargs.keys() | ||
all_args = (*args, *kwargs.values()) | ||
split_kwargs_index = len(all_args) - len(kwargs_keys) | ||
all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) | ||
|
||
if cast_device: | ||
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) | ||
|
||
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] | ||
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) | ||
|
||
out = fn(model, *args, **kwargs) | ||
return out | ||
return inner | ||
|
||
# gradient accumulation functions | ||
|
||
def split_iterable(it, split_size): | ||
accum = [] | ||
for ind in range(ceil(len(it) / split_size)): | ||
start_index = ind * split_size | ||
accum.append(it[start_index: (start_index + split_size)]) | ||
return accum | ||
|
||
def split(t, split_size = None): | ||
if not exists(split_size): | ||
return t | ||
|
||
if isinstance(t, torch.Tensor): | ||
return t.split(split_size, dim = 0) | ||
|
||
if isinstance(t, Iterable): | ||
return split_iterable(t, split_size) | ||
|
||
return TypeError | ||
|
||
def find_first(cond, arr): | ||
for el in arr: | ||
if cond(el): | ||
return el | ||
return None | ||
|
||
def split_args_and_kwargs(*args, split_size = None, **kwargs): | ||
all_args = (*args, *kwargs.values()) | ||
len_all_args = len(all_args) | ||
first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) | ||
assert exists(first_tensor) | ||
|
||
batch_size = len(first_tensor) | ||
split_size = default(split_size, batch_size) | ||
num_chunks = ceil(batch_size / split_size) | ||
|
||
dict_len = len(kwargs) | ||
dict_keys = kwargs.keys() | ||
split_kwargs_index = len_all_args - dict_len | ||
|
||
split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] | ||
chunk_sizes = tuple(map(len, split_all_args[0])) | ||
|
||
for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): | ||
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] | ||
chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) | ||
chunk_size_frac = chunk_size / batch_size | ||
yield chunk_size_frac, (chunked_args, chunked_kwargs) | ||
|
||
# exponential moving average wrapper | ||
|
||
class EMA(nn.Module): | ||
def __init__( | ||
self, | ||
model, | ||
beta = 0.9999, | ||
update_after_step = 1000, | ||
update_every = 10, | ||
): | ||
super().__init__() | ||
self.beta = beta | ||
self.online_model = model | ||
self.ema_model = copy.deepcopy(model) | ||
|
||
self.update_every = update_every | ||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 | ||
|
||
self.register_buffer('initted', torch.Tensor([False])) | ||
self.register_buffer('step', torch.tensor([0])) | ||
|
||
def restore_ema_model_device(self): | ||
device = self.initted.device | ||
self.ema_model.to(device) | ||
|
||
def copy_params_from_model_to_ema(self): | ||
self.ema_model.state_dict(self.online_model.state_dict()) | ||
|
||
def update(self): | ||
self.step += 1 | ||
|
||
if (self.step % self.update_every) != 0: | ||
return | ||
|
||
if self.step <= self.update_after_step: | ||
self.copy_params_from_model_to_ema() | ||
return | ||
|
||
if not self.initted: | ||
self.copy_params_from_model_to_ema() | ||
self.initted.data.copy_(torch.Tensor([True])) | ||
|
||
self.update_moving_average(self.ema_model, self.online_model) | ||
|
||
def update_moving_average(self, ma_model, current_model): | ||
def calculate_ema(beta, old, new): | ||
if not exists(old): | ||
return new | ||
return old * beta + (1 - beta) * new | ||
|
||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): | ||
old_weight, up_weight = ma_params.data, current_params.data | ||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight) | ||
|
||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): | ||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) | ||
ma_buffer.copy_(new_buffer_value) | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.ema_model(*args, **kwargs) | ||
|
||
# imagen trainer | ||
|
||
def imagen_sample_in_chunks(fn): | ||
@wraps(fn) | ||
def inner(self, *args, max_batch_size = None, **kwargs): | ||
if not exists(max_batch_size): | ||
return fn(self, *args, **kwargs) | ||
|
||
if self.imagen.unconditional: | ||
batch_size = kwargs.get('batch_size') | ||
batch_sizes = num_to_groups(batch_size, max_batch_size) | ||
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] | ||
else: | ||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] | ||
|
||
return torch.cat(outputs, dim = 0) | ||
return inner | ||
|
||
class ImagenTrainer(nn.Module): | ||
def __init__( | ||
self, | ||
imagen, | ||
use_ema = True, | ||
lr = 1e-4, | ||
wd = 1e-2, | ||
eps = 1e-8, | ||
max_grad_norm = 0.5, | ||
amp = False, | ||
group_wd_params = True, | ||
**kwargs | ||
): | ||
super().__init__() | ||
assert isinstance(imagen, Imagen) | ||
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) | ||
|
||
self.imagen = imagen | ||
self.num_unets = len(self.imagen.unets) | ||
|
||
self.use_ema = use_ema | ||
self.ema_unets = nn.ModuleList([]) | ||
|
||
self.amp = amp | ||
|
||
# be able to finely customize learning rate, weight decay | ||
# per unet | ||
|
||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps)) | ||
|
||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.imagen.unets, lr, wd, eps)): | ||
optimizer = Adam( | ||
unet.parameters(), | ||
lr = unet_lr, | ||
eps = unet_eps, | ||
**kwargs | ||
) | ||
|
||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers | ||
|
||
if self.use_ema: | ||
self.ema_unets.append(EMA(unet, **ema_kwargs)) | ||
|
||
scaler = GradScaler(enabled = amp) | ||
setattr(self, f'scaler{ind}', scaler) | ||
|
||
# gradient clipping if needed | ||
|
||
self.max_grad_norm = max_grad_norm | ||
|
||
self.register_buffer('step', torch.tensor([0.])) | ||
|
||
def save(self, path, overwrite = True, **kwargs): | ||
path = Path(path) | ||
assert not (path.exists() and not overwrite) | ||
path.parent.mkdir(parents = True, exist_ok = True) | ||
|
||
save_obj = dict( | ||
model = self.imagen.state_dict(), | ||
version = get_pkg_version(), | ||
step = self.step.item(), | ||
**kwargs | ||
) | ||
|
||
for ind in range(0, self.num_unets): | ||
scaler_key = f'scaler{ind}' | ||
optimizer_key = f'scaler{ind}' | ||
scaler = getattr(self, scaler_key) | ||
optimizer = getattr(self, optimizer_key) | ||
save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} | ||
|
||
if self.use_ema: | ||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} | ||
|
||
torch.save(save_obj, str(path)) | ||
|
||
def load(self, path, only_model = False, strict = True): | ||
path = Path(path) | ||
assert path.exists() | ||
|
||
loaded_obj = torch.load(str(path)) | ||
|
||
if get_pkg_version() != loaded_obj['version']: | ||
print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') | ||
|
||
self.imagen.load_state_dict(loaded_obj['model'], strict = strict) | ||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) | ||
|
||
if only_model: | ||
return loaded_obj | ||
|
||
for ind in range(0, self.num_unets): | ||
scaler_key = f'scaler{ind}' | ||
optimizer_key = f'scaler{ind}' | ||
scaler = getattr(self, scaler_key) | ||
optimizer = getattr(self, optimizer_key) | ||
|
||
scaler.load_state_dict(loaded_obj[scaler_key]) | ||
optimizer.load_state_dict(loaded_obj[optimizer_key]) | ||
|
||
if self.use_ema: | ||
assert 'ema' in loaded_obj | ||
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) | ||
|
||
return loaded_obj | ||
|
||
@property | ||
def unets(self): | ||
return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) | ||
|
||
def scale(self, loss, *, unet_number): | ||
assert 1 <= unet_number <= self.num_unets | ||
index = unet_number - 1 | ||
scaler = getattr(self, f'scaler{index}') | ||
return scaler.scale(loss) | ||
|
||
def update(self, unet_number = None): | ||
if self.num_unets == 1: | ||
unet_number = default(unet_number, 1) | ||
|
||
assert exists(unet_number) and 1 <= unet_number <= self.num_unets | ||
index = unet_number - 1 | ||
unet = self.imagen.unets[index] | ||
|
||
optimizer = getattr(self, f'optim{index}') | ||
scaler = getattr(self, f'scaler{index}') | ||
|
||
if exists(self.max_grad_norm): | ||
scaler.unscale_(optimizer) | ||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) | ||
|
||
scaler.step(optimizer) | ||
scaler.update() | ||
optimizer.zero_grad() | ||
|
||
if self.use_ema: | ||
ema_unet = self.ema_unets[index] | ||
ema_unet.update() | ||
|
||
self.step += 1 | ||
|
||
@torch.no_grad() | ||
@cast_torch_tensor | ||
@imagen_sample_in_chunks | ||
def sample(self, *args, **kwargs): | ||
if kwargs.pop('use_non_ema', False) or not self.use_ema: | ||
return self.imagen.sample(*args, **kwargs) | ||
|
||
trainable_unets = self.imagen.unets | ||
self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling | ||
|
||
output = self.imagen.sample(*args, **kwargs) | ||
|
||
self.imagen.unets = trainable_unets # restore original training unets | ||
|
||
# cast the ema_model unets back to original device | ||
for ema in self.ema_unets: | ||
ema.restore_ema_model_device() | ||
|
||
return output | ||
|
||
@cast_torch_tensor | ||
def forward( | ||
self, | ||
*args, | ||
unet_number = None, | ||
max_batch_size = None, | ||
**kwargs | ||
): | ||
if self.num_unets == 1: | ||
unet_number = default(unet_number, 1) | ||
|
||
total_loss = 0. | ||
|
||
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): | ||
with autocast(enabled = self.amp): | ||
loss = self.imagen(*chunked_args, unet_number = unet_number, **chunked_kwargs) | ||
loss = loss * chunk_size_frac | ||
|
||
total_loss += loss.item() | ||
|
||
if self.training: | ||
self.scale(loss, unet_number = unet_number).backward() | ||
|
||
return total_loss |