|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | +from tqdm import tqdm |
| 5 | + |
| 6 | +from submodules.GAN_stability.gan_training.train import toggle_grad, Trainer as TrainerBase |
| 7 | +from submodules.GAN_stability.gan_training.eval import Evaluator as EvaluatorBase |
| 8 | +from submodules.GAN_stability.gan_training.metrics import FIDEvaluator, KIDEvaluator |
| 9 | + |
| 10 | +from .utils import save_video, color_depth_map |
| 11 | + |
| 12 | + |
| 13 | +class Trainer(TrainerBase): |
| 14 | + def __init__(self, *args, use_amp=False, **kwargs): |
| 15 | + super(Trainer, self).__init__(*args, **kwargs) |
| 16 | + self.use_amp = use_amp |
| 17 | + if self.use_amp: |
| 18 | + self.scaler = torch.cuda.amp.GradScaler() |
| 19 | + |
| 20 | + def generator_trainstep(self, y, z): |
| 21 | + if not self.use_amp: |
| 22 | + return super(Trainer, self).generator_trainstep(y, z) |
| 23 | + assert (y.size(0) == z.size(0)) |
| 24 | + toggle_grad(self.generator, True) |
| 25 | + toggle_grad(self.discriminator, False) |
| 26 | + self.generator.train() |
| 27 | + self.discriminator.train() |
| 28 | + self.g_optimizer.zero_grad() |
| 29 | + |
| 30 | + with torch.cuda.amp.autocast(): |
| 31 | + x_fake = self.generator(z, y) |
| 32 | + d_fake = self.discriminator(x_fake, y) |
| 33 | + gloss = self.compute_loss(d_fake, 1) |
| 34 | + self.scaler.scale(gloss).backward() |
| 35 | + |
| 36 | + self.scaler.step(self.g_optimizer) |
| 37 | + self.scaler.update() |
| 38 | + |
| 39 | + return gloss.item() |
| 40 | + |
| 41 | + def discriminator_trainstep(self, x_real, y, z, data_aug): |
| 42 | + return super(Trainer, self).discriminator_trainstep(x_real, y, z, data_aug) # spectral norm raises error for when using amp |
| 43 | + |
| 44 | + |
| 45 | +class Evaluator(EvaluatorBase): |
| 46 | + def __init__(self, eval_fid_kid, *args, **kwargs): |
| 47 | + super(Evaluator, self).__init__(*args, **kwargs) |
| 48 | + if eval_fid_kid: |
| 49 | + self.inception_eval = FIDEvaluator( |
| 50 | + device=self.device, |
| 51 | + batch_size=self.batch_size, |
| 52 | + resize=True, |
| 53 | + n_samples=20000, |
| 54 | + n_samples_fake=1000, |
| 55 | + ) |
| 56 | + |
| 57 | + def get_rays(self, pose): |
| 58 | + return self.generator.val_ray_sampler(self.generator.H, self.generator.W, |
| 59 | + self.generator.focal, pose)[0] |
| 60 | + |
| 61 | + def create_samples(self, z, poses=None): |
| 62 | + self.generator.eval() |
| 63 | + N_samples = len(z) |
| 64 | + device = self.generator.device |
| 65 | + if self.batch_size > 1: |
| 66 | + z = z.to(device).split(self.batch_size) |
| 67 | + if poses is None: |
| 68 | + rays = [None] * len(z) |
| 69 | + else: |
| 70 | + rays = torch.stack([self.get_rays(poses[i].to(device)) for i in range(N_samples)]) |
| 71 | + rays = rays.split(self.batch_size) |
| 72 | + |
| 73 | + rgb, disp, acc = [], [], [] |
| 74 | + with torch.no_grad(): |
| 75 | + if self.batch_size > 1: |
| 76 | + for z_i, rays_i in tqdm(zip(z, rays), total=len(z), desc='Create samples...'): |
| 77 | + bs = len(z_i) |
| 78 | + if rays_i is not None: |
| 79 | + rays_i = rays_i.permute(1, 0, 2, 3).flatten(1, 2) # Bx2x(HxW)xC -> 2x(BxHxW)x3 |
| 80 | + rgb_i, disp_i, acc_i, _ = self.generator(z_i, rays=rays_i) |
| 81 | + |
| 82 | + reshape = lambda x: x.view(bs, self.generator.H, self.generator.W, x.shape[1]).permute(0, 3, 1, 2) # (NxHxW)xC -> NxCxHxW |
| 83 | + rgb.append(reshape(rgb_i).cpu()) |
| 84 | + disp.append(reshape(disp_i).cpu()) |
| 85 | + acc.append(reshape(acc_i).cpu()) |
| 86 | + else: |
| 87 | + for rays_i in rays: |
| 88 | + bs = len(z) |
| 89 | + if rays_i is not None: |
| 90 | + rays_i = rays_i.permute(1, 0, 2, 3).flatten(1, 2) # Bx2x(HxW)xC -> 2x(BxHxW)x3 |
| 91 | + rgb_i, disp_i, acc_i, _ = self.generator(z, rays=rays_i) |
| 92 | + |
| 93 | + reshape = lambda x: x.view(bs, self.generator.H, self.generator.W, x.shape[1]).permute(0, 3, 1, 2) # (NxHxW)xC -> NxCxHxW |
| 94 | + rgb.append(reshape(rgb_i).cpu()) |
| 95 | + disp.append(reshape(disp_i).cpu()) |
| 96 | + acc.append(reshape(acc_i).cpu()) |
| 97 | + |
| 98 | + rgb = torch.cat(rgb) |
| 99 | + disp = torch.cat(disp) |
| 100 | + acc = torch.cat(acc) |
| 101 | + |
| 102 | + depth = self.disp_to_cdepth(disp) |
| 103 | + |
| 104 | + return rgb, depth, acc |
| 105 | + |
| 106 | + def make_video(self, basename, z, poses, as_gif=True): |
| 107 | + """ Generate images and save them as video. |
| 108 | + z (N_samples, zdim): latent codes |
| 109 | + poses (N_frames, 3 x 4): camera poses for all frames of video |
| 110 | + """ |
| 111 | + N_samples, N_frames = len(z), len(poses) |
| 112 | + |
| 113 | + # reshape inputs |
| 114 | + z = z.unsqueeze(1).expand(-1, N_frames, -1).flatten(0, 1) # (N_samples x N_frames) x z_dim |
| 115 | + poses = poses.unsqueeze(0) \ |
| 116 | + .expand(N_samples, -1, -1, -1).flatten(0, 1) # (N_samples x N_frames) x 3 x 4 |
| 117 | + |
| 118 | + rgbs, depths, accs = self.create_samples(z, poses=poses) |
| 119 | + |
| 120 | + reshape = lambda x: x.view(N_samples, N_frames, *x.shape[1:]) |
| 121 | + rgbs = reshape(rgbs) |
| 122 | + depths = reshape(depths) |
| 123 | + print('Done, saving', rgbs.shape) |
| 124 | + |
| 125 | + fps = min(int(N_frames / 2.), 25) # aim for at least 2 second video |
| 126 | + for i in range(N_samples): |
| 127 | + save_video(rgbs[i], basename + '{:04d}_rgb.mp4'.format(i), as_gif=as_gif, fps=fps) |
| 128 | + save_video(depths[i], basename + '{:04d}_depth.mp4'.format(i), as_gif=as_gif, fps=fps) |
| 129 | + |
| 130 | + def disp_to_cdepth(self, disps): |
| 131 | + """Convert depth to color values""" |
| 132 | + if (disps == 2e10).all(): # no values predicted |
| 133 | + return torch.ones_like(disps) |
| 134 | + |
| 135 | + near, far = self.generator.render_kwargs_test['near'], self.generator.render_kwargs_test['far'] |
| 136 | + |
| 137 | + disps = disps / 2 + 0.5 # [-1, 1] -> [0, 1] |
| 138 | + |
| 139 | + depth = 1. / torch.max(1e-10 * torch.ones_like(disps), disps) # disparity -> depth |
| 140 | + depth[disps == 1e10] = far # set undefined values to far plane |
| 141 | + |
| 142 | + # scale between near, far plane for better visualization |
| 143 | + depth = (depth - near) / (far - near) |
| 144 | + |
| 145 | + depth = np.stack([color_depth_map(d) for d in depth[:, 0].detach().cpu().numpy()]) # convert to color |
| 146 | + depth = (torch.from_numpy(depth).permute(0, 3, 1, 2) / 255.) * 2 - 1 # [0, 255] -> [-1, 1] |
| 147 | + |
| 148 | + return depth |
| 149 | + |
| 150 | + def compute_fid_kid(self, sample_generator=None): |
| 151 | + if sample_generator is None: |
| 152 | + def sample(): |
| 153 | + while True: |
| 154 | + z = self.zdist.sample((self.batch_size,)) |
| 155 | + rgb, _, _ = self.create_samples(z) |
| 156 | + # convert to uint8 and back to get correct binning |
| 157 | + rgb = (rgb / 2 + 0.5).mul_(255).clamp_(0, 255).to(torch.uint8).to(torch.float) / 255. * 2 - 1 |
| 158 | + yield rgb.cpu() |
| 159 | + |
| 160 | + sample_generator = sample() |
| 161 | + |
| 162 | + fid, (kids, vars) = self.inception_eval.get_fid_kid(sample_generator) |
| 163 | + kid = np.mean(kids) |
| 164 | + return fid, kid |
0 commit comments