-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathinferencer.py
45 lines (35 loc) · 1.53 KB
/
inferencer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from networks import Generator, Discriminator
import torch
import os.path
import torchvision.utils as vutils
import torch.nn.functional as F
class Inferencer:
def __init__(self, generator_channels, nz, style_depth):
self.nz = nz
self.generator = Generator(generator_channels, nz, style_depth).cuda()
def inference(self, n):
test_z = torch.randn(n, self.nz).cuda()
with torch.no_grad():
self.grow()
img_size = 8
filename = 'checkpoints/{}x{}_last.pth'.format(img_size, img_size)
while os.path.isfile(filename):
self.load_checkpoint(img_size, filename)
self.generator.eval()
fake = self.generator(test_z, alpha=1)
fake = (fake + 1) * 0.5
fake = torch.clamp(fake, min=0.0, max=1.0)
fake = F.interpolate(fake, size=(256, 256))
vutils.save_image(fake, 'images/{}x{}.png'.format(img_size, img_size))
self.grow()
img_size *= 2
filename = 'checkpoints/{}x{}_last.pth'.format(img_size, img_size)
def grow(self):
self.generator.grow()
self.generator.cuda()
def load_checkpoint(self, img_size, filename):
checkpoint = torch.load(filename)
print('load {}x{} checkpoint'.format(checkpoint['img_size'], checkpoint['img_size']))
while img_size < checkpoint['img_size']:
self.grow()
self.generator.load_state_dict(checkpoint['generator'])