diff --git a/Experiments/audio_regression/configs/bach/val_der_sine.txt b/Experiments/audio_regression/configs/bach/val_der_sine.txt new file mode 100644 index 0000000..c9232d1 --- /dev/null +++ b/Experiments/audio_regression/configs/bach/val_der_sine.txt @@ -0,0 +1,5 @@ +exp_name = bach/val_der_sine + +supervision = val_der +filename = gt_bach.wav +activations = [sine, sine, sine, sine] diff --git a/Experiments/audio_regression/configs/bach/val_sine.txt b/Experiments/audio_regression/configs/bach/val_sine.txt new file mode 100644 index 0000000..4426cdd --- /dev/null +++ b/Experiments/audio_regression/configs/bach/val_sine.txt @@ -0,0 +1,5 @@ +exp_name = bach/val_sine + +supervision = val +filename = gt_bach.wav +activations = [sine, sine, sine, sine] diff --git a/Experiments/audio_regression/configs/counting/val_der_sine.txt b/Experiments/audio_regression/configs/counting/val_der_sine.txt new file mode 100644 index 0000000..66809bd --- /dev/null +++ b/Experiments/audio_regression/configs/counting/val_der_sine.txt @@ -0,0 +1,5 @@ +exp_name = counting/val_der_sine + +supervision = val_der +filename = gt_counting.wav +activations = [sine, sine, sine, sine] diff --git a/Experiments/audio_regression/configs/counting/val_sine.txt b/Experiments/audio_regression/configs/counting/val_sine.txt new file mode 100644 index 0000000..a49b6d7 --- /dev/null +++ b/Experiments/audio_regression/configs/counting/val_sine.txt @@ -0,0 +1,5 @@ +exp_name = counting/val_sine + +supervision = val +filename = gt_counting.wav +activations = [sine, sine, sine, sine] diff --git a/Experiments/audio_regression/dataset.py b/Experiments/audio_regression/dataset.py new file mode 100644 index 0000000..74b57df --- /dev/null +++ b/Experiments/audio_regression/dataset.py @@ -0,0 +1,38 @@ +import os + +import kornia +import scipy.io.wavfile as wavfile + +import torch + + +def get_data(data_root, filename, factor): + + rate, wav = wavfile.read(os.path.join(data_root, filename)) + print("Rate: %d" % rate) + print("Raw data shape: ", wav.shape) + + wav = torch.tensor(wav).reshape(-1, 1) + scale = torch.max(torch.abs(wav)) + wav = wav / scale # (N, 1) + + grad = kornia.filters.spatial_gradient(wav.unsqueeze(0).unsqueeze(0), mode='diff', order=1, normalized=True).squeeze() # (2, N) + grad = grad[1, :].reshape(-1, 1) # (N, 1) + + coordinate = torch.linspace(0, len(wav) - 1, len(wav)).reshape(-1, 1) # (N, 1) + + downsampled_wav = wav[::factor, :] + downsampled_grad = grad[::factor, :] + downsampled_coordinate = coordinate[::factor, :] + + return { + 'wav': wav, + 'grad': grad, + 'coordinate': coordinate, + + 'downsampled_wav': downsampled_wav, + 'downsampled_grad': downsampled_grad, + 'downsampled_coordinate': downsampled_coordinate, + } + + diff --git a/Experiments/audio_regression/diff_operators.py b/Experiments/audio_regression/diff_operators.py new file mode 100644 index 0000000..514f4f2 --- /dev/null +++ b/Experiments/audio_regression/diff_operators.py @@ -0,0 +1,9 @@ +import torch + + +def gradient(y, x, grad_outputs=None): + if grad_outputs is None: + grad_outputs = torch.ones_like(y) + grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] + return grad + diff --git a/Experiments/audio_regression/loss.py b/Experiments/audio_regression/loss.py new file mode 100644 index 0000000..67ce35d --- /dev/null +++ b/Experiments/audio_regression/loss.py @@ -0,0 +1,18 @@ +import torch + + +def mse(x, y): + return (x - y).pow(2).mean() + + +def val_mse(gt, pred): + val_loss = mse(gt, pred) + + return {'val_loss': val_loss} + + +def der_mse(gt_grad, pred_grad): + weights = torch.ones(gt_grad.shape[1]).to(gt_grad.device) + der_loss = torch.mean((weights * (gt_grad - pred_grad).pow(2)).sum(-1)) + + return {'der_loss': der_loss} diff --git a/Experiments/audio_regression/main.py b/Experiments/audio_regression/main.py new file mode 100644 index 0000000..1f0e93b --- /dev/null +++ b/Experiments/audio_regression/main.py @@ -0,0 +1,167 @@ +import os +import shutil + +import torch +from torch.utils.tensorboard import SummaryWriter + +from dataset import get_data +from model import MLP +from loss import * +from utils import * + +set_random_seed(0) + + +def config_parser(): + + import configargparse + parser = configargparse.ArgumentParser() + parser.add_argument('--config', is_config_file=True, help="Path of config file.") + + # logging options + parser.add_argument('--logging_root', type=str, default='./logs/', help="Where to store ckpts and logs.") + parser.add_argument('--epochs_til_ckpt', type=int, default=1000, help="Time interval in epochs until checkpoint is saved.") + parser.add_argument('--epochs_til_summary', type=int, default=100, help="Time interval in epochs until tensorboard summary is saved.") + + # training options + parser.add_argument('--lrate', type=float, default='5e-5') + parser.add_argument('--num_epochs', type=int, default=8000, help="Number of epochs to train for.") + + # experiment options + parser.add_argument('--exp_name', type=str, default='supervision_val_der', + help="Name of experiment.") + parser.add_argument('--supervision', type=str, default='val_der', choices=('val', 'der', 'val_der')) + parser.add_argument('--activations', nargs='+', default=['sine', 'sine', 'sine', 'sine']) + parser.add_argument('--w0', type=float, default='30.') + parser.add_argument('--has_pos_encoding', action='store_true') + parser.add_argument('--lambda_der', type=float, default='1.') + + # model options + parser.add_argument('--hidden_features', type=int, default=256) + parser.add_argument('--num_hidden_layers', type=int, default=3) + + # dataset options + parser.add_argument('--data_root', type=str, default='../../data/Audio', help="Root path to audio dataset.") + parser.add_argument('--filename', type=str, help="Name of wav file.") + parser.add_argument('--factor', type=int, default=5, help="Factor of downsampling.") + + return parser + + +def train(args, model, data, epochs, lrate, epochs_til_summary, epochs_til_checkpoint, logging_dir, train_summary_fn, test_summary_fn, log_f): + + summaries_dir = os.path.join(logging_dir, 'summaries') + os.makedirs(summaries_dir) + writer = SummaryWriter(summaries_dir) + + checkpoints_dir = os.path.join(logging_dir, 'checkpoints') + os.makedirs(checkpoints_dir) + + out_train_imgs_dir = os.path.join(logging_dir, 'out_train_imgs') + os.makedirs(out_train_imgs_dir) + + out_test_imgs_dir = os.path.join(logging_dir, 'out_test_imgs') + os.makedirs(out_test_imgs_dir) + + optim = torch.optim.Adam(lr=lrate, params=model.parameters()) + + # move data to GPU + data = {key: value.cuda() for key, value in data.items() if torch.is_tensor(value)} + + for epoch in range(1, epochs + 1): + + # forward and calculate loss + model_output = model(data['downsampled_coordinate'], mode='train') + losses = {} + losses.update(val_mse(data['downsampled_wav'], model_output['pred'])) + losses.update(der_mse(data['downsampled_grad'], model_output['pred_grad'])) + if args.supervision == 'val': + train_loss = losses['val_loss'] + elif args.supervision == 'der': + train_loss = losses['der_loss'] + elif args.supervision == 'val_der': + train_loss = 1. * losses['val_loss'] + args.lambda_der * losses['der_loss'] + # tensorboard + for loss_name, loss in losses.items(): + writer.add_scalar(loss_name, loss, epoch) + writer.add_scalar("train_loss", train_loss, epoch) + + # backward + optim.zero_grad() + train_loss.backward() + optim.step() + + if (not epoch % epochs_til_summary) or (epoch == epochs): + + # training summary + psnr = train_summary_fn(data, model_output, writer, epoch, out_train_imgs_dir) + str_print = "[Train] epoch: (%d/%d) " % (epoch, epochs) + for loss_name, loss in losses.items(): + str_print += loss_name + ": %0.6f, " % loss + str_print += "PSNR: %.3f " % (psnr) + print(str_print) + print(str_print, file=log_f) + + # test summary + with torch.no_grad(): + model_output = model(data['coordinate'], mode='test') + psnr = test_summary_fn(data, model_output, writer, epoch, out_test_imgs_dir, args.factor, args.filename) + str_print = "[Test]: PSNR: %.3f" % (psnr) + print(str_print) + print(str_print, file=log_f) + + # save checkpoint + if (not epoch % epochs_til_checkpoint) or (epoch == epochs): + torch.save(model.state_dict(), os.path.join(checkpoints_dir, 'model_epoch_%05d.pth' % epoch)) + + torch.save(model.state_dict(), os.path.join(checkpoints_dir, 'model_final.pth')) + + +def main(): + + parser = config_parser() + args = parser.parse_args() + + logging_dir = os.path.join(args.logging_root, args.exp_name) + if os.path.exists(logging_dir): + if input("The logging directory %s exists. Overwrite? (y/n)" % logging_dir) == 'y': + shutil.rmtree(logging_dir) + os.makedirs(logging_dir) + + with open(os.path.join(logging_dir, 'log.txt'), 'w') as log_f: + + print("Args:\n", args) + print("Args:\n", args, file=log_f) + + data = get_data(args.data_root, args.filename, args.factor) + print('Shape of original wav:', data['wav'].shape) + print('Shape of downsampled wav:', data['downsampled_wav'].shape) + + model = MLP( + in_features=1, + out_features=1, + w0=args.w0, + activations=args.activations, + hidden_features=args.hidden_features, + num_hidden_layers=args.num_hidden_layers, + has_pos_encoding=args.has_pos_encoding, + length=len(data['wav']), + fn_samples=len(data['downsampled_wav'])) + model.cuda() + + train( + args=args, + model=model, + data=data, + epochs=args.num_epochs, + lrate=args.lrate, + epochs_til_summary=args.epochs_til_summary, + epochs_til_checkpoint=args.epochs_til_ckpt, + logging_dir=logging_dir, + train_summary_fn=write_train_summary, + test_summary_fn=write_test_summary, + log_f=log_f) + + +if __name__=='__main__': + main() diff --git a/Experiments/audio_regression/model.py b/Experiments/audio_regression/model.py new file mode 100644 index 0000000..f557cf7 --- /dev/null +++ b/Experiments/audio_regression/model.py @@ -0,0 +1,167 @@ +import math +from functools import partial + +import torch +from torch import nn +import numpy as np + +import diff_operators + + +class Sine(nn.Module): + def __init__(self, w0): + super().__init__() + self.w0 = w0 + + def forward(self, input): + return torch.sin(self.w0 * input) + + +def init_weights_normal(m): + if type(m) == nn.Linear: + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + + +def first_layer_sine_init(m): + with torch.no_grad(): + if hasattr(m, 'weight'): + num_input = m.weight.size(-1) + m.weight.uniform_(-1 / num_input, 1 / num_input) + + +def sine_init(m, w0): + with torch.no_grad(): + if hasattr(m, 'weight'): + num_input = m.weight.size(-1) + m.weight.uniform_(-np.sqrt(6 / num_input) / w0, np.sqrt(6 / num_input) / w0) + + +class PosEncodingNeRF(nn.Module): + def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True): + """__init__. + + :param in_features: + :param sidelength: [3, height, width] + :param fn_samples: + :param use_nyquist: + """ + super().__init__() + + self.in_features = in_features + + if self.in_features == 3: + self.num_frequencies = 10 + elif self.in_features == 2: + assert sidelength is not None + if isinstance(sidelength, int): + sidelength = (sidelength, sidelength) + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1])) + elif self.in_features == 1: + assert fn_samples is not None + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples) + + self.out_dim = in_features + 2 * in_features * self.num_frequencies + + def get_num_frequencies_nyquist(self, samples): + nyquist_rate = 1 / (2 * (2 * 1 / samples)) + return int(math.floor(math.log(nyquist_rate, 2))) + + def forward(self, coords): + coords = coords.view(coords.shape[0], -1, self.in_features) + + coords_pos_enc = coords + for i in range(self.num_frequencies): + for j in range(self.in_features): + c = coords[..., j] + + sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1) + cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1) + + coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1) + + return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim) + + +class MLP(nn.Module): + def __init__(self, in_features, w0, activations, out_features, hidden_features, num_hidden_layers, has_pos_encoding, length, fn_samples): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.w0 = w0 + self.activations = activations + self.hidden_features = hidden_features + self.num_hidden_layers = num_hidden_layers + self.has_pos_encoding = has_pos_encoding + self.length = length # N, for normalizing input coordinate + self.fn_samples = fn_samples # N//factor, for positional encoding + + assert(len(self.activations) == (self.num_hidden_layers + 1)) + + activations_and_inits = { + 'sine':(Sine(self.w0), first_layer_sine_init, partial(sine_init, w0=self.w0)), + 'relu':(nn.ReLU(inplace=True), init_weights_normal, init_weights_normal), + } + + + if self.has_pos_encoding: + self.positional_encoding = PosEncodingNeRF(in_features=in_features, + fn_samples=fn_samples, + use_nyquist=True) + in_features = self.positional_encoding.out_dim + + # network architecture + net = [] + net.append(nn.Sequential(nn.Linear(in_features=in_features, out_features=hidden_features), activations_and_inits[self.activations[0]][0])) # input layer + for i in range(num_hidden_layers): # hidden layers + net.append(nn.Sequential(nn.Linear(in_features=hidden_features, out_features=hidden_features), activations_and_inits[self.activations[i + 1]][0])) + + net.append(nn.Sequential(nn.Linear(in_features=hidden_features, out_features=out_features))) # output linear layer, without activation + self.net = nn.Sequential(*net) + + self.net[0].apply(activations_and_inits[self.activations[0]][1]) # input layer + for i in range(self.num_hidden_layers): # hidden layers + self.net[i + 1].apply(activations_and_inits[self.activations[i + 1]][2]) + self.net[-1].apply(activations_and_inits[self.activations[-1]][2]) # output layer, initialize as hidden layers + + print("Network:\n", self) + + + def normalize_coordinate(self, coordinate, length): + normalized_coordinate = coordinate / (length - 1) # [0 ~ N-1] to [0 ~ 1] + normalized_coordinate -= 0.5 # [0 ~ 1] to [-0.5 ~ 0.5] + normalized_coordinate *= 200. # [-0.5 ~ 0.5] to [-100, 100] + + return normalized_coordinate + + + def forward(self, input_coordinate, mode='train'): + """forward. + + :param input_coordinate: [(H//factor)*(W//factor), 1] for train mode, [H*W, 1] for test mode + :param mode: 'train' or 'test' + """ + + if mode == 'train': + # Enables us to compute gradients w.r.t. coordinates + original_coordinate = input_coordinate.clone().detach().requires_grad_(True) + coordinate = self.normalize_coordinate(original_coordinate, self.length) + elif mode == 'test': + coordinate = input_coordinate.clone() + coordinate = self.normalize_coordinate(coordinate, self.length) + + if self.has_pos_encoding: + coordinate = self.positional_encoding(coordinate).squeeze(1) + + pred = self.net(coordinate) # (n, 1) + + if mode == 'train': + pred_grad = diff_operators.gradient(pred, original_coordinate) + return {'pred': pred, 'pred_grad': pred_grad} + elif mode == 'test': + return {'pred': pred} diff --git a/Experiments/audio_regression/utils.py b/Experiments/audio_regression/utils.py new file mode 100644 index 0000000..c25db8a --- /dev/null +++ b/Experiments/audio_regression/utils.py @@ -0,0 +1,126 @@ +import os +import random + +import numpy as np +import matplotlib.pyplot as plt + +import torch + + +SAVE_FORMAT = 'png' + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def cal_psnr(gt, pred, max_val=1.): + mse = (gt - pred).pow(2).mean() + return 10. * torch.log10(max_val ** 2 / mse) + + +def write_train_summary(data, output, writer, epoch, out_train_imgs_dir): + prefix = 'train_' + + gt = data['downsampled_wav'] + pred = output['pred'] + + # plot + start_index = int(0.05 * len(gt)) + end_index = int(0.95 * len(gt)) + x_plot = data['downsampled_coordinate'][start_index:end_index].detach().cpu().numpy() + gt_plot = gt[start_index:end_index].detach().cpu().numpy() + pred_plot = pred[start_index:end_index].detach().cpu().numpy() + + fig, axes = plt.subplots(3, 1) + axes[0].plot(x_plot, gt_plot) + axes[1].plot(x_plot, pred_plot) + axes[2].plot(x_plot, gt_plot - pred_plot) + axes[0].set_ylim([-0.75, 0.75]) + axes[1].set_ylim([-0.75, 0.75]) + axes[2].set_ylim([-0.25, 0.25]) + + axes[0].set_ylabel('GT') + axes[1].set_ylabel('Pred') + axes[2].set_ylabel('Error') + + fig.savefig(os.path.join(out_train_imgs_dir, '%05d.%s' % (epoch, SAVE_FORMAT)), format=SAVE_FORMAT) + plt.close(fig) + + # write metric + psnr = cal_psnr(gt, pred) + writer.add_scalar(prefix + 'PSNR', psnr, epoch) + + return psnr + + +def write_test_summary(data, output, writer, epoch, out_test_imgs_dir, factor, filename): + prefix = 'test_' + + gt = data['wav'] + pred = output['pred'] + + start_index = int(0.05 * len(gt)) + end_index = int(0.95 * len(gt)) + x_plot = data['coordinate'][start_index:end_index].detach().cpu().numpy() + gt_plot = gt[start_index:end_index].detach().cpu().numpy() + pred_plot = pred[start_index:end_index].detach().cpu().numpy() + + zoom_length = 600 + if 'counting' in filename: + zoom_start_index = 445000 + elif 'bach' in filename: + zoom_start_index = 150000 + else: + raise NotImplementedError() + zoom_end_index = zoom_start_index + zoom_length + # print("zoom index: %d ~ %d" % (zoom_start_index, zoom_end_index)) + zoom_x_plot = data['coordinate'][zoom_start_index:zoom_end_index].detach().cpu().numpy() + zoom_gt_plot = gt[zoom_start_index:zoom_end_index].detach().cpu().numpy() + zoom_pred_plot = pred[zoom_start_index:zoom_end_index].detach().cpu().numpy() + + fig, axes = plt.subplots(3, 2) + axes[0, 0].plot(x_plot, gt_plot) + axes[1, 0].plot(x_plot, pred_plot) + axes[2, 0].plot(x_plot, gt_plot - pred_plot) + axes[0, 1].plot(zoom_x_plot, zoom_gt_plot) + axes[1, 1].plot(zoom_x_plot, zoom_pred_plot) + axes[2, 1].plot(zoom_x_plot, zoom_gt_plot - zoom_pred_plot) + + value_ylim = [-0.75, 0.75] + diff_ylim = [-0.25, 0.25] + axes[0, 0].set_ylim(value_ylim) + axes[1, 0].set_ylim(value_ylim) + axes[2, 0].set_ylim(diff_ylim) + axes[0, 1].set_ylim(value_ylim) + axes[1, 1].set_ylim(value_ylim) + axes[2, 1].set_ylim(diff_ylim) + + axes[0, 0].axvline(x=zoom_start_index, color='r', linewidth=0.5) + # axes[0, 0].axvline(x=zoom_end_index, color='r', linewidth=0.5) + axes[1, 0].axvline(x=zoom_start_index, color='r', linewidth=0.5) + axes[2, 0].axvline(x=zoom_start_index, color='r', linewidth=0.5) + + axes[0, 0].set_ylabel('GT') + axes[1, 0].set_ylabel('Pred') + axes[2, 0].set_ylabel('Error') + + fig.savefig(os.path.join(out_test_imgs_dir, '%05d.%s' % (epoch, SAVE_FORMAT)), format=SAVE_FORMAT) + plt.close(fig) + + # write metric + eval_mask = torch.ones_like(gt, dtype=torch.bool) + eval_mask[::factor, :] = False + eval_gt = gt[eval_mask].clone() + eval_pred = pred[eval_mask].clone() + psnr = cal_psnr(eval_gt, eval_pred) + writer.add_scalar(prefix + 'PSNR', psnr, epoch) + + return psnr + + diff --git a/Experiments/image_regression/configs/baby/val/relu.txt b/Experiments/image_regression/configs/baby/val/relu.txt new file mode 100644 index 0000000..a1e5db8 --- /dev/null +++ b/Experiments/image_regression/configs/baby/val/relu.txt @@ -0,0 +1,5 @@ +exp_name = baby/val/relu + +filename = baby +supervision = val +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/baby/val/relu_pos.txt b/Experiments/image_regression/configs/baby/val/relu_pos.txt new file mode 100644 index 0000000..3db8b67 --- /dev/null +++ b/Experiments/image_regression/configs/baby/val/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = baby/val/relu_pos + +filename = baby +supervision = val +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/baby/val/sine.txt b/Experiments/image_regression/configs/baby/val/sine.txt new file mode 100644 index 0000000..ae7d22e --- /dev/null +++ b/Experiments/image_regression/configs/baby/val/sine.txt @@ -0,0 +1,5 @@ +exp_name = baby/val/sine + +filename = baby +supervision = val +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/baby/val_der/relu.txt b/Experiments/image_regression/configs/baby/val_der/relu.txt new file mode 100644 index 0000000..baafb6e --- /dev/null +++ b/Experiments/image_regression/configs/baby/val_der/relu.txt @@ -0,0 +1,5 @@ +exp_name = baby/val_der/relu + +filename = baby +supervision = val_der +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/baby/val_der/relu_pos.txt b/Experiments/image_regression/configs/baby/val_der/relu_pos.txt new file mode 100644 index 0000000..5a4e419 --- /dev/null +++ b/Experiments/image_regression/configs/baby/val_der/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = baby/val_der/relu_pos + +filename = baby +supervision = val_der +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/baby/val_der/sine.txt b/Experiments/image_regression/configs/baby/val_der/sine.txt new file mode 100644 index 0000000..e527f9e --- /dev/null +++ b/Experiments/image_regression/configs/baby/val_der/sine.txt @@ -0,0 +1,5 @@ +exp_name = baby/val_der/sine + +filename = baby +supervision = val_der +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/bird/val/relu.txt b/Experiments/image_regression/configs/bird/val/relu.txt new file mode 100644 index 0000000..a83bb04 --- /dev/null +++ b/Experiments/image_regression/configs/bird/val/relu.txt @@ -0,0 +1,5 @@ +exp_name = bird/val/relu + +filename = bird +supervision = val +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/bird/val/relu_pos.txt b/Experiments/image_regression/configs/bird/val/relu_pos.txt new file mode 100644 index 0000000..29a1867 --- /dev/null +++ b/Experiments/image_regression/configs/bird/val/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = bird/val/relu_pos + +filename = bird +supervision = val +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/bird/val/sine.txt b/Experiments/image_regression/configs/bird/val/sine.txt new file mode 100644 index 0000000..f674238 --- /dev/null +++ b/Experiments/image_regression/configs/bird/val/sine.txt @@ -0,0 +1,5 @@ +exp_name = bird/val/sine + +filename = bird +supervision = val +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/bird/val_der/relu.txt b/Experiments/image_regression/configs/bird/val_der/relu.txt new file mode 100644 index 0000000..16d7263 --- /dev/null +++ b/Experiments/image_regression/configs/bird/val_der/relu.txt @@ -0,0 +1,5 @@ +exp_name = bird/val_der/relu + +filename = bird +supervision = val_der +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/bird/val_der/relu_pos.txt b/Experiments/image_regression/configs/bird/val_der/relu_pos.txt new file mode 100644 index 0000000..ac8df4b --- /dev/null +++ b/Experiments/image_regression/configs/bird/val_der/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = bird/val_der/relu_pos + +filename = bird +supervision = val_der +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/bird/val_der/sine.txt b/Experiments/image_regression/configs/bird/val_der/sine.txt new file mode 100644 index 0000000..e2fb35c --- /dev/null +++ b/Experiments/image_regression/configs/bird/val_der/sine.txt @@ -0,0 +1,5 @@ +exp_name = bird/val_der/sine + +filename = bird +supervision = val_der +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/butterfly/val/relu.txt b/Experiments/image_regression/configs/butterfly/val/relu.txt new file mode 100644 index 0000000..84846fa --- /dev/null +++ b/Experiments/image_regression/configs/butterfly/val/relu.txt @@ -0,0 +1,5 @@ +exp_name = butterfly/val/relu + +filename = butterfly +supervision = val +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/butterfly/val/relu_pos.txt b/Experiments/image_regression/configs/butterfly/val/relu_pos.txt new file mode 100644 index 0000000..6bc5d2e --- /dev/null +++ b/Experiments/image_regression/configs/butterfly/val/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = butterfly/val/relu_pos + +filename = butterfly +supervision = val +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/butterfly/val/sine.txt b/Experiments/image_regression/configs/butterfly/val/sine.txt new file mode 100644 index 0000000..bdf1a54 --- /dev/null +++ b/Experiments/image_regression/configs/butterfly/val/sine.txt @@ -0,0 +1,5 @@ +exp_name = butterfly/val/sine + +filename = butterfly +supervision = val +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/butterfly/val_der/relu.txt b/Experiments/image_regression/configs/butterfly/val_der/relu.txt new file mode 100644 index 0000000..30f3127 --- /dev/null +++ b/Experiments/image_regression/configs/butterfly/val_der/relu.txt @@ -0,0 +1,5 @@ +exp_name = butterfly/val_der/relu + +filename = butterfly +supervision = val_der +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/butterfly/val_der/relu_pos.txt b/Experiments/image_regression/configs/butterfly/val_der/relu_pos.txt new file mode 100644 index 0000000..e756533 --- /dev/null +++ b/Experiments/image_regression/configs/butterfly/val_der/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = butterfly/val_der/relu_pos + +filename = butterfly +supervision = val_der +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/butterfly/val_der/sine.txt b/Experiments/image_regression/configs/butterfly/val_der/sine.txt new file mode 100644 index 0000000..4c45a93 --- /dev/null +++ b/Experiments/image_regression/configs/butterfly/val_der/sine.txt @@ -0,0 +1,5 @@ +exp_name = butterfly/val_der/sine + +filename = butterfly +supervision = val_der +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/head/val/relu.txt b/Experiments/image_regression/configs/head/val/relu.txt new file mode 100644 index 0000000..7bca335 --- /dev/null +++ b/Experiments/image_regression/configs/head/val/relu.txt @@ -0,0 +1,5 @@ +exp_name = head/val/relu + +filename = head +supervision = val +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/head/val/relu_pos.txt b/Experiments/image_regression/configs/head/val/relu_pos.txt new file mode 100644 index 0000000..4a4d0f7 --- /dev/null +++ b/Experiments/image_regression/configs/head/val/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = head/val/relu_pos + +filename = head +supervision = val +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/head/val/sine.txt b/Experiments/image_regression/configs/head/val/sine.txt new file mode 100644 index 0000000..73b62a9 --- /dev/null +++ b/Experiments/image_regression/configs/head/val/sine.txt @@ -0,0 +1,5 @@ +exp_name = head/val/sine + +filename = head +supervision = val +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/head/val_der/relu.txt b/Experiments/image_regression/configs/head/val_der/relu.txt new file mode 100644 index 0000000..500f1ad --- /dev/null +++ b/Experiments/image_regression/configs/head/val_der/relu.txt @@ -0,0 +1,5 @@ +exp_name = head/val_der/relu + +filename = head +supervision = val_der +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/head/val_der/relu_pos.txt b/Experiments/image_regression/configs/head/val_der/relu_pos.txt new file mode 100644 index 0000000..ba62687 --- /dev/null +++ b/Experiments/image_regression/configs/head/val_der/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = head/val_der/relu_pos + +filename = head +supervision = val_der +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/head/val_der/sine.txt b/Experiments/image_regression/configs/head/val_der/sine.txt new file mode 100644 index 0000000..d776d5a --- /dev/null +++ b/Experiments/image_regression/configs/head/val_der/sine.txt @@ -0,0 +1,5 @@ +exp_name = head/val_der/sine + +filename = head +supervision = val_der +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/woman/val/relu.txt b/Experiments/image_regression/configs/woman/val/relu.txt new file mode 100644 index 0000000..1bc26e2 --- /dev/null +++ b/Experiments/image_regression/configs/woman/val/relu.txt @@ -0,0 +1,5 @@ +exp_name = woman/val/relu + +filename = woman +supervision = val +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/woman/val/relu_pos.txt b/Experiments/image_regression/configs/woman/val/relu_pos.txt new file mode 100644 index 0000000..72369f2 --- /dev/null +++ b/Experiments/image_regression/configs/woman/val/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = woman/val/relu_pos + +filename = woman +supervision = val +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/woman/val/sine.txt b/Experiments/image_regression/configs/woman/val/sine.txt new file mode 100644 index 0000000..37b7bb4 --- /dev/null +++ b/Experiments/image_regression/configs/woman/val/sine.txt @@ -0,0 +1,5 @@ +exp_name = woman/val/sine + +filename = woman +supervision = val +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/configs/woman/val_der/relu.txt b/Experiments/image_regression/configs/woman/val_der/relu.txt new file mode 100644 index 0000000..b2a69e9 --- /dev/null +++ b/Experiments/image_regression/configs/woman/val_der/relu.txt @@ -0,0 +1,5 @@ +exp_name = woman/val_der/relu + +filename = woman +supervision = val_der +activations = [relu, relu, relu, relu] diff --git a/Experiments/image_regression/configs/woman/val_der/relu_pos.txt b/Experiments/image_regression/configs/woman/val_der/relu_pos.txt new file mode 100644 index 0000000..bd8fd80 --- /dev/null +++ b/Experiments/image_regression/configs/woman/val_der/relu_pos.txt @@ -0,0 +1,6 @@ +exp_name = woman/val_der/relu_pos + +filename = woman +supervision = val_der +activations = [relu, relu, relu, relu] +has_pos_encoding = True diff --git a/Experiments/image_regression/configs/woman/val_der/sine.txt b/Experiments/image_regression/configs/woman/val_der/sine.txt new file mode 100644 index 0000000..ab54910 --- /dev/null +++ b/Experiments/image_regression/configs/woman/val_der/sine.txt @@ -0,0 +1,5 @@ +exp_name = woman/val_der/sine + +filename = woman +supervision = val_der +activations = [sine, sine, sine, sine] diff --git a/Experiments/image_regression/dataset.py b/Experiments/image_regression/dataset.py new file mode 100644 index 0000000..72032bd --- /dev/null +++ b/Experiments/image_regression/dataset.py @@ -0,0 +1,65 @@ +import os +from collections import namedtuple + +from PIL import Image +import kornia +import einops + +import torch +from torchvision import transforms + + +def get_data(data_root, dataset, filename, is_gray, factor, der_operator): + + dataset_root = os.path.join(data_root, dataset) + + img = Image.open(os.path.join(dataset_root, 'HR', filename + '.png')) + if (img.height % factor) or (img.width % factor): + raise ValueError("The width/height of image must be an integer multiple of factor!") + if is_gray: + img = img.convert('L') + + transform = transforms.Compose([ + transforms.Resize((img.height, img.width)), + transforms.ToTensor(), # [0 ~ 255] to [0. ~ 1.] + transforms.Normalize(torch.Tensor([0.5]), torch.Tensor([0.5])) # [0. ~ 1.] to [-1. ~ 1.] + ]) + img = transform(img) # (c, h, w) + + img_shape = namedtuple('shape', 'height width')(img.shape[1], img.shape[2]) + grad = kornia.filters.spatial_gradient(img.unsqueeze(0), mode=der_operator, order=1, normalized=True).squeeze(0) # (c, 2, h, w), grad_x = grad[:, 0], grad_y = grad[:, 1] + coordinate = torch.stack( + torch.meshgrid( + torch.linspace(0, img.shape[2] - 1, img.shape[2]), # (0 ~ w-1) + torch.linspace(0, img.shape[1] - 1, img.shape[1]), # (0 ~ h-1) + indexing='xy'), + -1) # (h, w, 2) + + downsampled_img = img[:, ::factor, ::factor].clone() # equal to F.interpolate with mode='nearest' + # downsampled_img = F.interpolate(img.unsqueeze(0), scale_factor=1/factor, mode='nearest').squeeze(0) + downsampled_img_shape = namedtuple('shape', 'height width')(downsampled_img.shape[1], downsampled_img.shape[2]) + downsampled_grad = grad[:, :, 0::factor, 0::factor].clone() + downsampled_coordinate = coordinate[::factor, ::factor, :].clone() + + # reshape data + img = einops.rearrange(img, 'c h w -> (h w) c') # (h*w, c) + grad = einops.rearrange(grad, 'c d h w -> (h w) (c d)') # (h*w, c*2), if c=3: (dx_r, dy_r, dx_g, dy_g, dx_b, dy_b) + coordinate = einops.rearrange(coordinate, 'h w c -> (h w) c') # (h*w, 2) + + downsampled_img = einops.rearrange(downsampled_img, 'c h w -> (h w) c') + downsampled_grad = einops.rearrange(downsampled_grad, 'c d h w -> (h w) (c d)') + downsampled_coordinate = einops.rearrange(downsampled_coordinate, 'h w c -> (h w) c') + + + return { + 'img': img, + 'img_shape': img_shape, + 'grad': grad, + 'coordinate': coordinate, + + 'downsampled_img': downsampled_img, + 'downsampled_img_shape': downsampled_img_shape, + 'downsampled_grad': downsampled_grad, + 'downsampled_coordinate': downsampled_coordinate, + } + diff --git a/Experiments/image_regression/diff_operators.py b/Experiments/image_regression/diff_operators.py new file mode 100644 index 0000000..514f4f2 --- /dev/null +++ b/Experiments/image_regression/diff_operators.py @@ -0,0 +1,9 @@ +import torch + + +def gradient(y, x, grad_outputs=None): + if grad_outputs is None: + grad_outputs = torch.ones_like(y) + grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] + return grad + diff --git a/Experiments/image_regression/loss.py b/Experiments/image_regression/loss.py new file mode 100644 index 0000000..67ce35d --- /dev/null +++ b/Experiments/image_regression/loss.py @@ -0,0 +1,18 @@ +import torch + + +def mse(x, y): + return (x - y).pow(2).mean() + + +def val_mse(gt, pred): + val_loss = mse(gt, pred) + + return {'val_loss': val_loss} + + +def der_mse(gt_grad, pred_grad): + weights = torch.ones(gt_grad.shape[1]).to(gt_grad.device) + der_loss = torch.mean((weights * (gt_grad - pred_grad).pow(2)).sum(-1)) + + return {'der_loss': der_loss} diff --git a/Experiments/image_regression/main.py b/Experiments/image_regression/main.py new file mode 100644 index 0000000..e14b0b5 --- /dev/null +++ b/Experiments/image_regression/main.py @@ -0,0 +1,181 @@ +import os +import shutil + + +import torch +from torch.utils.tensorboard import SummaryWriter + +from dataset import get_data +from model import MLP +from loss import * +from utils import * + +set_random_seed(0) + + +def config_parser(): + + import configargparse + parser = configargparse.ArgumentParser() + parser.add_argument('--config', is_config_file=True, help="Path of config file.") + + # logging options + parser.add_argument('--logging_root', type=str, default='./logs/', help="Where to store ckpts and logs.") + parser.add_argument('--epochs_til_ckpt', type=int, default=1000, help="Time interval in epochs until checkpoint is saved.") + parser.add_argument('--epochs_til_summary', type=int, default=100, help="Time interval in epochs until tensorboard summary is saved.") + + # training options + parser.add_argument('--lrate', type=float, default='1e-4') + parser.add_argument('--num_epochs', type=int, default=50000, help="Number of epochs to train for.") + + # experiment options + parser.add_argument('--exp_name', type=str, default='supervision_val_der', help="Name of experiment.") + parser.add_argument('--supervision', type=str, default='val_der', choices=('val', 'der', 'val_der')) + parser.add_argument('--activations', nargs='+', default=['sine', 'sine', 'sine', 'sine']) + parser.add_argument('--w0', type=float, default='30.') + parser.add_argument('--is_gray', action='store_true') + parser.add_argument('--der_operator', type=str, default='sobel', choices=('sobel', 'diff')) + parser.add_argument('--has_pos_encoding', action='store_true') + parser.add_argument('--has_fourier_feature', action='store_true') + parser.add_argument('--lambda_der', type=float, default='1.') + + # model options + parser.add_argument('--hidden_features', type=int, default=256) + parser.add_argument('--num_hidden_layers', type=int, default=3) + + # dataset options + parser.add_argument('--data_root', type=str, default='../../data', help="Root path to image datasets.") + parser.add_argument('--dataset', type=str, default='Set5', choices=('Set5', 'DIV2K_valid')) + parser.add_argument('--filename', type=str, help="Name of image file.") + parser.add_argument('--factor', type=int, default=4, choices=(1, 2, 3, 4), help="Factor of downsampling.") + + return parser + + +def train(args, model, data, epochs, lrate, epochs_til_summary, epochs_til_checkpoint, logging_dir, train_summary_fn, test_summary_fn, log_f): + + summaries_dir = os.path.join(logging_dir, 'summaries') + os.makedirs(summaries_dir) + writer = SummaryWriter(summaries_dir) + + checkpoints_dir = os.path.join(logging_dir, 'checkpoints') + os.makedirs(checkpoints_dir) + + out_train_imgs_dir = os.path.join(logging_dir, 'out_train_imgs') + os.makedirs(out_train_imgs_dir) + + out_test_imgs_dir = os.path.join(logging_dir, 'out_test_imgs') + os.makedirs(out_test_imgs_dir) + + optim = torch.optim.Adam(lr=lrate, params=model.parameters()) + + img_shape = data['img_shape'] + downsampled_img_shape = data['downsampled_img_shape'] + # move data to GPU + data = {key: value.cuda() for key, value in data.items() if torch.is_tensor(value)} + + for epoch in range(1, epochs + 1): + + # forward and calculate loss + model_output = model(data['downsampled_coordinate'], mode='train') + losses = {} + losses.update(val_mse(data['downsampled_img'], model_output['pred'])) + losses.update(der_mse(data['downsampled_grad'], model_output['pred_grad'])) + if args.supervision == 'val': + train_loss = losses['val_loss'] + elif args.supervision == 'der': + train_loss = losses['der_loss'] + elif args.supervision == 'val_der': + train_loss = 1. * losses['val_loss'] + args.lambda_der * losses['der_loss'] + + # tensorboard + for loss_name, loss in losses.items(): + writer.add_scalar(loss_name, loss, epoch) + writer.add_scalar("train_loss", train_loss, epoch) + + # backward + optim.zero_grad() + train_loss.backward() + optim.step() + + if (not epoch % epochs_til_summary) or (epoch == epochs): + + # training summary + psnr, ssim = train_summary_fn(data, model_output, writer, epoch, downsampled_img_shape, out_train_imgs_dir) + + str_print = "[Train] Epoch: (%d/%d) " % (epoch, epochs) + for loss_name, loss in losses.items(): + str_print += loss_name + ": %0.6f, " % loss + str_print += "PSNR: %.3f, SSIM: %.4f, " % (psnr, ssim) + print(str_print) + print(str_print, file=log_f) + + # test summary + with torch.no_grad(): + model_output = model(data['coordinate'], mode='test') + psnr, ssim = test_summary_fn(data, model_output, writer, epoch, img_shape, out_test_imgs_dir, factor=args.factor) + str_print = "[Test]: PSNR: %.3f, SSIM: %.4f" % (psnr, ssim) + print(str_print) + print(str_print, file=log_f) + + # save checkpoint + if (not epoch % epochs_til_checkpoint) or (epoch == epochs): + torch.save(model.state_dict(), os.path.join(checkpoints_dir, 'model_epoch_%05d.pth' % epoch)) + + torch.save(model.state_dict(), os.path.join(checkpoints_dir, 'model_final.pth')) + + +def main(): + + parser = config_parser() + args = parser.parse_args() + + logging_dir = os.path.join(args.logging_root, args.exp_name) + if os.path.exists(logging_dir): + # if input("The logging directory %s exists. Overwrite? (y/n)" % logging_dir) == 'y': + shutil.rmtree(logging_dir) + os.makedirs(logging_dir) + + with open(os.path.join(logging_dir, 'log.txt'), 'w') as log_f: + + print("Args:\n", args) + print("Args:\n", args, file=log_f) + + data = get_data(args.data_root, args.dataset, args.filename, args.is_gray, args.factor, args.der_operator) + print('Shape of original image:', data['img_shape']) + print('Shape of downsampled image:', data['downsampled_img_shape']) + + if args.is_gray: + out_features = 1 + else: + out_features = 3 + + model = MLP( + in_features=2, + out_features=out_features, + w0=args.w0, + activations=args.activations, + hidden_features=args.hidden_features, + num_hidden_layers=args.num_hidden_layers, + has_pos_encoding=args.has_pos_encoding, + has_fourier_feature=args.has_fourier_feature, + shape=data['img_shape'], + sidelength=data['downsampled_img_shape']) + model.cuda() + + train( + args=args, + model=model, + data=data, + epochs=args.num_epochs, + lrate=args.lrate, + epochs_til_summary=args.epochs_til_summary, + epochs_til_checkpoint=args.epochs_til_ckpt, + logging_dir=logging_dir, + train_summary_fn=write_train_summary, + test_summary_fn=write_test_summary, + log_f=log_f) + + +if __name__=='__main__': + main() diff --git a/Experiments/image_regression/model.py b/Experiments/image_regression/model.py new file mode 100644 index 0000000..75c3b9c --- /dev/null +++ b/Experiments/image_regression/model.py @@ -0,0 +1,179 @@ +import math +from functools import partial + +import torch +from torch import nn +import numpy as np + +import diff_operators + + +class Sine(nn.Module): + def __init__(self, w0): + super().__init__() + self.w0 = w0 + + def forward(self, input): + return torch.sin(self.w0 * input) + + +def init_weights_normal(m): + if type(m) == nn.Linear: + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + + +def first_layer_sine_init(m): + with torch.no_grad(): + if hasattr(m, 'weight'): + num_input = m.weight.size(-1) + m.weight.uniform_(-1 / num_input, 1 / num_input) + + +def sine_init(m, w0): + with torch.no_grad(): + if hasattr(m, 'weight'): + num_input = m.weight.size(-1) + m.weight.uniform_(-np.sqrt(6 / num_input) / w0, np.sqrt(6 / num_input) / w0) + + +class PosEncodingNeRF(nn.Module): + def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True): + """__init__. + + :param in_features: + :param sidelength: [3, height, width] + :param fn_samples: + :param use_nyquist: + """ + super().__init__() + + self.in_features = in_features + + if self.in_features == 3: + self.num_frequencies = 10 + elif self.in_features == 2: + assert sidelength is not None + if isinstance(sidelength, int): + sidelength = (sidelength, sidelength) + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1])) + elif self.in_features == 1: + assert fn_samples is not None + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples) + + self.out_dim = in_features + 2 * in_features * self.num_frequencies + + def get_num_frequencies_nyquist(self, samples): + nyquist_rate = 1 / (2 * (2 * 1 / samples)) + return int(math.floor(math.log(nyquist_rate, 2))) + + def forward(self, coords): + coords = coords.view(coords.shape[0], -1, self.in_features) + + coords_pos_enc = coords + for i in range(self.num_frequencies): + for j in range(self.in_features): + c = coords[..., j] + + sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1) + cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1) + + coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1) + + return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim) + + +class MLP(nn.Module): + def __init__(self, in_features, w0, activations, out_features, hidden_features, num_hidden_layers, has_pos_encoding, has_fourier_feature, shape, sidelength): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.w0 = w0 + self.activations = activations + self.hidden_features = hidden_features + self.num_hidden_layers = num_hidden_layers + self.has_pos_encoding = has_pos_encoding + self.has_fourier_feature = has_fourier_feature + self.shape = shape # (H, W), for normalizing input coordinate + self.sidelength = sidelength # (H//factor, W//factor), for positional encoding + + assert(len(self.activations) == (self.num_hidden_layers + 1)) + + activations_and_inits = { + 'sine':(Sine(self.w0), first_layer_sine_init, partial(sine_init, w0=self.w0)), + 'relu':(nn.ReLU(inplace=True), init_weights_normal, init_weights_normal), + } + + + if self.has_pos_encoding: + self.positional_encoding = PosEncodingNeRF(in_features=in_features, + sidelength=sidelength, + use_nyquist=True) + in_features = self.positional_encoding.out_dim + + if self.has_fourier_feature: + raise NotImplementedError("Fourier feature network: not implemented!") + + # network architecture + net = [] + net.append(nn.Sequential(nn.Linear(in_features=in_features, out_features=hidden_features), activations_and_inits[self.activations[0]][0])) # input layer + for i in range(num_hidden_layers): # hidden layers + net.append(nn.Sequential(nn.Linear(in_features=hidden_features, out_features=hidden_features), activations_and_inits[self.activations[i + 1]][0])) + + net.append(nn.Sequential(nn.Linear(in_features=hidden_features, out_features=out_features))) # output linear layer, without activation + self.net = nn.Sequential(*net) + + self.net[0].apply(activations_and_inits[self.activations[0]][1]) # input layer + for i in range(self.num_hidden_layers): # hidden layers + self.net[i + 1].apply(activations_and_inits[self.activations[i + 1]][2]) # following layer + self.net[-1].apply(activations_and_inits[self.activations[-1]][2]) # output layer, initialize as hidden layers + + print("Network:\n", self) + + + def normalize_coordinate(self, coordinate, shape): + """normalize_coordinate. + + :param coordinate: [h, w, 2], indexing='xy' + :param shape: namedtuple('height', 'width') + """ + normalized_x = coordinate[..., 0] / (shape.width - 1) # [0 ~ w-1] to [0. ~ 1.] + normalized_y = coordinate[..., 1] / (shape.height - 1) # [0 ~ h-1] to [0. ~ 1.] + normalized_coordinate = torch.stack((normalized_x, normalized_y), dim=-1) + normalized_coordinate -= 0.5 # [0. ~ 1.] to [-0.5 ~ 0.5] + normalized_coordinate *= 2. # [-0.5 ~ 0.5] to [-1., 1.] + + return normalized_coordinate + + + def forward(self, input_coordinate, mode='train'): + """forward. + + :param input_coordinate: [(H//factor)*(W//factor), 2] for train mode, [H*W, 2] for test mode + :param mode: 'train' or 'test' + """ + if mode == 'train': + original_coordinate = input_coordinate.clone().detach().requires_grad_(True) + coordinate = self.normalize_coordinate(original_coordinate, self.shape) + elif mode == 'test': + coordinate = input_coordinate.clone() + coordinate = self.normalize_coordinate(coordinate, self.shape) + + if self.has_pos_encoding: + coordinate = self.positional_encoding(coordinate).squeeze(1) + + pred = self.net(coordinate) # (h*w, c) + + if mode == 'train': + if pred.shape[1] == 1: # gray + pred_grad = diff_operators.gradient(pred, original_coordinate) + else: # color + pred_grad = torch.concat([diff_operators.gradient(pred[..., i], original_coordinate) for i in range(3)], dim=-1) # (h*w, c*2) + return {'pred': pred, 'pred_grad': pred_grad} + elif mode == 'test': + return {'pred': pred} diff --git a/Experiments/image_regression/utils.py b/Experiments/image_regression/utils.py new file mode 100644 index 0000000..d869e2c --- /dev/null +++ b/Experiments/image_regression/utils.py @@ -0,0 +1,149 @@ +import os +import math +import random + +import cv2 +import numpy as np +import einops +import kornia + +import torch +from torchvision.utils import make_grid + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def grad2rgb(grad, q=0.05): + """grad2rgb. + The scale of grad do not affect the final rgb image + + :param grad: (b, h, w, 2) + :param q: quantile + """ + if grad.shape[-1] != 2: + raise ValueError("grad is not of the right shape.") + B, H, W, _ = grad.shape + grad_cpu = grad.detach().cpu() + grad_x = grad_cpu[..., 0] # mGc, (b, h, w) + grad_y = grad_cpu[..., 1] # mGr + # grad_angle = torch.arctan(grad_x / grad_y) # (b, h, w), numerical instability + grad_angle = torch.from_numpy(np.arctan2(grad_x.numpy(), grad_y.numpy())) + grad_mag = torch.sqrt(grad_x.pow(2) + grad_y.pow(2)) # (b, h, w) + grad_hsv = torch.zeros((B, 3, H, W), dtype=torch.float32) # (b, 3, h, w) + grad_hsv[:, 0] = (grad_angle + torch.pi) # kornia.color.hsv_to_rgb assume Hue values are in the range [0 ~ 2pi] + grad_hsv[:, 1] = 1. + + per_min = torch.quantile(einops.rearrange(grad_mag, 'b h w -> b (h w)'), q=q, dim=-1) + per_max = torch.quantile(einops.rearrange(grad_mag, 'b h w -> b (h w)'), q=1.-q, dim=-1) + + grad_mag = (grad_mag - per_min) / (per_max - per_min) + grad_mag = torch.clip(grad_mag, 0., 1.) + + grad_hsv[:, 2] = grad_mag + grad_rgb = kornia.color.hsv_to_rgb(grad_hsv) # (b, 3, h, w) + + return grad_rgb + + +def cal_psnr(gt, pred, max_val=1.): + mse = (gt - pred).pow(2).mean() + return 10. * torch.log10(max_val ** 2 / mse) + + +def cal_ssim(gt, pred): + return kornia.metrics.ssim(gt, pred, 11).mean() + + +def cal_metric(gt, pred, factor=None): + gt = (gt * 0.5) + 0.5 # [-1. ~ 1.] to [0 ~ 1.] + pred = (pred * 0.5) + 0.5 + + if gt.shape[1] == 3: # color + gt = (gt * torch.tensor([65.738, 129.057, 25.064]).div(256.).view(1, 3, 1, 1).to(gt.device)).sum(dim=1, keepdim=True) # (b, 1, h, w) + pred = (pred * torch.tensor([65.738, 129.057, 25.064]).div(256.).view(1, 3, 1, 1).to(pred.device)).sum(dim=1, keepdim=True) + + if factor == None: # train + psnr = cal_psnr(gt, pred) + ssim = cal_ssim(gt, pred) + else: # test + eval_mask = torch.ones_like(gt, dtype=torch.bool) # (b, 1, H, W) + eval_mask[:, :, ::factor, ::factor] = False + psnr = cal_psnr(gt[eval_mask], pred[eval_mask]) # only evaluate test data + ssim = cal_ssim(gt[..., factor:-factor, factor:-factor], pred[..., factor:-factor, factor:-factor]) # crop image following common settings of super resolution. + + return psnr, ssim + + +def write_train_summary(data, output, writer, epoch, shape, out_train_imgs_dir): + prefix = 'train_' + + gt = data['downsampled_img'] # (h*w, c) + gt = einops.rearrange(gt, '(h w) c -> c h w', h=shape.height).unsqueeze(0) # (1, c, h, w) + pred = output['pred'] + pred = einops.rearrange(pred, '(h w) c -> c h w', h=shape.height).unsqueeze(0) + + # write image + gt_vs_pred = torch.cat([gt, pred], dim=-1) # (b, c, h, w*2) + gt_vs_pred = make_grid(gt_vs_pred, scale_each=True, normalize=True) # (c, h, w*2) + writer.add_image(prefix + 'gt_vs_pred', gt_vs_pred, global_step=epoch) + cv2.imwrite(os.path.join(out_train_imgs_dir, 'gt_vs_pred' + '-' + '%05d' % epoch + '.png'), (255 * einops.rearrange(torch.flip(gt_vs_pred, dims=[0]), 'c h w -> h w c').cpu().numpy()).astype(np.uint8)) + + # write grad + gt_grad = data['downsampled_grad'] # (h*w, c*2] + gt_grad = einops.rearrange(gt_grad, '(h w) t -> h w t', h=shape.height).unsqueeze(0) # (1, h*w, c*2) + pred_grad = output['pred_grad'] + pred_grad = einops.rearrange(pred_grad, '(h w) t -> h w t', h=shape.height).unsqueeze(0) + + if output['pred'].shape[1] == 1: # gray + gt_grad_rgb = grad2rgb(gt_grad) # (b, c, h, w) + pred_grad_rgb = grad2rgb(pred_grad) + gt_vs_pred_grad = torch.cat([gt_grad_rgb, pred_grad_rgb], dim=-1) # (b, c, h, w*2) + gt_vs_pred_grad = make_grid(gt_vs_pred_grad, scale_each=True, normalize=True) # (c, h, w*2) + writer.add_image(prefix + 'gt_vs_pred_grad', gt_vs_pred_grad, global_step=epoch) + cv2.imwrite(os.path.join(out_train_imgs_dir, 'gt_vs_pred_grad' + '-' + '%05d' % epoch + '.png'), (255 * einops.rearrange(torch.flip(gt_vs_pred_grad, dims=[0]), 'c h w -> h w c').cpu().numpy()).astype(np.uint8)) + + elif output['pred'].shape[1] == 3: # color + for i, c in enumerate(['r', 'g', 'b']): + gt_grad_rgb = grad2rgb(gt_grad[..., 2*i:2*i+2]) + pred_grad_rgb = grad2rgb(pred_grad[..., 2*i:2*i+2]) + gt_vs_pred_grad = torch.cat([gt_grad_rgb, pred_grad_rgb], dim=-1) + gt_vs_pred_grad = make_grid(gt_vs_pred_grad, scale_each=True, normalize=True) + writer.add_image(prefix + c + '_gt_vs_pred_grad', gt_vs_pred_grad, global_step=epoch) + cv2.imwrite(os.path.join(out_train_imgs_dir, c + '_gt_vs_pred_grad' + '-' + '%05d' % epoch + '.png'), (255 * einops.rearrange(torch.flip(gt_vs_pred_grad, dims=[0]), 'c h w -> h w c').cpu().numpy()).astype(np.uint8)) + + # write metrics + psnr, ssim = cal_metric(gt, pred) + writer.add_scalar(prefix + 'PSNR', psnr, epoch) + writer.add_scalar(prefix + 'SSIM', ssim, epoch) + + return psnr, ssim + + +def write_test_summary(data, output, writer, epoch, shape, out_test_imgs_dir, factor): + prefix = 'test_' + + gt = data['img'].clone() # [h*w, c) + gt = einops.rearrange(gt, '(h w) c -> c h w', h=shape.height).unsqueeze(0) # (1, c, h, w) + pred = output['pred'].clone() + pred = einops.rearrange(pred, '(h w) c -> c h w', h=shape.height).unsqueeze(0) + + # write img + gt_vs_pred = torch.cat([gt, pred], dim=-1) # (b, c, h, w*2) + gt_vs_pred = make_grid(gt_vs_pred, scale_each=True, normalize=True) # (c, h, 2*w) + cv2.imwrite(os.path.join(out_test_imgs_dir, 'gt_vs_pred' + '-' + '%05d' % epoch + '.png'), (255 * einops.rearrange(torch.flip(gt_vs_pred, dims=[0]), 'c h w -> h w c').cpu().numpy()).astype(np.uint8)) + writer.add_image(prefix + 'gt_vs_pred', gt_vs_pred, global_step=epoch) + + # write metric + psnr, ssim = cal_metric(gt, pred, factor) + writer.add_scalar(prefix + 'PSNR', psnr, epoch) + writer.add_scalar(prefix + 'SSIM', ssim, epoch) + + return psnr, ssim diff --git a/Experiments/inverse_rendering/configs/fern/val/relu.txt b/Experiments/inverse_rendering/configs/fern/val/relu.txt new file mode 100644 index 0000000..d54842b --- /dev/null +++ b/Experiments/inverse_rendering/configs/fern/val/relu.txt @@ -0,0 +1,5 @@ +expname = fern/val/relu + +datadir = ../../data/nerf_llff_data/fern +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/fern/val/sine.txt b/Experiments/inverse_rendering/configs/fern/val/sine.txt new file mode 100644 index 0000000..f8e869e --- /dev/null +++ b/Experiments/inverse_rendering/configs/fern/val/sine.txt @@ -0,0 +1,5 @@ +expname = fern/val/sine + +datadir = ../../data/nerf_llff_data/fern +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/fern/val_der/relu.txt b/Experiments/inverse_rendering/configs/fern/val_der/relu.txt new file mode 100644 index 0000000..473794b --- /dev/null +++ b/Experiments/inverse_rendering/configs/fern/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = fern/val_der/relu + +datadir = ../../data/nerf_llff_data/fern +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/fern/val_der/sine.txt b/Experiments/inverse_rendering/configs/fern/val_der/sine.txt new file mode 100644 index 0000000..1afcf2b --- /dev/null +++ b/Experiments/inverse_rendering/configs/fern/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = fern/val_der/sine + +datadir = ../../data/nerf_llff_data/fern +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/flower/val/relu.txt b/Experiments/inverse_rendering/configs/flower/val/relu.txt new file mode 100644 index 0000000..f451da4 --- /dev/null +++ b/Experiments/inverse_rendering/configs/flower/val/relu.txt @@ -0,0 +1,5 @@ +expname = flower/val/relu + +datadir = ../../data/nerf_llff_data/flower +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/flower/val/sine.txt b/Experiments/inverse_rendering/configs/flower/val/sine.txt new file mode 100644 index 0000000..015610f --- /dev/null +++ b/Experiments/inverse_rendering/configs/flower/val/sine.txt @@ -0,0 +1,5 @@ +expname = flower/val/sine + +datadir = ../../data/nerf_llff_data/flower +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/flower/val_der/relu.txt b/Experiments/inverse_rendering/configs/flower/val_der/relu.txt new file mode 100644 index 0000000..6b7baa8 --- /dev/null +++ b/Experiments/inverse_rendering/configs/flower/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = flower/val_der/relu + +datadir = ../../data/nerf_llff_data/flower +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/flower/val_der/sine.txt b/Experiments/inverse_rendering/configs/flower/val_der/sine.txt new file mode 100644 index 0000000..9f72fb1 --- /dev/null +++ b/Experiments/inverse_rendering/configs/flower/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = flower/val_der/sine + +datadir = ../../data/nerf_llff_data/flower +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/fortress/val/relu.txt b/Experiments/inverse_rendering/configs/fortress/val/relu.txt new file mode 100644 index 0000000..f4d9336 --- /dev/null +++ b/Experiments/inverse_rendering/configs/fortress/val/relu.txt @@ -0,0 +1,5 @@ +expname = fortress/val/relu + +datadir = ../../data/nerf_llff_data/fortress +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/fortress/val/sine.txt b/Experiments/inverse_rendering/configs/fortress/val/sine.txt new file mode 100644 index 0000000..7eb32d9 --- /dev/null +++ b/Experiments/inverse_rendering/configs/fortress/val/sine.txt @@ -0,0 +1,5 @@ +expname = fortress/val/sine + +datadir = ../../data/nerf_llff_data/fortress +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/fortress/val_der/relu.txt b/Experiments/inverse_rendering/configs/fortress/val_der/relu.txt new file mode 100644 index 0000000..4f661b5 --- /dev/null +++ b/Experiments/inverse_rendering/configs/fortress/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = fortress/val_der/relu + +datadir = ../../data/nerf_llff_data/fortress +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/fortress/val_der/sine.txt b/Experiments/inverse_rendering/configs/fortress/val_der/sine.txt new file mode 100644 index 0000000..824d629 --- /dev/null +++ b/Experiments/inverse_rendering/configs/fortress/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = fortress/val_der/sine + +datadir = ../../data/nerf_llff_data/fortress +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/horns/val/relu.txt b/Experiments/inverse_rendering/configs/horns/val/relu.txt new file mode 100644 index 0000000..13dc555 --- /dev/null +++ b/Experiments/inverse_rendering/configs/horns/val/relu.txt @@ -0,0 +1,5 @@ +expname = horns/val/relu + +datadir = ../../data/nerf_llff_data/horns +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/horns/val/sine.txt b/Experiments/inverse_rendering/configs/horns/val/sine.txt new file mode 100644 index 0000000..56fb2dc --- /dev/null +++ b/Experiments/inverse_rendering/configs/horns/val/sine.txt @@ -0,0 +1,5 @@ +expname = horns/val/sine + +datadir = ../../data/nerf_llff_data/horns +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/horns/val_der/relu.txt b/Experiments/inverse_rendering/configs/horns/val_der/relu.txt new file mode 100644 index 0000000..9b393ce --- /dev/null +++ b/Experiments/inverse_rendering/configs/horns/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = horns/val_der/relu + +datadir = ../../data/nerf_llff_data/horns +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/horns/val_der/sine.txt b/Experiments/inverse_rendering/configs/horns/val_der/sine.txt new file mode 100644 index 0000000..d070cb6 --- /dev/null +++ b/Experiments/inverse_rendering/configs/horns/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = horns/val_der/sine + +datadir = ../../data/nerf_llff_data/horns +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/leaves/val/relu.txt b/Experiments/inverse_rendering/configs/leaves/val/relu.txt new file mode 100644 index 0000000..e690a68 --- /dev/null +++ b/Experiments/inverse_rendering/configs/leaves/val/relu.txt @@ -0,0 +1,5 @@ +expname = leaves/val/relu + +datadir = ../../data/nerf_llff_data/leaves +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/leaves/val/sine.txt b/Experiments/inverse_rendering/configs/leaves/val/sine.txt new file mode 100644 index 0000000..368a80d --- /dev/null +++ b/Experiments/inverse_rendering/configs/leaves/val/sine.txt @@ -0,0 +1,5 @@ +expname = leaves/val/sine + +datadir = ../../data/nerf_llff_data/leaves +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/leaves/val_der/relu.txt b/Experiments/inverse_rendering/configs/leaves/val_der/relu.txt new file mode 100644 index 0000000..2ddf4ad --- /dev/null +++ b/Experiments/inverse_rendering/configs/leaves/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = leaves/val_der/relu + +datadir = ../../data/nerf_llff_data/leaves +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/leaves/val_der/sine.txt b/Experiments/inverse_rendering/configs/leaves/val_der/sine.txt new file mode 100644 index 0000000..3ab15a6 --- /dev/null +++ b/Experiments/inverse_rendering/configs/leaves/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = leaves/val_der/sine + +datadir = ../../data/nerf_llff_data/leaves +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/orchids/val/relu.txt b/Experiments/inverse_rendering/configs/orchids/val/relu.txt new file mode 100644 index 0000000..0c826a2 --- /dev/null +++ b/Experiments/inverse_rendering/configs/orchids/val/relu.txt @@ -0,0 +1,5 @@ +expname = orchids/val/relu + +datadir = ../../data/nerf_llff_data/orchids +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/orchids/val/sine.txt b/Experiments/inverse_rendering/configs/orchids/val/sine.txt new file mode 100644 index 0000000..8693262 --- /dev/null +++ b/Experiments/inverse_rendering/configs/orchids/val/sine.txt @@ -0,0 +1,5 @@ +expname = orchids/val/sine + +datadir = ../../data/nerf_llff_data/orchids +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/orchids/val_der/relu.txt b/Experiments/inverse_rendering/configs/orchids/val_der/relu.txt new file mode 100644 index 0000000..5474f93 --- /dev/null +++ b/Experiments/inverse_rendering/configs/orchids/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = orchids/val_der/relu + +datadir = ../../data/nerf_llff_data/orchids +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/orchids/val_der/sine.txt b/Experiments/inverse_rendering/configs/orchids/val_der/sine.txt new file mode 100644 index 0000000..a08bfd1 --- /dev/null +++ b/Experiments/inverse_rendering/configs/orchids/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = orchids/val_der/sine + +datadir = ../../data/nerf_llff_data/orchids +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/room/val/relu.txt b/Experiments/inverse_rendering/configs/room/val/relu.txt new file mode 100644 index 0000000..5ef9aa5 --- /dev/null +++ b/Experiments/inverse_rendering/configs/room/val/relu.txt @@ -0,0 +1,5 @@ +expname = room/val/relu + +datadir = ../../data/nerf_llff_data/room +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/room/val/sine.txt b/Experiments/inverse_rendering/configs/room/val/sine.txt new file mode 100644 index 0000000..5648a01 --- /dev/null +++ b/Experiments/inverse_rendering/configs/room/val/sine.txt @@ -0,0 +1,5 @@ +expname = room/val/sine + +datadir = ../../data/nerf_llff_data/room +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/room/val_der/relu.txt b/Experiments/inverse_rendering/configs/room/val_der/relu.txt new file mode 100644 index 0000000..0091194 --- /dev/null +++ b/Experiments/inverse_rendering/configs/room/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = room/val_der/relu + +datadir = ../../data/nerf_llff_data/room +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/room/val_der/sine.txt b/Experiments/inverse_rendering/configs/room/val_der/sine.txt new file mode 100644 index 0000000..e6de65a --- /dev/null +++ b/Experiments/inverse_rendering/configs/room/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = room/val_der/sine + +datadir = ../../data/nerf_llff_data/room +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/trex/val/relu.txt b/Experiments/inverse_rendering/configs/trex/val/relu.txt new file mode 100644 index 0000000..cb083a2 --- /dev/null +++ b/Experiments/inverse_rendering/configs/trex/val/relu.txt @@ -0,0 +1,5 @@ +expname = trex/val/relu + +datadir = ../../data/nerf_llff_data/trex +supervision = val +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/trex/val/sine.txt b/Experiments/inverse_rendering/configs/trex/val/sine.txt new file mode 100644 index 0000000..40fc864 --- /dev/null +++ b/Experiments/inverse_rendering/configs/trex/val/sine.txt @@ -0,0 +1,5 @@ +expname = trex/val/sine + +datadir = ../../data/nerf_llff_data/trex +supervision = val +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/configs/trex/val_der/relu.txt b/Experiments/inverse_rendering/configs/trex/val_der/relu.txt new file mode 100644 index 0000000..1c22585 --- /dev/null +++ b/Experiments/inverse_rendering/configs/trex/val_der/relu.txt @@ -0,0 +1,5 @@ +expname = trex/val_der/relu + +datadir = ../../data/nerf_llff_data/trex +supervision = val_der +activations = [relu, relu, relu, relu, relu, relu, relu, relu] diff --git a/Experiments/inverse_rendering/configs/trex/val_der/sine.txt b/Experiments/inverse_rendering/configs/trex/val_der/sine.txt new file mode 100644 index 0000000..0ccf4db --- /dev/null +++ b/Experiments/inverse_rendering/configs/trex/val_der/sine.txt @@ -0,0 +1,5 @@ +expname = trex/val_der/sine + +datadir = ../../data/nerf_llff_data/trex +supervision = val_der +activations = [sine, sine, sine, sine, sine, sine, sine, sine] diff --git a/Experiments/inverse_rendering/datasets/__init__.py b/Experiments/inverse_rendering/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Experiments/inverse_rendering/datasets/load_llff.py b/Experiments/inverse_rendering/datasets/load_llff.py new file mode 100644 index 0000000..98b7916 --- /dev/null +++ b/Experiments/inverse_rendering/datasets/load_llff.py @@ -0,0 +1,319 @@ +import numpy as np +import os, imageio + + +########## Slightly modified version of LLFF data loading code +########## see https://github.com/Fyusion/LLFF for original + +def _minify(basedir, factors=[], resolutions=[]): + needtoload = False + for r in factors: + imgdir = os.path.join(basedir, 'images_{}'.format(r)) + if not os.path.exists(imgdir): + needtoload = True + for r in resolutions: + imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) + if not os.path.exists(imgdir): + needtoload = True + if not needtoload: + return + + from shutil import copy + from subprocess import check_output + + imgdir = os.path.join(basedir, 'images') + imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] + imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] + imgdir_orig = imgdir + + wd = os.getcwd() + + for r in factors + resolutions: + if isinstance(r, int): + name = 'images_{}'.format(r) + resizearg = '{}%'.format(100./r) + else: + name = 'images_{}x{}'.format(r[1], r[0]) + resizearg = '{}x{}'.format(r[1], r[0]) + imgdir = os.path.join(basedir, name) + if os.path.exists(imgdir): + continue + + print('Minifying', r, basedir) + + os.makedirs(imgdir) + check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) + + ext = imgs[0].split('.')[-1] + args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) + print(args) + os.chdir(imgdir) + check_output(args, shell=True) + os.chdir(wd) + + if ext != 'png': + check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) + print('Removed duplicates') + print('Done') + + + + +def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): + + poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) + bds = poses_arr[:, -2:].transpose([1,0]) + + img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ + if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] + sh = imageio.imread(img0).shape + + sfx = '' + + if factor is not None: + sfx = '_{}'.format(factor) + _minify(basedir, factors=[factor]) + factor = factor + elif height is not None: + factor = sh[0] / float(height) + width = int(sh[1] / factor) + _minify(basedir, resolutions=[[height, width]]) + sfx = '_{}x{}'.format(width, height) + elif width is not None: + factor = sh[1] / float(width) + height = int(sh[0] / factor) + _minify(basedir, resolutions=[[height, width]]) + sfx = '_{}x{}'.format(width, height) + else: + factor = 1 + + imgdir = os.path.join(basedir, 'images' + sfx) + if not os.path.exists(imgdir): + print( imgdir, 'does not exist, returning' ) + return + + imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] + if poses.shape[-1] != len(imgfiles): + print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) + return + + sh = imageio.imread(imgfiles[0]).shape + poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1./factor + + if not load_imgs: + return poses, bds + + def imread(f): + if f.endswith('png'): + return imageio.imread(f, ignoregamma=True) + else: + return imageio.imread(f) + + imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] + imgs = np.stack(imgs, -1) + + print('Loaded image data', imgs.shape, poses[:,-1,0]) + return poses, bds, imgs + + + + + + +def normalize(x): + return x / np.linalg.norm(x) + +def viewmatrix(z, up, pos): + vec2 = normalize(z) + vec1_avg = up + vec0 = normalize(np.cross(vec1_avg, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + +def ptstocam(pts, c2w): + tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] + return tt + +def poses_avg(poses): + + hwf = poses[0, :3, -1:] + + center = poses[:, :3, 3].mean(0) + vec2 = normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) + + return c2w + + + +def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): + render_poses = [] + rads = np.array(list(rads) + [1.]) + hwf = c2w[:,4:5] + + for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: + c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) + z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) + render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) + return render_poses + + + +def recenter_poses(poses): + + poses_ = poses+0 + bottom = np.reshape([0,0,0,1.], [1,4]) + c2w = poses_avg(poses) + c2w = np.concatenate([c2w[:3,:4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) + poses = np.concatenate([poses[:,:3,:4], bottom], -2) + + poses = np.linalg.inv(c2w) @ poses + poses_[:,:3,:4] = poses[:,:3,:4] + poses = poses_ + return poses + + +##################### + + +def spherify_poses(poses, bds): + + p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) + + rays_d = poses[:,:3,2:3] + rays_o = poses[:,:3,3:4] + + def min_line_dist(rays_o, rays_d): + A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) + b_i = -A_i @ rays_o + pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) + return pt_mindist + + pt_mindist = min_line_dist(rays_o, rays_d) + + center = pt_mindist + up = (poses[:,:3,3] - center).mean(0) + + vec0 = normalize(up) + vec1 = normalize(np.cross([.1,.2,.3], vec0)) + vec2 = normalize(np.cross(vec0, vec1)) + pos = center + c2w = np.stack([vec1, vec2, vec0, pos], 1) + + poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) + + rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) + + sc = 1./rad + poses_reset[:,:3,3] *= sc + bds *= sc + rad *= sc + + centroid = np.mean(poses_reset[:,:3,3], 0) + zh = centroid[2] + radcircle = np.sqrt(rad**2-zh**2) + new_poses = [] + + for th in np.linspace(0.,2.*np.pi, 120): + + camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) + up = np.array([0,0,-1.]) + + vec2 = normalize(camorigin) + vec0 = normalize(np.cross(vec2, up)) + vec1 = normalize(np.cross(vec2, vec0)) + pos = camorigin + p = np.stack([vec0, vec1, vec2, pos], 1) + + new_poses.append(p) + + new_poses = np.stack(new_poses, 0) + + new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) + poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) + + return poses_reset, new_poses, bds + + +def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): + + + poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x + print('Loaded', basedir, bds.min(), bds.max()) + + # Correct rotation matrix ordering and move variable dim to axis 0 + poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) + images = imgs + bds = np.moveaxis(bds, -1, 0).astype(np.float32) + + # Rescale if bd_factor is provided + sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) + poses[:,:3,3] *= sc + bds *= sc + + if recenter: + poses = recenter_poses(poses) + + if spherify: + poses, render_poses, bds = spherify_poses(poses, bds) + + else: + + c2w = poses_avg(poses) + print('recentered', c2w.shape) + print(c2w[:3,:4]) + + ## Get spiral + # Get average pose + up = normalize(poses[:, :3, 1].sum(0)) + + # Find a reasonable "focus depth" for this dataset + close_depth, inf_depth = bds.min()*.9, bds.max()*5. + dt = .75 + mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) + focal = mean_dz + + # Get radii for spiral path + shrink_factor = .8 + zdelta = close_depth * .2 + tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T + rads = np.percentile(np.abs(tt), 90, 0) + c2w_path = c2w + N_views = 120 + N_rots = 2 + if path_zflat: +# zloc = np.percentile(tt, 10, 0)[2] + zloc = -close_depth * .1 + c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] + rads[2] = 0. + N_rots = 1 + N_views/=2 + + # Generate poses for spiral path + render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) + + + render_poses = np.array(render_poses).astype(np.float32) + + c2w = poses_avg(poses) + print('Data:') + print(poses.shape, images.shape, bds.shape) + + dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) + i_test = np.argmin(dists) + print('HOLDOUT view is', i_test) + + images = images.astype(np.float32) + poses = poses.astype(np.float32) + + return images, poses, bds, render_poses, i_test + + + diff --git a/Experiments/inverse_rendering/datasets/ray_utils.py b/Experiments/inverse_rendering/datasets/ray_utils.py new file mode 100644 index 0000000..d95f15d --- /dev/null +++ b/Experiments/inverse_rendering/datasets/ray_utils.py @@ -0,0 +1,37 @@ +import torch + + +def get_rays(K, c2w, coordinate): + """get_rays. + + :param K: [3, 3] + :param c2w: pose, [3, 4] or [N_rand, 3, 4] + :param coordinate: (x, y) + """ + dirs = torch.stack([(coordinate[..., 0] - K[0][2]) / K[0][0], -(coordinate[..., 1] - K[1][2]) / K[1][1], -torch.ones_like(coordinate[..., 0])], -1) + # Rotate ray directions from camera frame to the world frame + rays_d = torch.sum(dirs[..., None, :] * c2w[..., :3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] + # Translate camera frame's origin to the world frame. It is the origin of all rays. + rays_o = c2w[..., :3,-1].expand(rays_d.shape) + + return rays_o, rays_d + + +def ndc_rays(H, W, focal, near, rays_o, rays_d): + # Shift ray origins to near plane + t = -(near + rays_o[...,2]) / rays_d[...,2] + rays_o = rays_o + t[...,None] * rays_d + + # Projection + o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] + o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] + o2 = 1. + 2. * near / rays_o[...,2] + + d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) + d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) + d2 = -2. * near / rays_o[...,2] + + rays_o = torch.stack([o0,o1,o2], -1) + rays_d = torch.stack([d0,d1,d2], -1) + + return rays_o, rays_d diff --git a/Experiments/inverse_rendering/diff_operators.py b/Experiments/inverse_rendering/diff_operators.py new file mode 100644 index 0000000..514f4f2 --- /dev/null +++ b/Experiments/inverse_rendering/diff_operators.py @@ -0,0 +1,9 @@ +import torch + + +def gradient(y, x, grad_outputs=None): + if grad_outputs is None: + grad_outputs = torch.ones_like(y) + grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] + return grad + diff --git a/Experiments/inverse_rendering/eval.py b/Experiments/inverse_rendering/eval.py new file mode 100644 index 0000000..535874c --- /dev/null +++ b/Experiments/inverse_rendering/eval.py @@ -0,0 +1,94 @@ +import os +import glob + +import imageio +import einops +import numpy as np + +from decimal import Decimal +rounding_half_up_4 = lambda x: Decimal(str(x)).quantize(Decimal("0.0001"), rounding="ROUND_HALF_UP") + +import torch + +from datasets.load_llff import load_llff_data +from train import config_parser +from utils import cal_psnr, cal_ssim + + +def read_data(args): + + # Load LLFF data + images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor, + recenter=True, bd_factor=.75, + spherify=args.spherify) + hwf = poses[0,:3,-1] + poses = poses[:,:3,:4] + print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) + if not isinstance(i_test, list): + i_test = [i_test] + + if args.llffhold > 0: + print('Auto LLFF holdout,', args.llffhold) + i_test = np.arange(images.shape[0])[::args.llffhold] + + return images, i_test + + +def eval_img(gt, pred): + gt = einops.rearrange(torch.tensor(gt).unsqueeze(0), 'b h w c -> b c h w') + pred = einops.rearrange(torch.tensor(pred).unsqueeze(0), 'b h w c -> b c h w') + + return cal_psnr(gt, pred).item(), cal_ssim(gt, pred).item() + + +def eval_testset(testset_gt, testset_pred, f=None): + psnr_list = [] + ssim_list = [] + for i, (gt, pred) in enumerate(zip(testset_gt, testset_pred)): + psnr, ssim = eval_img(gt, pred) + print("%03d.png PSNR: %s, SSIM: %s" % (i, rounding_half_up_4(psnr), rounding_half_up_4(ssim))) + print("%03d.png PSNR: %s, SSIM: %s" % (i, rounding_half_up_4(psnr), rounding_half_up_4(ssim)), file=f) + psnr_list.append(psnr) + ssim_list.append(ssim) + print("Mean PSNR: %s, Mean SSIM: %s" % (rounding_half_up_4(np.mean(psnr_list)), rounding_half_up_4(np.mean(ssim_list)))) + print("Mean PSNR: %s, Mean SSIM: %s" % (rounding_half_up_4(np.mean(psnr_list)), rounding_half_up_4(np.mean(ssim_list))), file=f) + + +def eval_exp(): + parser = config_parser() + args = parser.parse_args() + print("args: ", args) + + exp_log_dir = os.path.join(args.basedir, args.expname) + testset_pred_dirs = glob.glob(os.path.join(exp_log_dir, 'testset_*')) + if len(testset_pred_dirs) == 0: + print("==============> EXP: %s, no testset result, skipping..." % args.expname) + return + + print("==============> EXP: %s" % args.expname) + + testset_pred_to_be_processed = [] + for testset_pred_dir in testset_pred_dirs: + score_file = os.path.join(testset_pred_dir, 'score.txt') + if os.path.exists(score_file): + print("Evaluating: %s, score.txt exists, skipping..." % (testset_pred_dir)) + else: + testset_pred_to_be_processed.append(testset_pred_dir) + if len(testset_pred_to_be_processed) == 0: + return + + images, i_test = read_data(args) + + for testset_pred_dir in testset_pred_to_be_processed: + testset_pred = [imageio.imread(fname).astype(np.float32) / 255. for fname in glob.glob(os.path.join(testset_pred_dir, '*.png'))] + if len(testset_pred) != len(i_test): + print("Evaluating: %s, the number of predicted images is wrong, skipping..." % (testset_pred_dir)) + else: + score_file = os.path.join(testset_pred_dir, 'score.txt') + with open(score_file, 'w') as f: + print("Evaluating: ", testset_pred_dir) + eval_testset(images[i_test], testset_pred, f) + + +if __name__ == '__main__': + eval_exp() diff --git a/Experiments/inverse_rendering/loss.py b/Experiments/inverse_rendering/loss.py new file mode 100644 index 0000000..58c0229 --- /dev/null +++ b/Experiments/inverse_rendering/loss.py @@ -0,0 +1,23 @@ +import torch +import numpy as np + +import diff_operators + + +# Misc +img2mse = lambda x, y : torch.mean((x - y) ** 2) +mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) +to8b = lambda x : (255 * np.clip(x, 0, 1)).astype(np.uint8) + + +def der_mse(rgb, coordinate, gt_grad): + + pred_grad_r = diff_operators.gradient(rgb[..., 0], coordinate) # [B, N, 2], order: (r_x, r_y) + pred_grad_g = diff_operators.gradient(rgb[..., 1], coordinate) # [B, N, 2], order: (g_x, g_y) + pred_grad_b = diff_operators.gradient(rgb[..., 2], coordinate) # [B, N, 2], order: (b_x, b_y) + pred_grad = torch.concat((pred_grad_r, pred_grad_g, pred_grad_b), dim=-1) # [B, N, 6] + + weights = torch.tensor([1., 1., 1., 1., 1., 1.]).cuda() + der_loss = torch.mean((weights * (gt_grad - pred_grad).pow(2)).sum(-1)) + + return der_loss diff --git a/Experiments/inverse_rendering/models/__init__.py b/Experiments/inverse_rendering/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Experiments/inverse_rendering/models/nerf.py b/Experiments/inverse_rendering/models/nerf.py new file mode 100644 index 0000000..b24029a --- /dev/null +++ b/Experiments/inverse_rendering/models/nerf.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +# Positional encoding (section 5.1) +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x : x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires, i=0): + if i == -1: + return nn.Identity(), 3 + + embed_kwargs = { + 'include_input' : True, + 'input_dims' : 3, + 'max_freq_log2' : multires-1, + 'num_freqs' : multires, + 'log_sampling' : True, + 'periodic_fns' : [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + embed = lambda x, eo=embedder_obj : eo.embed(x) + return embed, embedder_obj.out_dim + + +class TinyNeRF(nn.Module): + def __init__(self, D, activations, W=256, input_ch=3, input_ch_views=3, output_ch=4): + super(TinyNeRF, self).__init__() + + self.D = D + self.activations = activations + if len(self.activations) != self.D: + raise ValueError("Length of activations must equals D!") + + self.W = W + self.input_ch = input_ch + self.input_ch_views = input_ch_views + + self.pts_linears = nn.ModuleList([nn.Linear(input_ch, W)] + [nn.Linear(W, W) for i in range(D - 1)]) + self.output_linear = nn.Linear(W, output_ch) + + def forward(self, x): + input_pts, _ = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) + h = input_pts + for i, l in enumerate(self.pts_linears): + h = self.pts_linears[i](h) + if self.activations[i] == 'relu': + h = F.relu(h) + elif self.activations[i] == 'sine': + h = torch.sin(h) + else: + raise NotImplementedError + + outputs = self.output_linear(h) + + return outputs diff --git a/Experiments/inverse_rendering/models/rendering.py b/Experiments/inverse_rendering/models/rendering.py new file mode 100644 index 0000000..3c00daa --- /dev/null +++ b/Experiments/inverse_rendering/models/rendering.py @@ -0,0 +1,274 @@ +import torch +import torch.nn.functional as F +import numpy as np +import time +from tqdm import tqdm +import os +import imageio + +from datasets.ray_utils import get_rays, ndc_rays +from loss import to8b + + +DEBUG = False + + +def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): + """Transforms model's predictions to semantically meaningful values. + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) + + dists = z_vals[...,1:] - z_vals[...,:-1] + dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] + + dists = dists * torch.norm(rays_d[...,None,:], dim=-1) + + rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] + noise = 0. + if raw_noise_std > 0.: + noise = torch.randn(raw[...,3].shape) * raw_noise_std + + # Overwrite randomly sampled data if pytest + if pytest: + np.random.seed(0) + noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std + noise = torch.Tensor(noise) + + alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] + # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) + weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] + rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] + + depth_map = torch.sum(weights * z_vals, -1) + disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) + acc_map = torch.sum(weights, -1) + + if white_bkgd: + rgb_map = rgb_map + (1.-acc_map[...,None]) + + return rgb_map, disp_map, acc_map, weights, depth_map + + +def render_rays(ray_batch, + network_fn, + network_query_fn, + N_samples, + retraw=False, + lindisp=False, + perturb=0., + white_bkgd=False, + raw_noise_std=0., + verbose=False, + pytest=False): + """Volumetric rendering. + Args: + ray_batch: array of shape [batch_size, ...]. All information necessary + for sampling along a ray, including: ray origin, ray direction, min + dist, max dist, and unit-magnitude viewing direction. + network_fn: function. Model for predicting RGB and density at each point + in space. + network_query_fn: function used for passing queries to network_fn. + N_samples: int. Number of different times to sample along each ray. + retraw: bool. If True, include model's raw, unprocessed predictions. + lindisp: bool. If True, sample linearly in inverse depth rather than in depth. + perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified + random points in time. + white_bkgd: bool. If True, assume a white background. + raw_noise_std: ... + verbose: bool. If True, print more debugging info. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. 1 / depth. + acc_map: [num_rays]. Accumulated opacity along each ray. + raw: [num_rays, num_samples, 4]. Raw predictions from model. + rgb0: See rgb_map. Output for coarse model. + disp0: See disp_map. Output for coarse model. + acc0: See acc_map. Output for coarse model. + z_std: [num_rays]. Standard deviation of distances along ray for each + sample. + """ + N_rays = ray_batch.shape[0] + rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each + viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None + bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) + near, far = bounds[...,0], bounds[...,1] # [-1,1] + + t_vals = torch.linspace(0., 1., steps=N_samples) + if not lindisp: + z_vals = near * (1.-t_vals) + far * (t_vals) + else: + z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) + + z_vals = z_vals.expand([N_rays, N_samples]) + + if perturb > 0.: + # get intervals between samples + mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) + upper = torch.cat([mids, z_vals[...,-1:]], -1) + lower = torch.cat([z_vals[...,:1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape) + + # Pytest, overwrite u with numpy's fixed random numbers + if pytest: + np.random.seed(0) + t_rand = np.random.rand(*list(z_vals.shape)) + t_rand = torch.Tensor(t_rand) + + z_vals = lower + (upper - lower) * t_rand + + pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] + + +# raw = run_network(pts) + raw = network_query_fn(pts, viewdirs, network_fn) + rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) + + ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map} + if retraw: + ret['raw'] = raw + + for k in ret: + if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: + print(f"! [Numerical Error] {k} contains nan or inf.") + + return ret + + +def batchify_rays(rays_flat, chunk=1024*32, **kwargs): + """Render rays in smaller minibatches to avoid OOM. + """ + all_ret = {} + for i in range(0, rays_flat.shape[0], chunk): + ret = render_rays(rays_flat[i:i+chunk], **kwargs) + for k in ret: + if k not in all_ret: + all_ret[k] = [] + all_ret[k].append(ret[k]) + + all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} + return all_ret + + +def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, + near=0., far=1., + use_viewdirs=False, c2w_staticcam=None, + **kwargs): + """Render rays + Args: + H: int. Height of image in pixels. + W: int. Width of image in pixels. + focal: float. Focal length of pinhole camera. + chunk: int. Maximum number of rays to process simultaneously. Used to + control maximum memory usage. Does not affect final results. + rays: array of shape [2, batch_size, 3]. Ray origin and direction for + each example in batch. + c2w: array of shape [3, 4]. Camera-to-world transformation matrix. + ndc: bool. If True, represent ray origin, direction in NDC coordinates. + near: float or array of shape [batch_size]. Nearest distance for a ray. + far: float or array of shape [batch_size]. Farthest distance for a ray. + use_viewdirs: bool. If True, use viewing direction of a point in space in model. + c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for + camera while using other c2w argument for viewing directions. + Returns: + rgb_map: [batch_size, 3]. Predicted RGB values for rays. + disp_map: [batch_size]. Disparity map. Inverse of depth. + acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. + extras: dict with everything returned by render_rays(). + """ + if c2w is not None: + # special case to render full image + xy = torch.stack(torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H), indexing='xy'), -1) # [H, W, 2] + rays_o, rays_d = get_rays(K, c2w, xy) + + else: + # use provided ray batch + rays_o, rays_d = rays + + if use_viewdirs: + # provide ray directions as input + viewdirs = rays_d + if c2w_staticcam is not None: + # special case to visualize effect of viewdirs + # rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) + rays_o, rays_d = get_rays(K, c2w_staticcam, H, W) + + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1,3]).float() + + sh = rays_d.shape # [..., 3] + if ndc: + # for forward facing scenes + rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) + + # Create ray batch + rays_o = torch.reshape(rays_o, [-1,3]).float() + rays_d = torch.reshape(rays_d, [-1,3]).float() + + near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) + rays = torch.cat([rays_o, rays_d, near, far], -1) + if use_viewdirs: + rays = torch.cat([rays, viewdirs], -1) + + # Render and reshape + all_ret = batchify_rays(rays, chunk, **kwargs) + for k in all_ret: + k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) + all_ret[k] = torch.reshape(all_ret[k], k_sh) + + k_extract = ['rgb_map', 'disp_map', 'acc_map'] + ret_list = [all_ret[k] for k in k_extract] + ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} + return ret_list + [ret_dict] + + +def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): + + H, W, focal = hwf + + if render_factor!=0: + # Render downsampled for speed + H = H//render_factor + W = W//render_factor + focal = focal/render_factor + + rgbs = [] + disps = [] + + t = time.time() + for i, c2w in enumerate(tqdm(render_poses)): + print(i, time.time() - t) + t = time.time() + rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) + rgbs.append(rgb.cpu().numpy()) + disps.append(disp.cpu().numpy()) + if i==0: + print(rgb.shape, disp.shape) + + """ + if gt_imgs is not None and render_factor==0: + p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) + print(p) + """ + + if savedir is not None: + rgb8 = to8b(rgbs[-1]) + filename = os.path.join(savedir, '{:03d}.png'.format(i)) + imageio.imwrite(filename, rgb8) + + + rgbs = np.stack(rgbs, 0) + disps = np.stack(disps, 0) + + return rgbs, disps + diff --git a/Experiments/inverse_rendering/train.py b/Experiments/inverse_rendering/train.py new file mode 100644 index 0000000..79bfb95 --- /dev/null +++ b/Experiments/inverse_rendering/train.py @@ -0,0 +1,471 @@ +import os +import time + +import imageio +import kornia +import einops +import numpy as np +from tqdm import tqdm, trange + +import torch + +from datasets.load_llff import load_llff_data +from datasets.ray_utils import get_rays +from models.nerf import TinyNeRF, get_embedder +from models.rendering import render, render_path +from loss import img2mse, mse2psnr, to8b, der_mse +from utils import set_random_seed + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +set_random_seed(seed=0) + + +def batchify(fn, chunk): + """Constructs a version of 'fn' that applies to smaller batches. + """ + if chunk is None: + return fn + def ret(inputs): + return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) + return ret + + +def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): + """Prepares inputs and applies network 'fn'. + """ + inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) + embedded = embed_fn(inputs_flat) + + if viewdirs is not None: + input_dirs = viewdirs[:,None].expand(inputs.shape) + input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) + embedded_dirs = embeddirs_fn(input_dirs_flat) + embedded = torch.cat([embedded, embedded_dirs], -1) + + outputs_flat = batchify(fn, netchunk)(embedded) + outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) + return outputs + + +def create_tiny_nerf(args): + """Instantiate TinyNeRF model. + """ + embed_fn, input_ch = get_embedder(args.multires, args.i_embed) + + input_ch_views = 0 + embeddirs_fn = None + if args.use_viewdirs: + embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) + output_ch = 4 + model = TinyNeRF( + D=args.netdepth, + activations=args.activations, + W=args.netwidth, + input_ch=input_ch, + input_ch_views=input_ch_views, + output_ch=output_ch + ).to(device) + print("model: ", model) + grad_vars = list(model.parameters()) + + network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, + embed_fn=embed_fn, + embeddirs_fn=embeddirs_fn, + netchunk=args.netchunk) + + # Create optimizer + optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) + + start = 0 + basedir = args.basedir + expname = args.expname + + ########################## + + # Load checkpoints + if args.ft_path is not None and args.ft_path!='None': + ckpts = [args.ft_path] + else: + ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] + + print('Found ckpts', ckpts) + if len(ckpts) > 0 and not args.no_reload: + ckpt_path = ckpts[-1] + print('Reloading from', ckpt_path) + ckpt = torch.load(ckpt_path) + + start = ckpt['global_step'] + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + + # Load model + model.load_state_dict(ckpt['network_fn_state_dict']) + + ########################## + + model = torch.nn.DataParallel(model) + + render_kwargs_train = { + 'network_query_fn' : network_query_fn, + 'perturb' : args.perturb, + 'N_samples' : args.N_samples, + 'network_fn' : model, + 'use_viewdirs' : args.use_viewdirs, + 'white_bkgd' : args.white_bkgd, + 'raw_noise_std' : args.raw_noise_std, + } + + render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} + render_kwargs_test['perturb'] = False + render_kwargs_test['raw_noise_std'] = 0. + + return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer + + +def config_parser(): + + import configargparse + parser = configargparse.ArgumentParser() + parser.add_argument('--config', is_config_file=True, + help='config file path') + parser.add_argument("--expname", type=str, + help='experiment name') + parser.add_argument("--basedir", type=str, default='./logs', + help='where to store ckpts and logs') + parser.add_argument("--datadir", type=str, default='../../data/nerf_llff_data/fern', + help='input data directory') + + # training options + parser.add_argument("--netdepth", type=int, default=8, + help='layers in network') + parser.add_argument("--netwidth", type=int, default=256, + help='channels per layer') + parser.add_argument("--N_rand", type=int, default=1024, + help='batch size (number of random rays per gradient step)') + parser.add_argument("--lrate", type=float, default=5e-4, + help='learning rate') + parser.add_argument("--lrate_decay", type=int, default=250, + help='exponential learning rate decay (in 1000 steps)') + parser.add_argument("--chunk", type=int, default=1024*32, + help='number of rays processed in parallel, decrease if running out of memory') + parser.add_argument("--netchunk", type=int, default=1024*64, + help='number of pts sent through network in parallel, decrease if running out of memory') + parser.add_argument("--no_reload", action='store_true', + help='do not reload weights from saved ckpt') + parser.add_argument("--ft_path", type=str, default=None, + help='specific weights npy file to reload for coarse network') + + # rendering options + parser.add_argument("--N_samples", type=int, default=128, + help='number of coarse samples per ray') + parser.add_argument("--perturb", type=float, default=1., + help='set to 0. for no jitter, 1. for jitter') + parser.add_argument("--use_viewdirs", action='store_true', + help='use full 5D input instead of 3D') + parser.add_argument("--i_embed", type=int, default=0, + help='set 0 for default positional encoding, -1 for none') + parser.add_argument("--multires", type=int, default=10, + help='log2 of max freq for positional encoding (3D location)') + parser.add_argument("--multires_views", type=int, default=4, + help='log2 of max freq for positional encoding (2D direction)') + parser.add_argument("--raw_noise_std", type=float, default=1e0, + help='std dev of noise added to regularize sigma_a output, 1e0 recommended') + + parser.add_argument("--render_only", action='store_true', + help='do not optimize, reload weights and render out render_poses path') + parser.add_argument("--render_test", action='store_true', + help='render the test set instead of render_poses path') + parser.add_argument("--render_factor", type=int, default=0, + help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') + parser.add_argument("--white_bkgd", action='store_true', + help='set to render synthetic data on a white bkgd (always use for dvoxels)') + + # training options + parser.add_argument("--precrop_iters", type=int, default=0, + help='number of steps to train on central crops') + parser.add_argument("--precrop_frac", type=float, + default=.5, help='fraction of img taken for central crops') + + ## llff flags + parser.add_argument("--llff_factor", type=int, default=4, + help='downsample factor for LLFF images') + parser.add_argument("--no_ndc", action='store_true', + help='do not use normalized device coordinates (set for non-forward facing scenes)') + parser.add_argument("--lindisp", action='store_true', + help='sampling linearly in disparity rather than depth') + parser.add_argument("--spherify", action='store_true', + help='set for spherical 360 scenes') + parser.add_argument("--llffhold", type=int, default=8, + help='will take every 1/N images as LLFF test set, paper uses 8') + + # logging/saving options + parser.add_argument("--i_print", type=int, default=100, + help='frequency of console printout and metric loggin') + parser.add_argument("--i_img", type=int, default=500, + help='frequency of tensorboard image logging') + parser.add_argument("--i_weights", type=int, default=10000, + help='frequency of weight ckpt saving') + parser.add_argument("--i_testset", type=int, default=50000, + help='frequency of testset saving') + parser.add_argument("--i_video", type=int, default=400000, + help='frequency of render_poses video saving') + + # added options + parser.add_argument('--activations', nargs='+', default=['sine'] * 8) + parser.add_argument('--der_operator', type=str, default='sobel', choices=('sobel', 'diff')) + parser.add_argument('--supervision', type=str, default='val_der', choices=('val', 'der', 'val_der')) + parser.add_argument("--lambda_der", type=float, default=1.) + parser.add_argument("--show_der_loss", action='store_true', + help="for supervision=val, show der_loss values") + + # downsample options + parser.add_argument('--factor', type=int, default=4) + + return parser + + +def train(): + + parser = config_parser() + args = parser.parse_args() + print("args: ", args) + + # Load LLFF data + images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.llff_factor, + recenter=True, bd_factor=.75, + spherify=args.spherify) + hwf = poses[0,:3,-1] + poses = poses[:,:3,:4] + print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) + if not isinstance(i_test, list): + i_test = [i_test] + + if args.llffhold > 0: + print('Auto LLFF holdout,', args.llffhold) + i_test = np.arange(images.shape[0])[::args.llffhold] + + i_val = i_test + i_train = np.array([i for i in np.arange(int(images.shape[0])) if + (i not in i_test and i not in i_val)]) + + print('DEFINING BOUNDS') + if args.no_ndc: + near = np.ndarray.min(bds) * .9 + far = np.ndarray.max(bds) * 1. + + else: + near = 0. + far = 1. + print('NEAR FAR', near, far) + + # Cast intrinsics to right types + H, W, focal = hwf + H, W = int(H), int(W) + hwf = [H, W, focal] + + K = np.array([ + [focal, 0, 0.5*W], + [0, focal, 0.5*H], + [0, 0, 1] + ]) + + # calculate first-order image derivatives + grad = kornia.filters.spatial_gradient(einops.rearrange(torch.Tensor(images), 'b h w c -> b c h w'), mode=args.der_operator, order=1, normalized=True) # [B, C, 2, H, W], grad_x = grad[:, 0], grad_y = grad[:, 1] + grad = einops.rearrange(grad, 'b c d h w -> b h w (c d)') # [B, H, W, C*2] + + if args.render_test: + render_poses = np.array(poses[i_test]) + + # Create log dir and copy the config file + basedir = args.basedir + expname = args.expname + os.makedirs(os.path.join(basedir, expname), exist_ok=True) + f = os.path.join(basedir, expname, 'args.txt') + with open(f, 'w') as file: + for arg in sorted(vars(args)): + attr = getattr(args, arg) + file.write('{} = {}\n'.format(arg, attr)) + if args.config is not None: + f = os.path.join(basedir, expname, 'config.txt') + with open(f, 'w') as file: + file.write(open(args.config, 'r').read()) + + train_logfile = os.path.join(basedir, expname, 'train_log.txt') + + # Create nerf model + render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_tiny_nerf(args) + global_step = start + + bds_dict = { + 'near' : near, + 'far' : far, + } + render_kwargs_train.update(bds_dict) + render_kwargs_test.update(bds_dict) + + # Move testing data to GPU + render_poses = torch.Tensor(render_poses).to(device) + + # Short circuit if only rendering out from trained model + if args.render_only: + print('RENDER ONLY') + with torch.no_grad(): + if args.render_test: + # render_test switches to test poses + images = images[i_test] + else: + # Default is smoother render_poses path + images = None + + testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) + os.makedirs(testsavedir, exist_ok=True) + print('test poses shape', render_poses.shape) + + rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) + print('Done rendering', testsavedir) + imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) + + return + + # Prepare raybatch tensor if batching random rays + N_rand = args.N_rand + # For random ray batching + images = torch.Tensor(images).to(device) # [N, H, W, 3] + print('get rays') + coordinate = torch.stack( + torch.meshgrid( + torch.linspace(0, W - 1, W), + torch.linspace(0, H - 1, H), + indexing='xy'), + -1) # [H, W, 2] + coordinate = coordinate.unsqueeze(0).expand(images.shape[0], -1, -1, -1) # [N, H, W, 2] + + downsampled_images = images[:, ::args.factor, ::args.factor, :] # [N, H//factor, W//factor, 3] + print("Downsampled images shape: ", downsampled_images.shape) + downsampled_coordinate = coordinate[:, ::args.factor, ::args.factor, :] # [N, H//factor, W//factor, 2] + downsampled_grad = grad[:, ::args.factor, ::args.factor, :] # [N, H//factor, W//factor, 3*2] + + poses = torch.Tensor(poses).to(device) # [N, 3, 4] + data = torch.cat([downsampled_images, downsampled_coordinate, einops.rearrange(poses, 'n e f -> n () () (e f)').expand(-1, H//args.factor, W//args.factor, -1), downsampled_grad], dim=-1) # [B, H//factor, W//factor, 3+2+12+c*2], (rgb, xy, pose, grad) + data = torch.stack([data[i] for i in i_train]) # [N_train, H//factor, W//factor, 3+2+12+3*2] + data = einops.rearrange(data, 'n h w c -> (n h w) c') # [N_train*(H//factor)*(W//factor), 3+2+12+3*2], (rgb, xy, pose, grad) + + print('shuffle data') + rand_idx = torch.randperm(data.shape[0]) + data = data[rand_idx] + + print('done') + i_batch = 0 + + N_iters = 400000 + 1 + print('Begin') + print('TRAIN views are', i_train) + print('TEST views are', i_test) + print('VAL views are', i_val) + + start = start + 1 + for i in trange(start, N_iters): + time0 = time.time() + + # Sample random ray batch, random over all images + batch_data = data[i_batch:i_batch+N_rand] # [N_train*H*W, 3+2+12+c*2], (rgb, xy, pose, grad) + target_s = batch_data[:, :3] + coordinate_s = batch_data[:, 3:5].float().clone().detach().requires_grad_(True) # enable us to compute gradients w.r.t. coordinates(xy) + poses_s = batch_data[:, 5:17] + poses_s = einops.rearrange(poses_s, 'b (e f) -> b e f', e=3) # [B, 12] + target_grad_s = batch_data[:, 17:] # [B, c*2] + rays_o, rays_d = get_rays(K, poses_s, coordinate_s) + batch_rays = torch.stack([rays_o, rays_d], dim=0) # [2, B, 3] + + i_batch += N_rand + if i_batch >= data.shape[0]: + print("Shuffle data after an epoch!") + rand_idx = torch.randperm(data.shape[0]) + data = data[rand_idx] + i_batch = 0 + + ##### Core optimization loop ##### + rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, + verbose=i < 10, retraw=True, + **render_kwargs_train) + + optimizer.zero_grad() + val_loss = img2mse(rgb, target_s) + trans = extras['raw'][...,-1] + + if args.supervision == 'val': + loss = val_loss + if args.show_der_loss: + der_loss = der_mse(rgb, coordinate_s, target_grad_s) + der_loss = der_loss.item() + elif args.supervision == 'der': + der_loss = der_mse(rgb, coordinate_s, target_grad_s) + loss = der_loss + else: # args.supervision == 'val_der' + der_loss = der_mse(rgb, coordinate_s, target_grad_s) + loss = 1. * val_loss + args.lambda_der * der_loss + + psnr = mse2psnr(val_loss) + + loss.backward() + optimizer.step() + + # NOTE: IMPORTANT! + ### update learning rate ### + decay_rate = 0.1 + decay_steps = args.lrate_decay * 1000 + new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) + for param_group in optimizer.param_groups: + param_group['lr'] = new_lrate + ################################ + + dt = time.time()-time0 + # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}") + ##### end ##### + + # Rest is logging + if i%args.i_weights==0: + path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) + torch.save({ + 'global_step': global_step, + 'network_fn_state_dict': render_kwargs_train['network_fn'].module.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, path) + print('Saved checkpoints at', path) + + if i%args.i_testset==0 and i > 0: + testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) + os.makedirs(testsavedir, exist_ok=True) + print('test poses shape', poses[i_test].shape) + with torch.no_grad(): + render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) + print('Saved test set') + + if i%args.i_video==0 and i > 0: + # Turn on testing mode + with torch.no_grad(): + rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test) + print('Done, saving', rgbs.shape, disps.shape) + moviebase = os.path.join(basedir, expname, 'spiral_{:06d}_'.format(i)) + imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) + imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) + + if i%args.i_print==0: + if args.supervision == 'val': + if args.show_der_loss: + out_str = f"[TRAIN] Iter: {i} val_loss: {val_loss.item()} der_loss: {der_loss.item()} PSNR: {psnr.item()}" + else: + out_str = f"[TRAIN] Iter: {i} val_loss: {val_loss.item()} PSNR: {psnr.item()}" + else: + out_str = f"[TRAIN] Iter: {i} val_loss: {val_loss.item()} der_loss: {der_loss.item()} PSNR: {psnr.item()}" + tqdm.write(out_str) + with open(train_logfile, 'a') as file: + file.write(out_str + '\n') + + global_step += 1 + + +if __name__=='__main__': + torch.set_default_tensor_type('torch.cuda.FloatTensor') + + train() diff --git a/Experiments/inverse_rendering/utils.py b/Experiments/inverse_rendering/utils.py new file mode 100644 index 0000000..e5e03ca --- /dev/null +++ b/Experiments/inverse_rendering/utils.py @@ -0,0 +1,34 @@ +import random +import numpy as np +import kornia + +import torch + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def cal_psnr(gt, pred, max_val=1.): + """cal_psnr. + + :param pred: [B, C, H, W] + :param gt: [B, C, H, W] + """ + mse = (gt - pred).pow(2).mean(dim=(1, 2, 3)) # [B] + return 10. * torch.log10(max_val ** 2 / mse) + + +def cal_ssim(gt, pred): + """cal_ssim. + + :param pred: [B, C, H, W] + :param gt: [B, C, H, W] + """ + return kornia.metrics.ssim(gt, pred, 11).mean(dim=(1, 2, 3)) # [B] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b270919 --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright 2022 Megvii Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..9c81a1d --- /dev/null +++ b/README.md @@ -0,0 +1,103 @@ +# Sobolev Training for Implicit Neural Representations with Approximated Image Derivatives +The experimental code of "[Sobolev Training for Implicit Neural Representations with Approximated Image Derivatives](https://arxiv.org/abs/xxxx.xxxxx)" in ECCV 2022. + +## Abstract +Recently, Implicit Neural Representations (INRs) parameterized by neural networks have emerged as a powerful and promising tool to represent all kinds of signals due to its continuous, differentiable properties, showing many superiorities to classical discretized representations. Nevertheless, training of neural networks for INRs only utilizes input-output pairs, and the derivatives of target output with respect to the input which can be accessed in some cases are usually ignored. In this paper, we propose a training paradigm for INRs whose target output is image pixels, to encode image derivatives in addition to image values within the neural network. +Specifically, we use finite differences to approximate image derivatives. +Further, the neural network activated by ReLUs is poorly suited for representing complex signal's derivatives under the derivative supervision in practice, so +the periodic activation function is adopted to get better derivative convergence properties. +Lastly, we show how the training paradigm can be leveraged to solve typical INRs problems, such as image regression, inverse rendering, and demonstrate this training paradigm can improve the data-efficiency and generalization capabilities of INRs. + + + +## Setup + +### Environment + +* Clone this repo + ```shell + git clone https://github.com/prstrive/Sobolev_training_INRs.git + cd Sobolev_training_INRs + ``` +* Install dependencies +
+ Python 3 dependencies (click to expand) + + * PyTorch >= 1.10 + * torchvision + * ConfigArgParse + * einops + * imageio + * kornia + * matplotlib + * numpy + * opencv_python + * Pillow + * scipy + * tqdm +
+ + To setup a conda environment: + ```shell + conda create -n st_inrs python=3.7 + conda activate st_inrs + pip install -r requirements.txt + ``` +### Data +* Create a directory with command: + ```shell + mkdir data + ``` +* Download data: + * Download [Set5](https://drive.google.com/file/d/1C-C2eZIO3AQYi48EJ92MNWmWBGMdRYB6/view?usp=sharing) for __image regression__ task. + * Download LLFF data from [NeRF authors' drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) for __inverse rendering__ task. + * Download `gt_bach.wav` and `gt_counting.wav` from [SIREN authors' drive](https://drive.google.com/drive/folders/1_iq__37-hw7FJOEUK1tX7mdp8SKB368K) for __audio regression__ task. Put two WAV files to folder `Audio`. +* Create soft links: + ```shell + ln -s [path to nerf_llff_data] ./data + ln -s [path to Set5] ./data + ln -s [path to Audio] ./data + ``` + +## Reproducing Experiments +### Image Regression +```shell +cd Experiments/image_regression +python main.py --config [config txt file] +``` +For example, training with __value and derivative supervision__ on *Baby* with a __sine-based__ model: +```shell +python main.py --config configs/baby/val_der/sine.txt +``` +### Inverse Rendering +```shell +cd Experiments/inverse_rendering +python train.py --config [config txt file] +``` +For example, train with __value and derivative supervision__ on *Fern* with a __ReLU-based__ MLP: +```shell +python train.py --config configs/fern/val_der/relu.txt +``` +After training for 400K iterations, you can find novel view results in `logs/fern/val_der/relu/testset_400000`, you can evaluate results with following command and `score.txt` will be generated in `logs/fern/val_der/relu/testset_400000/score.txt`: +```shell +python eval.py --config configs/fern/val_der/relu.txt +``` +### Audio Regression +```shell +cd Experiments/audio_regression +python main.py --config [config txt file] +``` +For example, training with __value supervision__ and on *Bach* with a __sine-based__ model: +```shell +python main.py --config configs/bach/val_sine.txt +``` + +## Citation +If you find our work useful in your research, please cite: +``` +@inproceedings{yuan2022sobolev, + title={Sobolev Training for Implicit Neural Representations with Approximated Image Derivatives}, + author={Wentao Yuan and Qingtian Zhu and Xiangyue Liu and Yikang Ding and Haotian Zhang and Chi Zhang}, + year={2022}, + booktitle={ECCV}, +``` diff --git a/imgs/pipeline.png b/imgs/pipeline.png new file mode 100644 index 0000000..5998b7e Binary files /dev/null and b/imgs/pipeline.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..77aae95 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +ConfigArgParse==1.5.3 +einops==0.4.1 +imageio==2.13.5 +kornia==0.6.5 +matplotlib==3.0.3 +numpy==1.18.1 +opencv_python==4.1.2.30 +Pillow==9.2.0 +scipy==1.4.1 +torch==1.12.0 +torchvision==0.13.0 +tqdm==4.28.1