-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathload_model.py
40 lines (31 loc) · 1.22 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import torch
from model import SEDD
import utils
from model.ema import ExponentialMovingAverage
import graph_lib
import noise_lib
from omegaconf import OmegaConf
def load_model_hf(dir, device):
score_model = SEDD.from_pretrained(dir).to(device)
graph = graph_lib.get_graph(score_model.config, device)
noise = noise_lib.get_noise(score_model.config).to(device)
return score_model, graph, noise
def load_model_local(root_dir, device):
cfg = utils.load_hydra_config_from_run(root_dir)
graph = graph_lib.get_graph(cfg, device)
noise = noise_lib.get_noise(cfg).to(device)
score_model = SEDD(cfg).to(device)
ema = ExponentialMovingAverage(score_model.parameters(), decay=cfg.training.ema)
ckpt_dir = os.path.join(root_dir, "checkpoints-meta", "checkpoint.pth")
loaded_state = torch.load(ckpt_dir, map_location=device)
score_model.load_state_dict(loaded_state['model'])
ema.load_state_dict(loaded_state['ema'])
ema.store(score_model.parameters())
ema.copy_to(score_model.parameters())
return score_model, graph, noise
def load_model(root_dir, device):
try:
return load_model_hf(root_dir, device)
except:
return load_model_local(root_dir, device)