Skip to content

Commit

Permalink
try to restore as much as possible if state_dict fails (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
deepglugs authored Aug 18, 2022
1 parent 423ffe8 commit b9b45ad
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,21 @@ def inner(self, *args, max_batch_size = None, **kwargs):

return inner


def restore_parts(state_dict_target, state_dict_from):
for name, param in state_dict_from.items():

if name not in state_dict_target:
continue

if param.size() == state_dict_target[name].size():
state_dict_target[name].copy_(param)
else:
print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")

return state_dict_target


class ImagenTrainer(nn.Module):
locked = False

Expand Down Expand Up @@ -716,7 +731,12 @@ def load(self, path, only_model = False, strict = True, noop_if_not_exist = Fals
if version.parse(__version__) != version.parse(loaded_obj['version']):
self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')

self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
try:
self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
except RuntimeError:
print("Failed loading state dict. Trying partial load")
self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
loaded_obj['model']))
self.steps.copy_(loaded_obj['steps'])

if only_model:
Expand Down Expand Up @@ -748,7 +768,12 @@ def load(self, path, only_model = False, strict = True, noop_if_not_exist = Fals

if self.use_ema:
assert 'ema' in loaded_obj
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
try:
self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
except RuntimeError:
print("Failed loading state dict. Trying partial load")
self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
loaded_obj['ema']))

self.print(f'checkpoint loaded from {path}')
return loaded_obj
Expand Down

0 comments on commit b9b45ad

Please sign in to comment.