diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c2e6eb3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +.DS_Store +debug* +Models/ +*/model/ +*/checkpoints/ +*/results/ +*/*.pyc +*/**/*.pyc +*/**/**/*.pyc +*/**/**/**/*.pyc +*/**/**/**/**/*.pyc +*/**/__pycache__ \ No newline at end of file diff --git a/Audio/audio/03Fsi1831.wav b/Audio/audio/03Fsi1831.wav new file mode 100644 index 0000000..0e5ac19 Binary files /dev/null and b/Audio/audio/03Fsi1831.wav differ diff --git a/Audio/audio/5_00006.wav b/Audio/audio/5_00006.wav new file mode 100644 index 0000000..cf4a274 Binary files /dev/null and b/Audio/audio/5_00006.wav differ diff --git a/Audio/code/atcnet.py b/Audio/code/atcnet.py new file mode 100644 index 0000000..0541524 --- /dev/null +++ b/Audio/code/atcnet.py @@ -0,0 +1,329 @@ +import os +import glob +import time +import torch +import torch.utils +import torch.nn as nn +import torchvision +from torch.autograd import Variable +from torch.utils.data import DataLoader +from torch.nn.modules.module import _addindent +import numpy as np +from collections import OrderedDict +import argparse + +from dataset import LRW_1D_lstm_3dmm, LRW_1D_lstm_3dmm_pose +from dataset import News_1D_lstm_3dmm_pose + +from models import ATC_net + +from torch.nn import init +import pdb + +def multi2single(model_path, id): + checkpoint = torch.load(model_path) + state_dict = checkpoint + if id ==1: + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] + new_state_dict[name] = v + return new_state_dict + else: + return state_dict + +def initialize_weights( net, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) + +class Trainer(): + def __init__(self, config): + if config.lstm == True: + if config.pose == 0: + self.generator = ATC_net(config.para_dim) + else: + self.generator = ATC_net(config.para_dim+6) + print('---------- Networks initialized -------------') + num_params = 0 + for param in self.generator.parameters(): + num_params += param.numel() + print('[Network] Total number of parameters : %.3f M' % ( num_params / 1e6)) + print('-----------------------------------------------') + #pdb.set_trace() + self.l1_loss_fn = nn.L1Loss() + self.mse_loss_fn = nn.MSELoss() + self.config = config + + if config.cuda: + device_ids = [int(i) for i in config.device_ids.split(',')] + if len(device_ids) > 1: + self.generator = nn.DataParallel(self.generator, device_ids=device_ids).cuda() + else: + self.generator = self.generator.cuda() + self.mse_loss_fn = self.mse_loss_fn.cuda() + self.l1_loss_fn = self.l1_loss_fn.cuda() + initialize_weights(self.generator) + if config.continue_train: + state_dict = multi2single(config.model_name, 0) + self.generator.load_state_dict(state_dict) + print('load pretrained [{}]'.format(config.model_name)) + self.start_epoch = 0 + if config.load_model: + self.start_epoch = config.start_epoch + self.load(config.pretrained_dir, config.pretrained_epoch) + self.opt_g = torch.optim.Adam( self.generator.parameters(), + lr=config.lr, betas=(config.beta1, config.beta2)) + if config.lstm: + if config.pose == 0: + self.dataset = LRW_1D_lstm_3dmm(config.dataset_dir, train=config.is_train, indexes=config.indexes) + else: + if config.dataset == 'lrw': + self.dataset = LRW_1D_lstm_3dmm_pose(config.dataset_dir, train=config.is_train, indexes=config.indexes, relativeframe=config.relativeframe) + self.dataset2 = LRW_1D_lstm_3dmm_pose(config.dataset_dir, train='test', indexes=config.indexes, relativeframe=config.relativeframe) + elif config.dataset == 'news': + self.dataset = News_1D_lstm_3dmm_pose(config.dataset_dir, train=config.is_train, indexes=config.indexes, relativeframe=config.relativeframe, + newsname=config.newsname, start=config.start, trainN=config.trainN, testN=config.testN) + + self.data_loader = DataLoader(self.dataset, + batch_size=config.batch_size, + num_workers=config.num_thread, + shuffle=True, drop_last=True) + if config.dataset == 'lrw': + self.data_loader_val = DataLoader(self.dataset2, + batch_size=config.batch_size, + num_workers= config.num_thread, + shuffle=False, drop_last=True) + + + def fit(self): + config = self.config + L = config.para_dim + + num_steps_per_epoch = len(self.data_loader) + print("num_steps_per_epoch", num_steps_per_epoch) + cc = 0 + t00 = time.time() + t0 = time.time() + + + for epoch in range(self.start_epoch, config.max_epochs): + for step, (coeff, audio, coeff2) in enumerate(self.data_loader): + t1 = time.time() + + if config.cuda: + coeff = Variable(coeff.float()).cuda() + audio = Variable(audio.float()).cuda() + else: + coeff = Variable(coeff.float()) + audio = Variable(audio.float()) + + #print(audio.shape, coeff.shape) # torch.Size([16, 16, 28, 12]) torch.Size([16, 16, 70]) + fake_coeff= self.generator(audio) + + loss = self.mse_loss_fn(fake_coeff , coeff) + + if config.less_constrain: + loss = self.mse_loss_fn(fake_coeff[:,:,:L], coeff[:,:,:L]) + config.lambda_pose * self.mse_loss_fn(fake_coeff[:,:,L:], coeff[:,:,L:]) + + # put smooth on pose + # tidu ermo pingfang + if config.smooth_loss: + loss1 = loss.clone() + frame_dif = fake_coeff[:,1:,L:] - fake_coeff[:,:-1,L:] # [16, 15, 6] + #norm2 = torch.norm(frame_dif, dim = 1) # default 2-norm, [16, 6] + #norm2_ss1 = torch.sum(torch.mul(norm2, norm2), dim=1) # [16, 1] + norm2_ss = torch.sum(torch.mul(frame_dif,frame_dif), dim=[1,2]) + loss2 = torch.mean(norm2_ss) + #pdb.set_trace() + loss = loss1 + loss2 * config.lambda_smooth + + # put smooth on expression + if config.smooth_loss2: + loss3 = loss.clone() + frame_dif2 = fake_coeff[:,1:,:L] - fake_coeff[:,:-1,:L] + norm2_ss2 = torch.sum(torch.mul(frame_dif2,frame_dif2), dim=[1,2]) + loss4 = torch.mean(norm2_ss2) + loss = loss3 + loss4 * config.lambda_smooth2 + + + loss.backward() + self.opt_g.step() + self._reset_gradients() + + + if (step+1) % 10 == 0 or (step+1) == num_steps_per_epoch: + steps_remain = num_steps_per_epoch-step+1 + \ + (config.max_epochs-epoch+1)*num_steps_per_epoch + + if not config.smooth_loss and not config.smooth_loss2: + print("[{}/{}][{}/{}] loss1: {:.8f},data time: {:.4f}, model time: {} second" + .format(epoch+1, config.max_epochs, + step+1, num_steps_per_epoch, loss, t1-t0, time.time() - t1)) + elif config.smooth_loss and not config.smooth_loss2: + print("[{}/{}][{}/{}] loss1: {:.8f},lossgt: {:.8f},losstv: {:.8f},data time: {:.4f}, model time: {} second" + .format(epoch+1, config.max_epochs, + step+1, num_steps_per_epoch, loss, loss1, loss2*config.lambda_smooth, t1-t0, time.time() - t1)) + elif not config.smooth_loss and config.smooth_loss2: + print("[{}/{}][{}/{}] loss1: {:.8f},lossgt: {:.8f},losstv2: {:.8f},data time: {:.4f}, model time: {} second" + .format(epoch+1, config.max_epochs, + step+1, num_steps_per_epoch, loss, loss3, loss4*config.lambda_smooth2, t1-t0, time.time() - t1)) + else: + print("[{}/{}][{}/{}] loss1: {:.8f},lossgt: {:.8f},losstv: {:.8f},losstv2: {:.8f},data time: {:.4f}, model time: {} second" + .format(epoch+1, config.max_epochs, + step+1, num_steps_per_epoch, loss, loss1, loss2*config.lambda_smooth, loss4*config.lambda_smooth2, t1-t0, time.time() - t1)) + + if (num_steps_per_epoch > 100 and (step) % (int(num_steps_per_epoch / 10 )) == 0 and step != 0) or (num_steps_per_epoch <= 100 and (step+1) == num_steps_per_epoch): + if config.lstm: + for indx in range(3): + for jj in range(16): + name = "{}/real_{}_{}_{}.npy".format(config.sample_dir,cc, indx,jj) + coeff2 = coeff.data.cpu().numpy() + np.save(name, coeff2[indx,jj]) + if config.relativeframe: + name = "{}/real2_{}_{}_{}.npy".format(config.sample_dir,cc, indx,jj) + np.save(name, coeff2[indx,jj]) + name = "{}/fake_{}_{}_{}.npy".format(config.sample_dir,cc, indx,jj) + fake_coeff2 = fake_coeff.data.cpu().numpy() + np.save(name, fake_coeff2[indx,jj]) + # check val set loss + vloss = 0 + if config.dataset == 'lrw': + for step, (coeff, audio, coeff2) in enumerate(self.data_loader_val): + with torch.no_grad(): + if step == 100: + break + if config.cuda: + coeff = Variable(coeff.float()).cuda() + audio = Variable(audio.float()).cuda() + fake_coeff= self.generator(audio) + valloss = self.mse_loss_fn(fake_coeff,coeff) + if config.less_constrain: + valloss = self.mse_loss_fn(fake_coeff[:,:,:L], coeff[:,:,:L]) + config.lambda_pose * self.mse_loss_fn(fake_coeff[:,:,L:], coeff[:,:,L:]) + vloss += valloss.cpu().numpy() + print("[{}/{}][{}/{}] val loss:{}".format(epoch+1, config.max_epochs, + step+1, num_steps_per_epoch, vloss/100.)) + # save model + print("[{}/{}][{}/{}] save model".format(epoch+1, config.max_epochs, + step+1, num_steps_per_epoch)) + torch.save(self.generator.state_dict(), + "{}/atcnet_lstm_{}.pth" + .format(config.model_dir,cc)) + cc += 1 + + t0 = time.time() + print("total time: {} second".format(time.time()-t00)) + + def _reset_gradients(self): + self.generator.zero_grad() + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--lr", + type=float, + default=0.0002) + parser.add_argument("--beta1", + type=float, + default=0.5) + parser.add_argument("--beta2", + type=float, + default=0.999) + parser.add_argument("--lambda1", + type=int, + default=100) + parser.add_argument("--batch_size", + type=int, + default=16) + parser.add_argument("--max_epochs", + type=int, + default=10) + parser.add_argument("--cuda", + default=True) + parser.add_argument("--dataset_dir", + type=str, + default="../dataset/") + parser.add_argument("--model_dir", + type=str, + default="../model/atcnet/") + parser.add_argument("--sample_dir", + type=str, + default="../sample/atcnet/") + parser.add_argument('--device_ids', type=str, default='0') + parser.add_argument('--dataset', type=str, default='lrw') + parser.add_argument('--lstm', type=bool, default= True) + parser.add_argument('--num_thread', type=int, default=2) + parser.add_argument('--weight_decay', type=float, default=4e-4) + parser.add_argument('--load_model', action='store_true') + parser.add_argument('--pretrained_dir', type=str) + parser.add_argument('--pretrained_epoch', type=int) + parser.add_argument('--start_epoch', type=int, default=0, help='start from 0') + parser.add_argument('--rnn', type=bool, default=True) + parser.add_argument('--para_dim', type=int, default=64) + parser.add_argument('--index', type=str, default='80,144', help='index ranges') + parser.add_argument('--pose', type=int, default=0, help='whether predict pose') + parser.add_argument('--relativeframe', type=int, default=0, help='whether use relative frame value for pose') + # for personalized data + parser.add_argument('--newsname', type=str, default='Learn_English') + parser.add_argument('--start', type=int, default=357) + parser.add_argument('--trainN', type=int, default=300) + parser.add_argument('--testN', type=int, default=100) + # for continnue train + parser.add_argument('--continue_train', type=bool, default=False) + parser.add_argument("--model_name", type=str, default='../model/atcnet_pose0/atcnet_lstm_24.pth') + parser.add_argument('--preserve_mouth', type=bool, default=False) + # for remove jittering + parser.add_argument('--smooth_loss', type=bool, default=False) # smooth in time, similar to total variation + parser.add_argument('--smooth_loss2', type=bool, default=False) # smooth in time, for expression + parser.add_argument('--lambda_smooth', type=float, default=0.01) + parser.add_argument('--lambda_smooth2', type=float, default=0.0001) + # for less constrain for pose + parser.add_argument('--less_constrain', type=bool, default=False) + parser.add_argument('--lambda_pose', type=float, default=0.2) + + return parser.parse_args() + + +def main(config): + t = trainer.Trainer(config) + t.fit() + +if __name__ == "__main__": + + config = parse_args() + str_ids = config.index.split(',') + config.indexes = [] + for i in range(int(len(str_ids)/2)): + start = int(str_ids[2*i]) + end = int(str_ids[2*i+1]) + if end > start: + config.indexes += range(start, end) + #print('indexes', config.indexes) + print('device', config.device_ids) + os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids + config.is_train = 'train' + import atcnet as trainer + if not os.path.exists(config.model_dir): + os.mkdir(config.model_dir) + if not os.path.exists(config.sample_dir): + os.mkdir(config.sample_dir) + config.cuda1 = torch.device('cuda:{}'.format(config.device_ids)) + main(config) + diff --git a/Audio/code/atcnet_test1.py b/Audio/code/atcnet_test1.py new file mode 100644 index 0000000..9c9fdf6 --- /dev/null +++ b/Audio/code/atcnet_test1.py @@ -0,0 +1,117 @@ + +#encoding:utf-8 +#测试一个随机的wav +import argparse +import scipy.misc +import os +import glob +import time +import torch +import torch.utils +import torch.nn as nn +import torchvision +from torch.autograd import Variable +import numpy as np +from collections import OrderedDict +import librosa +import python_speech_features + +from models import ATC_net + +def multi2single(model_path, id): + checkpoint = torch.load(model_path) + state_dict = checkpoint + if id ==1: + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] + new_state_dict[name] = v + return new_state_dict + else: + return state_dict + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--cuda", + default=True) + parser.add_argument('-i','--in_file', type=str, default='../audio/test.wav') + parser.add_argument("--model_name", + type=str, + default="../model/atcnet/atcnet_lstm_24.pth") + parser.add_argument("--sample_dir", + type=str, + default="../results/atcnet/test/") + parser.add_argument('--device_ids', type=str, default='0') + parser.add_argument('--dataset', type=str, default='lrw') + parser.add_argument('--lstm', type=bool, default=True) + # parser.add_argument('--flownet_pth', type=str, help='path of flownets model') + parser.add_argument('--para_dim', type=int, default=64) + parser.add_argument('--index', type=str, default='80,144', help='index ranges') + parser.add_argument('--pose', type=int, default=0, help='whether predict pose') + parser.add_argument('--relativeframe', type=int, default=0, help='whether use relative frame value for pose') + + + return parser.parse_args() +config = parse_args() +str_ids = config.index.split(',') +config.indexes = [] +for i in range(int(len(str_ids)/2)): + start = int(str_ids[2*i]) + end = int(str_ids[2*i+1]) + if end > start: + config.indexes += range(start, end) +#print('indexes', config.indexes) +print('device', config.device_ids) + + +def test(): + os.environ["CUDA_VISIBLE_DEVICES"] = config.device_ids + config.is_train = 'test' + if config.lstm == True: + if config.pose == 0: + generator = ATC_net(config.para_dim) + else: + generator = ATC_net(config.para_dim+6) + + test_file = config.in_file + speech, sr = librosa.load(test_file, sr=16000) + mfcc = python_speech_features.mfcc(speech ,16000,winstep=0.01) + speech = np.insert(speech, 0, np.zeros(1920)) + speech = np.append(speech, np.zeros(1920)) + mfcc = python_speech_features.mfcc(speech,16000,winstep=0.01) + #print(mfcc.shape) + + state_dict = multi2single(config.model_name, 0) + generator.load_state_dict(state_dict) + print('load pretrained [{}]'.format(config.model_name)) + + if config.cuda: + generator = generator.cuda() + generator.eval() + + ind = 3 + with torch.no_grad(): + input_mfcc = [] + while ind <= int(mfcc.shape[0]/4) - 4: + # take 280 ms segment + t_mfcc =mfcc[( ind - 3)*4: (ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc).cuda() + input_mfcc.append(t_mfcc) + ind += 1 + input_mfcc = torch.stack(input_mfcc,dim = 0) + input_mfcc = input_mfcc.unsqueeze(0) + print(input_mfcc.shape) + if config.cuda: + input_mfcc = Variable(input_mfcc.float()).cuda() + if config.lstm: + fake_coeff= generator(input_mfcc) + fake_coeff = fake_coeff.data.cpu().numpy() + if not os.path.exists(config.sample_dir): + os.makedirs(config.sample_dir) + for jj in range(len(fake_coeff[0])): + name = "%s/%05d.npy"%(config.sample_dir,jj) + np.save(name, fake_coeff[0,jj]) + + +test() \ No newline at end of file diff --git a/Audio/code/choose_bg_gexinghua2_reassign.py b/Audio/code/choose_bg_gexinghua2_reassign.py new file mode 100644 index 0000000..14bdc23 --- /dev/null +++ b/Audio/code/choose_bg_gexinghua2_reassign.py @@ -0,0 +1,290 @@ +#encoding:utf-8 +import os +import glob +import shutil +import numpy as np +from scipy.io import loadmat,savemat +import cv2 +import pdb +import sys +from scipy.signal import argrelextrema +import matplotlib.pyplot as plt + +def IOU(a,b):#quyu > 0 + I = np.sum((a*b)>0)#a>0 and b>0 + U = np.sum((a+b)>0)#a>0 or b>0 + return I/U + +def smooth(x,window_len=11,window='hanning'): + if x.ndim != 1: + raise(ValueError, "smooth only accepts 1 dimension arrays.") + if x.size < window_len: + raise(ValueError, "Input vector needs to be bigger than window size.") + if window_len<3: + return x + if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']: + raise(ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'") + s=np.r_[x[window_len-1:0:-1],x,x[-2:-window_len-1:-1]] + if window == 'flat': + w=np.ones(window_len,'d') + else: + w=eval('np.'+window+'(window_len)') + + y=np.convolve(w/w.sum(),s,mode='valid') + return y[int(window_len/2):-int(window_len/2)] + +def nearest(sucai, query): + diff = np.abs(np.tile(query, [sucai.shape[0],1]) - sucai) + cost = np.sum(diff[:,:3],axis=1) #+ 0.1 * np.sum(diff[:,3:],axis=1) + I = np.argmin(cost) + #print(query, diff[I,:], cost[I]) + #print(query,sucai[I]) + return I + +def nearestIoU(sucai2, query): + cost = np.zeros(sucai2.shape[0]) + for i in range(sucai2.shape[0]): + cost[i] = IOU(query, sucai2[i]) + I = np.argmax(cost) + #print(cost,cost[I]) + #pdb.set_trace() + return I + +def nearest2(sucai, query, sucai2, lastI, lamda=1./255., choice = 1): + diff1 = np.abs(np.tile(query, [sucai.shape[0],1]) - sucai) + cost1 = np.sum(diff1[:,:3],axis=1) + ## + #cost1[lastI] = 100 + lbg = sucai2[lastI] + if choice == 1: + ## BG L1 similarity + diff2 = np.abs(sucai2 - np.tile(lbg, [sucai.shape[0],1,1,1])) + cost2 = np.mean(diff2, axis=(1,2,3)) + #pdb.set_trace() + I = np.argmin(cost1+cost2*lamda) + elif choice == 2: + ## BG IOU + cost2 = np.zeros(cost1.shape) + for i in range(len(sucai2)): + cost2[i] = IOU(sucai2[i], lbg) #iou larger better + I = np.argmin(cost1+(1-cost2)*lamda) + #pdb.set_trace() + elif choice == 0: + # coeff similarity + lbg = sucai[lastI] + diff2 = np.abs(np.tile(lbg, [sucai.shape[0],1]) - sucai) + cost2 = np.sum(diff2[:,:3],axis=1) + #pdb.set_trace() + I = np.argmin(cost1+cost2*lamda) + #print(query,sucai[I],cost1[I],np.sum(diff1[I,:3])) + if I != lastI: + print(I) + return I + +def nearest2IoU(query, sucai2, lastI, lamda=1./255.): + cost1 = np.zeros(sucai2.shape[0]) + for i in range(sucai2.shape[0]): + cost1[i] = IOU(query, sucai2[i]) + ## + #cost1[lastI] = -100 + lbg = sucai2[lastI] + ## BG IOU + cost2 = np.zeros(cost1.shape) + for i in range(len(sucai2)): + cost2[i] = IOU(sucai2[i], lbg) #iou larger better + I = np.argmax(cost1+cost2*lamda) + #pdb.set_trace() + return I + +def choose_bg_gexinghua2_reassign2(video, audio, start, audiomodel='', num=300, debug=0, tran=0, speed=2, aaxis=2): + print('choose_bg_gexinghua2',video,audio,start,audiomodel) + rootdir = '../../Deep3DFaceReconstruction/' + matdir = os.path.join(rootdir,'output/coeff',video) + pngdir = os.path.join(rootdir,'output/render',video) + L = 64 + if audiomodel == '': + folder_to_process = '../results/atcnet_pose0/' + audio + else: + folder_to_process = '../results/' + audiomodel + files = sorted(glob.glob(os.path.join(folder_to_process,'*.npy'))) + tardir = os.path.join('../results/chosenbg','%s_%s'%(audio,video)) + if audiomodel != '': + tardir = os.path.join('../results/chosenbg','%s_%s_%s'%(audio,video,audiomodel.replace('/','_'))) + tardir2 = os.path.join(tardir, 'reassign') + print(tardir2) + if not os.path.exists(tardir2): + os.makedirs(tardir2) + + sucai = np.zeros((num,6)) + lm_5p = np.zeros((num,2)) + for i in range(start,start+num): + coeff = loadmat(os.path.join(matdir,'frame%d.mat')%i) + sucai[i-start,:3] = coeff['coeff'][:,224:227] + sucai[i-start,3:] = coeff['coeff'][:,254:257] + if tran: + lm_5p[i-start,:] = np.mean(coeff['lm_5p'],axis=0) + + cnt = 0 + Is = [] + period = 0 + periods = [] + # find max and mins + N = len(files) + datas = np.zeros((N,3)) #存姿势的3个角度 + datasall = np.zeros((N,70)) #存表情和姿势的所有系数 + for i in range(N): + temp = np.load(files[i]) + datas[i] = temp[L:L+3] + datasall[i] = temp + #pdb.set_trace() + + # 得到关键帧Ids + #y = smooth(datas[:,2],window_len=7) + y = [0,0,0] + Ids = [0,0,0] + y0 = [0,0,0] + n = 0 + if aaxis in [0,1,2]: + axises = [aaxis] + thre = N/10. + elif aaxis in [5]: + axises = [2] + thre = N/10. + else: + axises = [0,1,2] + thre = N/5. + print(aaxis, axises) + for k in axises: + y[k] = smooth(datas[:,k],window_len=7) + y0[k] = y[k] + #if debug: + #plt.plot(y0[k]) + #plt.plot(datas[:,k]) + #plt.legend(['axis0','axis1','axis2']) + # local maxima + maxIds = argrelextrema(y[k],np.greater) + # local minima + minIds = argrelextrema(y[k],np.less) + Ids[k] = np.concatenate((maxIds[0],minIds[0])) + Ids[k] = np.sort(Ids[k]) + n += Ids[k].shape[0] + y[k] = y0[k][Ids[k]] + + while n > thre: + n = 0 + for k in axises: + maxIds = argrelextrema(y[k],np.greater,order=2) + minIds = argrelextrema(y[k],np.less,order=2) + Ids[k] = np.concatenate((Ids[k][maxIds],Ids[k][minIds])) + Ids[k] = np.sort(Ids[k]) + n += Ids[k].shape[0] + y[k] = y0[k][Ids[k]] + + # 关键帧: 0, Ids, N-1 + # 画图看选的关键帧在哪里 + if debug: + pdb.set_trace() + #plt.plot(0,datas[0,2],'+') + #for k in axises: + # for i in range(len(Ids[k])): + # plt.plot(Ids[k][i],y0[k][Ids[k][i]],'+') + #plt.plot(N-1,datas[N-1,2],'+') + + #plt.savefig('theta.jpg') + if aaxis == 4: + Ids = np.concatenate((Ids[0],Ids[1],Ids[2])) + elif aaxis == 5: + Ids = Ids[2] + else: + Ids = Ids[aaxis] + Ids = np.sort(np.unique(Ids)) + print(Ids) + for i in range(1,Ids.shape[0]): + if Ids[i] - Ids[i-1] < 3: + #print(Ids[i-1],Ids[i], datas[Ids[i-1],:], datas[Ids[i],:]) + if np.max(np.abs(datas[Ids[i],:])) > np.max(np.abs(datas[Ids[i-1],:])): + Ids[i-1] = -1 + else: + Ids[i] = -1 + Ids = np.delete(Ids,np.argwhere(Ids==-1)) + print(Ids.shape[0],N) + print(Ids) + + + # 查找和关键帧姿势最接近背景 + if debug: + tempdir = os.path.join(tardir, 'temp') + if not os.path.exists(tempdir): + os.makedirs(tempdir) + Ids=np.insert(Ids,0,0) + Ids=np.append(Ids,N-1) + Is=np.zeros(Ids.shape) + I = nearest(sucai[:,:3], datas[0]) + Is[0] = I + for i in range(1,Ids.shape[0]): + period = Ids[i] - Ids[i-1] + #sucait = sucai[max(0,I-2*period):min(num,I+2*period),:3] + sucait = sucai[max(0,int(I-speed*period)):min(num,I+int(speed*period)),:3] + In = nearest(sucait, datas[Ids[i]]) + #I = max(0,I-2*period) + In + I = max(0,int(I-speed*period)) + In + print(Ids[i],I, Ids[i-1]) + Is[i] = I + if debug: + if tran == 0: + for j in range(Ids[i-1], Ids[i]+1): + shutil.copy(os.path.join(pngdir,'frame%d.png'%(I+start)), + os.path.join(tempdir,'%05d.png'%j)) + else: + for j in range(Ids[i-1], Ids[i]+1): + shutil.copy(os.path.join(pngdir,'frame%d_input2.png'%(I+start)), + os.path.join(tempdir,'%05d.png'%j)) + + if debug: + os.system('ffmpeg -loglevel panic -framerate 25 -i ' + tempdir + '/%05d.png -c:v libx264 -y -vf format=yuv420p ' + tempdir + '.mp4') + + print(Ids,Is) + # reassign,重新设置姿势的系数 + assigns = [0] * N + startI = 0 + for i in range(Ids.shape[0]-1): + l = Ids[i+1] - Ids[i] + assigns[Ids[i]] = int(Is[i]) + for j in range(1,l): + assigns[Ids[i]+j] = int(round(float(j)/l*(Is[i+1]-Is[i]) + Is[i])) + startI += l + assigns[Ids[-1]] = int(Is[-1]) + print(assigns) + if not os.path.exists(folder_to_process+'/reassign'): + os.mkdir(folder_to_process+'/reassign') + for i in range(N): + if tran == 0: + data = datasall[i] + if aaxis == 5 and i in Ids: + #pdb.set_trace() + data[L+3:L+6] = sucai[assigns[i],3:6] + continue + data[L:L+6] = sucai[assigns[i]] + else: + # 把trans存在.npy里 + data = np.zeros((L+9)) + data[:L] = datasall[i,:L] + data[L:L+6] = sucai[assigns[i]] + data[L+6:L+8] = lm_5p[assigns[i]] + data[L+8] = assigns[i]+start + print(i,'assigni',assigns[i]+start,'lm_5p',lm_5p[assigns[i]]) + savename = os.path.join(folder_to_process,'reassign','%05d.npy'%i) + np.save(savename, data) + if tran == 0 or tran == 2: + shutil.copy(os.path.join(pngdir,'frame%d.png'%(assigns[i]+start)), + os.path.join(tardir2,'%05d.png'%i)) + elif tran == 1: + print(os.path.join(pngdir,'frame%d_input2.png'%(assigns[i]+start))) + shutil.copy(os.path.join(pngdir,'frame%d_input2.png'%(assigns[i]+start)), + os.path.join(tardir2,'%05d.png'%i)) + + if debug: + os.system('ffmpeg -loglevel panic -framerate 25 -i ' + tardir2 + '/%05d.png -c:v libx264 -y -vf format=yuv420p ' + tardir2 + '.mp4') + + return tardir2 + diff --git a/Audio/code/convolutional_rnn/__init__.py b/Audio/code/convolutional_rnn/__init__.py new file mode 100644 index 0000000..d095f0f --- /dev/null +++ b/Audio/code/convolutional_rnn/__init__.py @@ -0,0 +1,24 @@ +from .module import Conv2dRNN +from .module import Conv2dLSTM +from .module import Conv2dPeepholeLSTM +from .module import Conv2dGRU + +from .module import Conv3dRNN +from .module import Conv3dLSTM +from .module import Conv3dPeepholeLSTM +from .module import Conv3dGRU + +from .module import Conv1dRNNCell +from .module import Conv1dLSTMCell +from .module import Conv1dPeepholeLSTMCell +from .module import Conv1dGRUCell + +from .module import Conv2dRNNCell +from .module import Conv2dLSTMCell +from .module import Conv2dPeepholeLSTMCell +from .module import Conv2dGRUCell + +from .module import Conv3dRNNCell +from .module import Conv3dLSTMCell +from .module import Conv3dPeepholeLSTMCell +from .module import Conv3dGRUCell diff --git a/Audio/code/convolutional_rnn/functional.py b/Audio/code/convolutional_rnn/functional.py new file mode 100644 index 0000000..9b18999 --- /dev/null +++ b/Audio/code/convolutional_rnn/functional.py @@ -0,0 +1,321 @@ +from functools import partial + +import torch +import torch.nn.functional as F +#from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend + +from .utils import _single, _pair, _triple + + +def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): + """ Copied from torch.nn._functions.rnn and modified """ + if linear_func is None: + linear_func = F.linear + hy = F.relu(linear_func(input, w_ih, b_ih) + linear_func(hidden, w_hh, b_hh)) + return hy + + +def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): + """ Copied from torch.nn._functions.rnn and modified """ + if linear_func is None: + linear_func = F.linear + hy = F.tanh(linear_func(input, w_ih, b_ih) + linear_func(hidden, w_hh, b_hh)) + return hy + + +def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): + """ Copied from torch.nn._functions.rnn and modified """ + if linear_func is None: + linear_func = F.linear + if input.is_cuda and linear_func is F.linear: + igates = linear_func(input, w_ih) + hgates = linear_func(hidden[0], w_hh) + #state = fusedBackend.LSTMFused.apply + #return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) + # Slice off the workspace arg (used only for backward) + return _cuda_fused_lstm_cell(igates, hgates, hidden[1], b_ih, b_hh)[:2] + + hx, cx = hidden + gates = linear_func(input, w_ih, b_ih) + linear_func(hx, w_hh, b_hh) + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = F.sigmoid(ingate) + forgetgate = F.sigmoid(forgetgate) + cellgate = F.tanh(cellgate) + outgate = F.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * F.tanh(cy) + + return hy, cy + + +def PeepholeLSTMCell(input, hidden, w_ih, w_hh, w_pi, w_pf, w_po, + b_ih=None, b_hh=None, linear_func=None): + if linear_func is None: + linear_func = F.linear + hx, cx = hidden + gates = linear_func(input, w_ih, b_ih) + linear_func(hx, w_hh, b_hh) + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate += linear_func(cx, w_pi) + forgetgate += linear_func(cx, w_pf) + ingate = F.sigmoid(ingate) + forgetgate = F.sigmoid(forgetgate) + cellgate = F.tanh(cellgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + outgate += linear_func(cy, w_po) + outgate = F.sigmoid(outgate) + + hy = outgate * F.tanh(cy) + + return hy, cy + + +def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): + """ Copied from torch.nn._functions.rnn and modified """ + if linear_func is None: + linear_func = F.linear + if input.is_cuda and linear_func is F.linear: + gi = linear_func(input, w_ih) + gh = linear_func(hidden, w_hh) + #state = fusedBackend.GRUFused.apply + #return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh) + return _cuda_fused_gru_cell(gi, gh, hidden, b_ih, b_hh)[0] + gi = linear_func(input, w_ih, b_ih) + gh = linear_func(hidden, w_hh, b_hh) + i_r, i_i, i_n = gi.chunk(3, 1) + h_r, h_i, h_n = gh.chunk(3, 1) + + resetgate = F.sigmoid(i_r + h_r) + inputgate = F.sigmoid(i_i + h_i) + newgate = F.tanh(i_n + resetgate * h_n) + hy = newgate + inputgate * (hidden - newgate) + + return hy + + +def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): + """ Copied from torch.nn._functions.rnn and modified """ + + num_directions = len(inners) + total_layers = num_layers * num_directions + + def forward(input, hidden, weight, batch_sizes): + assert(len(weight) == total_layers) + next_hidden = [] + ch_dim = input.dim() - weight[0][0].dim() + 1 + + if lstm: + hidden = list(zip(*hidden)) + + for i in range(num_layers): + all_output = [] + for j, inner in enumerate(inners): + l = i * num_directions + j + + hy, output = inner(input, hidden[l], weight[l], batch_sizes) + next_hidden.append(hy) + all_output.append(output) + + input = torch.cat(all_output, ch_dim) + + if dropout != 0 and i < num_layers - 1: + input = F.dropout(input, p=dropout, training=train, inplace=False) + + if lstm: + next_h, next_c = zip(*next_hidden) + next_hidden = ( + torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), + torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) + ) + else: + next_hidden = torch.cat(next_hidden, 0).view( + total_layers, *next_hidden[0].size()) + + return next_hidden, input + + return forward + + +def Recurrent(inner, reverse=False): + """ Copied from torch.nn._functions.rnn without any modification """ + def forward(input, hidden, weight, batch_sizes): + output = [] + steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) + for i in steps: + hidden = inner(input[i], hidden, *weight) + # hack to handle LSTM + output.append(hidden[0] if isinstance(hidden, tuple) else hidden) + + if reverse: + output.reverse() + output = torch.cat(output, 0).view(input.size(0), *output[0].size()) + + return hidden, output + + return forward + + +def variable_recurrent_factory(inner, reverse=False): + """ Copied from torch.nn._functions.rnn without any modification """ + if reverse: + return VariableRecurrentReverse(inner) + else: + return VariableRecurrent(inner) + + +def VariableRecurrent(inner): + """ Copied from torch.nn._functions.rnn without any modification """ + def forward(input, hidden, weight, batch_sizes): + output = [] + input_offset = 0 + last_batch_size = batch_sizes[0] + hiddens = [] + flat_hidden = not isinstance(hidden, tuple) + if flat_hidden: + hidden = (hidden,) + for batch_size in batch_sizes: + step_input = input[input_offset:input_offset + batch_size] + input_offset += batch_size + + dec = last_batch_size - batch_size + if dec > 0: + hiddens.append(tuple(h[-dec:] for h in hidden)) + hidden = tuple(h[:-dec] for h in hidden) + last_batch_size = batch_size + + if flat_hidden: + hidden = (inner(step_input, hidden[0], *weight),) + else: + hidden = inner(step_input, hidden, *weight) + + output.append(hidden[0]) + hiddens.append(hidden) + hiddens.reverse() + + hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) + assert hidden[0].size(0) == batch_sizes[0] + if flat_hidden: + hidden = hidden[0] + output = torch.cat(output, 0) + + return hidden, output + + return forward + + +def VariableRecurrentReverse(inner): + """ Copied from torch.nn._functions.rnn without any modification """ + def forward(input, hidden, weight, batch_sizes): + output = [] + input_offset = input.size(0) + last_batch_size = batch_sizes[-1] + initial_hidden = hidden + flat_hidden = not isinstance(hidden, tuple) + if flat_hidden: + hidden = (hidden,) + initial_hidden = (initial_hidden,) + hidden = tuple(h[:batch_sizes[-1]] for h in hidden) + for i in reversed(range(len(batch_sizes))): + batch_size = batch_sizes[i] + inc = batch_size - last_batch_size + if inc > 0: + hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) + for h, ih in zip(hidden, initial_hidden)) + last_batch_size = batch_size + step_input = input[input_offset - batch_size:input_offset] + input_offset -= batch_size + + if flat_hidden: + hidden = (inner(step_input, hidden[0], *weight),) + else: + hidden = inner(step_input, hidden, *weight) + output.append(hidden[0]) + + output.reverse() + output = torch.cat(output, 0) + if flat_hidden: + hidden = hidden[0] + return hidden, output + + return forward + + +def ConvNdWithSamePadding(convndim=2, stride=1, dilation=1, groups=1): + def forward(input, w, b=None): + if convndim == 1: + ntuple = _single + elif convndim == 2: + ntuple = _pair + elif convndim == 3: + ntuple = _triple + else: + raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) + + if input.dim() != convndim + 2: + raise RuntimeError('Input dim must be {}, bot got {}'.format(convndim + 2, input.dim())) + if w.dim() != convndim + 2: + raise RuntimeError('w must be {}, bot got {}'.format(convndim + 2, w.dim())) + + insize = input.shape[2:] + kernel_size = w.shape[2:] + _stride = ntuple(stride) + _dilation = ntuple(dilation) + + ps = [(i + 1 - h + s * (h - 1) + d * (k - 1)) // 2 + for h, k, s, d in list(zip(insize, kernel_size, _stride, _dilation))[::-1] for i in range(2)] + # Padding to make the output shape to have the same shape as the input + input = F.pad(input, ps, 'constant', 0) + return getattr(F, 'conv{}d'.format(convndim))( + input, w, b, stride=_stride, padding=ntuple(0), dilation=_dilation, groups=groups) + return forward + + +def _conv_cell_helper(mode, convndim=2, stride=1, dilation=1, groups=1): + linear_func = ConvNdWithSamePadding(convndim=convndim, stride=stride, dilation=dilation, groups=groups) + + if mode == 'RNN_RELU': + cell = partial(RNNReLUCell, linear_func=linear_func) + elif mode == 'RNN_TANH': + cell = partial(RNNTanhCell, linear_func=linear_func) + elif mode == 'LSTM': + cell = partial(LSTMCell, linear_func=linear_func) + elif mode == 'GRU': + cell = partial(GRUCell, linear_func=linear_func) + elif mode == 'PeepholeLSTM': + cell = partial(PeepholeLSTMCell, linear_func=linear_func) + else: + raise Exception('Unknown mode: {}'.format(mode)) + return cell + + +def AutogradConvRNN( + mode, num_layers=1, batch_first=False, + dropout=0, train=True, bidirectional=False, variable_length=False, + convndim=2, stride=1, dilation=1, groups=1): + """ Copied from torch.nn._functions.rnn and modified """ + cell = _conv_cell_helper(mode, convndim=convndim, stride=stride, dilation=dilation, groups=groups) + + rec_factory = variable_recurrent_factory if variable_length else Recurrent + + if bidirectional: + layer = (rec_factory(cell), rec_factory(cell, reverse=True)) + else: + layer = (rec_factory(cell),) + + func = StackedRNN(layer, num_layers, (mode in ('LSTM', 'PeepholeLSTM')), dropout=dropout, train=train) + + def forward(input, weight, hidden, batch_sizes): + if batch_first and batch_sizes is None: + input = input.transpose(0, 1) + + nexth, output = func(input, hidden, weight, batch_sizes) + + if batch_first and batch_sizes is None: + output = output.transpose(0, 1) + + return output, nexth + + return forward diff --git a/Audio/code/convolutional_rnn/module.py b/Audio/code/convolutional_rnn/module.py new file mode 100644 index 0000000..9313ff7 --- /dev/null +++ b/Audio/code/convolutional_rnn/module.py @@ -0,0 +1,888 @@ +import math +from typing import Union, Sequence + +import torch +from torch.nn import Parameter +from torch.nn.utils.rnn import PackedSequence + +from .functional import AutogradConvRNN, _conv_cell_helper +from .utils import _single, _pair, _triple + + +class ConvNdRNNBase(torch.nn.Module): + def __init__(self, + mode, + in_channels, + out_channels, + kernel_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + convndim=2, + stride=1, + dilation=1, + groups=1): + super(ConvNdRNNBase, self).__init__() + self.mode = mode + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = dropout + self.bidirectional = bidirectional + self.convndim = convndim + + if convndim == 1: + ntuple = _single + elif convndim == 2: + ntuple = _pair + elif convndim == 3: + ntuple = _triple + else: + raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) + + self.kernel_size = ntuple(kernel_size) + self.stride = ntuple(stride) + self.dilation = ntuple(dilation) + + self.groups = groups + + num_directions = 2 if bidirectional else 1 + + if mode in ('LSTM', 'PeepholeLSTM'): + gate_size = 4 * out_channels + elif mode == 'GRU': + gate_size = 3 * out_channels + else: + gate_size = out_channels + + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + layer_input_size = in_channels if layer == 0 else out_channels * num_directions + w_ih = Parameter(torch.Tensor(gate_size, layer_input_size // groups, *self.kernel_size)) + w_hh = Parameter(torch.Tensor(gate_size, out_channels // groups, *self.kernel_size)) + + b_ih = Parameter(torch.Tensor(gate_size)) + b_hh = Parameter(torch.Tensor(gate_size)) + + if mode == 'PeepholeLSTM': + w_pi = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) + w_pf = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) + w_po = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) + layer_params = (w_ih, w_hh, w_pi, w_pf, w_po, b_ih, b_hh) + param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}', + 'weight_pi_l{}{}', 'weight_pf_l{}{}', 'weight_po_l{}{}'] + else: + layer_params = (w_ih, w_hh, b_ih, b_hh) + param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] + if bias: + param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] + + suffix = '_reverse' if direction == 1 else '' + param_names = [x.format(layer, suffix) for x in param_names] + + for name, param in zip(param_names, layer_params): + setattr(self, name, param) + self._all_weights.append(param_names) + + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.out_channels) + for weight in self.parameters(): + weight.data.uniform_(-stdv, stdv) + + def check_forward_args(self, input, hidden, batch_sizes): + is_input_packed = batch_sizes is not None + expected_input_dim = (2 if is_input_packed else 3) + self.convndim + if input.dim() != expected_input_dim: + raise RuntimeError( + 'input must have {} dimensions, got {}'.format( + expected_input_dim, input.dim())) + ch_dim = 1 if is_input_packed else 2 + if self.in_channels != input.size(ch_dim): + raise RuntimeError( + 'input.size({}) must be equal to in_channels . Expected {}, got {}'.format( + ch_dim, self.in_channels, input.size(ch_dim))) + + if is_input_packed: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.out_channels) + input.shape[ch_dim + 1:] + + def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): + if tuple(hx.size()) != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) + + if self.mode in ('LSTM', 'PeepholeLSTM'): + check_hidden_size(hidden[0], expected_hidden_size, + 'Expected hidden[0] size {}, got {}') + check_hidden_size(hidden[1], expected_hidden_size, + 'Expected hidden[1] size {}, got {}') + else: + check_hidden_size(hidden, expected_hidden_size) + + def forward(self, input, hx=None): + is_packed = isinstance(input, PackedSequence) + if is_packed: + input, batch_sizes = input + max_batch_size = batch_sizes[0] + insize = input.shape[2:] + else: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + insize = input.shape[3:] + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = input.new_zeros(self.num_layers * num_directions, max_batch_size, self.out_channels, + *insize, requires_grad=False) + if self.mode in ('LSTM', 'PeepholeLSTM'): + hx = (hx, hx) + + self.check_forward_args(input, hx, batch_sizes) + func = AutogradConvRNN( + self.mode, + num_layers=self.num_layers, + batch_first=self.batch_first, + dropout=self.dropout, + train=self.training, + bidirectional=self.bidirectional, + variable_length=batch_sizes is not None, + convndim=self.convndim, + stride=self.stride, + dilation=self.dilation, + groups=self.groups + ) + output, hidden = func(input, self.all_weights, hx, batch_sizes) + if is_packed: + output = PackedSequence(output, batch_sizes) + return output, hidden + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.groups != 1: + s += ', groups={groups}' + if self.num_layers != 1: + s += ', num_layers={num_layers}' + if self.bias is not True: + s += ', bias={bias}' + if self.batch_first is not False: + s += ', batch_first={batch_first}' + if self.dropout != 0: + s += ', dropout={dropout}' + if self.bidirectional is not False: + s += ', bidirectional={bidirectional}' + return s.format(**self.__dict__) + + def __setstate__(self, d): + super(ConvNdRNNBase, self).__setstate__(d) + if 'all_weights' in d: + self._all_weights = d['all_weights'] + if isinstance(self._all_weights[0][0], str): + return + num_layers = self.num_layers + num_directions = 2 if self.bidirectional else 1 + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + suffix = '_reverse' if direction == 1 else '' + if self.mode == 'PeepholeLSTM': + weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', + 'weight_pi_l{}{}', 'weight_pf_l{}{}', 'weight_po_l{}{}', + 'bias_ih_l{}{}', 'bias_hh_l{}{}'] + else: + weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', + 'bias_ih_l{}{}', 'bias_hh_l{}{}'] + weights = [x.format(layer, suffix) for x in weights] + if self.bias: + self._all_weights += [weights] + else: + self._all_weights += [weights[:len(weights) // 2]] + + @property + def all_weights(self): + return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] + + + +class Conv2dRNN(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + nonlinearity='tanh', + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + if nonlinearity == 'tanh': + mode = 'RNN_TANH' + elif nonlinearity == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) + super().__init__( + mode=mode, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups) + + +class Conv2dLSTM(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + super().__init__( + mode='LSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups) + + +class Conv2dPeepholeLSTM(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + super().__init__( + mode='PeepholeLSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups) + + +class Conv2dGRU(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + super(Conv2dGRU, self).__init__( + mode='GRU', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups) + + +class Conv3dRNN(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + nonlinearity='tanh', + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + if nonlinearity == 'tanh': + mode = 'RNN_TANH' + elif nonlinearity == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) + super().__init__( + mode=mode, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups) + + +class Conv3dLSTM(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + super().__init__( + mode='LSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups) + + +class Conv3dPeepholeLSTM(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + super().__init__( + mode='PeepholeLSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups) + + +class Conv3dGRU(ConvNdRNNBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_layers=1, + bias=True, + batch_first=False, + dropout=0., + bidirectional=False, + stride=1, + dilation=1, + groups=1): + super().__init__( + mode='GRU', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_layers=num_layers, + bias=bias, + batch_first=batch_first, + dropout=dropout, + bidirectional=bidirectional, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups) + + +class ConvRNNCellBase(torch.nn.Module): + def __init__(self, + mode, + in_channels, + out_channels, + kernel_size, + bias=True, + convndim=2, + stride=1, + dilation=1, + groups=1 + ): + super().__init__() + self.mode = mode + self.in_channels = in_channels + self.out_channels = out_channels + self.bias = bias + self.convndim = convndim + + if convndim == 1: + ntuple = _single + elif convndim == 2: + ntuple = _pair + elif convndim == 3: + ntuple = _triple + else: + raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) + + self.kernel_size = ntuple(kernel_size) + self.stride = ntuple(stride) + self.dilation = ntuple(dilation) + + self.groups = groups + + if mode in ('LSTM', 'PeepholeLSTM'): + gate_size = 4 * out_channels + elif mode == 'GRU': + gate_size = 3 * out_channels + else: + gate_size = out_channels + + self.weight_ih = Parameter(torch.Tensor(gate_size, in_channels // groups, *self.kernel_size)) + self.weight_hh = Parameter(torch.Tensor(gate_size, out_channels // groups, *self.kernel_size)) + + if bias: + self.bias_ih = Parameter(torch.Tensor(gate_size)) + self.bias_hh = Parameter(torch.Tensor(gate_size)) + else: + self.register_parameter('bias_ih', None) + self.register_parameter('bias_hh', None) + + if mode == 'PeepholeLSTM': + self.weight_pi = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) + self.weight_pf = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) + self.weight_po = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) + + self.reset_parameters() + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is not True: + s += ', bias={bias}' + if self.bidirectional is not False: + s += ', bidirectional={bidirectional}' + return s.format(**self.__dict__) + + def check_forward_input(self, input): + if input.size(1) != self.in_channels: + raise RuntimeError( + "input has inconsistent channels: got {}, expected {}".format( + input.size(1), self.in_channels)) + + def check_forward_hidden(self, input, hx, hidden_label=''): + if input.size(0) != hx.size(0): + raise RuntimeError( + "Input batch size {} doesn't match hidden{} batch size {}".format( + input.size(0), hidden_label, hx.size(0))) + + if hx.size(1) != self.out_channels: + raise RuntimeError( + "hidden{} has inconsistent hidden_size: got {}, expected {}".format( + hidden_label, hx.size(1), self.out_channels)) + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.out_channels) + for weight in self.parameters(): + weight.data.uniform_(-stdv, stdv) + + def forward(self, input, hx=None): + self.check_forward_input(input) + + if hx is None: + batch_size = input.size(0) + insize = input.shape[2:] + hx = input.new_zeros(batch_size, self.out_channels, *insize, requires_grad=False) + if self.mode in ('LSTM', 'PeepholeLSTM'): + hx = (hx, hx) + if self.mode in ('LSTM', 'PeepholeLSTM'): + self.check_forward_hidden(input, hx[0]) + self.check_forward_hidden(input, hx[1]) + else: + self.check_forward_hidden(input, hx) + + cell = _conv_cell_helper( + self.mode, + convndim=self.convndim, + stride=self.stride, + dilation=self.dilation, + groups=self.groups) + if self.mode == 'PeepholeLSTM': + return cell( + input, hx, + self.weight_ih, self.weight_hh, self.weight_pi, self.weight_pf, self.weight_po, + self.bias_ih, self.bias_hh + ) + else: + return cell( + input, hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) + + +class Conv1dRNNCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + nonlinearity='tanh', + bias=True, + stride=1, + dilation=1, + groups=1 + ): + if nonlinearity == 'tanh': + mode = 'RNN_TANH' + elif nonlinearity == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) + super().__init__( + mode=mode, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=1, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv1dLSTMCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='LSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=1, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv1dPeepholeLSTMCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='PeepholeLSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=1, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv1dGRUCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='GRU', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=1, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv2dRNNCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + nonlinearity='tanh', + bias=True, + stride=1, + dilation=1, + groups=1 + ): + if nonlinearity == 'tanh': + mode = 'RNN_TANH' + elif nonlinearity == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) + super().__init__( + mode=mode, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv2dLSTMCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='LSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv2dPeepholeLSTMCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='PeepholeLSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv2dGRUCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='GRU', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=2, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv3dRNNCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + nonlinearity='tanh', + bias=True, + stride=1, + dilation=1, + groups=1 + ): + if nonlinearity == 'tanh': + mode = 'RNN_TANH' + elif nonlinearity == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) + super().__init__( + mode=mode, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv3dLSTMCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='LSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv3dPeepholeLSTMCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='PeepholeLSTM', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups + ) + + +class Conv3dGRUCell(ConvRNNCellBase): + def __init__(self, + in_channels, + out_channels, + kernel_size, + bias=True, + stride=1, + dilation=1, + groups=1 + ): + super().__init__( + mode='GRU', + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + bias=bias, + convndim=3, + stride=stride, + dilation=dilation, + groups=groups + ) diff --git a/Audio/code/convolutional_rnn/utils.py b/Audio/code/convolutional_rnn/utils.py new file mode 100644 index 0000000..6161683 --- /dev/null +++ b/Audio/code/convolutional_rnn/utils.py @@ -0,0 +1,19 @@ +import collections +from itertools import repeat + + +""" Copied from torch.nn.modules.utils """ + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) diff --git a/Audio/code/dataset.py b/Audio/code/dataset.py new file mode 100644 index 0000000..bea715f --- /dev/null +++ b/Audio/code/dataset.py @@ -0,0 +1,1211 @@ +import os +import random +import pickle +import numpy as np +import cv2 +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from PIL import Image +from torch.utils.data import DataLoader +import time +import copy +import python_speech_features +from scipy.io import loadmat +#import utils +import pdb +#EIGVECS = np.load('../basics/S.npy') +#MS = np.load('../basics/mean_shape.npy') + + + +class LRW_1D_lstm_landmark_pca(data.Dataset): + def __init__(self, + dataset_dir, + train='train'): + self.train = train + self.num_frames = 16 + self.lmark_root_path = '../dataset/landmark1d' + self.pca = torch.FloatTensor(np.load('../basics/U_lrw1.npy')[:,:6] ) + self.mean = torch.FloatTensor(np.load('../basics/mean_lrw1.npy')) + if self.train=='train': + _file = open(os.path.join(dataset_dir, "lmark_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "lmark_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "img_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + + + + def __getitem__(self, index): + if self.train=='train': + lmark_path = os.path.join(self.lmark_root_path , self.train_data[index][0] , self.train_data[index][1],self.train_data[index][2], self.train_data[index][2] + '.npy') + mfcc_path = os.path.join('../dataset/mfcc/', self.train_data[index][0], self.train_data[index][1], self.train_data[index][2] + '.npy') + + lmark = np.load(lmark_path) * 5.0 + lmark = torch.FloatTensor(lmark) + lmark = lmark - self.mean.expand_as(lmark) + lmark = torch.mm(lmark,self.pca) + + mfcc = np.load(mfcc_path) + + r = random.choice( + [x for x in range(3,8)]) + example_landmark =lmark[r,:] + example_mfcc = mfcc[(r -3) * 4 : (r + 4) * 4, 1 :] + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + landmark =lmark[r+1 : r + 17,:] + example_mfcc = torch.FloatTensor(example_mfcc) + return example_landmark, example_mfcc, landmark, mfccs + if self.train=='test': + lmark_path = os.path.join(self.lmark_root_path , self.test_data[index][0] , self.test_data[index][1],self.test_data[index][2], self.test_data[index][2] + '.npy') + mfcc_path = os.path.join('../dataset/mfcc/', self.test_data[index][0], self.test_data[index][1], self.test_data[index][2] + '.npy') + + lmark = np.load(lmark_path) * 5.0 + lmark = torch.FloatTensor(lmark) + lmark = lmark - self.mean.expand_as(lmark) + lmark = torch.mm(lmark,self.pca) + + mfcc = np.load(mfcc_path) + + r = random.choice( + [x for x in range(3,8)]) + example_landmark =lmark[r,:] + example_mfcc = mfcc[(r -3) * 4 : (r + 4) * 4, 1 :] + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + landmark =lmark[r+1 : r + 17,:] + example_mfcc = torch.FloatTensor(example_mfcc) + return example_landmark, example_mfcc, landmark, mfccs + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + pass + +class LRW_1D_lstm_3dmm(data.Dataset): + def __init__(self, + dataset_dir, + train='train', + indexes=range(80,144)): + self.train = train + self.num_frames = 16 + if self.train=='train': + _file = open(os.path.join(dataset_dir, "coeff_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "coeff_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + self.indexes = indexes + + def __getitem__(self, index): + if self.train=='train': + mfcc_path = os.path.join('../dataset/mfcc/lrw', self.train_data[index][0], self.train_data[index][1], self.train_data[index][2] + '.npy') + coeff_path = os.path.join('../dataset/coeff/lrw', self.train_data[index][0], self.train_data[index][1], self.train_data[index][2] + '.npy') + + mfcc = np.load(mfcc_path) + coeff = np.load(coeff_path) + + r = random.choice( + [x for x in range(3,8)]) + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + #landmark =lmark[r+1 : r + 17,:] + coeff = coeff[r+1:r+17,self.indexes] + return coeff, mfccs + if self.train=='test': + mfcc_path = os.path.join('../dataset/mfcc/lrw', self.test_data[index][0], self.test_data[index][1], self.test_data[index][2] + '.npy') + coeff_path = os.path.join('../dataset/coeff/lrw', self.test_data[index][0], self.test_data[index][1], self.test_data[index][2] + '.npy') + + mfcc = np.load(mfcc_path) + coeff = np.load(coeff_path) + + r = random.choice( + [x for x in range(3,8)]) + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + #landmark =lmark[r+1 : r + 17,:] + coeff = coeff[r+1:r+17,self.indexes] + return coeff, mfccs + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + pass + +class LRW_1D_lstm_3dmm_pose(data.Dataset): + def __init__(self, + dataset_dir, + train='train', + indexes=range(80,144), + relativeframe=0): + self.train = train + self.num_frames = 16 + if self.train=='train': + _file = open(os.path.join(dataset_dir, "coeff_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "coeff_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + self.indexes = indexes + self.relativeframe = relativeframe + + def __getitem__(self, index): + if self.train=='train': + mfcc_path = os.path.join('../dataset/mfcc/lrw', self.train_data[index][0], self.train_data[index][1], self.train_data[index][2] + '.npy') + coeff_path = os.path.join('../dataset/coeff/lrw', self.train_data[index][0], self.train_data[index][1], self.train_data[index][2] + '.npy') + + mfcc = np.load(mfcc_path) + coeff = np.load(coeff_path) + + r = random.choice( + [x for x in range(3,8)]) + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + #landmark =lmark[r+1 : r + 17,:] + L = len(self.indexes) + coeffc = np.zeros((16,L+6),dtype=np.float32) + coeffc[:,:L] = coeff[r+1:r+17,self.indexes] + coeffc[:,L:L+3] = coeff[r+1:r+17,224:227] + coeffc[:,L+3:L+6] = coeff[r+1:r+17,254:257] + coeffc2 = coeffc.copy() + if self.relativeframe == 1: + coeffc[0,L:L+6] = 0 + coeffc[1:,L:L+6] = coeffc2[1:,L:L+6] - coeffc2[:-1,L:L+6] + else: + coeffc[:,L+5] -= 0.5 + coeffc = torch.FloatTensor(coeffc) + coeffc2 = torch.FloatTensor(coeffc2) + return coeffc, mfccs, coeffc2 + if self.train=='test': + mfcc_path = os.path.join('../dataset/mfcc/lrw', self.test_data[index][0], self.test_data[index][1], self.test_data[index][2] + '.npy') + coeff_path = os.path.join('../dataset/coeff/lrw', self.test_data[index][0], self.test_data[index][1], self.test_data[index][2] + '.npy') + + mfcc = np.load(mfcc_path) + coeff = np.load(coeff_path) + + r = random.choice( + [x for x in range(3,8)]) + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + #landmark =lmark[r+1 : r + 17,:] + L = len(self.indexes) + coeffc = np.zeros((16,L+6),dtype=np.float32) + coeffc[:,:L] = coeff[r+1:r+17,self.indexes] + coeffc[:,L:L+3] = coeff[r+1:r+17,224:227] + coeffc[:,L+3:L+6] = coeff[r+1:r+17,254:257] + coeffc2 = coeffc.copy() + if self.relativeframe == 1: + coeffc[0,L:L+6] = 0 + coeffc[1:,L:L+6] = coeffc2[1:,L:L+6] - coeffc2[:-1,L:L+6] + else: + coeffc[:,L+5] -= 0.5 + coeffc = torch.FloatTensor(coeffc) + coeffc2 = torch.FloatTensor(coeffc2) + return coeffc, mfccs, coeffc2 + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + pass + +class News_1D_lstm_3dmm_pose(data.Dataset): + def __init__(self, + dataset_dir, + train='train', + indexes=range(80,144), + relativeframe=0, + newsname='Learn_English', + start=357, + trainN=300, + testN=100): + self.train = train + self.num_frames = 16 + self.indexes = indexes + self.relativeframe = relativeframe + self.newsname = newsname + self.start = start + self.trainN = trainN + self.testN = testN + if self.train=='train': + mfcc_path = os.path.join('../dataset/mfcc/', self.newsname + '.npy') + mfcc = np.load(mfcc_path) + mfccs = [] + ind = 3 + while ind <= int(mfcc.shape[0]/4) - 4: + # take 280 ms segment + t_mfcc = mfcc[(ind - 3)*4: (ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + ind += 1 + mfccs = torch.stack(mfccs, dim = 0) + print(mfccs.shape) #(24101, 28, 12) for Ben_Shapiro, (16282, 28, 12) for BBC_Carrie_Lam + self.mfccs = mfccs[self.start:self.start+self.trainN] + + coeff = np.zeros((self.trainN,257),dtype=np.float32) + for i in range(self.trainN): + coeff_path = os.path.join('../../Deep3DFaceReconstruction/output/coeff/', self.newsname, 'frame%d.mat'%(self.start+i)) + #print(coeff_path) + data = loadmat(coeff_path) + coeff[i] = data['coeff'] + L = len(self.indexes) + coeffc = np.zeros((self.trainN,L+6),dtype=np.float32) + coeffc[:,:L] = coeff[:,self.indexes] + coeffc[:,L:L+3] = coeff[:,224:227] + coeffc[:,L+3:L+6] = coeff[:,254:257] + coeffc2 = coeffc.copy() + if self.relativeframe == 1: + coeffc[0,L:L+6] = 0 + coeffc[1:,L:L+6] = coeffc2[1:,L:L+6] - coeffc2[:-1,L:L+6] + else: + coeffc[:,L+5] -= 0.5 + self.coeffc = torch.FloatTensor(coeffc) + self.coeffc2 = torch.FloatTensor(coeffc2) + print(self.coeffc.shape) + + def __getitem__(self, index): + if self.train=='train': + coeffc = self.coeffc[index:index+16] + mfccs = self.mfccs[index:index+16] + coeffc2 = self.coeffc2[index:index+16] + return coeffc, mfccs, coeffc2 + + def __len__(self): + if self.train=='train': + return self.trainN-15 + elif self.train=='test': + return self.testN-15 + else: + pass + +class LRW_1D_single_landmark_pca(data.Dataset): + def __init__(self, + dataset_dir, + train='train'): + self.train = train + self.num_frames = 16 + self.lmark_root_path = '../dataset/landmark1d' + self.audio_root_path = '../dataset/audio' + self.pca = torch.FloatTensor(np.load('../basics/U_lrw1.npy')[:,:6] ) + self.mean = torch.FloatTensor(np.load('../basics/mean_lrw1.npy')) + + if self.train=='train': + _file = open(os.path.join(dataset_dir, "lmark_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "lmark_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "img_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + + + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + lmark_path = os.path.join(self.lmark_root_path , self.train_data[index][0] , self.train_data[index][1],self.train_data[index][2], self.train_data[index][2] + '.npy') + mfcc_path = os.path.join('../dataset/mfcc/', self.train_data[index][0], self.train_data[index][1], self.train_data[index][2] + '.npy') + + lmark = np.load(lmark_path) + + lmark = torch.FloatTensor(lmark) * 5.0 + lmark = lmark - self.mean.expand_as(lmark) + lmark = torch.mm(lmark,self.pca) + + mfcc = np.load(mfcc_path) + + r = random.choice( + [x for x in range(3,25)]) + example_landmark =lmark[r,:] + example_mfcc = mfcc[(r -3) * 4 : (r + 4) * 4, 1 :] + + while True: + current_frame_id = random.choice( + [x for x in range(3,25)]) + if current_frame_id != r: + break + t_mfcc =mfcc[( current_frame_id - 3)*4: (current_frame_id + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + landmark =lmark[current_frame_id , :] + example_landmark = torch.FloatTensor(example_landmark) + example_mfcc = torch.FloatTensor(example_mfcc) + landmark = torch.FloatTensor(landmark) + landmark = landmark + return example_landmark, example_mfcc, landmark, t_mfcc + if self.train=='test': + mfcc_path = os.path.join('../dataset/mfcc/', self.test_data[index][0], self.test_data[index][1], self.test_data[index][2] + '.npy') + lmark_path = os.path.join(self.lmark_root_path , self.test_data[index][0] , self.test_data[index][1],self.test_data[index][2], self.test_data[index][2] + '.npy') + lmark = np.load(lmark_path) + mfcc = np.load(mfcc_path) + example_landmark =lmark[3,:] + example_mfcc = mfcc[0 : 7 * 4, 1 :] + r =3 + ind = self.test_data[index][3] + + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + landmark =lmark[r+ ind,:] + # example_audio = torch.FloatTensor(example_audio) + example_mfcc = torch.FloatTensor(example_mfcc) + # audio = torch.FloatTensor(audio) + # mfccs = torch.FloatTensor(mfccs) + landmark = torch.FloatTensor(landmark) + # landmark = self.transform(landmark) + landmark = landmark * 5.0 + example_landmark = torch.FloatTensor(example_landmark).view(1,-1) + example_landmark = example_landmark - self.mean.expand_as(example_landmark) + example_landmark = torch.mm(example_landmark,self.pca) + + return example_landmark[0], example_mfcc, landmark, t_mfcc + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + pass + +class LRWdataset1D_single_gt(data.Dataset): + def __init__(self, + dataset_dir, + output_shape=[128, 128], + train='train'): + self.train = train + self.dataset_dir = dataset_dir + self.output_shape = tuple(output_shape) + + if not len(output_shape) in [2, 3]: + raise ValueError("[*] output_shape must be [H,W] or [C,H,W]") + + if self.train=='train': + _file = open(os.path.join(dataset_dir, "new_img_full_gt_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "new_img_full_gt_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "new_img_full_gt_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + + #load righ img + image_path = '../dataset/regions/' + self.train_data[index][0] + landmark_path = '../dataset/landmark1d/' + self.train_data[index][0][:-8] + '.npy' + + landmark = np.load(landmark_path) * 5.0 + + right_landmark = landmark[self.train_data[index][1] - 1] + right_landmark = torch.FloatTensor(right_landmark.reshape(-1)) + + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img = torch.FloatTensor(im) + + r = random.choice( + [x for x in range(1,30)]) + example_path = image_path[:-8] + '_%03d.jpg'%r + example_landmark = landmark[r - 1] + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + + return example_img, example_landmark, right_img,right_landmark + + elif self.train =='test': + # try: + #load righ img + image_path = '../dataset/regions/' + self.test_data[index][0] + landmark_path = '../dataset/landmark1d/' + self.test_data[index][0][:-8] + '.npy' + landmark = np.load(landmark_path) * 5.0 + right_landmark = landmark[self.test_data[index][1] - 1] + + right_landmark = torch.FloatTensor(right_landmark.reshape(-1)) + + im = cv2.imread(image_path) + + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img = torch.FloatTensor(im) + + example_path = '../image/musk1_region.jpg' + example_landmark = np.load('../image/musk1.npy') + + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) * 5.0 + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + + return example_img, example_landmark, right_img,right_landmark + + +class LRWdataset1D_lstm_gt(data.Dataset): + def __init__(self, + dataset_dir, + output_shape=[128, 128], + train='train'): + self.train = train + self.dataset_dir = dataset_dir + self.output_shape = tuple(output_shape) + + if not len(output_shape) in [2, 3]: + raise ValueError("[*] output_shape must be [H,W] or [C,H,W]") + + if self.train=='train': + _file = open(os.path.join(dataset_dir, "new_16_full_gt_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "new_16_full_gt_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "new_16_full_gt_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + + #load righ img + image_path = '../dataset/regions/' + self.train_data[index][0] + landmark_path = '../dataset/landmark1d/' + self.train_data[index][0][:-8] + '.npy' + current_frame_id =self.train_data[index][1] + right_img = torch.FloatTensor(16,3,self.output_shape[0],self.output_shape[1]) + for jj in range(16): + this_frame = current_frame_id + jj + image_path = '../dataset/regions/' + self.train_data[index][0][:-7] + '%03d.jpg'%this_frame + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img[jj,:,:,:] = torch.FloatTensor(im) + + + + landmark = np.load(landmark_path) * 5.0 + + right_landmark = landmark[self.train_data[index][1] - 1 : self.train_data[index][1] + 15 ] + + right_landmark = torch.FloatTensor(right_landmark.reshape(16,136)) + + + r = random.choice( + [x for x in range(1,30)]) + example_path = image_path[:-8] + '_%03d.jpg'%r + example_landmark = landmark[r - 1] + + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + # print (right_landmark.size()) + + return example_img, example_landmark, right_img,right_landmark + + + elif self.train =='test': + #load righ img + image_path = '../dataset/regions/' + self.test_data[index][0] + landmark_path = '../dataset/landmark1d/' + self.test_data[index][0][:-8] + '.npy' + current_frame_id =self.test_data[index][1] + right_img = torch.FloatTensor(16,3,self.output_shape[0],self.output_shape[1]) + for jj in range(16): + this_frame = current_frame_id + jj + image_path = '../dataset/regions/' + self.test_data[index][0][:-7] + '%03d.jpg'%this_frame + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img[jj,:,:,:] = torch.FloatTensor(im) + + landmark = np.load(landmark_path) * 5.0 + + right_landmark = landmark[self.test_data[index][1] - 1 : self.test_data[index][1] + 15 ] + + right_landmark = torch.FloatTensor(right_landmark.reshape(16,136)) + r = random.choice( + [x for x in range(1,30)]) + r = current_frame_id + example_path = image_path[:-8] + '_%03d.jpg'%r + example_landmark = landmark[r - 1] + + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + # print (right_landmark.size()) + + return example_img, example_landmark, right_img,right_landmark + + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + return len(self.demo_data) + + +class LRWdataset1D_single(data.Dataset): + def __init__(self, + dataset_dir, + output_shape=[128, 128], + train='train'): + self.train = train + self.dataset_dir = dataset_dir + self.output_shape = tuple(output_shape) + + if not len(output_shape) in [2, 3]: + raise ValueError("[*] output_shape must be [H,W] or [C,H,W]") + + if self.train=='train': + _file = open(os.path.join(dataset_dir, "new_img_small_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "new_img_small_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "img_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + while True: + # try: + #load righ img + image_path = self.train_data[index][0] + landmark_path = self.train_data[index][1] + landmark = np.load(landmark_path) + + right_landmark = landmark[self.train_data[index][2]] + tp = ( np.dot(right_landmark.reshape(1,6), EIGVECS))[0,:].reshape(68,3) + tp = tp[:,:-1].reshape(-1) + right_landmark = torch.FloatTensor(tp) + im = cv2.imread(image_path) + if im is None: + raise IOError + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img = torch.FloatTensor(im) + + r = random.choice( + [x for x in range(1,30)]) + example_path = image_path[:-8] + '_%03d.jpg'%r + example_landmark = landmark[r] + + + tp = ( np.dot(example_landmark.reshape(1,6), EIGVECS))[0,:].reshape(68,3) + tp = tp[:,:-1].reshape(-1) + + + example_landmark = torch.FloatTensor(tp) + + example_img = cv2.imread(example_path) + if example_img is None: + raise IOError + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + + return example_img, example_landmark, right_img,right_landmark + + elif self.train =='test': + while True: + image_path = self.test_data[index][0] + landmark_path = self.test_data[index][1] + landmark = np.load(landmark_path) + + right_landmark = landmark[self.test_data[index][2]] + right_landmark = torch.FloatTensor((MS + np.dot(right_landmark.reshape(1,6), EIGVECS)).reshape(-1)) + print (right_landmark.shape) + im = cv2.imread(image_path) + if im is None: + raise IOError + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img = torch.FloatTensor(im) + r = random.choice( + [x for x in range(1,30)]) + #load example image + example_path = image_path[:-8] + '_%03d.jpg'%r + example_landmark = landmark[self.train_data[r][2]] + example_landmark = torch.FloatTensor((MS + np.dot(example_landmark.reshape(1,6), EIGVECS)).reshape(-1)) + + example_img = cv2.imread(example_path) + if example_img is None: + raise IOError + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + return example_img, example_landmark, right_img,right_landmark + elif self.train =='demo': + landmarks = np.load('/home/lchen63/obama_fake.npy') + landmarks =np.reshape(landmarks, (landmarks.shape[0], 136)) + while True: + # try: + + image_path = self.demo_data[index][0] + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + + right_img = torch.FloatTensor(im) + example_path = '/mnt/disk1/dat/lchen63/lrw/demo/musk1_region.jpg' + + example_landmark = landmarks[0] + + + example_lip = cv2.imread(example_path) + + example_lip = cv2.cvtColor(example_lip, cv2.COLOR_BGR2RGB) + example_lip = cv2.resize(example_lip, self.output_shape) + example_lip = self.transform(example_lip) + + + right_landmark = torch.FloatTensor(landmarks[index-1]) + + wrong_landmark = right_landmark + return example_lip, example_landmark, right_img,right_landmark, wrong_landmark + + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + return len(self.demo_data) + +#############################################################grid +class GRIDdataset1D_single_gt(data.Dataset): + def __init__(self, + dataset_dir, + output_shape=[128, 128], + train='train'): + self.train = train + self.dataset_dir = dataset_dir + self.output_shape = tuple(output_shape) + + if not len(output_shape) in [2, 3]: + raise ValueError("[*] output_shape must be [H,W] or [C,H,W]") + + if self.train=='train': + _file = open(os.path.join(dataset_dir, "lmark_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "lmark_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "new_img_full_gt_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + + #load righ img + image_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.train_data[index][0], self.train_data[index][0], '%05d.jpg'%(self.train_data[index][1] + 1)) + + landmark_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.train_data[index][0], self.train_data[index][0] + '_norm_lmarks.npy') + + landmark = np.load(landmark_path) + + right_landmark = landmark[self.train_data[index][1]] + + right_landmark = torch.FloatTensor(right_landmark.reshape(-1)) + + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img = torch.FloatTensor(im) + + r = random.choice( + [x for x in range(1, 76)]) + example_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.train_data[index][0], self.train_data[index][0], '%05d.jpg'%(r)) + example_landmark = landmark[r - 1] + + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + return example_img, example_landmark, right_img,right_landmark + + + + + elif self.train =='test': + # try: + #load righ img + image_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.test_data[index][0], self.test_data[index][0], '%05d.jpg'%(self.test_data[index][1] + 1)) + + landmark_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.test_data[index][0], self.test_data[index][0] + '_norm_lmarks.npy') + + landmark = np.load(landmark_path) + + right_landmark = landmark[self.test_data[index][1]] + + right_landmark = torch.FloatTensor(right_landmark.reshape(-1)) + + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img = torch.FloatTensor(im) + + r = random.choice( + [x for x in range(1, 76)]) + example_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.test_data[index][0], self.test_data[index][0], '%05d.jpg'%(r)) + example_landmark = landmark[r - 1] + + + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + return example_img, example_landmark, right_img,right_landmark + + elif self.train =='demo': + + # try: + #load righ img + image_path = '/mnt/ssd0/dat/lchen63/lrw/demo/regions/' + self.demo_data[index][0] + landmark_path = '/mnt/ssd0/dat/lchen63/lrw/demo/landmark1d/' + self.demo_data[index][1].replace('obama_', 'obama_ge_') + right_landmark = np.load(landmark_path) + right_landmark = torch.FloatTensor(right_landmark.reshape(-1)) + # print ('=========================') + # print ('real path: ' + image_path) + im = cv2.imread(image_path) + if im is None: + print (image_path) + raise IOError + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img = torch.FloatTensor(im) + + + example_path = '../image/musk1_region.jpg' + example_landmark = np.load('../image/musk1.npy') + # tp = ( np.dot(example_landmark.reshape(1,6), EIGVECS))[0,:].reshape(68,3) + # tp = tp[:,:-1].reshape(-1) + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + # print (right_landmark.size()) + + return example_img, example_landmark, right_img,right_landmark + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + return len(self.demo_data) + + + +class GRIDdataset1D_lstm_gt(data.Dataset): + def __init__(self, + dataset_dir, + output_shape=[128, 128], + train='train'): + self.train = train + self.dataset_dir = dataset_dir + self.output_shape = tuple(output_shape) + + if not len(output_shape) in [2, 3]: + raise ValueError("[*] output_shape must be [H,W] or [C,H,W]") + + if self.train=='train': + _file = open(os.path.join(dataset_dir, "lmark_16_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "lmark_16_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "new_16_full_gt_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]) + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + + #load righ img + + image_path_root = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.train_data[index][0], self.train_data[index][0]) + + landmark_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.train_data[index][0], self.train_data[index][0] + '_norm_lmarks.npy') + + current_frame_id =self.train_data[index][1] + right_img = torch.FloatTensor(16,3,self.output_shape[0],self.output_shape[1]) + for jj in range(16): + this_frame = current_frame_id + jj + image_path = os.path.join(image_path_root, '%05d.jpg'%this_frame) + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img[jj,:,:,:] = torch.FloatTensor(im) + + + + landmark = np.load(landmark_path) + + right_landmark = landmark[self.train_data[index][1] - 1 : self.train_data[index][1] + 15 ] + + right_landmark = torch.FloatTensor(right_landmark.reshape(16,136)) + + r = random.choice( + [x for x in range(1,76)]) + example_path = os.path.join(image_path_root, '%05d.jpg'%r) + example_landmark = landmark[r - 1] + + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + + return example_img, example_landmark, right_img,right_landmark + + + elif self.train =='test': + image_path_root = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.test_data[index][0], self.test_data[index][0]) + + landmark_path = os.path.join('/mnt/ssd0/dat/lchen63/grid/data' , self.test_data[index][0], self.test_data[index][0] + '_norm_lmarks.npy') + + current_frame_id =self.test_data[index][1] + right_img = torch.FloatTensor(16,3,self.output_shape[0],self.output_shape[1]) + for jj in range(16): + this_frame = current_frame_id + jj + image_path = os.path.join(image_path_root, '%05d.jpg'%this_frame) + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = cv2.resize(im, self.output_shape) + im = self.transform(im) + right_img[jj,:,:,:] = torch.FloatTensor(im) + + + + landmark = np.load(landmark_path) + + right_landmark = landmark[self.test_data[index][1] - 1 : self.test_data[index][1] + 15 ] + + right_landmark = torch.FloatTensor(right_landmark.reshape(16,136)) + + r = random.choice( + [x for x in range(1,76)]) + example_path = os.path.join(image_path_root, '%05d.jpg'%r) + example_landmark = landmark[r - 1] + + example_landmark = torch.FloatTensor(example_landmark.reshape(-1)) + + example_img = cv2.imread(example_path) + example_img = cv2.cvtColor(example_img, cv2.COLOR_BGR2RGB) + example_img = cv2.resize(example_img, self.output_shape) + example_img = self.transform(example_img) + + return example_img, example_landmark, right_img,right_landmark + + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + return len(self.demo_data) +class GRID_1D_lstm_landmark_pca(data.Dataset): + def __init__(self, + dataset_dir, + train='train'): + self.train = train + self.num_frames = 16 + self.root_path = '/mnt/ssd0/dat/lchen63/grid/data' + self.pca = torch.FloatTensor(np.load('../basics/U_grid.npy')[:,:6] ) + self.mean = torch.FloatTensor(np.load('../basics/mean_grid.npy')) + if self.train=='train': + _file = open(os.path.join(dataset_dir, "lmark_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "lmark_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "img_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + try: + lmark_path = os.path.join(self.root_path , self.train_data[index][0] , self.train_data[index][0] + '_norm_lmarks.npy') + mfcc_path = os.path.join(self.root_path, self.train_data[index][0], self.train_data[index][0] +'_mfcc.npy') + + lmark = np.load(lmark_path) * 5.0 + lmark = torch.FloatTensor(lmark) + lmark = lmark - self.mean.expand_as(lmark) + lmark = torch.mm(lmark,self.pca) + + mfcc = np.load(mfcc_path) + + r = random.choice( + [x for x in range(6,50)]) + example_landmark =lmark[r,:] + example_mfcc = mfcc[(r -3) * 4 : (r + 4) * 4, 1 :] + + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + landmark =lmark[r+1 : r + 17,:] + + example_mfcc = torch.FloatTensor(example_mfcc) + return example_landmark, example_mfcc, landmark, mfccs + except: + self.__getitem__(index + 1) + if self.train=='test': + lmark_path = os.path.join(self.root_path , self.test_data[index][0] , self.test_data[index][0] + '_norm_lmarks.npy') + mfcc_path = os.path.join(self.root_path, self.test_data[index][0], self.test_data[index][0] +'_mfcc.npy') + + lmark = np.load(lmark_path) * 5.0 + lmark = torch.FloatTensor(lmark) + lmark = lmark - self.mean.expand_as(lmark) + lmark = torch.mm(lmark,self.pca) + + mfcc = np.load(mfcc_path) + + r = random.choice( + [x for x in range(3,70)]) + example_landmark =lmark[r,:] + example_mfcc = mfcc[(r -3) * 4 : (r + 4) * 4, 1 :] + + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + landmark =lmark[r+1 : r + 17,:] + + example_mfcc = torch.FloatTensor(example_mfcc) + + return example_landmark, example_mfcc, landmark, mfccs + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + pass + + +class GRID_1D_single_landmark_pca(data.Dataset): + def __init__(self, + dataset_dir, + train='train'): + self.train = train + self.num_frames = 16 + self.root_path = '/mnt/ssd0/dat/lchen63/grid/data' + self.pca = torch.FloatTensor(np.load('../basics/U_grid.npy')[:,:6] ) + self.mean = torch.FloatTensor(np.load('../basics/mean_grid.npy')) + if self.train=='train': + _file = open(os.path.join(dataset_dir, "lmark_train.pkl"), "rb") + self.train_data = pickle.load(_file) + _file.close() + elif self.train =='test': + _file = open(os.path.join(dataset_dir, "lmark_test.pkl"), "rb") + self.test_data = pickle.load(_file) + _file.close() + elif self.train =='demo' : + _file = open(os.path.join(dataset_dir, "img_demo.pkl"), "rb") + self.demo_data = pickle.load(_file) + _file.close() + + def __getitem__(self, index): + # In training phase, it return real_image, wrong_image, text + if self.train=='train': + # try: + + lmark_path = os.path.join(self.root_path , self.train_data[index][0] , self.train_data[index][0] + '_norm_lmarks.npy') + mfcc_path = os.path.join(self.root_path, self.train_data[index][0], self.train_data[index][0] +'_mfcc.npy') + ind = self.train_data[index][1] + + lmark = np.load(lmark_path) * 5.0 + lmark = torch.FloatTensor(lmark) + lmark = lmark - self.mean.expand_as(lmark) + lmark = torch.mm(lmark,self.pca) + + + mfcc = np.load(mfcc_path) + + r = random.choice( + [x for x in range(6,50)]) + example_landmark =lmark[r,:] + t_mfcc =mfcc[(ind - 3)*4: (ind + 4)*4, 1:] + + t_mfcc = torch.FloatTensor(t_mfcc) + landmark =lmark[ind,:] + + return example_landmark, t_mfcc, landmark, t_mfcc + + if self.train=='test': + lmark_path = os.path.join(self.root_path , self.test_data[index][0] , self.test_data[index][0] + '_norm_lmarks.npy') + mfcc_path = os.path.join(self.root_path, self.test_data[index][0], self.test_data[index][0] +'_mfcc.npy') + + lmark = np.load(lmark_path) * 5.0 + lmark = torch.FloatTensor(lmark) + lmark = lmark - self.mean.expand_as(lmark) + lmark = torch.mm(lmark,self.pca) + + mfcc = np.load(mfcc_path) + + r = random.choice( + [x for x in range(3,70)]) + example_landmark =lmark[r,:] + example_mfcc = mfcc[(r -3) * 4 : (r + 4) * 4, 1 :] + + mfccs = [] + for ind in range(1,17): + t_mfcc =mfcc[(r + ind - 3)*4: (r + ind + 4)*4, 1:] + t_mfcc = torch.FloatTensor(t_mfcc) + mfccs.append(t_mfcc) + mfccs = torch.stack(mfccs, dim = 0) + landmark =lmark[r+1 : r + 17,:] + + example_mfcc = torch.FloatTensor(example_mfcc) + + return example_landmark, example_mfcc, landmark, mfccs + + def __len__(self): + if self.train=='train': + return len(self.train_data) + elif self.train=='test': + return len(self.test_data) + else: + pass diff --git a/Audio/code/mesh_renderer/camera_utils.py b/Audio/code/mesh_renderer/camera_utils.py new file mode 100644 index 0000000..f28c555 --- /dev/null +++ b/Audio/code/mesh_renderer/camera_utils.py @@ -0,0 +1,183 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collection of TF functions for managing 3D camera matrices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import tensorflow as tf + + +def perspective(aspect_ratio, fov_y, near_clip, far_clip): + """Computes perspective transformation matrices. + + Functionality mimes gluPerspective (third_party/GL/glu/include/GLU/glu.h). + + Args: + aspect_ratio: float value specifying the image aspect ratio (width/height). + fov_y: 1-D float32 Tensor with shape [batch_size] specifying output vertical + field of views in degrees. + near_clip: 1-D float32 Tensor with shape [batch_size] specifying near + clipping plane distance. + far_clip: 1-D float32 Tensor with shape [batch_size] specifying far clipping + plane distance. + + Returns: + A [batch_size, 4, 4] float tensor that maps from right-handed points in eye + space to left-handed points in clip space. + """ + # The multiplication of fov_y by pi/360.0 simultaneously converts to radians + # and adds the half-angle factor of .5. + focal_lengths_y = 1.0 / tf.tan(fov_y * (math.pi / 360.0)) + depth_range = far_clip - near_clip + p_22 = -(far_clip + near_clip) / depth_range + p_23 = -2.0 * (far_clip * near_clip / depth_range) + + zeros = tf.zeros_like(p_23, dtype=tf.float32) + # pyformat: disable + perspective_transform = tf.concat( + [ + focal_lengths_y / aspect_ratio, zeros, zeros, zeros, + zeros, focal_lengths_y, zeros, zeros, + zeros, zeros, p_22, p_23, + zeros, zeros, -tf.ones_like(p_23, dtype=tf.float32), zeros + ], axis=0) + # pyformat: enable + perspective_transform = tf.reshape(perspective_transform, [4, 4, -1]) + return tf.transpose(perspective_transform, [2, 0, 1]) + + +def look_at(eye, center, world_up): + """Computes camera viewing matrices. + + Functionality mimes gluLookAt (third_party/GL/glu/include/GLU/glu.h). + + Args: + eye: 2-D float32 tensor with shape [batch_size, 3] containing the XYZ world + space position of the camera. + center: 2-D float32 tensor with shape [batch_size, 3] containing a position + along the center of the camera's gaze. + world_up: 2-D float32 tensor with shape [batch_size, 3] specifying the + world's up direction; the output camera will have no tilt with respect + to this direction. + + Returns: + A [batch_size, 4, 4] float tensor containing a right-handed camera + extrinsics matrix that maps points from world space to points in eye space. + """ + batch_size = center.shape[0].value + vector_degeneracy_cutoff = 1e-6 + forward = center - eye + forward_norm = tf.norm(forward, ord='euclidean', axis=1, keepdims=True) + #tf.assert_greater( + # forward_norm, + # vector_degeneracy_cutoff, + # message='Camera matrix is degenerate because eye and center are close.') + forward = tf.divide(forward, forward_norm) + + to_side = tf.linalg.cross(forward, world_up) + to_side_norm = tf.norm(to_side, ord='euclidean', axis=1, keepdims=True) + #tf.assert_greater( + # to_side_norm, + # vector_degeneracy_cutoff, + # message='Camera matrix is degenerate because up and gaze are close or' + # 'because up is degenerate.') + to_side = tf.divide(to_side, to_side_norm) + cam_up = tf.linalg.cross(to_side, forward) + + w_column = tf.constant( + batch_size * [[0., 0., 0., 1.]], dtype=tf.float32) # [batch_size, 4] + w_column = tf.reshape(w_column, [batch_size, 4, 1]) + view_rotation = tf.stack( + [to_side, cam_up, -forward, + tf.zeros_like(to_side, dtype=tf.float32)], + axis=1) # [batch_size, 4, 3] matrix + view_rotation = tf.concat( + [view_rotation, w_column], axis=2) # [batch_size, 4, 4] + + identity_batch = tf.tile(tf.expand_dims(tf.eye(3), 0), [batch_size, 1, 1]) + view_translation = tf.concat([identity_batch, tf.expand_dims(-eye, 2)], 2) + view_translation = tf.concat( + [view_translation, + tf.reshape(w_column, [batch_size, 1, 4])], 1) + camera_matrices = tf.matmul(view_rotation, view_translation) + return camera_matrices + + +def euler_matrices(angles): + """Computes a XYZ Tait-Bryan (improper Euler angle) rotation. + + Returns 4x4 matrices for convenient multiplication with other transformations. + + Args: + angles: a [batch_size, 3] tensor containing X, Y, and Z angles in radians. + + Returns: + a [batch_size, 4, 4] tensor of matrices. + """ + s = tf.sin(angles) + c = tf.cos(angles) + # Rename variables for readability in the matrix definition below. + c0, c1, c2 = (c[:, 0], c[:, 1], c[:, 2]) + s0, s1, s2 = (s[:, 0], s[:, 1], s[:, 2]) + + zeros = tf.zeros_like(s[:, 0]) + ones = tf.ones_like(s[:, 0]) + + # pyformat: disable + flattened = tf.concat( + [ + c2 * c1, c2 * s1 * s0 - c0 * s2, s2 * s0 + c2 * c0 * s1, zeros, + c1 * s2, c2 * c0 + s2 * s1 * s0, c0 * s2 * s1 - c2 * s0, zeros, + -s1, c1 * s0, c1 * c0, zeros, + zeros, zeros, zeros, ones + ], + axis=0) + # pyformat: enable + reshaped = tf.reshape(flattened, [4, 4, -1]) + return tf.transpose(reshaped, [2, 0, 1]) + + +def transform_homogeneous(matrices, vertices): + """Applies batched 4x4 homogenous matrix transformations to 3-D vertices. + + The vertices are input and output as as row-major, but are interpreted as + column vectors multiplied on the right-hand side of the matrices. More + explicitly, this function computes (MV^T)^T. + Vertices are assumed to be xyz, and are extended to xyzw with w=1. + + Args: + matrices: a [batch_size, 4, 4] tensor of matrices. + vertices: a [batch_size, N, 3] tensor of xyz vertices. + + Returns: + a [batch_size, N, 4] tensor of xyzw vertices. + + Raises: + ValueError: if matrices or vertices have the wrong number of dimensions. + """ + if len(matrices.shape) != 3: + raise ValueError( + 'matrices must have 3 dimensions (missing batch dimension?)') + if len(vertices.shape) != 3: + raise ValueError( + 'vertices must have 3 dimensions (missing batch dimension?)') + homogeneous_coord = tf.ones( + [tf.shape(vertices)[0], tf.shape(vertices)[1], 1], dtype=tf.float32) + vertices_homogeneous = tf.concat([vertices, homogeneous_coord], 2) + + return tf.matmul(vertices_homogeneous, matrices, transpose_b=True) diff --git a/Audio/code/mesh_renderer/mesh_renderer.py b/Audio/code/mesh_renderer/mesh_renderer.py new file mode 100644 index 0000000..c614afa --- /dev/null +++ b/Audio/code/mesh_renderer/mesh_renderer.py @@ -0,0 +1,402 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Differentiable 3-D rendering of a triangle mesh.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from . import camera_utils +from . import rasterize_triangles + + +def phong_shader(normals, + alphas, + pixel_positions, + light_positions, + light_intensities, + diffuse_colors=None, + camera_position=None, + specular_colors=None, + shininess_coefficients=None, + ambient_color=None): + """Computes pixelwise lighting from rasterized buffers with the Phong model. + + Args: + normals: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the world space XYZ normal for + the corresponding pixel. Should be already normalized. + alphas: a 3D float32 tensor with shape [batch_size, image_height, + image_width]. The inner dimension is the alpha value (transparency) + for the corresponding pixel. + pixel_positions: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the world space XYZ position for + the corresponding pixel. + light_positions: a 3D tensor with shape [batch_size, light_count, 3]. The + XYZ position of each light in the scene. In the same coordinate space as + pixel_positions. + light_intensities: a 3D tensor with shape [batch_size, light_count, 3]. The + RGB intensity values for each light. Intensities may be above one. + diffuse_colors: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the diffuse RGB coefficients at + a pixel in the range [0, 1]. + camera_position: a 1D tensor with shape [batch_size, 3]. The XYZ camera + position in the scene. If supplied, specular reflections will be + computed. If not supplied, specular_colors and shininess_coefficients + are expected to be None. In the same coordinate space as + pixel_positions. + specular_colors: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the specular RGB coefficients at + a pixel in the range [0, 1]. If None, assumed to be tf.zeros() + shininess_coefficients: A 3D float32 tensor that is broadcasted to shape + [batch_size, image_height, image_width]. The inner dimension is the + shininess coefficient for the object at a pixel. Dimensions that are + constant can be given length 1, so [batch_size, 1, 1] and [1, 1, 1] are + also valid input shapes. + ambient_color: a 2D tensor with shape [batch_size, 3]. The RGB ambient + color, which is added to each pixel before tone mapping. If None, it is + assumed to be tf.zeros(). + Returns: + A 4D float32 tensor of shape [batch_size, image_height, image_width, 4] + containing the lit RGBA color values for each image at each pixel. Colors + are in the range [0,1]. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + batch_size, image_height, image_width = [s.value for s in normals.shape[:-1]] + light_count = light_positions.shape[1].value + pixel_count = image_height * image_width + # Reshape all values to easily do pixelwise computations: + normals = tf.reshape(normals, [batch_size, -1, 3]) + alphas = tf.reshape(alphas, [batch_size, -1, 1]) + diffuse_colors = tf.reshape(diffuse_colors, [batch_size, -1, 3]) + if camera_position is not None: + specular_colors = tf.reshape(specular_colors, [batch_size, -1, 3]) + + # Ambient component + output_colors = tf.zeros([batch_size, image_height * image_width, 3]) + if ambient_color is not None: + ambient_reshaped = tf.expand_dims(ambient_color, axis=1) + output_colors = tf.add(output_colors, ambient_reshaped * diffuse_colors) + + # Diffuse component + pixel_positions = tf.reshape(pixel_positions, [batch_size, -1, 3]) + per_light_pixel_positions = tf.stack( + [pixel_positions] * light_count, + axis=1) # [batch_size, light_count, pixel_count, 3] + directions_to_lights = tf.nn.l2_normalize( + tf.expand_dims(light_positions, axis=2) - per_light_pixel_positions, + axis=3) # [batch_size, light_count, pixel_count, 3] + # The specular component should only contribute when the light and normal + # face one another (i.e. the dot product is nonnegative): + normals_dot_lights = tf.clip_by_value( + tf.reduce_sum( + tf.expand_dims(normals, axis=1) * directions_to_lights, axis=3), 0.0, + 1.0) # [batch_size, light_count, pixel_count] + diffuse_output = tf.expand_dims( + diffuse_colors, axis=1) * tf.expand_dims( + normals_dot_lights, axis=3) * tf.expand_dims( + light_intensities, axis=2) + diffuse_output = tf.reduce_sum( + diffuse_output, axis=1) # [batch_size, pixel_count, 3] + output_colors = tf.add(output_colors, diffuse_output) + + # Specular component + if camera_position is not None: + camera_position = tf.reshape(camera_position, [batch_size, 1, 3]) + mirror_reflection_direction = tf.nn.l2_normalize( + 2.0 * tf.expand_dims(normals_dot_lights, axis=3) * tf.expand_dims( + normals, axis=1) - directions_to_lights, + dim=3) + direction_to_camera = tf.nn.l2_normalize( + camera_position - pixel_positions, dim=2) + reflection_direction_dot_camera_direction = tf.reduce_sum( + tf.expand_dims(direction_to_camera, axis=1) * + mirror_reflection_direction, + axis=3) + # The specular component should only contribute when the reflection is + # external: + reflection_direction_dot_camera_direction = tf.clip_by_value( + tf.nn.l2_normalize(reflection_direction_dot_camera_direction, dim=2), + 0.0, 1.0) + # The specular component should also only contribute when the diffuse + # component contributes: + reflection_direction_dot_camera_direction = tf.where( + normals_dot_lights != 0.0, reflection_direction_dot_camera_direction, + tf.zeros_like( + reflection_direction_dot_camera_direction, dtype=tf.float32)) + # Reshape to support broadcasting the shininess coefficient, which rarely + # varies per-vertex: + reflection_direction_dot_camera_direction = tf.reshape( + reflection_direction_dot_camera_direction, + [batch_size, light_count, image_height, image_width]) + shininess_coefficients = tf.expand_dims(shininess_coefficients, axis=1) + specularity = tf.reshape( + tf.pow(reflection_direction_dot_camera_direction, + shininess_coefficients), + [batch_size, light_count, pixel_count, 1]) + specular_output = tf.expand_dims( + specular_colors, axis=1) * specularity * tf.expand_dims( + light_intensities, axis=2) + specular_output = tf.reduce_sum(specular_output, axis=1) + output_colors = tf.add(output_colors, specular_output) + rgb_images = tf.reshape(output_colors, + [batch_size, image_height, image_width, 3]) + alpha_images = tf.reshape(alphas, [batch_size, image_height, image_width, 1]) + valid_rgb_values = tf.concat(3 * [alpha_images > 0.5], axis=3) + rgb_images = tf.where(valid_rgb_values, rgb_images, + tf.zeros_like(rgb_images, dtype=tf.float32)) + return tf.reverse(tf.concat([rgb_images, alpha_images], axis=3), axis=[1]) + + +def tone_mapper(image, gamma): + """Applies gamma correction to the input image. + + Tone maps the input image batch in order to make scenes with a high dynamic + range viewable. The gamma correction factor is computed separately per image, + but is shared between all provided channels. The exact function computed is: + + image_out = A*image_in^gamma, where A is an image-wide constant computed so + that the maximum image value is approximately 1. The correction is applied + to all channels. + + Args: + image: 4-D float32 tensor with shape [batch_size, image_height, + image_width, channel_count]. The batch of images to tone map. + gamma: 0-D float32 nonnegative tensor. Values of gamma below one compress + relative contrast in the image, and values above one increase it. A + value of 1 is equivalent to scaling the image to have a maximum value + of 1. + Returns: + 4-D float32 tensor with shape [batch_size, image_height, image_width, + channel_count]. Contains the gamma-corrected images, clipped to the range + [0, 1]. + """ + batch_size = image.shape[0].value + corrected_image = tf.pow(image, gamma) + image_max = tf.reduce_max( + tf.reshape(corrected_image, [batch_size, -1]), axis=1) + scaled_image = tf.divide(corrected_image, + tf.reshape(image_max, [batch_size, 1, 1, 1])) + return tf.clip_by_value(scaled_image, 0.0, 1.0) + + +def mesh_renderer(vertices, + triangles, + normals, + diffuse_colors, + camera_position, + camera_lookat, + camera_up, + light_positions, + light_intensities, + image_width, + image_height, + specular_colors=None, + shininess_coefficients=None, + ambient_color=None, + fov_y=40.0, + near_clip=0.01, + far_clip=10.0): + """Renders an input scene using phong shading, and returns an output image. + + Args: + vertices: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each + triplet is an xyz position in world space. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + normals: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each + triplet is the xyz vertex normal for its corresponding vertex. Each + vector is assumed to be already normalized. + diffuse_colors: 3-D float32 tensor with shape [batch_size, + vertex_count, 3]. The RGB diffuse reflection in the range [0,1] for + each vertex. + camera_position: 2-D tensor with shape [batch_size, 3] or 1-D tensor with + shape [3] specifying the XYZ world space camera position. + camera_lookat: 2-D tensor with shape [batch_size, 3] or 1-D tensor with + shape [3] containing an XYZ point along the center of the camera's gaze. + camera_up: 2-D tensor with shape [batch_size, 3] or 1-D tensor with shape + [3] containing the up direction for the camera. The camera will have no + tilt with respect to this direction. + light_positions: a 3-D tensor with shape [batch_size, light_count, 3]. The + XYZ position of each light in the scene. In the same coordinate space as + pixel_positions. + light_intensities: a 3-D tensor with shape [batch_size, light_count, 3]. The + RGB intensity values for each light. Intensities may be above one. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + specular_colors: 3-D float32 tensor with shape [batch_size, + vertex_count, 3]. The RGB specular reflection in the range [0, 1] for + each vertex. If supplied, specular reflections will be computed, and + both specular_colors and shininess_coefficients are expected. + shininess_coefficients: a 0D-2D float32 tensor with maximum shape + [batch_size, vertex_count]. The phong shininess coefficient of each + vertex. A 0D tensor or float gives a constant shininess coefficient + across all batches and images. A 1D tensor must have shape [batch_size], + and a single shininess coefficient per image is used. + ambient_color: a 2D tensor with shape [batch_size, 3]. The RGB ambient + color, which is added to each pixel in the scene. If None, it is + assumed to be black. + fov_y: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + desired output image y field of view in degrees. + near_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + near clipping plane distance. + far_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + far clipping plane distance. + + Returns: + A 4-D float32 tensor of shape [batch_size, image_height, image_width, 4] + containing the lit RGBA color values for each image at each pixel. RGB + colors are the intensity values before tonemapping and can be in the range + [0, infinity]. Clipping to the range [0,1] with tf.clip_by_value is likely + reasonable for both viewing and training most scenes. More complex scenes + with multiple lights should tone map color values for display only. One + simple tonemapping approach is to rescale color values as x/(1+x); gamma + compression is another common techinque. Alpha values are zero for + background pixels and near one for mesh pixels. + Raises: + ValueError: An invalid argument to the method is detected. + """ + if len(vertices.shape) != 3: + raise ValueError('Vertices must have shape [batch_size, vertex_count, 3].') + batch_size = vertices.shape[0].value + if len(normals.shape) != 3: + raise ValueError('Normals must have shape [batch_size, vertex_count, 3].') + if len(light_positions.shape) != 3: + raise ValueError( + 'Light_positions must have shape [batch_size, light_count, 3].') + if len(light_intensities.shape) != 3: + raise ValueError( + 'Light_intensities must have shape [batch_size, light_count, 3].') + if len(diffuse_colors.shape) != 3: + raise ValueError( + 'vertex_diffuse_colors must have shape [batch_size, vertex_count, 3].') + if (ambient_color is not None and + ambient_color.get_shape().as_list() != [batch_size, 3]): + raise ValueError('Ambient_color must have shape [batch_size, 3].') + if camera_position.get_shape().as_list() == [3]: + camera_position = tf.tile( + tf.expand_dims(camera_position, axis=0), [batch_size, 1]) + elif camera_position.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_position must have shape [batch_size, 3]') + if camera_lookat.get_shape().as_list() == [3]: + camera_lookat = tf.tile( + tf.expand_dims(camera_lookat, axis=0), [batch_size, 1]) + elif camera_lookat.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_lookat must have shape [batch_size, 3]') + if camera_up.get_shape().as_list() == [3]: + camera_up = tf.tile(tf.expand_dims(camera_up, axis=0), [batch_size, 1]) + elif camera_up.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_up must have shape [batch_size, 3]') + if isinstance(fov_y, float): + fov_y = tf.constant(batch_size * [fov_y], dtype=tf.float32) + elif not fov_y.get_shape().as_list(): + fov_y = tf.tile(tf.expand_dims(fov_y, 0), [batch_size]) + elif fov_y.get_shape().as_list() != [batch_size]: + raise ValueError('Fov_y must be a float, a 0D tensor, or a 1D tensor with' + 'shape [batch_size]') + if isinstance(near_clip, float): + near_clip = tf.constant(batch_size * [near_clip], dtype=tf.float32) + elif not near_clip.get_shape().as_list(): + near_clip = tf.tile(tf.expand_dims(near_clip, 0), [batch_size]) + elif near_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Near_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + if isinstance(far_clip, float): + far_clip = tf.constant(batch_size * [far_clip], dtype=tf.float32) + elif not far_clip.get_shape().as_list(): + far_clip = tf.tile(tf.expand_dims(far_clip, 0), [batch_size]) + elif far_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Far_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + if specular_colors is not None and shininess_coefficients is None: + raise ValueError( + 'Specular colors were supplied without shininess coefficients.') + if shininess_coefficients is not None and specular_colors is None: + raise ValueError( + 'Shininess coefficients were supplied without specular colors.') + if specular_colors is not None: + # Since a 0-D float32 tensor is accepted, also accept a float. + if isinstance(shininess_coefficients, float): + shininess_coefficients = tf.constant( + shininess_coefficients, dtype=tf.float32) + if len(specular_colors.shape) != 3: + raise ValueError('The specular colors must have shape [batch_size, ' + 'vertex_count, 3].') + if len(shininess_coefficients.shape) > 2: + raise ValueError('The shininess coefficients must have shape at most' + '[batch_size, vertex_count].') + # If we don't have per-vertex coefficients, we can just reshape the + # input shininess to broadcast later, rather than interpolating an + # additional vertex attribute: + if len(shininess_coefficients.shape) < 2: + vertex_attributes = tf.concat( + [normals, vertices, diffuse_colors, specular_colors], axis=2) + else: + vertex_attributes = tf.concat( + [ + normals, vertices, diffuse_colors, specular_colors, + tf.expand_dims(shininess_coefficients, axis=2) + ], + axis=2) + else: + vertex_attributes = tf.concat([normals, vertices, diffuse_colors], axis=2) + + camera_matrices = camera_utils.look_at(camera_position, camera_lookat, + camera_up) + + perspective_transforms = camera_utils.perspective(image_width / image_height, + fov_y, near_clip, far_clip) + + clip_space_transforms = tf.matmul(perspective_transforms, camera_matrices) + + pixel_attributes = rasterize_triangles.rasterize( + vertices, vertex_attributes, triangles, clip_space_transforms, + image_width, image_height, [-1] * vertex_attributes.shape[2].value) + + # Extract the interpolated vertex attributes from the pixel buffer and + # supply them to the shader: + pixel_normals = tf.nn.l2_normalize(pixel_attributes[:, :, :, 0:3], axis=3) + pixel_positions = pixel_attributes[:, :, :, 3:6] + diffuse_colors = pixel_attributes[:, :, :, 6:9] + if specular_colors is not None: + specular_colors = pixel_attributes[:, :, :, 9:12] + # Retrieve the interpolated shininess coefficients if necessary, or just + # reshape our input for broadcasting: + if len(shininess_coefficients.shape) == 2: + shininess_coefficients = pixel_attributes[:, :, :, 12] + else: + shininess_coefficients = tf.reshape(shininess_coefficients, [-1, 1, 1]) + + pixel_mask = tf.cast(tf.reduce_any(diffuse_colors >= 0, axis=3), tf.float32) + + renders = phong_shader( + normals=pixel_normals, + alphas=pixel_mask, + pixel_positions=pixel_positions, + light_positions=light_positions, + light_intensities=light_intensities, + diffuse_colors=diffuse_colors, + camera_position=camera_position if specular_colors is not None else None, + specular_colors=specular_colors, + shininess_coefficients=shininess_coefficients, + ambient_color=ambient_color) + return renders diff --git a/Audio/code/mesh_renderer/mesh_renderer_test.py b/Audio/code/mesh_renderer/mesh_renderer_test.py new file mode 100644 index 0000000..930305a --- /dev/null +++ b/Audio/code/mesh_renderer/mesh_renderer_test.py @@ -0,0 +1,317 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import os + +import numpy as np +import tensorflow as tf + +import camera_utils +import mesh_renderer +import test_utils + + +class RenderTest(tf.test.TestCase): + + def setUp(self): + self.test_data_directory = ( + 'mesh_renderer/test_data/') + + tf.reset_default_graph() + # Set up a basic cube centered at the origin, with vertex normals pointing + # outwards along the line from the origin to the cube vertices: + self.cube_vertices = tf.constant( + [[-1, -1, 1], [-1, -1, -1], [-1, 1, -1], [-1, 1, 1], [1, -1, 1], + [1, -1, -1], [1, 1, -1], [1, 1, 1]], + dtype=tf.float32) + self.cube_normals = tf.nn.l2_normalize(self.cube_vertices, dim=1) + self.cube_triangles = tf.constant( + [[0, 1, 2], [2, 3, 0], [3, 2, 6], [6, 7, 3], [7, 6, 5], [5, 4, 7], + [4, 5, 1], [1, 0, 4], [5, 6, 2], [2, 1, 5], [7, 4, 0], [0, 3, 7]], + dtype=tf.int32) + + def testRendersSimpleCube(self): + """Renders a simple cube to test the full forward pass. + + Verifies the functionality of both the custom kernel and the python wrapper. + """ + + model_transforms = camera_utils.euler_matrices( + [[-20.0, 0.0, 60.0], [45.0, 60.0, 0.0]])[:, :3, :3] + + vertices_world_space = tf.matmul( + tf.stack([self.cube_vertices, self.cube_vertices]), + model_transforms, + transpose_b=True) + + normals_world_space = tf.matmul( + tf.stack([self.cube_normals, self.cube_normals]), + model_transforms, + transpose_b=True) + + # camera position: + eye = tf.constant(2 * [[0.0, 0.0, 6.0]], dtype=tf.float32) + center = tf.constant(2 * [[0.0, 0.0, 0.0]], dtype=tf.float32) + world_up = tf.constant(2 * [[0.0, 1.0, 0.0]], dtype=tf.float32) + image_width = 640 + image_height = 480 + light_positions = tf.constant([[[0.0, 0.0, 6.0]], [[0.0, 0.0, 6.0]]]) + light_intensities = tf.ones([2, 1, 3], dtype=tf.float32) + vertex_diffuse_colors = tf.ones_like(vertices_world_space, dtype=tf.float32) + + rendered = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + + with self.test_session() as sess: + images = sess.run(rendered, feed_dict={}) + for image_id in range(images.shape[0]): + target_image_name = 'Gray_Cube_%i.png' % image_id + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, images[image_id, :, :, :]) + + def testComplexShading(self): + """Tests specular highlights, colors, and multiple lights per image.""" + # rotate the cube for the test: + model_transforms = camera_utils.euler_matrices( + [[-20.0, 0.0, 60.0], [45.0, 60.0, 0.0]])[:, :3, :3] + + vertices_world_space = tf.matmul( + tf.stack([self.cube_vertices, self.cube_vertices]), + model_transforms, + transpose_b=True) + + normals_world_space = tf.matmul( + tf.stack([self.cube_normals, self.cube_normals]), + model_transforms, + transpose_b=True) + + # camera position: + eye = tf.constant([[0.0, 0.0, 6.0], [0., 0.2, 18.0]], dtype=tf.float32) + center = tf.constant([[0.0, 0.0, 0.0], [0.1, -0.1, 0.1]], dtype=tf.float32) + world_up = tf.constant( + [[0.0, 1.0, 0.0], [0.1, 1.0, 0.15]], dtype=tf.float32) + fov_y = tf.constant([40., 13.3], dtype=tf.float32) + near_clip = tf.constant(0.1, dtype=tf.float32) + far_clip = tf.constant(25.0, dtype=tf.float32) + image_width = 640 + image_height = 480 + light_positions = tf.constant([[[0.0, 0.0, 6.0], [1.0, 2.0, 6.0]], + [[0.0, -2.0, 4.0], [1.0, 3.0, 4.0]]]) + light_intensities = tf.constant( + [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[2.0, 0.0, 1.0], [0.0, 2.0, + 1.0]]], + dtype=tf.float32) + # pyformat: disable + vertex_diffuse_colors = tf.constant(2*[[[1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + [0.5, 0.5, 0.5]]], + dtype=tf.float32) + vertex_specular_colors = tf.constant(2*[[[0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + [0.5, 0.5, 0.5], + [1.0, 0.0, 0.0]]], + dtype=tf.float32) + # pyformat: enable + shininess_coefficients = 6.0 * tf.ones([2, 8], dtype=tf.float32) + ambient_color = tf.constant( + [[0., 0., 0.], [0.1, 0.1, 0.2]], dtype=tf.float32) + renders = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height, vertex_specular_colors, + shininess_coefficients, ambient_color, fov_y, near_clip, far_clip) + tonemapped_renders = tf.concat( + [ + mesh_renderer.tone_mapper(renders[:, :, :, 0:3], 0.7), + renders[:, :, :, 3:4] + ], + axis=3) + + # Check that shininess coefficient broadcasting works by also rendering + # with a scalar shininess coefficient, and ensuring the result is identical: + broadcasted_renders = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height, vertex_specular_colors, + 6.0, ambient_color, fov_y, near_clip, far_clip) + tonemapped_broadcasted_renders = tf.concat( + [ + mesh_renderer.tone_mapper(broadcasted_renders[:, :, :, 0:3], 0.7), + broadcasted_renders[:, :, :, 3:4] + ], + axis=3) + + with self.test_session() as sess: + images, broadcasted_images = sess.run( + [tonemapped_renders, tonemapped_broadcasted_renders], feed_dict={}) + + for image_id in range(images.shape[0]): + target_image_name = 'Colored_Cube_%i.png' % image_id + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, images[image_id, :, :, :]) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, + broadcasted_images[image_id, :, :, :]) + + def testFullRenderGradientComputation(self): + """Verifies the Jacobian matrix for the entire renderer. + + This ensures correct gradients are propagated backwards through the entire + process, not just through the rasterization kernel. Uses the simple cube + forward pass. + """ + image_height = 21 + image_width = 28 + + # rotate the cube for the test: + model_transforms = camera_utils.euler_matrices( + [[-20.0, 0.0, 60.0], [45.0, 60.0, 0.0]])[:, :3, :3] + + vertices_world_space = tf.matmul( + tf.stack([self.cube_vertices, self.cube_vertices]), + model_transforms, + transpose_b=True) + + normals_world_space = tf.matmul( + tf.stack([self.cube_normals, self.cube_normals]), + model_transforms, + transpose_b=True) + + # camera position: + eye = tf.constant([0.0, 0.0, 6.0], dtype=tf.float32) + center = tf.constant([0.0, 0.0, 0.0], dtype=tf.float32) + world_up = tf.constant([0.0, 1.0, 0.0], dtype=tf.float32) + + # Scene has a single light from the viewer's eye. + light_positions = tf.expand_dims(tf.stack([eye, eye], axis=0), axis=1) + light_intensities = tf.ones([2, 1, 3], dtype=tf.float32) + + vertex_diffuse_colors = tf.ones_like(vertices_world_space, dtype=tf.float32) + + rendered = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + + with self.test_session(): + theoretical, numerical = tf.test.compute_gradient( + self.cube_vertices, (8, 3), + rendered, (2, image_height, image_width, 4), + x_init_value=self.cube_vertices.eval(), + delta=1e-3) + jacobians_match, message = ( + test_utils.check_jacobians_are_nearly_equal( + theoretical, numerical, 0.01, 0.01)) + self.assertTrue(jacobians_match, message) + + def testThatCubeRotates(self): + """Optimize a simple cube's rotation using pixel loss. + + The rotation is represented as static-basis euler angles. This test checks + that the computed gradients are useful. + """ + image_height = 480 + image_width = 640 + initial_euler_angles = [[0.0, 0.0, 0.0]] + + euler_angles = tf.Variable(initial_euler_angles) + model_rotation = camera_utils.euler_matrices(euler_angles)[0, :3, :3] + + vertices_world_space = tf.reshape( + tf.matmul(self.cube_vertices, model_rotation, transpose_b=True), + [1, 8, 3]) + + normals_world_space = tf.reshape( + tf.matmul(self.cube_normals, model_rotation, transpose_b=True), + [1, 8, 3]) + + # camera position: + eye = tf.constant([[0.0, 0.0, 6.0]], dtype=tf.float32) + center = tf.constant([[0.0, 0.0, 0.0]], dtype=tf.float32) + world_up = tf.constant([[0.0, 1.0, 0.0]], dtype=tf.float32) + + vertex_diffuse_colors = tf.ones_like(vertices_world_space, dtype=tf.float32) + light_positions = tf.reshape(eye, [1, 1, 3]) + light_intensities = tf.ones([1, 1, 3], dtype=tf.float32) + + render = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + render = tf.reshape(render, [image_height, image_width, 4]) + + # Pick the desired cube rotation for the test: + test_model_rotation = camera_utils.euler_matrices([[-20.0, 0.0, + 60.0]])[0, :3, :3] + + desired_vertex_positions = tf.reshape( + tf.matmul(self.cube_vertices, test_model_rotation, transpose_b=True), + [1, 8, 3]) + desired_normals = tf.reshape( + tf.matmul(self.cube_normals, test_model_rotation, transpose_b=True), + [1, 8, 3]) + desired_render = mesh_renderer.mesh_renderer( + desired_vertex_positions, self.cube_triangles, desired_normals, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + desired_render = tf.reshape(desired_render, [image_height, image_width, 4]) + + loss = tf.reduce_mean(tf.abs(render - desired_render)) + optimizer = tf.train.MomentumOptimizer(0.7, 0.1) + grad = tf.gradients(loss, [euler_angles]) + grad, _ = tf.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients([(grad[0], euler_angles)]) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(35): + sess.run([loss, opt_func]) + final_image, desired_image = sess.run([render, desired_render]) + + target_image_name = 'Gray_Cube_0.png' + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, desired_image) + test_utils.expect_image_file_and_render_are_near( + self, + sess, + baseline_image_path, + final_image, + max_outlier_fraction=0.01, + pixel_error_threshold=0.04) + + +if __name__ == '__main__': + tf.test.main() diff --git a/Audio/code/mesh_renderer/rasterize_triangles.py b/Audio/code/mesh_renderer/rasterize_triangles.py new file mode 100644 index 0000000..ac8d106 --- /dev/null +++ b/Audio/code/mesh_renderer/rasterize_triangles.py @@ -0,0 +1,178 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Differentiable triangle rasterizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tensorflow as tf + +from . import camera_utils + +rasterize_triangles_module = tf.load_op_library( + #os.path.join(os.environ['TEST_SRCDIR'], + os.path.join('/home4/yiran/TalkingFace/Pipeline/Deep3DFaceReconstruction', + 'tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_kernel.so')) + + +def rasterize(world_space_vertices, attributes, triangles, camera_matrices, + image_width, image_height, background_value): + """Rasterizes a mesh and computes interpolated vertex attributes. + + Applies projection matrices and then calls rasterize_clip_space(). + + Args: + world_space_vertices: 3-D float32 tensor of xyz positions with shape + [batch_size, vertex_count, 3]. + attributes: 3-D float32 tensor with shape [batch_size, vertex_count, + attribute_count]. Each vertex attribute is interpolated across the + triangle using barycentric interpolation. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + camera_matrices: 3-D float tensor with shape [batch_size, 4, 4] containing + model-view-perspective projection matrices. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + background_value: a 1-D float32 tensor with shape [attribute_count]. Pixels + that lie outside all triangles take this value. + + Returns: + A 4-D float32 tensor with shape [batch_size, image_height, image_width, + attribute_count], containing the interpolated vertex attributes at + each pixel. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + clip_space_vertices = camera_utils.transform_homogeneous( + camera_matrices, world_space_vertices) + return rasterize_clip_space(clip_space_vertices, attributes, triangles, + image_width, image_height, background_value) + + +def rasterize_clip_space(clip_space_vertices, attributes, triangles, + image_width, image_height, background_value): + """Rasterizes the input mesh expressed in clip-space (xyzw) coordinates. + + Interpolates vertex attributes using perspective-correct interpolation and + clips triangles that lie outside the viewing frustum. + + Args: + clip_space_vertices: 3-D float32 tensor of homogenous vertices (xyzw) with + shape [batch_size, vertex_count, 4]. + attributes: 3-D float32 tensor with shape [batch_size, vertex_count, + attribute_count]. Each vertex attribute is interpolated across the + triangle using barycentric interpolation. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + background_value: a 1-D float32 tensor with shape [attribute_count]. Pixels + that lie outside all triangles take this value. + + Returns: + A 4-D float32 tensor with shape [batch_size, image_height, image_width, + attribute_count], containing the interpolated vertex attributes at + each pixel. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + if not image_width > 0: + raise ValueError('Image width must be > 0.') + if not image_height > 0: + raise ValueError('Image height must be > 0.') + if len(clip_space_vertices.shape) != 3: + raise ValueError('The vertex buffer must be 3D.') + + vertex_count = clip_space_vertices.shape[1].value + + batch_size = tf.shape(clip_space_vertices)[0] + + per_image_barycentric_coordinates = tf.TensorArray(dtype=tf.float32, + size=batch_size) + per_image_vertex_ids = tf.TensorArray(dtype=tf.int32, size=batch_size) + + def batch_loop_condition(b, *args): + return b < batch_size + + def batch_loop_iteration(b, per_image_barycentric_coordinates, + per_image_vertex_ids): + barycentric_coords, triangle_ids, _ = ( + rasterize_triangles_module.rasterize_triangles( + clip_space_vertices[b, :, :], triangles, image_width, + image_height)) + per_image_barycentric_coordinates = \ + per_image_barycentric_coordinates.write( + b, tf.reshape(barycentric_coords, [-1, 3])) + + vertex_ids = tf.gather(triangles, tf.reshape(triangle_ids, [-1])) + reindexed_ids = tf.add(vertex_ids, b * clip_space_vertices.shape[1].value) + per_image_vertex_ids = per_image_vertex_ids.write(b, reindexed_ids) + + return b+1, per_image_barycentric_coordinates, per_image_vertex_ids + + b = tf.constant(0) + _, per_image_barycentric_coordinates, per_image_vertex_ids = tf.while_loop( + batch_loop_condition, batch_loop_iteration, + [b, per_image_barycentric_coordinates, per_image_vertex_ids]) + + barycentric_coordinates = tf.reshape( + per_image_barycentric_coordinates.stack(), [-1, 3]) + vertex_ids = tf.reshape(per_image_vertex_ids.stack(), [-1, 3]) + + # Indexes with each pixel's clip-space triangle's extrema (the pixel's + # 'corner points') ids to get the relevant properties for deferred shading. + flattened_vertex_attributes = tf.reshape(attributes, + [batch_size * vertex_count, -1]) + corner_attributes = tf.gather(flattened_vertex_attributes, vertex_ids) + + # Computes the pixel attributes by interpolating the known attributes at the + # corner points of the triangle interpolated with the barycentric coordinates. + weighted_vertex_attributes = tf.multiply( + corner_attributes, tf.expand_dims(barycentric_coordinates, axis=2)) + summed_attributes = tf.reduce_sum(weighted_vertex_attributes, axis=1) + attribute_images = tf.reshape(summed_attributes, + [batch_size, image_height, image_width, -1]) + + # Barycentric coordinates should approximately sum to one where there is + # rendered geometry, but be exactly zero where there is not. + alphas = tf.clip_by_value( + tf.reduce_sum(2.0 * barycentric_coordinates, axis=1), 0.0, 1.0) + alphas = tf.reshape(alphas, [batch_size, image_height, image_width, 1]) + + attributes_with_background = ( + alphas * attribute_images + (1.0 - alphas) * background_value) + + return attributes_with_background + + +@tf.RegisterGradient('RasterizeTriangles') +def _rasterize_triangles_grad(op, df_dbarys, df_dids, df_dz): + # Gradients are only supported for barycentric coordinates. Gradients for the + # z-buffer are not currently implemented. If you need gradients w.r.t. z, + # include z as a vertex attribute when calling rasterize_triangles. + del df_dids, df_dz + return rasterize_triangles_module.rasterize_triangles_grad( + op.inputs[0], op.inputs[1], op.outputs[0], op.outputs[1], df_dbarys, + op.get_attr('image_width'), op.get_attr('image_height')), None diff --git a/Audio/code/mesh_renderer/rasterize_triangles_test.py b/Audio/code/mesh_renderer/rasterize_triangles_test.py new file mode 100644 index 0000000..ccd7e7c --- /dev/null +++ b/Audio/code/mesh_renderer/rasterize_triangles_test.py @@ -0,0 +1,196 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +import tensorflow as tf + +import test_utils +import camera_utils +import rasterize_triangles + + +class RenderTest(tf.test.TestCase): + + def setUp(self): + self.test_data_directory = 'mesh_renderer/test_data/' + + tf.reset_default_graph() + self.cube_vertex_positions = tf.constant( + [[-1, -1, 1], [-1, -1, -1], [-1, 1, -1], [-1, 1, 1], [1, -1, 1], + [1, -1, -1], [1, 1, -1], [1, 1, 1]], + dtype=tf.float32) + self.cube_triangles = tf.constant( + [[0, 1, 2], [2, 3, 0], [3, 2, 6], [6, 7, 3], [7, 6, 5], [5, 4, 7], + [4, 5, 1], [1, 0, 4], [5, 6, 2], [2, 1, 5], [7, 4, 0], [0, 3, 7]], + dtype=tf.int32) + + self.tf_float = lambda x: tf.constant(x, dtype=tf.float32) + + self.image_width = 640 + self.image_height = 480 + + self.perspective = camera_utils.perspective( + self.image_width / self.image_height, + self.tf_float([40.0]), self.tf_float([0.01]), + self.tf_float([10.0])) + + def runTriangleTest(self, w_vector, target_image_name): + """Directly renders a rasterized triangle's barycentric coordinates. + + Tests only the kernel (rasterize_triangles_module). + + Args: + w_vector: 3 element vector of w components to scale triangle vertices. + target_image_name: image file name to compare result against. + """ + clip_init = np.array( + [[-0.5, -0.5, 0.8, 1.0], [0.0, 0.5, 0.3, 1.0], [0.5, -0.5, 0.3, 1.0]], + dtype=np.float32) + clip_init = clip_init * np.reshape( + np.array(w_vector, dtype=np.float32), [3, 1]) + + clip_coordinates = tf.constant(clip_init) + triangles = tf.constant([[0, 1, 2]], dtype=tf.int32) + + rendered_coordinates, _, _ = ( + rasterize_triangles.rasterize_triangles_module.rasterize_triangles( + clip_coordinates, triangles, self.image_width, self.image_height)) + rendered_coordinates = tf.concat( + [rendered_coordinates, + tf.ones([self.image_height, self.image_width, 1])], axis=2) + with self.test_session() as sess: + image = rendered_coordinates.eval() + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, image) + + def testRendersSimpleTriangle(self): + self.runTriangleTest((1.0, 1.0, 1.0), 'Simple_Triangle.png') + + def testRendersPerspectiveCorrectTriangle(self): + self.runTriangleTest((0.2, 0.5, 2.0), 'Perspective_Corrected_Triangle.png') + + def testRendersTwoCubesInBatch(self): + """Renders a simple cube in two viewpoints to test the python wrapper.""" + + vertex_rgb = (self.cube_vertex_positions * 0.5 + 0.5) + vertex_rgba = tf.concat([vertex_rgb, tf.ones([8, 1])], axis=1) + + center = self.tf_float([[0.0, 0.0, 0.0]]) + world_up = self.tf_float([[0.0, 1.0, 0.0]]) + look_at_1 = camera_utils.look_at(self.tf_float([[2.0, 3.0, 6.0]]), + center, world_up) + look_at_2 = camera_utils.look_at(self.tf_float([[-3.0, 1.0, 6.0]]), + center, world_up) + projection_1 = tf.matmul(self.perspective, look_at_1) + projection_2 = tf.matmul(self.perspective, look_at_2) + projection = tf.concat([projection_1, projection_2], axis=0) + background_value = [0.0, 0.0, 0.0, 0.0] + + rendered = rasterize_triangles.rasterize( + tf.stack([self.cube_vertex_positions, self.cube_vertex_positions]), + tf.stack([vertex_rgba, vertex_rgba]), self.cube_triangles, projection, + self.image_width, self.image_height, background_value) + + with self.test_session() as sess: + images = sess.run(rendered, feed_dict={}) + for i in (0, 1): + image = images[i, :, :, :] + baseline_image_name = 'Unlit_Cube_{}.png'.format(i) + baseline_image_path = os.path.join(self.test_data_directory, + baseline_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, image) + + def testSimpleTriangleGradientComputation(self): + """Verifies the Jacobian matrix for a single pixel. + + The pixel is in the center of a triangle facing the camera. This makes it + easy to check which entries of the Jacobian might not make sense without + worrying about corner cases. + """ + test_pixel_x = 325 + test_pixel_y = 245 + + clip_coordinates = tf.placeholder(tf.float32, shape=[3, 4]) + + triangles = tf.constant([[0, 1, 2]], dtype=tf.int32) + + barycentric_coordinates, _, _ = ( + rasterize_triangles.rasterize_triangles_module.rasterize_triangles( + clip_coordinates, triangles, self.image_width, self.image_height)) + + pixels_to_compare = barycentric_coordinates[ + test_pixel_y:test_pixel_y + 1, test_pixel_x:test_pixel_x + 1, :] + + with self.test_session(): + ndc_init = np.array( + [[-0.5, -0.5, 0.8, 1.0], [0.0, 0.5, 0.3, 1.0], [0.5, -0.5, 0.3, 1.0]], + dtype=np.float32) + theoretical, numerical = tf.test.compute_gradient( + clip_coordinates, (3, 4), + pixels_to_compare, (1, 1, 3), + x_init_value=ndc_init, + delta=4e-2) + jacobians_match, message = ( + test_utils.check_jacobians_are_nearly_equal( + theoretical, numerical, 0.01, 0.0, True)) + self.assertTrue(jacobians_match, message) + + def testInternalRenderGradientComputation(self): + """Isolates and verifies the Jacobian matrix for the custom kernel.""" + image_height = 21 + image_width = 28 + + clip_coordinates = tf.placeholder(tf.float32, shape=[8, 4]) + + barycentric_coordinates, _, _ = ( + rasterize_triangles.rasterize_triangles_module.rasterize_triangles( + clip_coordinates, self.cube_triangles, image_width, image_height)) + + with self.test_session(): + # Precomputed transformation of the simple cube to normalized device + # coordinates, in order to isolate the rasterization gradient. + # pyformat: disable + ndc_init = np.array( + [[-0.43889722, -0.53184521, 0.85293502, 1.0], + [-0.37635487, 0.22206162, 0.90555805, 1.0], + [-0.22849123, 0.76811147, 0.80993629, 1.0], + [-0.2805393, -0.14092168, 0.71602166, 1.0], + [0.18631913, -0.62634289, 0.88603103, 1.0], + [0.16183566, 0.08129397, 0.93020856, 1.0], + [0.44147962, 0.53497446, 0.85076219, 1.0], + [0.53008741, -0.31276882, 0.77620775, 1.0]], + dtype=np.float32) + # pyformat: enable + theoretical, numerical = tf.test.compute_gradient( + clip_coordinates, (8, 4), + barycentric_coordinates, (image_height, image_width, 3), + x_init_value=ndc_init, + delta=4e-2) + jacobians_match, message = ( + test_utils.check_jacobians_are_nearly_equal( + theoretical, numerical, 0.01, 0.01)) + self.assertTrue(jacobians_match, message) + + +if __name__ == '__main__': + tf.test.main() diff --git a/Audio/code/mesh_renderer/test_utils.py b/Audio/code/mesh_renderer/test_utils.py new file mode 100644 index 0000000..6c0b46e --- /dev/null +++ b/Audio/code/mesh_renderer/test_utils.py @@ -0,0 +1,124 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common functions for the rasterizer and mesh renderer tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import tensorflow as tf + + +def check_jacobians_are_nearly_equal(theoretical, + numerical, + outlier_relative_error_threshold, + max_outlier_fraction, + include_jacobians_in_error_message=False): + """Compares two Jacobian matrices, allowing for some fraction of outliers. + + Args: + theoretical: 2D numpy array containing a Jacobian matrix with entries + computed via gradient functions. The layout should be as in the output + of gradient_checker. + numerical: 2D numpy array of the same shape as theoretical containing a + Jacobian matrix with entries computed via finite difference + approximations. The layout should be as in the output + of gradient_checker. + outlier_relative_error_threshold: float prescribing the maximum relative + error (from the finite difference approximation) is tolerated before + and entry is considered an outlier. + max_outlier_fraction: float defining the maximum fraction of entries in + theoretical that may be outliers before the check returns False. + include_jacobians_in_error_message: bool defining whether the jacobian + matrices should be included in the return message should the test fail. + + Returns: + A tuple where the first entry is a boolean describing whether + max_outlier_fraction was exceeded, and where the second entry is a string + containing an error message if one is relevant. + """ + outlier_gradients = np.abs( + numerical - theoretical) / numerical > outlier_relative_error_threshold + outlier_fraction = np.count_nonzero(outlier_gradients) / np.prod( + numerical.shape[:2]) + jacobians_match = outlier_fraction <= max_outlier_fraction + + message = ( + ' %f of theoretical gradients are relative outliers, but the maximum' + ' allowable fraction is %f ' % (outlier_fraction, max_outlier_fraction)) + if include_jacobians_in_error_message: + # the gradient_checker convention is the typical Jacobian transposed: + message += ('\nNumerical Jacobian:\n%s\nTheoretical Jacobian:\n%s' % + (repr(numerical.T), repr(theoretical.T))) + return jacobians_match, message + + +def expect_image_file_and_render_are_near(test_instance, + sess, + baseline_path, + result_image, + max_outlier_fraction=0.001, + pixel_error_threshold=0.01): + """Compares the output of mesh_renderer with an image on disk. + + The comparison is soft: the images are considered identical if at most + max_outlier_fraction of the pixels differ by more than a relative error of + pixel_error_threshold of the full color value. Note that before comparison, + mesh renderer values are clipped to the range [0,1]. + + Uses _images_are_near for the actual comparison. + + Args: + test_instance: a python unit test instance. + sess: a TensorFlow session for decoding the png. + baseline_path: path to the reference image on disk. + result_image: the result image, as a numpy array. + max_outlier_fraction: the maximum fraction of outlier pixels allowed. + pixel_error_threshold: pixel values are considered to differ if their + difference exceeds this amount. Range is 0.0 - 1.0. + """ + baseline_bytes = open(baseline_path, 'rb').read() + baseline_image = sess.run(tf.image.decode_png(baseline_bytes)) + + test_instance.assertEqual(baseline_image.shape, result_image.shape, + 'Image shapes %s and %s do not match.' % + (baseline_image.shape, result_image.shape)) + + result_image = np.clip(result_image, 0., 1.).copy(order='C') + baseline_image = baseline_image.astype(float) / 255.0 + + outlier_channels = (np.abs(baseline_image - result_image) > + pixel_error_threshold) + outlier_pixels = np.any(outlier_channels, axis=2) + outlier_count = np.count_nonzero(outlier_pixels) + outlier_fraction = outlier_count / np.prod(baseline_image.shape[:2]) + images_match = outlier_fraction <= max_outlier_fraction + + outputs_dir = "/tmp" #os.environ["TEST_TMPDIR"] + base_prefix = os.path.splitext(os.path.basename(baseline_path))[0] + result_output_path = os.path.join(outputs_dir, base_prefix + "_result.png") + + message = ('%s does not match. (%f of pixels are outliers, %f is allowed.). ' + 'Result image written to %s' % + (baseline_path, outlier_fraction, max_outlier_fraction, result_output_path)) + + if not images_match: + result_bytes = sess.run(tf.image.encode_png(result_image*255.0)) + with open(result_output_path, 'wb') as output_file: + output_file.write(result_bytes) + + test_instance.assertTrue(images_match, msg=message) diff --git a/Audio/code/models.py b/Audio/code/models.py new file mode 100644 index 0000000..2811a41 --- /dev/null +++ b/Audio/code/models.py @@ -0,0 +1,387 @@ +import torch +import torch.nn as nn +# from pts3d import * +from ops import * +import torchvision.models as models +import functools +from torch.autograd import Variable +import torch.nn.functional as F +from torch.nn import init +import numpy as np +from convolutional_rnn import Conv2dGRU + +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +class AT_net(nn.Module): + def __init__(self): + super(AT_net, self).__init__() + self.lmark_encoder = nn.Sequential( + nn.Linear(6,256), + nn.ReLU(True), + nn.Linear(256,512), + nn.ReLU(True), + + ) + self.audio_eocder = nn.Sequential( + conv2d(1,64,3,1,1), + conv2d(64,128,3,1,1), + nn.MaxPool2d(3, stride=(1,2)), + conv2d(128,256,3,1,1), + conv2d(256,256,3,1,1), + conv2d(256,512,3,1,1), + nn.MaxPool2d(3, stride=(2,2)) + ) + self.audio_eocder_fc = nn.Sequential( + nn.Linear(1024 *12,2048), + nn.ReLU(True), + nn.Linear(2048,256), + nn.ReLU(True), + + ) + self.lstm = nn.LSTM(256*3,256,3,batch_first = True) + self.lstm_fc = nn.Sequential( + nn.Linear(256,6), + ) + + def forward(self, example_landmark, audio): + hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()), + torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda())) + example_landmark_f = self.lmark_encoder(example_landmark) + #print 'example_landmark_f', example_landmark_f.shape # (1,512) + lstm_input = [] + for step_t in range(audio.size(1)): + current_audio = audio[ : ,step_t , :, :].unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + current_feature = self.audio_eocder_fc(current_feature) + features = torch.cat([example_landmark_f, current_feature], 1) + #print 'current_feature', current_feature.shape # (1,256) + #print 'features', features.shape # (1,768) + lstm_input.append(features) + lstm_input = torch.stack(lstm_input, dim = 1) + lstm_out, _ = self.lstm(lstm_input, hidden) + fc_out = [] + for step_t in range(audio.size(1)): + fc_in = lstm_out[:,step_t,:] + fc_out.append(self.lstm_fc(fc_in)) + return torch.stack(fc_out, dim = 1) + +class ATC_net(nn.Module): + def __init__(self, para_dim): + super(ATC_net, self).__init__() + self.audio_eocder = nn.Sequential( + conv2d(1,64,3,1,1), + conv2d(64,128,3,1,1), + nn.MaxPool2d(3, stride=(1,2)), + conv2d(128,256,3,1,1), + conv2d(256,256,3,1,1), + conv2d(256,512,3,1,1), + nn.MaxPool2d(3, stride=(2,2)) + ) + self.audio_eocder_fc = nn.Sequential( + nn.Linear(1024 *12,2048), + nn.ReLU(True), + nn.Linear(2048,256), + nn.ReLU(True), + + ) + self.lstm = nn.LSTM(256,256,3,batch_first = True) + self.lstm_fc = nn.Sequential( + nn.Linear(256,para_dim), + ) + + def forward(self, audio): + hidden = ( torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda()), + torch.autograd.Variable(torch.zeros(3, audio.size(0), 256).cuda())) + lstm_input = [] + for step_t in range(audio.size(1)): + current_audio = audio[ : ,step_t , :, :].unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + current_feature = self.audio_eocder_fc(current_feature) + lstm_input.append(current_feature) + lstm_input = torch.stack(lstm_input, dim = 1) + lstm_out, _ = self.lstm(lstm_input, hidden) # output, (hn,cn) = LSTM(input, (h0,c0)) + fc_out = [] + for step_t in range(audio.size(1)): + fc_in = lstm_out[:,step_t,:] + fc_out.append(self.lstm_fc(fc_in)) + return torch.stack(fc_out, dim = 1) + + +class AT_single(nn.Module): + def __init__(self): + super(AT_single, self).__init__() + # self.lmark_encoder = nn.Sequential( + # nn.Linear(6,256), + # nn.ReLU(True), + # nn.Linear(256,512), + # nn.ReLU(True), + + # ) + self.audio_eocder = nn.Sequential( + conv2d(1,64,3,1,1,normalizer = None), + conv2d(64,128,3,1,1,normalizer = None), + nn.MaxPool2d(3, stride=(1,2)), + conv2d(128,256,3,1,1,normalizer = None), + conv2d(256,256,3,1,1,normalizer = None), + conv2d(256,512,3,1,1,normalizer = None), + nn.MaxPool2d(3, stride=(2,2)) + ) + self.audio_eocder_fc = nn.Sequential( + nn.Linear(1024 *12,2048), + nn.ReLU(True), + nn.Linear(2048,256), + nn.ReLU(True), + nn.Linear(256, 6) + ) + # self.fusion = nn.Sequential( + # nn.Linear(256 *3, 256), + # nn.ReLU(True), + # nn.Linear(256, 6) + # ) + + def forward(self, audio): + current_audio = audio.unsqueeze(1) + current_feature = self.audio_eocder(current_audio) + current_feature = current_feature.view(current_feature.size(0), -1) + + output = self.audio_eocder_fc(current_feature) + + return output + + +class GL_Discriminator(nn.Module): + + + def __init__(self): + super(GL_Discriminator, self).__init__() + + self.image_encoder_dis = nn.Sequential( + conv2d(3,64,3,2, 1,normalizer=None), + # conv2d(64, 64, 4, 2, 1), + conv2d(64, 128, 3, 2, 1), + + conv2d(128, 256, 3, 2, 1), + + conv2d(256, 512, 3, 2, 1), + ) + self.encoder = nn.Sequential( + nn.Linear(136, 256), + nn.ReLU(True), + nn.Linear(256, 512), + nn.ReLU(True), + ) + self.decoder = nn.Sequential( + nn.Linear(1024, 512), + nn.ReLU(True), + nn.Linear(512, 136), + nn.Tanh() + ) + self.img_fc = nn.Sequential( + nn.Linear(512*8*8, 512), + nn.ReLU(True), + ) + + self.lstm = nn.LSTM(1024,256,3,batch_first = True) + self.lstm_fc = nn.Sequential( + nn.Linear(256,136), + nn.Tanh()) + self.decision = nn.Sequential( + nn.Linear(256,1), + ) + self.aggregator = nn.AvgPool1d(kernel_size = 16) + self.activate = nn.Sigmoid() + def forward(self, xs, example_landmark): + hidden = ( torch.autograd.Variable(torch.zeros(3, example_landmark.size(0), 256).cuda()), + torch.autograd.Variable(torch.zeros(3, example_landmark.size(0), 256).cuda())) + lstm_input = list() + lmark_feature= self.encoder(example_landmark) + for step_t in range(xs.size(1)): + x = xs[:,step_t,:,:, :] + x.data = x.data.contiguous() + x = self.image_encoder_dis(x) + x = x.view(x.size(0), -1) + x = self.img_fc(x) + new_feature = torch.cat([lmark_feature, x], 1) + lstm_input.append(new_feature) + lstm_input = torch.stack(lstm_input, dim = 1) + lstm_out, _ = self.lstm(lstm_input, hidden) + fc_out = [] + decision = [] + for step_t in range(xs.size(1)): + fc_in = lstm_out[:,step_t,:] + decision.append(self.decision(fc_in)) + fc_out.append(self.lstm_fc(fc_in)+ example_landmark) + fc_out = torch.stack(fc_out, dim = 1) + decision = torch.stack(decision, dim = 2) + decision = self.aggregator(decision) + decision = self.activate(decision) + return decision.view(decision.size(0)), fc_out + + + +class VG_net(nn.Module): + def __init__(self,input_nc = 3, output_nc = 3,ngf = 64, use_dropout=True, use_bias=False,norm_layer=nn.BatchNorm2d,n_blocks = 9,padding_type='zero'): + super(VG_net,self).__init__() + dtype = torch.FloatTensor + + + self.image_encoder1 = nn.Sequential( + nn.ReflectionPad2d(3), + conv2d(3, 64, 7,1, 0), + + # conv2d(64,16,3,1,1), + conv2d(64,64,3,2,1), + # conv2d(32,64,3,1,1), + conv2d(64,128,3,2,1) + ) + + self.image_encoder2 = nn.Sequential( + conv2d(128,256,3,2,1), + conv2d(256,512,3,2,1) + ) + + self.landmark_encoder = nn.Sequential( + nn.Linear(136, 64), + nn.ReLU(True) + ) + + self.landmark_encoder_stage2 = nn.Sequential( + conv2d(1,256,3), + + ) + self.lmark_att = nn.Sequential( + nn.ConvTranspose2d(512, 256,kernel_size=3, stride=(2),padding=(1), output_padding=1), + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.ConvTranspose2d(256, 128,kernel_size=3, stride=(2),padding=(1), output_padding=1), + nn.BatchNorm2d(128), + nn.ReLU(True), + conv2d(128, 1,3, activation=nn.Sigmoid, normalizer=None) + ) + self.lmark_feature = nn.Sequential( + conv2d(256,512,3)) + + model = [] + n_downsampling = 4 + mult = 2**(n_downsampling -1 ) + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling ): + mult = 2**(n_downsampling-i-1 ) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=(2), + padding=(1), output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + if i == n_downsampling-3: + self.generator1 = nn.Sequential(*model) + model = [] + + self.base = nn.Sequential(*model) + model = [] + model += [nn.Conv2d(int(ngf/2), output_nc, kernel_size=7, padding=3)] + model += [nn.Tanh()] + self.generator_color = nn.Sequential(*model) + + model = [] + model += [nn.Conv2d(int(ngf/2), 1, kernel_size=7, padding=3)] + model += [nn.Sigmoid()] + self.generator_attention = nn.Sequential(*model) + + self.bottle_neck = nn.Sequential(conv2d(1024,128,3,1,1)) + + + self.convGRU = Conv2dGRU(in_channels = 128, out_channels = 512, kernel_size = (3), num_layers = 1, bidirectional = False, dilation = 2, stride = 1, dropout = 0.5 ) + + def forward(self,image, landmarks, example_landmark ): + # ex_landmark1 = self.landmark_encoder(example_landmark.unsqueeze(2).unsqueeze(3).repeat(1, 1, 128,128)) + image_feature1 = self.image_encoder1(image) + image_feature = self.image_encoder2(image_feature1) + ex_landmark1 = self.landmark_encoder(example_landmark.view(example_landmark.size(0), -1)) + ex_landmark1 = ex_landmark1.view(ex_landmark1.size(0), 1, image_feature.size(2), image_feature.size(3) ) + ex_landmark1 = self.landmark_encoder_stage2(ex_landmark1) + ex_landmark = self.lmark_feature(ex_landmark1) + + lstm_input = list() + lmark_atts = list() + for step_t in range(landmarks.size(1)): + landmark = landmarks[:,step_t,:] + landmark.data = landmark.data.contiguous() + landmark = self.landmark_encoder(landmark.view(landmark.size(0), -1)) + landmark = landmark.view(landmark.size(0), 1, image_feature.size(2), image_feature.size(3) ) + landmark = self.landmark_encoder_stage2(landmark) + + lmark_att = self.lmark_att( torch.cat([landmark, ex_landmark1], dim=1)) + landmark = self.lmark_feature(landmark) + + inputs = self.bottle_neck(torch.cat([image_feature, landmark - ex_landmark], dim=1)) + lstm_input.append(inputs) + lmark_atts.append(lmark_att) + lmark_atts =torch.stack(lmark_atts, dim = 1) + lstm_input = torch.stack(lstm_input, dim = 1) + lstm_output, _ = self.convGRU(lstm_input) + + outputs = [] + atts = [] + colors = [] + for step_t in range(landmarks.size(1)): + input_t = lstm_output[:,step_t,:,:,:] + v_feature1 = self.generator1(input_t) + v_feature1_f = image_feature1 * (1- lmark_atts[:,step_t,:,:,:] ) + v_feature1 * lmark_atts[:,step_t,:,:,:] + base = self.base(v_feature1_f) + color = self.generator_color(base) + att = self.generator_attention(base) + atts.append(att) + colors.append(color) + output = att * color + (1 - att ) * image + outputs.append(output) + return torch.stack(outputs, dim = 1), torch.stack(atts, dim = 1), torch.stack(colors, dim = 1), lmark_atts + + +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out \ No newline at end of file diff --git a/Audio/code/ops.py b/Audio/code/ops.py new file mode 100644 index 0000000..e956642 --- /dev/null +++ b/Audio/code/ops.py @@ -0,0 +1,96 @@ +import torch +import torchvision +import torch.nn as nn +import torch.nn.init as init +from torch.autograd import Variable + + +class ResidualBlock(nn.Module): + def __init__(self, channel_in, channel_out): + super(ResidualBlock, self).__init__() + + self.block = nn.Sequential( + conv3d(channel_in, channel_out, 3, 1, 1), + conv3d(channel_out, channel_out, 3, 1, 1, activation=None) + ) + + self.lrelu = nn.ReLU(0.2) + + def forward(self, x): + residual = x + out = self.block(x) + + out += residual + out = self.lrelu(out) + return out + +def linear(channel_in, channel_out, + activation=nn.ReLU, + normalizer=nn.BatchNorm1d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.Linear(channel_in, channel_out, bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def conv2d(channel_in, channel_out, + ksize=3, stride=1, padding=1, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.Conv2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def conv_transpose2d(channel_in, channel_out, + ksize=4, stride=2, padding=1, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.ConvTranspose2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[0].weight) + + return nn.Sequential(*layer) + + +def nn_conv2d(channel_in, channel_out, + ksize=3, stride=1, padding=1, + scale_factor=2, + activation=nn.ReLU, + normalizer=nn.BatchNorm2d): + layer = list() + bias = True if not normalizer else False + + layer.append(nn.UpsamplingNearest2d(scale_factor=scale_factor)) + layer.append(nn.Conv2d(channel_in, channel_out, + ksize, stride, padding, + bias=bias)) + _apply(layer, activation, normalizer, channel_out) + # init.kaiming_normal(layer[1].weight) + + return nn.Sequential(*layer) + + +def _apply(layer, activation, normalizer, channel_out=None): + if normalizer: + layer.append(normalizer(channel_out)) + if activation: + layer.append(activation()) + return layer + diff --git a/Audio/code/reconstruct_mesh.py b/Audio/code/reconstruct_mesh.py new file mode 100644 index 0000000..2b37e30 --- /dev/null +++ b/Audio/code/reconstruct_mesh.py @@ -0,0 +1,285 @@ +import numpy as np + +# input: coeff with shape [1,257] +def Split_coeff(coeff): + id_coeff = coeff[:,:80] # identity(shape) coeff of dim 80 + ex_coeff = coeff[:,80:144] # expression coeff of dim 64 + tex_coeff = coeff[:,144:224] # texture(albedo) coeff of dim 80 + angles = coeff[:,224:227] # ruler angles(x,y,z) for rotation of dim 3 + gamma = coeff[:,227:254] # lighting coeff for 3 channel SH function of dim 27 + translation = coeff[:,254:] # translation coeff of dim 3 + + return id_coeff,ex_coeff,tex_coeff,angles,gamma,translation + + +# compute face shape with identity and expression coeff, based on BFM model +# input: id_coeff with shape [1,80] +# ex_coeff with shape [1,64] +# output: face_shape with shape [1,N,3], N is number of vertices +def Shape_formation(id_coeff,ex_coeff,facemodel): + face_shape = np.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \ + np.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + \ + facemodel.meanshape + + face_shape = np.reshape(face_shape,[1,-1,3]) + # re-center face shape + face_shape = face_shape - np.mean(np.reshape(facemodel.meanshape,[1,-1,3]), axis = 1, keepdims = True) + + return face_shape + +# compute vertex normal using one-ring neighborhood +# input: face_shape with shape [1,N,3] +# output: v_norm with shape [1,N,3] +def Compute_norm(face_shape,facemodel): + + face_id = facemodel.tri # vertex index for each triangle face, with shape [F,3], F is number of faces + point_id = facemodel.point_buf # adjacent face index for each vertex, with shape [N,8], N is number of vertex + shape = face_shape + face_id = (face_id - 1).astype(np.int32) + point_id = (point_id - 1).astype(np.int32) + v1 = shape[:,face_id[:,0],:] + v2 = shape[:,face_id[:,1],:] + v3 = shape[:,face_id[:,2],:] + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = np.cross(e1,e2) # compute normal for each face + face_norm = np.concatenate([face_norm,np.zeros([1,1,3])], axis = 1) # concat face_normal with a zero vector at the end + v_norm = np.sum(face_norm[:,point_id,:], axis = 2) # compute vertex normal using one-ring neighborhood + v_norm = v_norm/np.expand_dims(np.linalg.norm(v_norm,axis = 2),2) # normalize normal vectors + + return v_norm + +# compute vertex texture(albedo) with tex_coeff +# input: tex_coeff with shape [1,N,3] +# output: face_texture with shape [1,N,3], RGB order, range from 0-255 +def Texture_formation(tex_coeff,facemodel): + + face_texture = np.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex + face_texture = np.reshape(face_texture,[1,-1,3]) + + return face_texture + +# compute rotation matrix based on 3 ruler angles +# input: angles with shape [1,3] +# output: rotation matrix with shape [1,3,3] +def Compute_rotation_matrix(angles): + + angle_x = angles[:,0][0] + angle_y = angles[:,1][0] + angle_z = angles[:,2][0] + + # compute rotation matrix for X,Y,Z axis respectively + rotation_X = np.array([1.0,0,0,\ + 0,np.cos(angle_x),-np.sin(angle_x),\ + 0,np.sin(angle_x),np.cos(angle_x)]) + rotation_Y = np.array([np.cos(angle_y),0,np.sin(angle_y),\ + 0,1,0,\ + -np.sin(angle_y),0,np.cos(angle_y)]) + rotation_Z = np.array([np.cos(angle_z),-np.sin(angle_z),0,\ + np.sin(angle_z),np.cos(angle_z),0,\ + 0,0,1]) + + rotation_X = np.reshape(rotation_X,[1,3,3]) + rotation_Y = np.reshape(rotation_Y,[1,3,3]) + rotation_Z = np.reshape(rotation_Z,[1,3,3]) + + rotation = np.matmul(np.matmul(rotation_Z,rotation_Y),rotation_X) + rotation = np.transpose(rotation, axes = [0,2,1]) #transpose row and column (dimension 1 and 2) + + return rotation + +# project 3D face onto image plane +# input: face_shape with shape [1,N,3] +# rotation with shape [1,3,3] +# translation with shape [1,3] +# output: face_projection with shape [1,N,2] +# z_buffer with shape [1,N,1] +def Projection_layer(face_shape,rotation,translation,focal=1015.0,center=112.0): # we choose the focal length and camera position empirically + + camera_pos = np.reshape(np.array([0.0,0.0,10.0]),[1,1,3]) # camera position + reverse_z = np.reshape(np.array([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3]) + + + p_matrix = np.concatenate([[focal],[0.0],[center],[0.0],[focal],[center],[0.0],[0.0],[1.0]],axis = 0) # projection matrix + p_matrix = np.reshape(p_matrix,[1,3,3]) + + # calculate face position in camera space + face_shape_r = np.matmul(face_shape,rotation) + face_shape_t = face_shape_r + np.reshape(translation,[1,1,3]) + face_shape_t = np.matmul(face_shape_t,reverse_z) + camera_pos + + # calculate projection of face vertex using perspective projection + aug_projection = np.matmul(face_shape_t,np.transpose(p_matrix,[0,2,1])) + face_projection = aug_projection[:,:,0:2]/np.reshape(aug_projection[:,:,2],[1,np.shape(aug_projection)[1],1]) + z_buffer = np.reshape(aug_projection[:,:,2],[1,-1,1]) + + return face_projection,z_buffer + +# compute vertex color using face_texture and SH function lighting approximation +# input: face_texture with shape [1,N,3] +# norm with shape [1,N,3] +# gamma with shape [1,27] +# output: face_color with shape [1,N,3], RGB order, range from 0-255 +# lighting with shape [1,N,3], color under uniform texture +def Illumination_layer(face_texture,norm,gamma): + + num_vertex = np.shape(face_texture)[1] + + init_lit = np.array([0.8,0,0,0,0,0,0,0,0]) + gamma = np.reshape(gamma,[-1,3,9]) + gamma = gamma + np.reshape(init_lit,[1,1,9]) + + # parameter of 9 SH function + a0 = np.pi + a1 = 2*np.pi/np.sqrt(3.0) + a2 = 2*np.pi/np.sqrt(8.0) + c0 = 1/np.sqrt(4*np.pi) + c1 = np.sqrt(3.0)/np.sqrt(4*np.pi) + c2 = 3*np.sqrt(5.0)/np.sqrt(12*np.pi) + + Y0 = np.tile(np.reshape(a0*c0,[1,1,1]),[1,num_vertex,1]) + Y1 = np.reshape(-a1*c1*norm[:,:,1],[1,num_vertex,1]) + Y2 = np.reshape(a1*c1*norm[:,:,2],[1,num_vertex,1]) + Y3 = np.reshape(-a1*c1*norm[:,:,0],[1,num_vertex,1]) + Y4 = np.reshape(a2*c2*norm[:,:,0]*norm[:,:,1],[1,num_vertex,1]) + Y5 = np.reshape(-a2*c2*norm[:,:,1]*norm[:,:,2],[1,num_vertex,1]) + Y6 = np.reshape(a2*c2*0.5/np.sqrt(3.0)*(3*np.square(norm[:,:,2])-1),[1,num_vertex,1]) + Y7 = np.reshape(-a2*c2*norm[:,:,0]*norm[:,:,2],[1,num_vertex,1]) + Y8 = np.reshape(a2*c2*0.5*(np.square(norm[:,:,0])-np.square(norm[:,:,1])),[1,num_vertex,1]) + + Y = np.concatenate([Y0,Y1,Y2,Y3,Y4,Y5,Y6,Y7,Y8],axis=2) + + # Y shape:[batch,N,9]. + + lit_r = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,0,:],2)),2) #[batch,N,9] * [batch,9,1] = [batch,N] + lit_g = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,1,:],2)),2) + lit_b = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,2,:],2)),2) + + # shape:[batch,N,3] + face_color = np.stack([lit_r*face_texture[:,:,0],lit_g*face_texture[:,:,1],lit_b*face_texture[:,:,2]],axis = 2) + lighting = np.stack([lit_r,lit_g,lit_b],axis = 2)*128 + + return face_color,lighting + +# face reconstruction with coeff and BFM model +def Reconstruction(coeff,facemodel): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + # compute face shape + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + # compute vertex texture(albedo) + face_texture = Texture_formation(tex_coeff, facemodel) + # vertex normal + face_norm = Compute_norm(face_shape,facemodel) + # rotation matrix + rotation = Compute_rotation_matrix(angles) + face_norm_r = np.matmul(face_norm,rotation) + + # compute vertex projection on image plane (with image sized 224*224) + face_projection,z_buffer = Projection_layer(face_shape,rotation,translation) + face_projection = np.stack([face_projection[:,:,0],224 - face_projection[:,:,1]], axis = 2) + + # compute 68 landmark on image plane + landmarks_2d = face_projection[:,facemodel.keypoints,:] + + # compute vertex color using SH function lighting approximation + face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma) + + # vertex index for each face of BFM model + tri = facemodel.tri + + return face_shape,face_texture,face_color,tri,face_projection,z_buffer,landmarks_2d + +# def Reconstruction_for_render(coeff,facemodel): +# id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) +# face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) +# face_texture = Texture_formation(tex_coeff, facemodel) +# face_norm = Compute_norm(face_shape,facemodel) +# rotation = Compute_rotation_matrix(angles) +# face_shape_r = np.matmul(face_shape,rotation) +# face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) +# face_norm_r = np.matmul(face_norm,rotation) +# face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma) +# tri = facemodel.face_buf + +# return face_shape_r,face_norm_r,face_color,tri + +def Reconstruction_for_render(coeff,facemodel): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + face_texture = Texture_formation(tex_coeff, facemodel) + face_norm = Compute_norm(face_shape,facemodel) + rotation = Compute_rotation_matrix(angles) + face_shape_r = np.matmul(face_shape,rotation) + face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) + face_norm_r = np.matmul(face_norm,rotation) + face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma) + tri = facemodel.tri + + return face_shape_r,face_norm_r,face_color,tri + +def Reconstruction_for_render_new_given(coeff,facemodel,tex2_path): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + face_texture2 = np.load(tex2_path) + face_norm = Compute_norm(face_shape,facemodel) + rotation = Compute_rotation_matrix(angles) + face_shape_r = np.matmul(face_shape,rotation) + face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) + face_norm_r = np.matmul(face_norm,rotation) + face_color2,lighting = Illumination_layer(face_texture2, face_norm_r, gamma) + tri = facemodel.tri + + return face_shape_r,face_norm_r,face_color2,tri + +import tensorflow as tf +import mesh_renderer.mesh_renderer as mesh_renderer + +def Render_layer(face_shape,face_norm,face_color,facemodel,batchsize): + + camera_position = tf.constant([0,0,10.0]) + camera_lookat = tf.constant([0,0,0.0]) + camera_up = tf.constant([0,1.0,0]) + light_positions = tf.tile(tf.reshape(tf.constant([0,0,1e5]),[1,1,3]),[batchsize,1,1]) + light_intensities = tf.tile(tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3]),[batchsize,1,1]) + ambient_color = tf.tile(tf.reshape(tf.constant([1.0,1,1]),[1,3]),[batchsize,1]) + + render = mesh_renderer.mesh_renderer(face_shape, + tf.cast(facemodel.tri-1,tf.int32), + face_norm, + face_color/255, + camera_position = camera_position, + camera_lookat = camera_lookat, + camera_up = camera_up, + light_positions = light_positions, + light_intensities = light_intensities, + image_width = 224, + image_height = 224, + fov_y = 12.5936, + ambient_color = ambient_color) + + return render + +def Render_layer2(face_shape,face_norm,face_color,facemodel,batchsize): + + camera_position = tf.constant([0,0,10.0]) + camera_lookat = tf.constant([0,0,0.0]) + camera_up = tf.constant([0,1.0,0]) + light_positions = tf.tile(tf.reshape(tf.constant([0,0,1e5]),[1,1,3]),[batchsize,1,1]) + light_intensities = tf.tile(tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3]),[batchsize,1,1]) + ambient_color = tf.tile(tf.reshape(tf.constant([1.0,1,1]),[1,3]),[batchsize,1]) + + render = mesh_renderer.mesh_renderer(face_shape, + tf.cast(facemodel.tri-1,tf.int32), + face_norm, + face_color/255, + camera_position = camera_position, + camera_lookat = camera_lookat, + camera_up = camera_up, + light_positions = light_positions, + light_intensities = light_intensities, + image_width = 256, + image_height = 256, + fov_y = 12.5936, + ambient_color = ambient_color) + + return render \ No newline at end of file diff --git a/Audio/code/render_for_view2.py b/Audio/code/render_for_view2.py new file mode 100644 index 0000000..bf1fb25 --- /dev/null +++ b/Audio/code/render_for_view2.py @@ -0,0 +1,177 @@ +import tensorflow as tf +import os +from scipy.io import loadmat,savemat +from reconstruct_mesh import Reconstruction_for_render, Render_layer, Render_layer2 +from reconstruct_mesh import Reconstruction_for_render_new_given +import numpy as np +import sys +import glob +from PIL import Image +import pdb + +rootdir = '../../Deep3DFaceReconstruction/' + +class BFM(): + def __init__(self): + model_path = rootdir+'BFM/BFM_model_front.mat' + model = loadmat(model_path) + self.meanshape = model['meanshape'] # mean face shape + self.idBase = model['idBase'] # identity basis + self.exBase = model['exBase'] # expression basis + self.meantex = model['meantex'] # mean face texture + self.texBase = model['texBase'] # texture basis + self.point_buf = model['point_buf'] # adjacent face index for each vertex, starts from 1 (only used for calculating face normal) + self.tri = model['tri'] # vertex index for each triangle face, starts from 1 + self.keypoints = np.squeeze(model['keypoints']).astype(np.int32) - 1 # 68 face landmark index, starts from 0 + +class RenderObject(object): + def __init__(self, sess): + # read face model + self.facemodel = BFM() + + self.faceshaper = tf.placeholder(name = "face_shape_r", shape = [1,35709,3], dtype = tf.float32) + self.facenormr = tf.placeholder(name = "face_norm_r", shape = [1,35709,3], dtype = tf.float32) + self.facecolor = tf.placeholder(name = "face_color", shape = [1,35709,3], dtype = tf.float32) + self.rendered = Render_layer(self.faceshaper,self.facenormr,self.facecolor,self.facemodel,1) + self.rendered2 = Render_layer2(self.faceshaper,self.facenormr,self.facecolor,self.facemodel,1) + + self.rstimg = tf.placeholder(name = 'rstimg', dtype=tf.uint8) + self.encode_png = tf.image.encode_png(self.rstimg) + + self.sess = sess + self.last = np.zeros((6)) + + def save_image(self, final_images, result_output_path): + result_image = final_images[0, :, :, :] + result_image = np.clip(result_image, 0., 1.).copy(order='C') + result_bytes = sess.run(self.encode_png, {self.rstimg: result_image*255.0}) + with open(result_output_path, 'wb') as output_file: + output_file.write(result_bytes) + + def save_image2(self, final_images, result_output_path, tx=0, ty=0): + result_image = final_images[0, :, :, :] + #print(result_image.shape) + result_image = np.clip(result_image, 0., 1.) * 255.0 + result_image = np.round(result_image).astype(np.uint8) + im = Image.fromarray(result_image,'RGBA') + #pdb.set_trace() + if tx != 0 or ty != 0: + im = im.transform(im.size, Image.AFFINE, (1, 0, -tx, 0, 1, -ty)) + im.save(result_output_path) + + + def render(self, coef_path): + data = loadmat(coef_path) + coef = data['coeff'] + + result_output_path = os.path.join('output/render_fuse',coef_path[13:-4]+'_render.png') + if not os.path.exists(os.path.dirname(result_output_path)): + os.makedirs(os.path.dirname(result_output_path)) + + face_shape_r,face_norm_r,face_color,tri = Reconstruction_for_render(coef,self.facemodel) + final_images = self.sess.run(self.rendered, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color.astype('float32')}) + self.save_image(final_images, result_output_path) + + def render256(self, coef_path): + data = loadmat(coef_path) + coef = data['coeff'] + + result_output_path = os.path.join('output/render_fuse',coef_path[13:-4]+'_render256.png') + if not os.path.exists(os.path.dirname(result_output_path)): + os.makedirs(os.path.dirname(result_output_path)) + + face_shape_r,face_norm_r,face_color,tri = Reconstruction_for_render(coef,self.facemodel) + final_images = self.sess.run(self.rendered2, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color.astype('float32')}) + self.save_image(final_images, result_output_path) + + + def render2(self, coef_path1, coef_path2, result_output_path, pose=0, relativeframe=0, frame0=0, tran=0): + data1 = loadmat(coef_path1) + coef1 = data1['coeff'] + if coef_path2[-4:] == '.mat': + data2 = loadmat(coef_path2) + coef2 = data2['coeff'] + #transfer ex_coef + coef1[:,80:144] = coef2[:,80:144] + else: + coef2 = np.load(coef_path2) # shape (64, ) + if pose == 0: + coef1[:,80:144] = coef2 + else: + L = 64 + coef1[:,80:144] = coef2[:L] + if relativeframe == 0:################# + coef1[:,224:227] = coef2[L:L+3] + coef1[:,254:257] = coef2[L+3:L+6] + coef1[:,256] += 0.5 + elif relativeframe == 2: + coef1[:,224:227] = coef2[L:L+3] + coef1[:,254:257] = coef2[L+3:L+6] + else: + if not frame0: + coef1[:,224:227] = coef2[L:L+3] + self.last[:3] + coef1[:,254:257] = coef2[L+3:L+6] + self.last[3:6] + self.last[:3] = coef1[:,224:227] + self.last[3:6] = coef1[:,254:257] + + face_shape_r,face_norm_r,face_color,tri = Reconstruction_for_render(coef1,self.facemodel) + final_images = self.sess.run(self.rendered, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color.astype('float32')}) + if coef2.shape[0] >= L+8 and tran==1: + self.save_image2(final_images, result_output_path, tx=coef2[L+6], ty=coef2[L+7]) + else: + self.save_image(final_images, result_output_path) + + def render2_newtex(self, coef_path1, coef_path2, tex2_path, result_output_path, pose=0, relativeframe=0, frame0=0, tran=0): + data1 = loadmat(coef_path1) + coef1 = data1['coeff'] + if coef_path2[-4:] == '.mat': + data2 = loadmat(coef_path2) + coef2 = data2['coeff'] + #transfer ex_coef + coef1[:,80:144] = coef2[:,80:144] + else: + coef2 = np.load(coef_path2) # shape (64, ) + L = 64 + coef1[:,80:144] = coef2[:L] + if relativeframe == 2: + coef1[:,224:227] = coef2[L:L+3] + coef1[:,254:257] = coef2[L+3:L+6] + if relativeframe == 0: + coef1[:,224:227] = coef2[L:L+3] + coef1[:,254:257] = coef2[L+3:L+6] + coef1[:,256] += 0.5 + face_shape_r,face_norm_r,face_color2,tri = Reconstruction_for_render_new_given(coef1,self.facemodel,tex2_path) + final_images = self.sess.run(self.rendered, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color2.astype('float32')}) + if coef2.shape[0] >= L+8 and tran==1: + self.save_image2(final_images, result_output_path, tx=coef2[L+6], ty=coef2[L+7]) + else: + self.save_image(final_images, result_output_path) + +if __name__ == '__main__': + coef_dir = sys.argv[1] + coef_path1 = sys.argv[2] + save_dir = sys.argv[3] + pose = int('pose' in coef_dir) + relativeframe = int(sys.argv[4]) + tran = int(sys.argv[5]) + tex2_path = sys.argv[6] if len(sys.argv) > 6 else '' + print('pose',pose) + print('relativeframe',relativeframe) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + coef_paths = sorted(glob.glob(coef_dir+'/*.npy')) + L = len(coef_paths) + with tf.Session() as sess: + render_object = RenderObject(sess) + for i in range(L): + basen = os.path.basename(coef_paths[i]) + save = os.path.join(save_dir,basen[:-4]+'.png') + if tex2_path == '': + # old texture + #render_object.render2(coef_path1, coef_paths[i], save) + render_object.render2(coef_path1, coef_paths[i], save, pose, relativeframe, i==0, tran=tran) + else: + # new texture + render_object.render2_newtex(coef_path1, coef_paths[i], tex2_path, save, pose, relativeframe, i==0, tran=tran) + if i % 100 == 0 and i != 0: + print('rendered', i) \ No newline at end of file diff --git a/Audio/code/test_personalized.py b/Audio/code/test_personalized.py new file mode 100755 index 0000000..b9dd316 --- /dev/null +++ b/Audio/code/test_personalized.py @@ -0,0 +1,101 @@ +#encoding:utf-8 +#test different audio +import os +from choose_bg_gexinghua2_reassign import choose_bg_gexinghua2_reassign2 +from trans_with_bigbg import merge_with_bigbg +import glob +import pdb +from PIL import Image +import numpy as np +import sys + +def getsingle(srcdir,name,varybg=0,multi=0): + srcroot = os.getcwd() + if not varybg: + imgs = glob.glob(os.path.join(srcroot,srcdir,'*_blend.png')) + print('srcdir',os.path.join(srcroot,srcdir,'*_blend.png')) + else: + imgs = glob.glob(os.path.join(srcroot,srcdir,'*_blend2.png')) + print('srcdir',os.path.join(srcroot,srcdir,'*_blend2.png')) + f1 = open('../../render-to-video/datasets/list/testSingle/%s.txt'%name,'w') + imgs = sorted(imgs) + if multi: + imgs = imgs[2:] + for im in imgs: + print(im, file=f1) + f1.close() + +gpu_id = 0 if len(sys.argv) < 4 else int(sys.argv[3]) +start=0;ganepoch=60;audioepoch=99 + + +audiobasen=sys.argv[1] +n = int(sys.argv[2])#person id + +if __name__ == "__main__": + person = str(n) + if os.path.exists(os.path.join('../audio/',audiobasen+'.wav')): + in_file = os.path.join('../audio/',audiobasen+'.wav') + elif os.path.exists(os.path.join('../audio/',audiobasen+'.mp3')): + in_file = os.path.join('../audio/',audiobasen+'.mp3') + else: + print('audio file not exists, please put in %s'%os.path.join(os.getcwd(),'../audio')) + exit(-1) + + audio_exp_name = 'atcnet_pose0_con3/'+person + audiomodel=os.path.join(audio_exp_name,audiobasen+'_%d'%audioepoch) + sample_dir = os.path.join('../results/',audiomodel) + ganmodel='memory_seq_p2p/%s'%person;post='_full9' + pingyi = 1; + seq='rseq_'+person+'_'+audiobasen+post + if audioepoch == 49: + seq='rseq_'+person+'_'+audiobasen+'_%d%s'%(audioepoch,post) + + + ## 1.audio to 3dmm + if not os.path.exists(sample_dir+'/00000.npy'): + add = '--model_name ../model/%s/atcnet_lstm_%d.pth --pose 1 --relativeframe 0' % (audio_exp_name,audioepoch) + print('python atcnet_test1.py --device_ids %d %s --sample_dir %s --in_file %s' % (gpu_id,add,sample_dir,in_file)) + os.system('python atcnet_test1.py --device_ids %d %s --sample_dir %s --in_file %s' % (gpu_id,add,sample_dir,in_file)) + + ## 2.background matching + speed=1 + num = 300 + bgdir = choose_bg_gexinghua2_reassign2('19_news/'+person, audiobasen, start, audiomodel, num=num, tran=pingyi, speed=speed) + + + ## 3.render to save_dir + coeff_dir = os.path.join(sample_dir,'reassign') + rootdir = '../../Deep3DFaceReconstruction/output/coeff/' + tex2_path = '' + coef_path1 = rootdir+'19_news/'+person+'/frame%d.mat'%start + save_dir = os.path.join(sample_dir,'R_%s_reassign2'%person) + relativeframe = 2 + os.system('CUDA_VISIBLE_DEVICES=%d python render_for_view2.py %s %s %s %d %d %s'%(gpu_id,coeff_dir,coef_path1,save_dir, relativeframe,pingyi,tex2_path)) + + + ## 4.blend rendered with background + srcdir = save_dir + #if not os.path.exists(save_dir+'/00000_blend2.png'): + cmd = "cd ../results; matlab -nojvm -nosplash -nodesktop -nodisplay -r \"alpha_blend_vbg('" + bgdir + "','" + srcdir + "'); quit;\"" + os.system(cmd) + + ## 5.gan + sample_dir2 = '../../render-to-video/results/%s/test_%d/images%s/'%(ganmodel,ganepoch,seq) + #if not os.path.exists(sample_dir2): + getsingle(save_dir,seq,1,1) + os.system('cd ../../render-to-video; python test_memory.py --dataroot %s --name %s --netG unetac_adain_256 --model test --Nw 3 --norm batch --dataset_mode single_multi --use_memory 1 --attention 1 --num_test 10000 --epoch %d --gpu_ids %d --imagefolder images%s'%(seq,ganmodel,ganepoch,gpu_id,seq)) + + + os.system('cp '+sample_dir2+'/R_'+person+'_reassign2-00002_blend2_fake.png '+sample_dir2+'/R_'+person+'_reassign2-00000_blend2_fake.png') + os.system('cp '+sample_dir2+'/R_'+person+'_reassign2-00002_blend2_fake.png '+sample_dir2+'/R_'+person+'_reassign2-00001_blend2_fake.png') + + video_name = os.path.join(sample_dir,'%s_%swav_results%s.mp4'%(person,audiobasen,post)) + command = 'ffmpeg -loglevel panic -framerate 25 -i ' + sample_dir2 + '/R_' + person + '_reassign2-%05d_blend2_fake.png -c:v libx264 -y -vf format=yuv420p ' + video_name + os.system(command) + command = 'ffmpeg -loglevel panic -i ' + video_name + ' -i ' + in_file + ' -vcodec copy -acodec copy -y ' + video_name.replace('.mp4','.mov') + os.system(command) + os.remove(video_name) + print('saved to',video_name.replace('.mp4','.mov')) + + merge_with_bigbg(audiobasen,n) diff --git a/Audio/code/train_19news_1.py b/Audio/code/train_19news_1.py new file mode 100644 index 0000000..a952079 --- /dev/null +++ b/Audio/code/train_19news_1.py @@ -0,0 +1,63 @@ +import librosa +import python_speech_features +import numpy as np +import os, glob, sys + +def get_mfcc_extend(video, srcdir, tardir): + test_file = os.path.join(srcdir,video) + save_file = os.path.join(tardir,video[:-4]+'.npy') + if os.path.exists(save_file): + mfcc = np.load(save_file) + return mfcc + speech, sr = librosa.load(test_file, sr=16000) + speech = np.insert(speech, 0, np.zeros(1920)) + speech = np.append(speech, np.zeros(1920)) + mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01) + if not os.path.exists(os.path.dirname(save_file)): + os.makedirs(os.path.dirname(save_file)) + np.save(save_file, mfcc) + return mfcc + +def save_each_100(folder): + pths = sorted(glob.glob(folder+'/*.pth')) + for pth in pths: + epoch = int(os.path.basename(pth).split('_')[-1][:-4]) + if epoch % 100 == 99: + continue + #print(epoch) + os.remove(pth) + +n = int(sys.argv[1]) +gpu_id = int(sys.argv[2]) +# check video +mp4 = '../../Data/%d.mp4'%n +if not os.path.exists(mp4): + print('target video', mp4, 'not exists') + exit(-1) +# check 3d recon +rootdir = '../../Deep3DFaceReconstruction/output/coeff/19_news/%d' % n +valid = True +for i in range(300): + if not os.path.exists(os.path.join(rootdir,'frame%d.mat'%i)): + print(n,'lack','frame%d.mat'%i) + valid = False +if not valid: + print('not all 300 frames are reconstructed successfully') + exit(-1) +# extract mfcc +srcdir = '../../Data/' +tardir = '../dataset/mfcc/19_news' +video = str(n)+'.mp4' +get_mfcc_extend(video, srcdir, tardir) + +# fine tune audio +n = str(n) +if not os.path.exists('../model/atcnet_pose0_con3/%s'%n): + os.makedirs('../model/atcnet_pose0_con3/%s'%n) +if not os.path.exists('../sample/atcnet_pose0_con3/%s'%n): + os.makedirs('../sample/atcnet_pose0_con3/%s'%n) +if not os.path.exists('../model/atcnet_pose0_con3/%s/atcnet_lstm_99.pth'%n): + cmd = 'python atcnet.py --pose 1 --relativeframe 0 --dataset news --newsname 19_news/%s --start 0 --model_dir ../model/atcnet_pose0_con3/%s/ --continue_train 1 --lr 0.0001 --less_constrain 1 --smooth_loss 1 --smooth_loss2 1 --model_name ../model/atcnet_lstm_general.pth --sample_dir ../sample/atcnet_pose0_con3/%s --device_ids %d --max_epochs 100' % (n, n, n, gpu_id) + print(cmd) + os.system(cmd) +save_each_100('../model/atcnet_pose0_con3/%s'%n) \ No newline at end of file diff --git a/Audio/code/trans_with_bigbg.py b/Audio/code/trans_with_bigbg.py new file mode 100644 index 0000000..1704d2d --- /dev/null +++ b/Audio/code/trans_with_bigbg.py @@ -0,0 +1,88 @@ +import os, sys +import glob +import numpy as np +import pdb +from PIL import Image +import cv2 + + +def merge_with_bigbg(audiobasen,n): + start=0;ganepoch=60;audioepoch=99 + seamlessclone = 1 + person = str(n) + if os.path.exists(os.path.join('../audio/',audiobasen+'.wav')): + in_file = os.path.join('../audio/',audiobasen+'.wav') + elif os.path.exists(os.path.join('../audio/',audiobasen+'.mp3')): + in_file = os.path.join('../audio/',audiobasen+'.mp3') + else: + print('audio file not exists, please put in %s'%os.path.join(os.getcwd(),'../audio')) + return + + audio_exp_name = 'atcnet_pose0_con3/'+person + audiomodel=os.path.join(audio_exp_name,audiobasen+'_%d'%audioepoch) + sample_dir = os.path.join('../results/',audiomodel) + ganmodel='memory_seq_p2p/%sold3'%person;post='_full9' + seq='rseq_'+person+'_'+audiobasen+post + if audioepoch == 49: + seq='rseq_'+person+'_'+audiobasen+'_%d%s'%(audioepoch,post) + + coeff_dir = os.path.join(sample_dir,'reassign') + + sample_dir2 = '../../render-to-video/results/%s/test_%d/images%s/'%(ganmodel,ganepoch,seq) + os.system('cp '+sample_dir2+'/R_'+person+'_reassign2-00002_blend2_fake.png '+sample_dir2+'/R_'+person+'_reassign2-00000_blend2_fake.png') + os.system('cp '+sample_dir2+'/R_'+person+'_reassign2-00002_blend2_fake.png '+sample_dir2+'/R_'+person+'_reassign2-00001_blend2_fake.png') + + video_name = os.path.join(sample_dir,'%s_%swav_results%s.mp4'%(person,audiobasen,post)) + command = 'ffmpeg -loglevel panic -framerate 25 -i ' + sample_dir2 + '/R_' + person + '_reassign2-%05d_blend2_fake.png -c:v libx264 -y -vf format=yuv420p ' + video_name + os.system(command) + command = 'ffmpeg -loglevel panic -i ' + video_name + ' -i ' + in_file + ' -vcodec copy -acodec copy -y ' + video_name.replace('.mp4','.mov') + os.system(command) + os.remove(video_name) + + if not os.path.exists(os.path.join('../../Data',str(n),'transbig.npy')): + cmd = 'cd ../../Deep3DFaceReconstruction/; python demo_preprocess.py %d %d' % (n,n+1) + os.system(cmd) + transdata = np.load(os.path.join('../../Data',str(n),'transbig.npy')) + w2 = transdata[0] + h2 = transdata[1] + t0 = transdata[2] + t1 = transdata[3] + + coeffs = glob.glob(coeff_dir+'/*.npy') + transbigbgdir = os.path.join(sample_dir,'trans_bigbg') + if not os.path.exists(transbigbgdir): + os.mkdir(transbigbgdir) + for i in range(len(coeffs)): + data = np.load(coeff_dir+'/%05d.npy'%i) + assigni = data[-1] + if seamlessclone == 0: + # direct paste + img = Image.open('../../Data/'+person+'/frame%d.png'%assigni) + img1 = Image.open(sample_dir2+'/R_'+person+'_reassign2-%05d_blend2_fake.png'%i) + img1 = img1.resize((w2,h2),resample = Image.LANCZOS) + img.paste(img1,(t0,t1,t0+img1.size[0],t1+img1.size[1])) + img.save(os.path.join(transbigbgdir,'%05d.png'%i)) + else: + # seamless clone + img = cv2.imread('../../Data/'+person+'/frame%d.png'%assigni) + img1 = cv2.imread(sample_dir2+'/R_'+person+'_reassign2-%05d_blend2_fake.png'%i) + img1 = cv2.resize(img1,(w2,h2),interpolation=cv2.INTER_LANCZOS4) + mask = np.ones(img1.shape,img1.dtype) * 255 + center = (t0+int(img1.shape[0]/2),t1+int(img1.shape[1]/2)) + output = cv2.seamlessClone(img1,img,mask,center,cv2.NORMAL_CLONE) + cv2.imwrite(os.path.join(transbigbgdir,'%05d.png'%i),output) + + transbigbgdir = os.path.join(sample_dir,'trans_bigbg') + video_name = os.path.join(sample_dir,'%s_%swav_results_transbigbg.mp4'%(person,audiobasen)) + command = 'ffmpeg -loglevel panic -framerate 25 -i ' + transbigbgdir + '/%05d.png -c:v libx264 -y -vf format=yuv420p ' + video_name + os.system(command) + command = 'ffmpeg -loglevel panic -i ' + video_name + ' -i ' + in_file + ' -vcodec copy -acodec copy -y ' + video_name.replace('.mp4','.mov') + os.system(command) + os.remove(video_name) + print('saved to', video_name.replace('.mp4','.mov')) + +audiobasen=sys.argv[1] +n = int(sys.argv[2]) + +if __name__ == "__main__": + merge_with_bigbg(audiobasen,n) diff --git a/Audio/dataset/combine_coeff.py b/Audio/dataset/combine_coeff.py new file mode 100644 index 0000000..6db518c --- /dev/null +++ b/Audio/dataset/combine_coeff.py @@ -0,0 +1,37 @@ +import os +import glob +import pdb +from scipy.io import loadmat,savemat +import numpy as np + +srcdir = '/home4/yiran/Dataset/LRW/lipread_mp4/' +tardir = '../../Deep3DFaceReconstruction/output/coeff/lrw/' +coeffdir = 'coeff/lrw/' + +#for lrw, gather all frames' coeff together +for w in sorted(glob.glob(srcdir+'/*')): + word = os.path.basename(w) + if not os.path.exists(os.path.join(coeffdir,word,'train')): + os.makedirs(os.path.join(coeffdir,word,'train')) + if not os.path.exists(os.path.join(coeffdir,word,'test')): + os.makedirs(os.path.join(coeffdir,word,'test')) + for v in sorted(glob.glob(os.path.join(srcdir,word,'*/*.mp4'))): + ss = v.split('/') + video = '%s/%s/%s'%(word,ss[-2],ss[-1][:-4]) + + # all videos in lrw have 29 frames + complete = True + coeff = np.zeros((29,257),np.float32) + for i in range(29): + coffpath = os.path.join(tardir,video,'frame%d.mat'%i) + if not os.path.exists(coffpath): + complete = False + break + data = loadmat(coffpath) + coeff[i,:] = data['coeff'] + if not complete: + continue + save_file = os.path.join(coeffdir,word,ss[-2],ss[-1][:-4]+'.npy') + np.save(save_file,coeff) + print(save_file) + diff --git a/Audio/dataset/extract_mfcc.py b/Audio/dataset/extract_mfcc.py new file mode 100644 index 0000000..c564fdb --- /dev/null +++ b/Audio/dataset/extract_mfcc.py @@ -0,0 +1,42 @@ +#coding:utf-8 +import librosa +import python_speech_features +import numpy as np +import os +import glob +import torch +import cv2 +import sys + +def get_mfcc(video, srcdir, tardir): + test_file = os.path.join(srcdir,video) + save_file = os.path.join(tardir,video[:-4]+'.npy') + if os.path.exists(save_file): + mfcc = np.load(save_file) + return mfcc + speech, sr = librosa.load(test_file, sr=16000) + mfcc = python_speech_features.mfcc(speech, 16000, winstep=0.01) + np.save(save_file, mfcc) + return mfcc + +if __name__ == '__main__': + # FOR LRW, all videos + srcdir = '/home4/yiran/Dataset/LRW/lipread_mp4/' + tardir = 'mfcc/lrw/' + print(len(sorted(glob.glob(srcdir+'/*')))) + for w in sorted(glob.glob(srcdir+'/*')): + word = os.path.basename(w) + # for train + if not os.path.exists(os.path.join(tardir,word,'train')): + os.makedirs(os.path.join(tardir,word,'train')) + #print(word,len(sorted(glob.glob(os.path.join(srcdir,word,'train','*.mp4'))))) + for v in sorted(glob.glob(os.path.join(srcdir,word,'train','*.mp4'))): + video = '%s/train/%s'%(word,os.path.join(os.path.basename(v))) + get_mfcc(video, srcdir, tardir) + # for test + if not os.path.exists(os.path.join(tardir,word,'test')): + os.makedirs(os.path.join(tardir,word,'test')) + #print(word,len(sorted(glob.glob(os.path.join(srcdir,word,'test','*.mp4'))))) + for v in sorted(glob.glob(os.path.join(srcdir,word,'test','*.mp4'))): + video = '%s/test/%s'%(word,os.path.join(os.path.basename(v))) + get_mfcc(video, srcdir, tardir) diff --git a/Audio/dataset/get_list_pkl.py b/Audio/dataset/get_list_pkl.py new file mode 100644 index 0000000..003f2e6 --- /dev/null +++ b/Audio/dataset/get_list_pkl.py @@ -0,0 +1,28 @@ +import glob +import pickle +import pdb +import os + +pfile = 'coeff_train.pkl' +if not os.path.exists(pfile): + tlist = [] + files = sorted(glob.glob('coeff/lrw/*/train/*.npy')) + for file in files: + splits = file.split('/') + tlist.append([splits[-3],splits[-2],splits[-1][:-4]]) + + _file = open(pfile,"wb") + pickle.dump(tlist,_file) + _file.close() + +pfile = 'coeff_test.pkl' +if not os.path.exists(pfile): + tlist = [] + files = sorted(glob.glob('coeff/lrw/*/test/*.npy')) + for file in files: + splits = file.split('/') + tlist.append([splits[-3],splits[-2],splits[-1][:-4]]) + + _file = open(pfile,"wb") + pickle.dump(tlist,_file) + _file.close() diff --git a/Data/extract_frame1.py b/Data/extract_frame1.py new file mode 100644 index 0000000..aff232d --- /dev/null +++ b/Data/extract_frame1.py @@ -0,0 +1,65 @@ +import cv2 +import os, sys +import glob +import dlib +import numpy as np +import time +import pdb +detector = dlib.get_frontal_face_detector() +predictor = dlib.shape_predictor('../Deep3DFaceReconstruction/shape_predictor_68_face_landmarks.dat') + +def shape_to_np(shape, dtype="int"): + # initialize the list of (x, y)-coordinates + coords = np.zeros((shape.num_parts, 2), dtype=dtype) + + # loop over all facial landmarks and convert them + # to a 2-tuple of (x, y)-coordinates + for i in range(0, shape.num_parts): + coords[i] = (shape.part(i).x, shape.part(i).y) + + # return the list of (x, y)-coordinates + return coords + +def detect_image(imagename, savepath=""): + image = cv2.imread(imagename) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + rects = detector(gray, 1) + for (i, rect) in enumerate(rects): + shape = predictor(gray, rect) + shape = shape_to_np(shape) + for (x, y) in shape: + cv2.circle(image, (x, y), 1, (0, 0, 255), -1) + eyel = np.round(np.mean(shape[36:42,:],axis=0)).astype("int") + eyer = np.round(np.mean(shape[42:48,:],axis=0)).astype("int") + nose = shape[33] + mouthl = shape[48] + mouthr = shape[54] + if savepath != "": + message = '%d %d\n%d %d\n%d %d\n%d %d\n%d %d\n' % (eyel[0],eyel[1], + eyer[0],eyer[1],nose[0],nose[1], + mouthl[0],mouthl[1],mouthr[0],mouthr[1]) + with open(savepath, 'w') as s_file: + s_file.write(message) + return +def detect_dir(folder): + for file in sorted(glob.glob(folder+"/*.jpg")+glob.glob(folder+"/*.png")): + print(file) + detect_image(imagename=file, savepath=file[:-4]+'.txt') + +t1 = time.time() +mp4 = sys.argv[1] +videoname = mp4 +cap = cv2.VideoCapture(videoname) +length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) +#print videoname, length +success, image = cap.read() +postfix = ".png" +if not os.path.exists(mp4[:-4]): + os.makedirs(mp4[:-4]) +count = 0 +while count<400: + cv2.imwrite("%s/frame%d%s"%(mp4[:-4],count,postfix),image) + success, image = cap.read() + count += 1 +detect_dir(mp4[:-4]) +t2 = time.time() diff --git a/Deep3DFaceReconstruction/BFM/BFM_exp_idx.mat b/Deep3DFaceReconstruction/BFM/BFM_exp_idx.mat new file mode 100644 index 0000000..1146e4e Binary files /dev/null and b/Deep3DFaceReconstruction/BFM/BFM_exp_idx.mat differ diff --git a/Deep3DFaceReconstruction/BFM/BFM_front_idx.mat b/Deep3DFaceReconstruction/BFM/BFM_front_idx.mat new file mode 100644 index 0000000..b9d7b09 Binary files /dev/null and b/Deep3DFaceReconstruction/BFM/BFM_front_idx.mat differ diff --git a/Deep3DFaceReconstruction/BFM/facemodel_info.mat b/Deep3DFaceReconstruction/BFM/facemodel_info.mat new file mode 100644 index 0000000..3e516ec Binary files /dev/null and b/Deep3DFaceReconstruction/BFM/facemodel_info.mat differ diff --git a/Deep3DFaceReconstruction/BFM/similarity_Lm3D_all.mat b/Deep3DFaceReconstruction/BFM/similarity_Lm3D_all.mat new file mode 100644 index 0000000..a0e2358 Binary files /dev/null and b/Deep3DFaceReconstruction/BFM/similarity_Lm3D_all.mat differ diff --git a/Deep3DFaceReconstruction/BFM/std_exp.txt b/Deep3DFaceReconstruction/BFM/std_exp.txt new file mode 100644 index 0000000..767b8de --- /dev/null +++ b/Deep3DFaceReconstruction/BFM/std_exp.txt @@ -0,0 +1 @@ +453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19 \ No newline at end of file diff --git a/Deep3DFaceReconstruction/demo.py b/Deep3DFaceReconstruction/demo.py new file mode 100644 index 0000000..8c3b0bb --- /dev/null +++ b/Deep3DFaceReconstruction/demo.py @@ -0,0 +1,119 @@ +import tensorflow as tf +import numpy as np +import cv2 +from PIL import Image +import os +import glob +from scipy.io import loadmat,savemat +import sys + +from preprocess_img import Preprocess,Preprocess2 +from load_data import * +from reconstruct_mesh import Reconstruction +from reconstruct_mesh import Reconstruction_for_render, Render_layer +import pdb +import time + +def load_graph(graph_filename): + with tf.gfile.GFile(graph_filename,'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + return graph_def + +def demo(image_path): + # input and output folder + save_path = 'output/coeff' + save_path2 = 'output/render' + if image_path[-1] == '/': + image_path = image_path[:-1] + name = os.path.basename(image_path) + print(image_path, name) + img_list = glob.glob(image_path + '/' + '*.txt') + img_list = [e[:-4]+'.png' for e in img_list] + already = glob.glob(save_path + '/' + name + '/*.mat') + already = [e[len(save_path)+1:-4].replace(name,image_path)+'.png' for e in already] + ret = list(set(img_list).difference(set(already))) + img_list = ret + img_list = sorted(img_list) + print('img_list len:', len(img_list)) + if not os.path.exists(os.path.join(save_path,name)): + os.makedirs(os.path.join(save_path,name)) + if not os.path.exists(os.path.join(save_path2,name)): + os.makedirs(os.path.join(save_path2,name)) + + # read BFM face model + # transfer original BFM model to our model + if not os.path.isfile('./BFM/BFM_model_front.mat'): + transferBFM09() + + # read face model + facemodel = BFM() + # read standard landmarks for preprocessing images + lm3D = load_lm3d() + n = 0 + t1 = time.time() + + # build reconstruction model + with tf.Graph().as_default() as graph: + + images = tf.placeholder(name = 'input_imgs', shape = [None,224,224,3], dtype = tf.float32) + graph_def = load_graph('network/FaceReconModel.pb') + tf.import_graph_def(graph_def,name='resnet',input_map={'input_imgs:0': images}) + + # output coefficients of R-Net (dim = 257) + coeff = graph.get_tensor_by_name('resnet/coeff:0') + + faceshaper = tf.placeholder(name = "face_shape_r", shape = [1,35709,3], dtype = tf.float32) + facenormr = tf.placeholder(name = "face_norm_r", shape = [1,35709,3], dtype = tf.float32) + facecolor = tf.placeholder(name = "face_color", shape = [1,35709,3], dtype = tf.float32) + rendered = Render_layer(faceshaper,facenormr,facecolor,facemodel,1) + + rstimg = tf.placeholder(name = 'rstimg', shape = [224,224,4], dtype=tf.uint8) + encode_png = tf.image.encode_png(rstimg) + + with tf.Session() as sess: + print('reconstructing...') + for file in img_list: + n += 1 + # load images and corresponding 5 facial landmarks + if '_mtcnn' not in image_path: + img,lm = load_img(file,file[:-4]+'.txt') + else: + img,lm = load_img(file,file[:-4].replace(name,name+'_mtcnn')+'.txt') + file = file.replace(image_path.replace('_mtcnn',''), name) + # preprocess input image + input_img,lm_new,transform_params = Preprocess(img,lm,lm3D) + if n==1: + transform_firstflame=transform_params + input_img2,lm_new2 = Preprocess2(img,lm,transform_firstflame) + + coef = sess.run(coeff,feed_dict = {images: input_img}) + + face_shape_r,face_norm_r,face_color,tri = Reconstruction_for_render(coef,facemodel) + final_images = sess.run(rendered, feed_dict={faceshaper: face_shape_r.astype('float32'), facenormr: face_norm_r.astype('float32'), facecolor: face_color.astype('float32')}) + result_image = final_images[0, :, :, :] + result_image = np.clip(result_image, 0., 1.).copy(order='C') + result_bytes = sess.run(encode_png,{rstimg: result_image*255.0}) + result_output_path = os.path.join(save_path2,file[:-4]+'_render.png') + with open(result_output_path, 'wb') as output_file: + output_file.write(result_bytes) + + # reshape outputs + input_img = np.squeeze(input_img) + im = Image.fromarray(input_img[:,:,::-1]) + cropped_output_path = os.path.join(save_path2,file[:-4]+'.png') + im.save(cropped_output_path) + + input_img2 = np.squeeze(input_img2) + im = Image.fromarray(input_img2[:,:,::-1]) + cropped_output_path = os.path.join(save_path2,file[:-4]+'_input2.png') + im.save(cropped_output_path) + + # save output files + savemat(os.path.join(save_path,file[:-4]+'.mat'),{'coeff':coef,'lm_5p':lm_new2-lm_new}) + t2 = time.time() + print('Total n:', n, 'Time:', t2-t1) + +if __name__ == '__main__': + demo(sys.argv[1]) \ No newline at end of file diff --git a/Deep3DFaceReconstruction/demo_19news.py b/Deep3DFaceReconstruction/demo_19news.py new file mode 100644 index 0000000..7232fd6 --- /dev/null +++ b/Deep3DFaceReconstruction/demo_19news.py @@ -0,0 +1,113 @@ +import tensorflow as tf +import numpy as np +import cv2 +from PIL import Image +import os +import glob +from scipy.io import loadmat,savemat +import sys + +from preprocess_img import Preprocess,Preprocess2 +from load_data import * +from reconstruct_mesh import Reconstruction +from reconstruct_mesh import Reconstruction_for_render, Render_layer +import pdb +import time + +def load_graph(graph_filename): + with tf.gfile.GFile(graph_filename,'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + return graph_def + +def demo(image_path): + # input and output folder + save_path = 'output/coeff/19_news' + save_path2 = 'output/render/19_news' + if image_path[-1] == '/': + image_path = image_path[:-1] + name = os.path.basename(image_path) + print(image_path, name) + img_list = glob.glob(image_path + '/' + '*.txt') + img_list = [e[:-4]+'.png' for e in img_list] + img_list = sorted(img_list) + print('img_list len:', len(img_list)) + if not os.path.exists(os.path.join(save_path,name)): + os.makedirs(os.path.join(save_path,name)) + if not os.path.exists(os.path.join(save_path2,name)): + os.makedirs(os.path.join(save_path2,name)) + + # read BFM face model + # transfer original BFM model to our model + if not os.path.isfile('./BFM/BFM_model_front.mat'): + transferBFM09() + + # read face model + facemodel = BFM() + # read standard landmarks for preprocessing images + lm3D = load_lm3d() + n = 0 + t1 = time.time() + + # build reconstruction model + #with tf.Graph().as_default() as graph,tf.device('/cpu:0'): + with tf.Graph().as_default() as graph: + + images = tf.placeholder(name = 'input_imgs', shape = [None,224,224,3], dtype = tf.float32) + graph_def = load_graph('network/FaceReconModel.pb') + tf.import_graph_def(graph_def,name='resnet',input_map={'input_imgs:0': images}) + + # output coefficients of R-Net (dim = 257) + coeff = graph.get_tensor_by_name('resnet/coeff:0') + + faceshaper = tf.placeholder(name = "face_shape_r", shape = [1,35709,3], dtype = tf.float32) + facenormr = tf.placeholder(name = "face_norm_r", shape = [1,35709,3], dtype = tf.float32) + facecolor = tf.placeholder(name = "face_color", shape = [1,35709,3], dtype = tf.float32) + rendered = Render_layer(faceshaper,facenormr,facecolor,facemodel,1) + + rstimg = tf.placeholder(name = 'rstimg', shape = [224,224,4], dtype=tf.uint8) + encode_png = tf.image.encode_png(rstimg) + + with tf.Session() as sess: + print('reconstructing...') + for file in img_list: + n += 1 + # load images and corresponding 5 facial landmarks + img,lm = load_img(file,file[:-4]+'.txt') + file = file.replace(image_path, name) + # preprocess input image + input_img,lm_new,transform_params = Preprocess(img,lm,lm3D) + if n==1: + transform_firstflame=transform_params + input_img2,lm_new2 = Preprocess2(img,lm,transform_firstflame) + + coef = sess.run(coeff,feed_dict = {images: input_img}) + + face_shape_r,face_norm_r,face_color,tri = Reconstruction_for_render(coef,facemodel) + final_images = sess.run(rendered, feed_dict={faceshaper: face_shape_r.astype('float32'), facenormr: face_norm_r.astype('float32'), facecolor: face_color.astype('float32')}) + result_image = final_images[0, :, :, :] + result_image = np.clip(result_image, 0., 1.).copy(order='C') + result_bytes = sess.run(encode_png,{rstimg: result_image*255.0}) + result_output_path = os.path.join(save_path2,file[:-4]+'_render.png') + with open(result_output_path, 'wb') as output_file: + output_file.write(result_bytes) + + # reshape outputs + input_img = np.squeeze(input_img) + im = Image.fromarray(input_img[:,:,::-1]) + cropped_output_path = os.path.join(save_path2,file[:-4]+'.png') + im.save(cropped_output_path) + + input_img2 = np.squeeze(input_img2) + im = Image.fromarray(input_img2[:,:,::-1]) + cropped_output_path = os.path.join(save_path2,file[:-4]+'_input2.png') + im.save(cropped_output_path) + + # save output files + savemat(os.path.join(save_path,file[:-4]+'.mat'),{'coeff':coef,'lm_5p':lm_new2-lm_new}) + t2 = time.time() + print('Total n:', n, 'Time:', t2-t1) + +if __name__ == '__main__': + demo(sys.argv[1]) diff --git a/Deep3DFaceReconstruction/demo_gettex.py b/Deep3DFaceReconstruction/demo_gettex.py new file mode 100644 index 0000000..0a6bd00 --- /dev/null +++ b/Deep3DFaceReconstruction/demo_gettex.py @@ -0,0 +1,152 @@ +import tensorflow as tf +import os +from scipy.io import loadmat,savemat +from load_data import * +from reconstruct_mesh import Reconstruction_for_render, Render_layer, Render_layer2 +from reconstruct_mesh import Reconstruction_for_render_new, Reconstruction_for_render_new_given, Reconstruction_for_render_new_given2, Project_layer +import sys +import glob +import pdb +import time +import cv2 +import random +from preprocess_img import Preprocess + +class RenderObject(object): + def __init__(self, sess): + if not os.path.isfile('./BFM/BFM_model_front.mat'): + transferBFM09() + # read face model + self.facemodel = BFM() + + self.faceshaper = tf.placeholder(name = "face_shape_r", shape = [1,35709,3], dtype = tf.float32) + self.facenormr = tf.placeholder(name = "face_norm_r", shape = [1,35709,3], dtype = tf.float32) + self.facecolor = tf.placeholder(name = "face_color", shape = [1,35709,3], dtype = tf.float32) + self.rendered = Render_layer(self.faceshaper,self.facenormr,self.facecolor,self.facemodel,1) + self.rendered2 = Render_layer2(self.faceshaper,self.facenormr,self.facecolor,self.facemodel,1) + #self.project = Project_layer(self.faceshaper) + + self.rstimg = tf.placeholder(name = 'rstimg', dtype=tf.uint8) + self.encode_png = tf.image.encode_png(self.rstimg) + + self.sess = sess + + def save_image(self, final_images, result_output_path): + result_image = final_images[0, :, :, :] + result_image = np.clip(result_image, 0., 1.).copy(order='C') + #result_bytes = sess.run(tf.image.encode_png(result_image*255.0)) + result_bytes = self.sess.run(self.encode_png, {self.rstimg: result_image*255.0}) + with open(result_output_path, 'wb') as output_file: + output_file.write(result_bytes) + + def save_image2(self, final_images, result_output_path, tx=0, ty=0): + result_image = final_images[0, :, :, :] + result_image = np.clip(result_image, 0., 1.) * 255.0 + result_image = np.round(result_image).astype(np.uint8) + im = Image.fromarray(result_image,'RGBA') + if tx != 0 or ty != 0: + im = im.transform(im.size, Image.AFFINE, (1, 0, tx, 0, 1, ty)) + im.save(result_output_path) + + def show_clip_vertices(self, coef_path, clip_vertices, image_width=224, image_height=224): + half_image_width = 0.5 * image_width + half_image_height = 0.5 * image_height + im = cv2.imread(coef_path.replace('coeff','render')[:-4]+'.png') + for i in range(clip_vertices.shape[1]): + if clip_vertices.shape[2] == 4: + v0x = clip_vertices[0,i,0] + v0y = clip_vertices[0,i,1] + v0w = clip_vertices[0,i,3] + px = int(round((v0x / v0w + 1.0) * half_image_width)) + py = int(image_height -1 - round((v0y / v0w + 1.0) * half_image_height)) + elif clip_vertices.shape[2] == 2: + px = int(round(clip_vertices[0,i,0])) + py = int(round(clip_vertices[0,i,1])) + if px >= 0 and px < image_width and py >= 0 and py < image_height: + cv2.circle(im, (px, py), 1, (0, 255, 0), -1) + cv2.imwrite('show_clip_vertices.png',im) + + def gettexture(self, coef_path): + data = loadmat(coef_path) + coef = data['coeff'] + img_path = coef_path.replace('coeff','render')[:-4]+'.png' + face_shape_r,face_norm_r,face_color,face_color2,face_texture2,tri,face_projection = Reconstruction_for_render_new(coef,self.facemodel,img_path) + np.save(coef_path[:-4]+'_tex2.npy',face_texture2) + return coef_path[:-4]+'_tex2.npy', face_texture2 + + def gettexture2(self, coef_path): + data = loadmat(coef_path) + coef = data['coeff'] + img_path = coef_path[:-4]+'.jpg' + face_shape_r,face_norm_r,face_color,face_color2,face_texture2,tri,face_projection = Reconstruction_for_render_new(coef,self.facemodel,img_path) + np.save(coef_path[:-4]+'_tex2.npy',face_texture2) + return coef_path[:-4]+'_tex2.npy' + + def render224_new(self, coef_path, result_output_path, tex2_path): + if not os.path.exists(coef_path): + return + if os.path.exists(result_output_path): + return + data = loadmat(coef_path) + coef = data['coeff'] + + if not os.path.exists(os.path.dirname(result_output_path)): + os.makedirs(os.path.dirname(result_output_path)) + + face_shape_r,face_norm_r,face_color2,tri = Reconstruction_for_render_new_given(coef,self.facemodel,tex2_path) + final_images = self.sess.run(self.rendered, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color2.astype('float32')}) + self.save_image(final_images, result_output_path) + + def render224_new2(self, coef_path, result_output_path, tex2): + if not os.path.exists(coef_path): + return + #if os.path.exists(result_output_path): + # return + data = loadmat(coef_path) + coef = data['coeff'] + + if not os.path.exists(os.path.dirname(result_output_path)): + os.makedirs(os.path.dirname(result_output_path)) + + face_shape_r,face_norm_r,face_color2,tri = Reconstruction_for_render_new_given2(coef,self.facemodel,tex2) + final_images = self.sess.run(self.rendered, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color2.astype('float32')}) + self.save_image(final_images, result_output_path) + + def render224(self, coef_path, result_output_path): + if not os.path.exists(coef_path): + return + if os.path.exists(result_output_path): + return + data = loadmat(coef_path) + coef = data['coeff'] + + if not os.path.exists(os.path.dirname(result_output_path)): + os.makedirs(os.path.dirname(result_output_path)) + + #t00 = time.time() + face_shape_r,face_norm_r,face_color,tri = Reconstruction_for_render(coef,self.facemodel) + final_images = self.sess.run(self.rendered, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color.astype('float32')}) + #t01 = time.time() + self.save_image(final_images, result_output_path) + #print(t01-t00,time.time()-t01) + + def render256(self, coef_path, savedir): + data = loadmat(coef_path) + coef = data['coeff'] + + basen = os.path.basename(coef_path)[:-4] + result_output_path = os.path.join(savedir,basen+'_render256.png') + if not os.path.exists(os.path.dirname(result_output_path)): + os.makedirs(os.path.dirname(result_output_path)) + + face_shape_r,face_norm_r,face_color,tri = Reconstruction_for_render(coef,self.facemodel) + final_images = self.sess.run(self.rendered2, feed_dict={self.faceshaper: face_shape_r.astype('float32'), self.facenormr: face_norm_r.astype('float32'), self.facecolor: face_color.astype('float32')}) + self.save_image(final_images, result_output_path) + +if __name__ == '__main__': + with tf.Session() as sess: + render_object = RenderObject(sess) + coef_path = sys.argv[1] + tex2_path,face_texture2 = render_object.gettexture(coef_path) + result_output_path = coef_path.replace('output/coeff','output/render')[:-4]+'_rendernew.png' + rp = render_object.render224_new2(coef_path,result_output_path,face_texture2) \ No newline at end of file diff --git a/Deep3DFaceReconstruction/demo_norender.py b/Deep3DFaceReconstruction/demo_norender.py new file mode 100644 index 0000000..9e7b64f --- /dev/null +++ b/Deep3DFaceReconstruction/demo_norender.py @@ -0,0 +1,88 @@ +import tensorflow as tf +import numpy as np +import cv2 +from PIL import Image +import os +import glob +from scipy.io import loadmat,savemat +import sys + +from preprocess_img import Preprocess,Preprocess2 +from load_data import * +from reconstruct_mesh import Reconstruction +from reconstruct_mesh import Reconstruction_for_render, Render_layer +import pdb +import time + +def load_graph(graph_filename): + with tf.gfile.GFile(graph_filename,'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + return graph_def + +def demo(image_path): + # output folder + save_dir = 'output/coeff' + save_coeff_path = save_dir + '/' + image_path + img_list = glob.glob(image_path + '/*.txt') + img_list = img_list + glob.glob(image_path + '/*/*.txt') + img_list = img_list + glob.glob(image_path + '/*/*/*.txt') + img_list = img_list + glob.glob(image_path + '/*/*/*/*.txt') + img_list = [e[:-4]+'.jpg' for e in img_list] + already = glob.glob(save_coeff_path + '/*.mat') + already = already + glob.glob(save_coeff_path + '/*/*.mat') + already = already + glob.glob(save_coeff_path + '/*/*/*.mat') + already = already + glob.glob(save_coeff_path + '/*/*/*/*.mat') + already = [e[len(save_dir)+1:-4]+'.jpg' for e in already] + ret = list(set(img_list).difference(set(already))) + img_list = ret + img_list = sorted(img_list) + print('img_list len:', len(img_list)) + if not os.path.exists(os.path.join(save_dir,image_path)): + os.makedirs(os.path.join(save_dir,image_path)) + for img in img_list: + if not os.path.exists(os.path.join(save_dir,os.path.dirname(img))): + os.makedirs(os.path.join(save_dir,os.path.dirname(img))) + + # read BFM face model + # transfer original BFM model to our model + if not os.path.isfile('./BFM/BFM_model_front.mat'): + transferBFM09() + + # read face model + facemodel = BFM() + # read standard landmarks for preprocessing images + lm3D = load_lm3d() + n = 0 + t1 = time.time() + + # build reconstruction model + #with tf.Graph().as_default() as graph,tf.device('/cpu:0'): + with tf.Graph().as_default() as graph: + + images = tf.placeholder(name = 'input_imgs', shape = [None,224,224,3], dtype = tf.float32) + graph_def = load_graph('network/FaceReconModel.pb') + tf.import_graph_def(graph_def,name='resnet',input_map={'input_imgs:0': images}) + + # output coefficients of R-Net (dim = 257) + coeff = graph.get_tensor_by_name('resnet/coeff:0') + + with tf.Session() as sess: + print('reconstructing...') + for file in img_list: + n += 1 + # load images and corresponding 5 facial landmarks + img,lm = load_img(file,file[:-4]+'.txt') + # preprocess input image + input_img,lm_new,transform_params = Preprocess(img,lm,lm3D) + + coef = sess.run(coeff,feed_dict = {images: input_img}) + + # save output files + savemat(os.path.join(save_dir,file[:-4]+'.mat'),{'coeff':coef,'lm_5p':lm_new}) + t2 = time.time() + print('Total n:', n, 'Time:', t2-t1) + +if __name__ == '__main__': + demo(sys.argv[1]) \ No newline at end of file diff --git a/Deep3DFaceReconstruction/demo_preprocess.py b/Deep3DFaceReconstruction/demo_preprocess.py new file mode 100644 index 0000000..38631a2 --- /dev/null +++ b/Deep3DFaceReconstruction/demo_preprocess.py @@ -0,0 +1,54 @@ +import tensorflow as tf +import numpy as np +import cv2 +from PIL import Image +import os +import glob +from scipy.io import loadmat,savemat +import sys + +from preprocess_img import Preprocess +from load_data import * +from reconstruct_mesh import Reconstruction +from reconstruct_mesh import Reconstruction_for_render, Render_layer +import pdb +import time +import matplotlib.pyplot as plt + +def load_graph(graph_filename): + with tf.gfile.GFile(graph_filename,'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + return graph_def + +def demo_19news(n1,n2): + lm3D = load_lm3d() + n = 0 + for n in range(n1,n2): + print(n) + start = 0 + file = os.path.join('../Data',str(n),'frame%d.png'%start) + print(file) + if not os.path.exists(file[:-4]+'.txt'): + continue + img,lm = load_img(file,file[:-4]+'.txt') + input_img,lm_new,transform_params = Preprocess(img,lm,lm3D) # lm_new 5x2 + input_img = np.squeeze(input_img) + img1 = Image.fromarray(input_img[:,:,::-1]) + + scale = 0.5 * (lm[0][0]-lm[1][0]) / (lm_new[0][0]-lm_new[1][0]) + 0.5 * (lm[3][0]-lm[4][0]) / (lm_new[3][0]-lm_new[4][0]) + print(scale) + trans = np.mean(lm-lm_new*scale, axis=0) + trans = np.round(trans).astype(np.int32) + w,h = img1.size + w2 = int(round(w*scale)) + h2 = int(round(h*scale)) + img1 = img1.resize((w2,h2),resample = Image.LANCZOS) + img.paste(img1,(trans[0],trans[1],trans[0]+img1.size[0],trans[1]+img1.size[1])) + np.save(os.path.join('../Data',str(n),'transbig.npy'),np.array([w2,h2,trans[0],trans[1]])) + print(os.path.join('../Data',str(n),'transbig.npy')) + img.save('combine.png') + +if __name__ == '__main__': + demo_19news(int(sys.argv[1]),int(sys.argv[2])) diff --git a/Deep3DFaceReconstruction/extract_frame_lm.py b/Deep3DFaceReconstruction/extract_frame_lm.py new file mode 100644 index 0000000..40b6903 --- /dev/null +++ b/Deep3DFaceReconstruction/extract_frame_lm.py @@ -0,0 +1,64 @@ +import cv2 +import os, sys +import glob +import dlib +import numpy as np +import time +import pdb +detector = dlib.get_frontal_face_detector() +predictor = dlib.shape_predictor('../Deep3DFaceReconstruction/shape_predictor_68_face_landmarks.dat') + +def shape_to_np(shape, dtype="int"): + # initialize the list of (x, y)-coordinates + coords = np.zeros((shape.num_parts, 2), dtype=dtype) + + # loop over all facial landmarks and convert them + # to a 2-tuple of (x, y)-coordinates + for i in range(0, shape.num_parts): + coords[i] = (shape.part(i).x, shape.part(i).y) + + # return the list of (x, y)-coordinates + return coords + +def detect_image(imagename, savepath=""): + image = cv2.imread(imagename) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + rects = detector(gray, 1) + for (i, rect) in enumerate(rects): + shape = predictor(gray, rect) + shape = shape_to_np(shape) + for (x, y) in shape: + cv2.circle(image, (x, y), 1, (0, 0, 255), -1) + eyel = np.round(np.mean(shape[36:42,:],axis=0)).astype("int") + eyer = np.round(np.mean(shape[42:48,:],axis=0)).astype("int") + nose = shape[33] + mouthl = shape[48] + mouthr = shape[54] + if savepath != "": + message = '%d %d\n%d %d\n%d %d\n%d %d\n%d %d\n' % (eyel[0],eyel[1], + eyer[0],eyer[1],nose[0],nose[1], + mouthl[0],mouthl[1],mouthr[0],mouthr[1]) + with open(savepath, 'w') as s_file: + s_file.write(message) + return +def detect_dir(folder): + for file in sorted(glob.glob(folder+"/*.jpg")+glob.glob(folder+"/*.png")): + #print(file) + detect_image(imagename=file, savepath=file[:-4]+'.txt') + +videos = sorted(glob.glob('/home4/yiran/Dataset/LRW/lipread_mp4/*/*/*.mp4')) +print(len(videos),'videos') +for mp4 in videos: + savedir = 'lrw/'+mp4.split('lipread_mp4/')[1][:-4] + cap = cv2.VideoCapture(mp4) + length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + success, image = cap.read() + postfix = ".jpg" + if not os.path.exists(savedir): + os.makedirs(savedir) + count = 0 + while count=2) = 255; + %figure;imshow(trans); + trans = double(trans)/255; + im3 = double(im1).*(1-trans) + double(im2).*trans; + im3 = uint8(im3); + imwrite(im3,fullfile(tardir,[file(1:end-4),'_render_bm.png'])); +end +toc(t1)%1094.343765 seconds \ No newline at end of file diff --git a/Deep3DFaceReconstruction/output/alpha_blend_newsold.m b/Deep3DFaceReconstruction/output/alpha_blend_newsold.m new file mode 100644 index 0000000..2683fe1 --- /dev/null +++ b/Deep3DFaceReconstruction/output/alpha_blend_newsold.m @@ -0,0 +1,27 @@ +function alpha_blend_newsold(video, starti, framenum) +%video = 'Learn_English'; +%starti = 357; % choose 400, 300 for training render-to-video, 100 for testing +%framenum = 400; + +srcdir = ['render/',video,'/']; +srcdir2 = ['render/',video,'/']; +tardir = ['render/',video,'/bm/']; +files = dir(fullfile(srcdir,'*.png')); +t1=tic; +if ~exist(tardir) + mkdir(tardir); +end +for i = starti:(starti+framenum-1) + file = ['frame',num2str(i),'.png']; + im1 = imread(fullfile(srcdir,file)); + [im2,~,trans] = imread(fullfile(srcdir2,[file(1:end-4),'_render.png'])); + [B,L] = bwboundaries(trans); + %imshow(label2rgb(L,@jet,[.5,.5,.5])); + trans(L>=2) = 255; + %figure;imshow(trans); + trans = double(trans)/255; + im3 = double(im1).*(1-trans) + double(im2).*trans; + im3 = uint8(im3); + imwrite(im3,fullfile(tardir,[file(1:end-4),'_renderold_bm.png'])); +end +toc(t1)%1094.343765 seconds \ No newline at end of file diff --git a/Deep3DFaceReconstruction/preprocess_img.py b/Deep3DFaceReconstruction/preprocess_img.py new file mode 100644 index 0000000..447483b --- /dev/null +++ b/Deep3DFaceReconstruction/preprocess_img.py @@ -0,0 +1,88 @@ +import numpy as np +from scipy.io import loadmat,savemat +from PIL import Image + +#calculating least sqaures problem +def POS(xp,x): + npts = xp.shape[1] + + A = np.zeros([2*npts,8]) + + A[0:2*npts-1:2,0:3] = x.transpose() + A[0:2*npts-1:2,3] = 1 + + A[1:2*npts:2,4:7] = x.transpose() + A[1:2*npts:2,7] = 1; + + b = np.reshape(xp.transpose(),[2*npts,1]) + + k,_,_,_ = np.linalg.lstsq(A,b,rcond=None) + + R1 = k[0:3] + R2 = k[4:7] + sTx = k[3] + sTy = k[7] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 + t = np.stack([sTx,sTy],axis = 0) + + return t,s + +def process_img(img,lm,t,s): + w0,h0 = img.size + img = img.transform(img.size, Image.AFFINE, (1, 0, t[0] - w0/2, 0, 1, h0/2 - t[1])) + w = (w0/s*102).astype(np.int32) + h = (h0/s*102).astype(np.int32) + #img = img.resize((w,h),resample = Image.BILINEAR) + img = img.resize((w,h),resample = Image.LANCZOS) + lm = np.stack([lm[:,0] - t[0] + w0/2,lm[:,1] - t[1] + h0/2],axis = 1)/s*102 + + # crop the image to 224*224 from image center + left = (w/2 - 112).astype(np.int32) + right = left + 224 + up = (h/2 - 112).astype(np.int32) + below = up + 224 + + img = img.crop((left,up,right,below)) + img = np.array(img) + img = img[:,:,::-1] + img = np.expand_dims(img,0) + lm = lm - np.reshape(np.array([(w/2 - 112),(h/2-112)]),[1,2]) + + return img,lm + + +# resize and crop input images before sending to the R-Net +def Preprocess(img,lm,lm3D): + + w0,h0 = img.size + + # change from image plane coordinates to 3D sapce coordinates(X-Y plane) + lm = np.stack([lm[:,0],h0 - 1 - lm[:,1]], axis = 1) + + # calculate translation and scale factors using 5 facial landmarks and standard landmarks + t,s = POS(lm.transpose(),lm3D.transpose()) + + # processing the image + img_new,lm_new = process_img(img,lm,t,s) + lm_new = np.stack([lm_new[:,0],223 - lm_new[:,1]], axis = 1) + trans_params = np.array([w0,h0,102.0/s,t[0],t[1]]) + + return img_new,lm_new,trans_params + + +def Preprocess2(img,lm,trans_params): + + w0,h0 = img.size + + # change from image plane coordinates to 3D sapce coordinates(X-Y plane) + lm = np.stack([lm[:,0],h0 - 1 - lm[:,1]], axis = 1) + + # calculate translation and scale factors from 1st frame's transform params + s = 102./trans_params[2] + t = np.stack([trans_params[3],trans_params[4]],axis = 0) + + # processing the image + img_new,lm_new = process_img(img,lm,t,s) + lm_new = np.stack([lm_new[:,0],223 - lm_new[:,1]], axis = 1) + + return img_new,lm_new \ No newline at end of file diff --git a/Deep3DFaceReconstruction/reconstruct_mesh.py b/Deep3DFaceReconstruction/reconstruct_mesh.py new file mode 100644 index 0000000..3403a76 --- /dev/null +++ b/Deep3DFaceReconstruction/reconstruct_mesh.py @@ -0,0 +1,373 @@ +import numpy as np +import cv2 +import pdb + +# input: coeff with shape [1,257] +def Split_coeff(coeff): + id_coeff = coeff[:,:80] # identity(shape) coeff of dim 80 + ex_coeff = coeff[:,80:144] # expression coeff of dim 64 + tex_coeff = coeff[:,144:224] # texture(albedo) coeff of dim 80 + angles = coeff[:,224:227] # ruler angles(x,y,z) for rotation of dim 3 + gamma = coeff[:,227:254] # lighting coeff for 3 channel SH function of dim 27 + translation = coeff[:,254:] # translation coeff of dim 3 + + return id_coeff,ex_coeff,tex_coeff,angles,gamma,translation + + +# compute face shape with identity and expression coeff, based on BFM model +# input: id_coeff with shape [1,80] +# ex_coeff with shape [1,64] +# output: face_shape with shape [1,N,3], N is number of vertices +def Shape_formation(id_coeff,ex_coeff,facemodel): + face_shape = np.einsum('ij,aj->ai',facemodel.idBase,id_coeff) + \ + np.einsum('ij,aj->ai',facemodel.exBase,ex_coeff) + \ + facemodel.meanshape + + face_shape = np.reshape(face_shape,[1,-1,3]) + # re-center face shape + face_shape = face_shape - np.mean(np.reshape(facemodel.meanshape,[1,-1,3]), axis = 1, keepdims = True) + + return face_shape + +# compute vertex normal using one-ring neighborhood +# input: face_shape with shape [1,N,3] +# output: v_norm with shape [1,N,3] +def Compute_norm(face_shape,facemodel): + + face_id = facemodel.tri # vertex index for each triangle face, with shape [F,3], F is number of faces + point_id = facemodel.point_buf # adjacent face index for each vertex, with shape [N,8], N is number of vertex + shape = face_shape + face_id = (face_id - 1).astype(np.int32) + point_id = (point_id - 1).astype(np.int32) + v1 = shape[:,face_id[:,0],:] + v2 = shape[:,face_id[:,1],:] + v3 = shape[:,face_id[:,2],:] + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = np.cross(e1,e2) # compute normal for each face + face_norm = np.concatenate([face_norm,np.zeros([1,1,3])], axis = 1) # concat face_normal with a zero vector at the end + v_norm = np.sum(face_norm[:,point_id,:], axis = 2) # compute vertex normal using one-ring neighborhood + v_norm = v_norm/np.expand_dims(np.linalg.norm(v_norm,axis = 2),2) # normalize normal vectors + + return v_norm + +# compute vertex texture(albedo) with tex_coeff +# input: tex_coeff with shape [1,N,3] +# output: face_texture with shape [1,N,3], RGB order, range from 0-255 +def Texture_formation(tex_coeff,facemodel): + + face_texture = np.einsum('ij,aj->ai',facemodel.texBase,tex_coeff) + facemodel.meantex + face_texture = np.reshape(face_texture,[1,-1,3]) + + return face_texture + +# compute rotation matrix based on 3 ruler angles +# input: angles with shape [1,3] +# output: rotation matrix with shape [1,3,3] +def Compute_rotation_matrix(angles): + + angle_x = angles[:,0][0] + angle_y = angles[:,1][0] + angle_z = angles[:,2][0] + + # compute rotation matrix for X,Y,Z axis respectively + rotation_X = np.array([1.0,0,0,\ + 0,np.cos(angle_x),-np.sin(angle_x),\ + 0,np.sin(angle_x),np.cos(angle_x)]) + rotation_Y = np.array([np.cos(angle_y),0,np.sin(angle_y),\ + 0,1,0,\ + -np.sin(angle_y),0,np.cos(angle_y)]) + rotation_Z = np.array([np.cos(angle_z),-np.sin(angle_z),0,\ + np.sin(angle_z),np.cos(angle_z),0,\ + 0,0,1]) + + rotation_X = np.reshape(rotation_X,[1,3,3]) + rotation_Y = np.reshape(rotation_Y,[1,3,3]) + rotation_Z = np.reshape(rotation_Z,[1,3,3]) + + rotation = np.matmul(np.matmul(rotation_Z,rotation_Y),rotation_X) + rotation = np.transpose(rotation, axes = [0,2,1]) #transpose row and column (dimension 1 and 2) + + return rotation + +# project 3D face onto image plane +# input: face_shape with shape [1,N,3] +# rotation with shape [1,3,3] +# translation with shape [1,3] +# output: face_projection with shape [1,N,2] +# z_buffer with shape [1,N,1] +def Projection_layer(face_shape,rotation,translation,focal=1015.0,center=112.0): # we choose the focal length and camera position empirically + + camera_pos = np.reshape(np.array([0.0,0.0,10.0]),[1,1,3]) # camera position + reverse_z = np.reshape(np.array([1.0,0,0,0,1,0,0,0,-1.0]),[1,3,3]) + + + p_matrix = np.concatenate([[focal],[0.0],[center],[0.0],[focal],[center],[0.0],[0.0],[1.0]],axis = 0) # projection matrix + p_matrix = np.reshape(p_matrix,[1,3,3]) + + # calculate face position in camera space + face_shape_r = np.matmul(face_shape,rotation) + face_shape_t = face_shape_r + np.reshape(translation,[1,1,3]) + face_shape_t = np.matmul(face_shape_t,reverse_z) + camera_pos + + # calculate projection of face vertex using perspective projection + aug_projection = np.matmul(face_shape_t,np.transpose(p_matrix,[0,2,1])) + face_projection = aug_projection[:,:,0:2]/np.reshape(aug_projection[:,:,2],[1,np.shape(aug_projection)[1],1]) + z_buffer = np.reshape(aug_projection[:,:,2],[1,-1,1]) + + return face_projection,z_buffer + +# compute vertex color using face_texture and SH function lighting approximation +# input: face_texture with shape [1,N,3] +# norm with shape [1,N,3] +# gamma with shape [1,27] +# output: face_color with shape [1,N,3], RGB order, range from 0-255 +# lighting with shape [1,N,3], color under uniform texture +def Illumination_layer(face_texture,norm,gamma): + + num_vertex = np.shape(face_texture)[1] + + init_lit = np.array([0.8,0,0,0,0,0,0,0,0]) + gamma = np.reshape(gamma,[-1,3,9]) + gamma = gamma + np.reshape(init_lit,[1,1,9]) + + # parameter of 9 SH function + a0 = np.pi + a1 = 2*np.pi/np.sqrt(3.0) + a2 = 2*np.pi/np.sqrt(8.0) + c0 = 1/np.sqrt(4*np.pi) + c1 = np.sqrt(3.0)/np.sqrt(4*np.pi) + c2 = 3*np.sqrt(5.0)/np.sqrt(12*np.pi) + + Y0 = np.tile(np.reshape(a0*c0,[1,1,1]),[1,num_vertex,1]) + Y1 = np.reshape(-a1*c1*norm[:,:,1],[1,num_vertex,1]) + Y2 = np.reshape(a1*c1*norm[:,:,2],[1,num_vertex,1]) + Y3 = np.reshape(-a1*c1*norm[:,:,0],[1,num_vertex,1]) + Y4 = np.reshape(a2*c2*norm[:,:,0]*norm[:,:,1],[1,num_vertex,1]) + Y5 = np.reshape(-a2*c2*norm[:,:,1]*norm[:,:,2],[1,num_vertex,1]) + Y6 = np.reshape(a2*c2*0.5/np.sqrt(3.0)*(3*np.square(norm[:,:,2])-1),[1,num_vertex,1]) + Y7 = np.reshape(-a2*c2*norm[:,:,0]*norm[:,:,2],[1,num_vertex,1]) + Y8 = np.reshape(a2*c2*0.5*(np.square(norm[:,:,0])-np.square(norm[:,:,1])),[1,num_vertex,1]) + + Y = np.concatenate([Y0,Y1,Y2,Y3,Y4,Y5,Y6,Y7,Y8],axis=2) + + # Y shape:[batch,N,9]. + + lit_r = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,0,:],2)),2) #[batch,N,9] * [batch,9,1] = [batch,N] + lit_g = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,1,:],2)),2) + lit_b = np.squeeze(np.matmul(Y,np.expand_dims(gamma[:,2,:],2)),2) + + # shape:[batch,N,3] + face_color = np.stack([lit_r*face_texture[:,:,0],lit_g*face_texture[:,:,1],lit_b*face_texture[:,:,2]],axis = 2) + lighting = np.stack([lit_r,lit_g,lit_b],axis = 2)*128 + + return face_color,lighting + +def Illumination_inv_layer(face_color,lighting): + face_texture = np.stack([face_color[:,:,0]/lighting[:,:,0],face_color[:,:,1]/lighting[:,:,1],face_color[:,:,2]/lighting[:,:,2]],axis=2)*128 + return face_texture + +# face reconstruction with coeff and BFM model +def Reconstruction(coeff,facemodel): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + # compute face shape + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + # compute vertex texture(albedo) + face_texture = Texture_formation(tex_coeff, facemodel) + # vertex normal + face_norm = Compute_norm(face_shape,facemodel) + # rotation matrix + rotation = Compute_rotation_matrix(angles) + face_norm_r = np.matmul(face_norm,rotation) + + # compute vertex projection on image plane (with image sized 224*224) + face_projection,z_buffer = Projection_layer(face_shape,rotation,translation) + face_projection = np.stack([face_projection[:,:,0],224 - face_projection[:,:,1]], axis = 2) + + # compute 68 landmark on image plane + landmarks_2d = face_projection[:,facemodel.keypoints,:] + + # compute vertex color using SH function lighting approximation + face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma) + + # vertex index for each face of BFM model + tri = facemodel.tri + + return face_shape,face_texture,face_color,tri,face_projection,z_buffer,landmarks_2d + +def Reconstruction_new_given(coeff,facemodel,tex2_path): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + # compute face shape + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + face_texture2 = np.load(tex2_path) + # vertex normal + face_norm = Compute_norm(face_shape,facemodel) + # rotation matrix + rotation = Compute_rotation_matrix(angles) + face_norm_r = np.matmul(face_norm,rotation) + + # compute vertex projection on image plane (with image sized 224*224) + face_projection,z_buffer = Projection_layer(face_shape,rotation,translation) + face_projection = np.stack([face_projection[:,:,0],224 - face_projection[:,:,1]], axis = 2) + + # compute 68 landmark on image plane + landmarks_2d = face_projection[:,facemodel.keypoints,:] + + # compute vertex color using SH function lighting approximation + face_color,lighting = Illumination_layer(face_texture2, face_norm_r, gamma) + + # vertex index for each face of BFM model + tri = facemodel.tri + + return face_shape,face_texture2,face_color,tri,face_projection,z_buffer,landmarks_2d + +# def Reconstruction_for_render(coeff,facemodel): +# id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) +# face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) +# face_texture = Texture_formation(tex_coeff, facemodel) +# face_norm = Compute_norm(face_shape,facemodel) +# rotation = Compute_rotation_matrix(angles) +# face_shape_r = np.matmul(face_shape,rotation) +# face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) +# face_norm_r = np.matmul(face_norm,rotation) +# face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma) +# tri = facemodel.face_buf + +# return face_shape_r,face_norm_r,face_color,tri + +def Reconstruction_for_render(coeff,facemodel): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + face_texture = Texture_formation(tex_coeff, facemodel) + face_norm = Compute_norm(face_shape,facemodel) + rotation = Compute_rotation_matrix(angles) + face_shape_r = np.matmul(face_shape,rotation) + face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) + face_norm_r = np.matmul(face_norm,rotation) + face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma) + tri = facemodel.tri + + return face_shape_r,face_norm_r,face_color,tri + +def Reconstruction_for_render_new(coeff,facemodel,imgpath): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + face_texture = Texture_formation(tex_coeff, facemodel) + face_norm = Compute_norm(face_shape,facemodel) + rotation = Compute_rotation_matrix(angles) + face_shape_r = np.matmul(face_shape,rotation) + face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) + face_norm_r = np.matmul(face_norm,rotation) + face_color,lighting = Illumination_layer(face_texture, face_norm_r, gamma) + tri = facemodel.tri + # compute vertex projection on image plane (with image sized 224*224) + face_projection,z_buffer = Projection_layer(face_shape,rotation,translation) + face_projection = np.stack([face_projection[:,:,0],224 - face_projection[:,:,1]], axis = 2) + imcolor = cv2.imread(imgpath) + fp = face_projection.astype('int') + face_color2 = imcolor[fp[0,:,1],fp[0,:,0],::-1] + face_color2 = np.expand_dims(face_color2,0) + face_texture2 = Illumination_inv_layer(face_color2,lighting) + + return face_shape_r,face_norm_r,face_color,face_color2,face_texture2,tri,face_projection + +def Reconstruction_for_render_new_given(coeff,facemodel,tex2_path): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + face_texture2 = np.load(tex2_path) + face_norm = Compute_norm(face_shape,facemodel) + rotation = Compute_rotation_matrix(angles) + face_shape_r = np.matmul(face_shape,rotation) + face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) + face_norm_r = np.matmul(face_norm,rotation) + face_color2,lighting = Illumination_layer(face_texture2, face_norm_r, gamma) + tri = facemodel.tri + + return face_shape_r,face_norm_r,face_color2,tri + +def Reconstruction_for_render_new_given2(coeff,facemodel,face_texture2): + id_coeff,ex_coeff,tex_coeff,angles,gamma,translation = Split_coeff(coeff) + face_shape = Shape_formation(id_coeff, ex_coeff, facemodel) + face_norm = Compute_norm(face_shape,facemodel) + rotation = Compute_rotation_matrix(angles) + face_shape_r = np.matmul(face_shape,rotation) + face_shape_r = face_shape_r + np.reshape(translation,[1,1,3]) + face_norm_r = np.matmul(face_norm,rotation) + face_color2,lighting = Illumination_layer(face_texture2, face_norm_r, gamma) + tri = facemodel.tri + + return face_shape_r,face_norm_r,face_color2,tri + +import tensorflow as tf +#from tf_mesh_renderer import mesh_renderer +import tf_mesh_renderer.mesh_renderer.mesh_renderer as mesh_renderer +import pdb + +def Project_layer(face_shape): + + camera_position = tf.constant([0,0,10.0]) + camera_lookat = tf.constant([0,0,0.0]) + camera_up = tf.constant([0,1.0,0]) + + #pdb.set_trace() + clip_space_vertices = mesh_renderer.clip_vertices(face_shape, + camera_position = camera_position, + camera_lookat = camera_lookat, + camera_up = camera_up, + image_width = 224, + image_height = 224, + fov_y = 12.5936) + + return clip_space_vertices + +def Render_layer(face_shape,face_norm,face_color,facemodel,batchsize): + + camera_position = tf.constant([0,0,10.0]) + camera_lookat = tf.constant([0,0,0.0]) + camera_up = tf.constant([0,1.0,0]) + light_positions = tf.tile(tf.reshape(tf.constant([0,0,1e5]),[1,1,3]),[batchsize,1,1]) + light_intensities = tf.tile(tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3]),[batchsize,1,1]) + ambient_color = tf.tile(tf.reshape(tf.constant([1.0,1,1]),[1,3]),[batchsize,1]) + + #pdb.set_trace() + render = mesh_renderer.mesh_renderer(face_shape, + tf.cast(facemodel.tri-1,tf.int32), + face_norm, + face_color/255, + camera_position = camera_position, + camera_lookat = camera_lookat, + camera_up = camera_up, + light_positions = light_positions, + light_intensities = light_intensities, + image_width = 224, + image_height = 224, + fov_y = 12.5936, + ambient_color = ambient_color) + + return render + +def Render_layer2(face_shape,face_norm,face_color,facemodel,batchsize): + + camera_position = tf.constant([0,0,10.0]) + camera_lookat = tf.constant([0,0,0.0]) + camera_up = tf.constant([0,1.0,0]) + light_positions = tf.tile(tf.reshape(tf.constant([0,0,1e5]),[1,1,3]),[batchsize,1,1]) + light_intensities = tf.tile(tf.reshape(tf.constant([0.0,0.0,0.0]),[1,1,3]),[batchsize,1,1]) + ambient_color = tf.tile(tf.reshape(tf.constant([1.0,1,1]),[1,3]),[batchsize,1]) + + #pdb.set_trace() + render = mesh_renderer.mesh_renderer(face_shape, + tf.cast(facemodel.tri-1,tf.int32), + face_norm, + face_color/255, + camera_position = camera_position, + camera_lookat = camera_lookat, + camera_up = camera_up, + light_positions = light_positions, + light_intensities = light_intensities, + image_width = 256, + image_height = 256, + fov_y = 12.5936, + ambient_color = ambient_color) + + return render \ No newline at end of file diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/CONTRIBUTING.md b/Deep3DFaceReconstruction/tf_mesh_renderer/CONTRIBUTING.md new file mode 100644 index 0000000..92ca112 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/CONTRIBUTING.md @@ -0,0 +1,24 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/LICENSE b/Deep3DFaceReconstruction/tf_mesh_renderer/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/README.md b/Deep3DFaceReconstruction/tf_mesh_renderer/README.md new file mode 100644 index 0000000..ba3ac98 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/README.md @@ -0,0 +1,58 @@ +This is a differentiable, 3D mesh renderer using TensorFlow. + +This is not an official Google product. + +The interface to the renderer is provided by mesh_renderer.py and +rasterize_triangles.py, which provide TensorFlow Ops that can be added to a +TensorFlow graph. The internals of the renderer are handled by a C++ kernel. + +The input to the C++ rendering kernel is a list of 3D vertices and a list of +triangles, where a triangle consists of a list of three vertex ids. The +output of the renderer is a pair of images containing triangle ids and +barycentric weights. Pixel values in the barycentric weight image are the +weights of the pixel center point with respect to the triangle at that pixel +(identified by the triangle id). The renderer provides derivatives of the +barycentric weights of the pixel centers with respect to the vertex +positions. + +Any approximation error stems from the assumption that the triangle id at a +pixel does not change as the vertices are moved. This is a reasonable +approximation for small changes in vertex position. Even when the triangle id +does change, the derivatives will be computed by extrapolating the barycentric +weights of a neighboring triangle, which will produce a good approximation if +the mesh is smooth. The main source of error occurs at occlusion boundaries, and +particularly at the edge of an open mesh, where the background appears opposite +the triangle's edge. + +The algorithm implemented is described by Olano and Greer, "Triangle Scan +Conversion using 2D Homogeneous Coordinates," HWWS 1997. + +How to Build +------------ + +Follow the instructions to [install TensorFlow using virtualenv](https://www.tensorflow.org/install/install_linux#installing_with_virtualenv). + +Build and run tests using Bazel from inside the (tensorflow) virtualenv: + +`(tensorflow)$ ./runtests.sh` + +The script calls the Bazel rules using the Python interpreter at +`$VIRTUAL_ENV/bin/python`. If you aren't using virtualenv, `bazel test ...` may +be sufficient. + +Citation +-------- + +If you use this renderer in your research, please cite [this paper](http://openaccess.thecvf.com/content_cvpr_2018/html/Genova_Unsupervised_Training_for_CVPR_2018_paper.html "CVF Version"): + +*Unsupervised Training for 3D Morphable Model Regression*. Kyle Genova, Forrester Cole, Aaron Maschinot, Aaron Sarna, Daniel Vlasic, and William T. Freeman. CVPR 2018, pp. 8377-8386. + +``` +@InProceedings{Genova_2018_CVPR, + author = {Genova, Kyle and Cole, Forrester and Maschinot, Aaron and Sarna, Aaron and Vlasic, Daniel and Freeman, William T.}, + title = {Unsupervised Training for 3D Morphable Model Regression}, + booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2018} +} +``` diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/WORKSPACE b/Deep3DFaceReconstruction/tf_mesh_renderer/WORKSPACE new file mode 100644 index 0000000..c02183a --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/WORKSPACE @@ -0,0 +1,10 @@ +workspace(name = "tf_mesh_renderer") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +# GoogleTest/GoogleMock framework. Used by most unit-tests. +http_archive( + name = "com_google_googletest", + urls = ["https://github.com/google/googletest/archive/master.zip"], + strip_prefix = "googletest-master", +) diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/BUILD b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/BUILD new file mode 100644 index 0000000..2fad828 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/BUILD @@ -0,0 +1,54 @@ +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "mesh_renderer", + srcs = ["mesh_renderer.py"], + deps = [ + ":rasterize_triangles", + ], +) + +py_test( + name = "mesh_renderer_test", + size = "medium", + srcs = ["mesh_renderer_test.py"], + data = [ + "//mesh_renderer/test_data:images", + "//mesh_renderer/kernels:rasterize_triangles_kernel", + ], + deps = [ + ":mesh_renderer", + ], +) + +py_library( + name = "camera_utils", + srcs = ["camera_utils.py"], +) + +py_library( + name = "test_utils", + srcs = ["test_utils.py"], +) + +py_library( + name = "rasterize_triangles", + srcs = ["rasterize_triangles.py"], + deps = [ + ":camera_utils", + ], +) + +py_test( + name = "rasterize_triangles_test", + srcs = ["rasterize_triangles_test.py"], + data = [ + "//mesh_renderer/test_data:images", + "//mesh_renderer/kernels:rasterize_triangles_kernel", + ], + deps = [ + ":camera_utils", + ":rasterize_triangles", + ":test_utils", + ], +) diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/camera_utils.py b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/camera_utils.py new file mode 100644 index 0000000..f28c555 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/camera_utils.py @@ -0,0 +1,183 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collection of TF functions for managing 3D camera matrices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import tensorflow as tf + + +def perspective(aspect_ratio, fov_y, near_clip, far_clip): + """Computes perspective transformation matrices. + + Functionality mimes gluPerspective (third_party/GL/glu/include/GLU/glu.h). + + Args: + aspect_ratio: float value specifying the image aspect ratio (width/height). + fov_y: 1-D float32 Tensor with shape [batch_size] specifying output vertical + field of views in degrees. + near_clip: 1-D float32 Tensor with shape [batch_size] specifying near + clipping plane distance. + far_clip: 1-D float32 Tensor with shape [batch_size] specifying far clipping + plane distance. + + Returns: + A [batch_size, 4, 4] float tensor that maps from right-handed points in eye + space to left-handed points in clip space. + """ + # The multiplication of fov_y by pi/360.0 simultaneously converts to radians + # and adds the half-angle factor of .5. + focal_lengths_y = 1.0 / tf.tan(fov_y * (math.pi / 360.0)) + depth_range = far_clip - near_clip + p_22 = -(far_clip + near_clip) / depth_range + p_23 = -2.0 * (far_clip * near_clip / depth_range) + + zeros = tf.zeros_like(p_23, dtype=tf.float32) + # pyformat: disable + perspective_transform = tf.concat( + [ + focal_lengths_y / aspect_ratio, zeros, zeros, zeros, + zeros, focal_lengths_y, zeros, zeros, + zeros, zeros, p_22, p_23, + zeros, zeros, -tf.ones_like(p_23, dtype=tf.float32), zeros + ], axis=0) + # pyformat: enable + perspective_transform = tf.reshape(perspective_transform, [4, 4, -1]) + return tf.transpose(perspective_transform, [2, 0, 1]) + + +def look_at(eye, center, world_up): + """Computes camera viewing matrices. + + Functionality mimes gluLookAt (third_party/GL/glu/include/GLU/glu.h). + + Args: + eye: 2-D float32 tensor with shape [batch_size, 3] containing the XYZ world + space position of the camera. + center: 2-D float32 tensor with shape [batch_size, 3] containing a position + along the center of the camera's gaze. + world_up: 2-D float32 tensor with shape [batch_size, 3] specifying the + world's up direction; the output camera will have no tilt with respect + to this direction. + + Returns: + A [batch_size, 4, 4] float tensor containing a right-handed camera + extrinsics matrix that maps points from world space to points in eye space. + """ + batch_size = center.shape[0].value + vector_degeneracy_cutoff = 1e-6 + forward = center - eye + forward_norm = tf.norm(forward, ord='euclidean', axis=1, keepdims=True) + #tf.assert_greater( + # forward_norm, + # vector_degeneracy_cutoff, + # message='Camera matrix is degenerate because eye and center are close.') + forward = tf.divide(forward, forward_norm) + + to_side = tf.linalg.cross(forward, world_up) + to_side_norm = tf.norm(to_side, ord='euclidean', axis=1, keepdims=True) + #tf.assert_greater( + # to_side_norm, + # vector_degeneracy_cutoff, + # message='Camera matrix is degenerate because up and gaze are close or' + # 'because up is degenerate.') + to_side = tf.divide(to_side, to_side_norm) + cam_up = tf.linalg.cross(to_side, forward) + + w_column = tf.constant( + batch_size * [[0., 0., 0., 1.]], dtype=tf.float32) # [batch_size, 4] + w_column = tf.reshape(w_column, [batch_size, 4, 1]) + view_rotation = tf.stack( + [to_side, cam_up, -forward, + tf.zeros_like(to_side, dtype=tf.float32)], + axis=1) # [batch_size, 4, 3] matrix + view_rotation = tf.concat( + [view_rotation, w_column], axis=2) # [batch_size, 4, 4] + + identity_batch = tf.tile(tf.expand_dims(tf.eye(3), 0), [batch_size, 1, 1]) + view_translation = tf.concat([identity_batch, tf.expand_dims(-eye, 2)], 2) + view_translation = tf.concat( + [view_translation, + tf.reshape(w_column, [batch_size, 1, 4])], 1) + camera_matrices = tf.matmul(view_rotation, view_translation) + return camera_matrices + + +def euler_matrices(angles): + """Computes a XYZ Tait-Bryan (improper Euler angle) rotation. + + Returns 4x4 matrices for convenient multiplication with other transformations. + + Args: + angles: a [batch_size, 3] tensor containing X, Y, and Z angles in radians. + + Returns: + a [batch_size, 4, 4] tensor of matrices. + """ + s = tf.sin(angles) + c = tf.cos(angles) + # Rename variables for readability in the matrix definition below. + c0, c1, c2 = (c[:, 0], c[:, 1], c[:, 2]) + s0, s1, s2 = (s[:, 0], s[:, 1], s[:, 2]) + + zeros = tf.zeros_like(s[:, 0]) + ones = tf.ones_like(s[:, 0]) + + # pyformat: disable + flattened = tf.concat( + [ + c2 * c1, c2 * s1 * s0 - c0 * s2, s2 * s0 + c2 * c0 * s1, zeros, + c1 * s2, c2 * c0 + s2 * s1 * s0, c0 * s2 * s1 - c2 * s0, zeros, + -s1, c1 * s0, c1 * c0, zeros, + zeros, zeros, zeros, ones + ], + axis=0) + # pyformat: enable + reshaped = tf.reshape(flattened, [4, 4, -1]) + return tf.transpose(reshaped, [2, 0, 1]) + + +def transform_homogeneous(matrices, vertices): + """Applies batched 4x4 homogenous matrix transformations to 3-D vertices. + + The vertices are input and output as as row-major, but are interpreted as + column vectors multiplied on the right-hand side of the matrices. More + explicitly, this function computes (MV^T)^T. + Vertices are assumed to be xyz, and are extended to xyzw with w=1. + + Args: + matrices: a [batch_size, 4, 4] tensor of matrices. + vertices: a [batch_size, N, 3] tensor of xyz vertices. + + Returns: + a [batch_size, N, 4] tensor of xyzw vertices. + + Raises: + ValueError: if matrices or vertices have the wrong number of dimensions. + """ + if len(matrices.shape) != 3: + raise ValueError( + 'matrices must have 3 dimensions (missing batch dimension?)') + if len(vertices.shape) != 3: + raise ValueError( + 'vertices must have 3 dimensions (missing batch dimension?)') + homogeneous_coord = tf.ones( + [tf.shape(vertices)[0], tf.shape(vertices)[1], 1], dtype=tf.float32) + vertices_homogeneous = tf.concat([vertices, homogeneous_coord], 2) + + return tf.matmul(vertices_homogeneous, matrices, transpose_b=True) diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/BUILD b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/BUILD new file mode 100644 index 0000000..9206aef --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/BUILD @@ -0,0 +1,31 @@ +cc_library( + name = "rasterize_triangles_impl", + srcs = ["rasterize_triangles_impl.cc"], + hdrs = ["rasterize_triangles_impl.h"], +) + +cc_test( + name = "rasterize_triangles_impl_test", + srcs = ["rasterize_triangles_impl_test.cc"], + data = [ + "//mesh_renderer/test_data:images", + ], + deps = [ + ":rasterize_triangles_impl", + "//third_party:lodepng", + "@com_google_googletest//:gtest_main", + ], +) + +genrule( + name = "rasterize_triangles_kernel", + srcs = ["rasterize_triangles_grad.cc", + "rasterize_triangles_op.cc", + "rasterize_triangles_impl.cc", + "rasterize_triangles_impl.h"], + outs = ["rasterize_triangles_kernel.so"], + cmd = "TF_INC=$$($(PYTHON) -c 'import tensorflow as tf; print(tf.sysconfig.get_include())');\ + TF_LIB=$$($(PYTHON) -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())'); \ + g++ -std=c++11 -shared $(SRCS) -o $@ -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 -I$$TF_INC -I$$TF_INC/external/nsync/public -L$$TF_LIB -ltensorflow_framework -O2", + visibility = ["//mesh_renderer:__subpackages__"], +) diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/librasterize_triangles_impl.so b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/librasterize_triangles_impl.so new file mode 100755 index 0000000..b500ee7 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/librasterize_triangles_impl.so differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_grad.cc b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_grad.cc new file mode 100644 index 0000000..d09867c --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_grad.cc @@ -0,0 +1,236 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace { + +// Threshold for a barycentric coordinate triplet's sum, below which the +// coordinates at a pixel are deemed degenerate. Most such degenerate triplets +// in an image will be exactly zero, as this is how pixels outside the mesh +// are rendered. +constexpr float kDegenerateBarycentricCoordinatesCutoff = 0.9f; + +// If the area of a triangle is very small in screen space, the corner vertices +// are approaching colinearity, and we should drop the gradient to avoid +// numerical instability (in particular, blowup, as the forward pass computation +// already only has 8 bits of precision). +constexpr float kMinimumTriangleArea = 1e-13; + +} // namespace + +namespace tf_mesh_renderer { + + using ::tensorflow::DEVICE_CPU; + using ::tensorflow::OpKernel; + using ::tensorflow::OpKernelConstruction; + using ::tensorflow::OpKernelContext; + using ::tensorflow::PartialTensorShape; + using ::tensorflow::Status; + using ::tensorflow::Tensor; + using ::tensorflow::TensorShape; + using ::tensorflow::errors::InvalidArgument; + + REGISTER_OP("RasterizeTrianglesGrad") + .Input("vertices: float32") + .Input("triangles: int32") + .Input("barycentric_coordinates: float32") + .Input("triangle_ids: int32") + .Input("df_dbarycentric_coordinates: float32") + .Attr("image_width: int") + .Attr("image_height: int") + .Output("df_dvertices: float32"); + + class RasterizeTrianglesGradOp : public OpKernel { + public: + explicit RasterizeTrianglesGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("image_width", &image_width_)); + OP_REQUIRES(context, image_width_ > 0, + InvalidArgument("Image width must be > 0, got ", image_width_)); + + OP_REQUIRES_OK(context, context->GetAttr("image_height", &image_height_)); + OP_REQUIRES( + context, image_height_ > 0, + InvalidArgument("Image height must be > 0, got ", image_height_)); + } + + ~RasterizeTrianglesGradOp() override {} + + void Compute(OpKernelContext* context) override { + const Tensor& vertices_tensor = context->input(0); + OP_REQUIRES( + context, + PartialTensorShape({-1, 4}).IsCompatibleWith(vertices_tensor.shape()), + InvalidArgument( + "RasterizeTrianglesGrad expects vertices to have shape (-1, 4).")); + auto vertices_flat = vertices_tensor.flat(); + const unsigned int vertex_count = vertices_flat.size() / 4; + const float* vertices = vertices_flat.data(); + + const Tensor& triangles_tensor = context->input(1); + OP_REQUIRES( + context, + PartialTensorShape({-1, 3}).IsCompatibleWith(triangles_tensor.shape()), + InvalidArgument( + "RasterizeTrianglesGrad expects triangles to be a matrix.")); + auto triangles_flat = triangles_tensor.flat(); + const int* triangles = triangles_flat.data(); + + const Tensor& barycentric_coordinates_tensor = context->input(2); + OP_REQUIRES(context, + TensorShape({image_height_, image_width_, 3}) == + barycentric_coordinates_tensor.shape(), + InvalidArgument( + "RasterizeTrianglesGrad expects barycentric_coordinates to " + "have shape {image_height, image_width, 3}")); + auto barycentric_coordinates_flat = + barycentric_coordinates_tensor.flat(); + const float* barycentric_coordinates = barycentric_coordinates_flat.data(); + + const Tensor& triangle_ids_tensor = context->input(3); + OP_REQUIRES( + context, + TensorShape({image_height_, image_width_}) == + triangle_ids_tensor.shape(), + InvalidArgument( + "RasterizeTrianglesGrad expected triangle_ids to have shape " + " {image_height, image_width}")); + auto triangle_ids_flat = triangle_ids_tensor.flat(); + const int* triangle_ids = triangle_ids_flat.data(); + + // The naming convention we use for all derivatives is d_d -> + // the partial of y with respect to x. + const Tensor& df_dbarycentric_coordinates_tensor = context->input(4); + OP_REQUIRES( + context, + TensorShape({image_height_, image_width_, 3}) == + df_dbarycentric_coordinates_tensor.shape(), + InvalidArgument( + "RasterizeTrianglesGrad expects df_dbarycentric_coordinates " + "to have shape {image_height, image_width, 3}")); + auto df_dbarycentric_coordinates_flat = + df_dbarycentric_coordinates_tensor.flat(); + const float* df_dbarycentric_coordinates = + df_dbarycentric_coordinates_flat.data(); + + Tensor* df_dvertices_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({vertex_count, 4}), + &df_dvertices_tensor)); + auto df_dvertices_flat = df_dvertices_tensor->flat(); + float* df_dvertices = df_dvertices_flat.data(); + std::fill(df_dvertices, df_dvertices + vertex_count * 4, 0.0f); + + // We first loop over each pixel in the output image, and compute + // dbarycentric_coordinate[0,1,2]/dvertex[0x, 0y, 1x, 1y, 2x, 2y]. + // Next we compute each value above's contribution to + // df/dvertices, building up that matrix as the output of this iteration. + for (unsigned int pixel_id = 0; pixel_id < image_height_ * image_width_; + ++pixel_id) { + // b0, b1, and b2 are the three barycentric coordinate values + // rendered at pixel pixel_id. + const float b0 = barycentric_coordinates[3 * pixel_id]; + const float b1 = barycentric_coordinates[3 * pixel_id + 1]; + const float b2 = barycentric_coordinates[3 * pixel_id + 2]; + + if (b0 + b1 + b2 < kDegenerateBarycentricCoordinatesCutoff) { + continue; + } + + const float df_db0 = df_dbarycentric_coordinates[3 * pixel_id]; + const float df_db1 = df_dbarycentric_coordinates[3 * pixel_id + 1]; + const float df_db2 = df_dbarycentric_coordinates[3 * pixel_id + 2]; + + const int triangle_at_current_pixel = triangle_ids[pixel_id]; + const int* vertices_at_current_pixel = + &triangles[3 * triangle_at_current_pixel]; + + // Extract vertex indices for the current triangle. + const int v0_id = 4 * vertices_at_current_pixel[0]; + const int v1_id = 4 * vertices_at_current_pixel[1]; + const int v2_id = 4 * vertices_at_current_pixel[2]; + + // Extract x,y,w components of the vertices' clip space coordinates. + const float x0 = vertices[v0_id]; + const float y0 = vertices[v0_id + 1]; + const float w0 = vertices[v0_id + 3]; + const float x1 = vertices[v1_id]; + const float y1 = vertices[v1_id + 1]; + const float w1 = vertices[v1_id + 3]; + const float x2 = vertices[v2_id]; + const float y2 = vertices[v2_id + 1]; + const float w2 = vertices[v2_id + 3]; + + // Compute pixel's NDC-s. + const int ix = pixel_id % image_width_; + const int iy = pixel_id / image_width_; + const float px = 2 * (ix + 0.5f) / image_width_ - 1.0f; + const float py = 2 * (iy + 0.5f) / image_height_ - 1.0f; + + // Baricentric gradients wrt each vertex coordinate share a common factor. + const float db0_dx = py * (w1 - w2) - (y1 - y2); + const float db1_dx = py * (w2 - w0) - (y2 - y0); + const float db2_dx = -(db0_dx + db1_dx); + const float db0_dy = (x1 - x2) - px * (w1 - w2); + const float db1_dy = (x2 - x0) - px * (w2 - w0); + const float db2_dy = -(db0_dy + db1_dy); + const float db0_dw = px * (y1 - y2) - py * (x1 - x2); + const float db1_dw = px * (y2 - y0) - py * (x2 - x0); + const float db2_dw = -(db0_dw + db1_dw); + + // Combine them with chain rule. + const float df_dx = df_db0 * db0_dx + df_db1 * db1_dx + df_db2 * db2_dx; + const float df_dy = df_db0 * db0_dy + df_db1 * db1_dy + df_db2 * db2_dy; + const float df_dw = df_db0 * db0_dw + df_db1 * db1_dw + df_db2 * db2_dw; + + // Values of edge equations and inverse w at the current pixel. + const float edge0_over_w = x2 * db0_dx + y2 * db0_dy + w2 * db0_dw; + const float edge1_over_w = x2 * db1_dx + y2 * db1_dy + w2 * db1_dw; + const float edge2_over_w = x1 * db2_dx + y1 * db2_dy + w1 * db2_dw; + const float w_inv = edge0_over_w + edge1_over_w + edge2_over_w; + + // All gradients share a common denominator. + const float w_sqr = 1 / (w_inv * w_inv); + + // Gradients wrt each vertex share a common factor. + const float edge0 = w_sqr * edge0_over_w; + const float edge1 = w_sqr * edge1_over_w; + const float edge2 = w_sqr * edge2_over_w; + + df_dvertices[v0_id + 0] += edge0 * df_dx; + df_dvertices[v0_id + 1] += edge0 * df_dy; + df_dvertices[v0_id + 3] += edge0 * df_dw; + df_dvertices[v1_id + 0] += edge1 * df_dx; + df_dvertices[v1_id + 1] += edge1 * df_dy; + df_dvertices[v1_id + 3] += edge1 * df_dw; + df_dvertices[v2_id + 0] += edge2 * df_dx; + df_dvertices[v2_id + 1] += edge2 * df_dy; + df_dvertices[v2_id + 3] += edge2 * df_dw; + } + } + + private: + int image_width_; + int image_height_; + }; + + REGISTER_KERNEL_BUILDER(Name("RasterizeTrianglesGrad").Device(DEVICE_CPU), + RasterizeTrianglesGradOp); + +} // namespace tf_mesh_renderer diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl.cc b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl.cc new file mode 100644 index 0000000..b7f34ef --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl.cc @@ -0,0 +1,201 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "rasterize_triangles_impl.h" + +namespace tf_mesh_renderer { + +namespace { + +// Takes the minimum of a, b, and c, rounds down, and converts to an integer +// in the range [low, high]. +inline int ClampedIntegerMin(float a, float b, float c, int low, int high) { + return std::min( + std::max(static_cast(std::floor(std::min(std::min(a, b), c))), low), + high); +} + +// Takes the maximum of a, b, and c, rounds up, and converts to an integer +// in the range [low, high]. +inline int ClampedIntegerMax(float a, float b, float c, int low, int high) { + return std::min( + std::max(static_cast(std::ceil(std::max(std::max(a, b), c))), low), + high); +} + +// Computes a 3x3 matrix inverse without dividing by the determinant. +// Instead, makes an unnormalized matrix inverse with the correct sign +// by flipping the sign of the matrix if the determinant is negative. +// By leaving out determinant division, the rows of M^-1 only depend on two out +// of three of the columns of M; i.e., the first row of M^-1 only depends on the +// second and third columns of M, the second only depends on the first and +// third, etc. This means we can compute edge functions for two neighboring +// triangles independently and produce exactly the same numerical result up to +// the sign. This in turn means we can avoid cracks in rasterization without +// using fixed-point arithmetic. +// See http://mathworld.wolfram.com/MatrixInverse.html +void ComputeUnnormalizedMatrixInverse(const float a11, const float a12, + const float a13, const float a21, + const float a22, const float a23, + const float a31, const float a32, + const float a33, float m_inv[9]) { + m_inv[0] = a22 * a33 - a32 * a23; + m_inv[1] = a13 * a32 - a33 * a12; + m_inv[2] = a12 * a23 - a22 * a13; + m_inv[3] = a23 * a31 - a33 * a21; + m_inv[4] = a11 * a33 - a31 * a13; + m_inv[5] = a13 * a21 - a23 * a11; + m_inv[6] = a21 * a32 - a31 * a22; + m_inv[7] = a12 * a31 - a32 * a11; + m_inv[8] = a11 * a22 - a21 * a12; + + // The first column of the unnormalized M^-1 contains intermediate values for + // det(M). + const float det = a11 * m_inv[0] + a12 * m_inv[3] + a13 * m_inv[6]; + + // Transfer the sign of the determinant. + if (det < 0.0f) { + for (int i = 0; i < 9; ++i) { + m_inv[i] = -m_inv[i]; + } + } +} + +// Computes the edge functions from M^-1 as described by Olano and Greer, +// "Triangle Scan Conversion using 2D Homogeneous Coordinates." +// +// This function combines equations (3) and (4). It first computes +// [a b c] = u_i * M^-1, where u_0 = [1 0 0], u_1 = [0 1 0], etc., +// then computes edge_i = aX + bY + c +void ComputeEdgeFunctions(const float px, const float py, const float m_inv[9], + float values[3]) { + for (int i = 0; i < 3; ++i) { + const float a = m_inv[3 * i + 0]; + const float b = m_inv[3 * i + 1]; + const float c = m_inv[3 * i + 2]; + + values[i] = a * px + b * py + c; + } +} + +// Determines whether the point p lies inside a front-facing triangle. +// Counts pixels exactly on an edge as inside the triangle, as long as the +// triangle is not degenerate. Degenerate (zero-area) triangles always fail the +// inside test. +bool PixelIsInsideTriangle(const float edge_values[3]) { + // Check that the edge values are all non-negative and that at least one is + // positive (triangle is non-degenerate). + return (edge_values[0] >= 0 && edge_values[1] >= 0 && edge_values[2] >= 0) && + (edge_values[0] > 0 || edge_values[1] > 0 || edge_values[2] > 0); +} + +} // namespace + +void RasterizeTrianglesImpl(const float* vertices, const int32* triangles, + int32 triangle_count, int32 image_width, + int32 image_height, int32* triangle_ids, + float* barycentric_coordinates, float* z_buffer) { + const float half_image_width = 0.5 * image_width; + const float half_image_height = 0.5 * image_height; + float unnormalized_matrix_inverse[9]; + float b_over_w[3]; + + for (int32 triangle_id = 0; triangle_id < triangle_count; ++triangle_id) { + const int32 v0_x_id = 4 * triangles[3 * triangle_id]; + const int32 v1_x_id = 4 * triangles[3 * triangle_id + 1]; + const int32 v2_x_id = 4 * triangles[3 * triangle_id + 2]; + + const float v0w = vertices[v0_x_id + 3]; + const float v1w = vertices[v1_x_id + 3]; + const float v2w = vertices[v2_x_id + 3]; + // Early exit: if all w < 0, triangle is entirely behind the eye. + if (v0w < 0 && v1w < 0 && v2w < 0) { + continue; + } + + const float v0x = vertices[v0_x_id]; + const float v0y = vertices[v0_x_id + 1]; + const float v1x = vertices[v1_x_id]; + const float v1y = vertices[v1_x_id + 1]; + const float v2x = vertices[v2_x_id]; + const float v2y = vertices[v2_x_id + 1]; + + ComputeUnnormalizedMatrixInverse(v0x, v1x, v2x, v0y, v1y, v2y, v0w, v1w, + v2w, unnormalized_matrix_inverse); + + // Initialize the bounding box to the entire screen. + int left = 0, right = image_width, bottom = 0, top = image_height; + // If the triangle is entirely inside the screen, project the vertices to + // pixel coordinates and find the triangle bounding box enlarged to the + // nearest integer and clamped to the image boundaries. + if (v0w > 0 && v1w > 0 && v2w > 0) { + const float p0x = (v0x / v0w + 1.0) * half_image_width; + const float p1x = (v1x / v1w + 1.0) * half_image_width; + const float p2x = (v2x / v2w + 1.0) * half_image_width; + const float p0y = (v0y / v0w + 1.0) * half_image_height; + const float p1y = (v1y / v1w + 1.0) * half_image_height; + const float p2y = (v2y / v2w + 1.0) * half_image_height; + left = ClampedIntegerMin(p0x, p1x, p2x, 0, image_width); + right = ClampedIntegerMax(p0x, p1x, p2x, 0, image_width); + bottom = ClampedIntegerMin(p0y, p1y, p2y, 0, image_height); + top = ClampedIntegerMax(p0y, p1y, p2y, 0, image_height); + } + + // Iterate over each pixel in the bounding box. + for (int iy = bottom; iy < top; ++iy) { + for (int ix = left; ix < right; ++ix) { + const float px = ((ix + 0.5) / half_image_width) - 1.0; + const float py = ((iy + 0.5) / half_image_height) - 1.0; + const int pixel_idx = iy * image_width + ix; + + ComputeEdgeFunctions(px, py, unnormalized_matrix_inverse, b_over_w); + if (!PixelIsInsideTriangle(b_over_w)) { + continue; + } + + const float one_over_w = b_over_w[0] + b_over_w[1] + b_over_w[2]; + const float b0 = b_over_w[0] / one_over_w; + const float b1 = b_over_w[1] / one_over_w; + const float b2 = b_over_w[2] / one_over_w; + + const float v0z = vertices[v0_x_id + 2]; + const float v1z = vertices[v1_x_id + 2]; + const float v2z = vertices[v2_x_id + 2]; + // Since we computed an unnormalized w above, we need to recompute + // a properly scaled clip-space w value and then divide clip-space z + // by that. + const float clip_z = b0 * v0z + b1 * v1z + b2 * v2z; + const float clip_w = b0 * v0w + b1 * v1w + b2 * v2w; + const float z = clip_z / clip_w; + + // Skip the pixel if it is farther than the current z-buffer pixel or + // beyond the near or far clipping plane. + if (z < -1.0 || z > 1.0 || z > z_buffer[pixel_idx]) { + continue; + } + + triangle_ids[pixel_idx] = triangle_id; + z_buffer[pixel_idx] = z; + barycentric_coordinates[3 * pixel_idx + 0] = b0; + barycentric_coordinates[3 * pixel_idx + 1] = b1; + barycentric_coordinates[3 * pixel_idx + 2] = b2; + } + } + } +} + +} // namespace tf_mesh_renderer diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl.h b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl.h new file mode 100644 index 0000000..a714330 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl.h @@ -0,0 +1,56 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MESH_RENDERER_KERNELS_RASTERIZE_TRIANGLES_IMPL_H_ +#define MESH_RENDERER_KERNELS_RASTERIZE_TRIANGLES_IMPL_H_ + +namespace tf_mesh_renderer { + +// Copied from tensorflow/core/platform/default/integral_types.h +// to avoid making this file depend on tensorflow. +typedef int int32; +typedef long long int64; + +// Computes the triangle id, barycentric coordinates, and z-buffer at each pixel +// in the image. +// +// vertices: A flattened 2D array with 4*vertex_count elements. +// Each contiguous triplet is the XYZW location of the vertex with that +// triplet's id. The coordinates are assumed to be OpenGL-style clip-space +// (i.e., post-projection, pre-divide), where X points right, Y points up, +// Z points away. +// triangles: A flattened 2D array with 3*triangle_count elements. +// Each contiguous triplet is the three vertex ids indexing into vertices +// describing one triangle with clockwise winding. +// triangle_count: The number of triangles stored in the array triangles. +// triangle_ids: A flattened 2D array with image_height*image_width elements. +// At return, each pixel contains a triangle id in the range +// [0, triangle_count). The id value is also 0 if there is no triangle +// at the pixel. The barycentric_coordinates must be checked to +// distinguish the two cases. +// barycentric_coordinates: A flattened 3D array with +// image_height*image_width*3 elements. At return, contains the triplet of +// barycentric coordinates at each pixel in the same vertex ordering as +// triangles. If no triangle is present, all coordinates are 0. +// z_buffer: A flattened 2D array with image_height*image_width elements. At +// return, contains the normalized device Z coordinates of the rendered +// triangles. +void RasterizeTrianglesImpl(const float* vertices, const int32* triangles, + int32 triangle_count, int32 image_width, + int32 image_height, int32* triangle_ids, + float* barycentric_coordinates, float* z_buffer); + +} // namespace tf_mesh_renderer + +#endif // MESH_RENDERER_OPS_KERNELS_RASTERIZE_TRIANGLES_IMPL_H_ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl_test.cc b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl_test.cc new file mode 100644 index 0000000..61bb028 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_impl_test.cc @@ -0,0 +1,254 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "rasterize_triangles_impl.h" + +#include "third_party/lodepng.h" + +namespace tf_mesh_renderer { +namespace { + +typedef unsigned char uint8; + +const int kImageHeight = 480; +const int kImageWidth = 640; + +std::string GetRunfilesRelativePath(const std::string& filename) { + const std::string srcdir = std::getenv("TEST_SRCDIR"); + const std::string test_data = "/tf_mesh_renderer/mesh_renderer/test_data/"; + return srcdir + test_data + filename; +} + +void LoadPng(const std::string& filename, std::vector* output) { + unsigned width, height; + unsigned error = lodepng::decode(*output, width, height, filename.c_str()); + ASSERT_TRUE(error == 0) << "Decoder error: " << lodepng_error_text(error); +} + +void SavePng(const std::string& filename, const std::vector& image) { + unsigned error = + lodepng::encode(filename.c_str(), image, kImageWidth, kImageHeight); + ASSERT_TRUE(error == 0) << "Encoder error: " << lodepng_error_text(error); +} + +void FloatRGBToUint8RGBA(const std::vector& input, + std::vector* output) { + output->resize(kImageHeight * kImageWidth * 4); + for (int y = 0; y < kImageHeight; ++y) { + for (int x = 0; x < kImageWidth; ++x) { + for (int c = 0; c < 3; ++c) { + (*output)[(y * kImageWidth + x) * 4 + c] = + input[(y * kImageWidth + x) * 3 + c] * 255; + } + (*output)[(y * kImageWidth + x) * 4 + 3] = 255; + } + } +} + +void ExpectImageFileAndImageAreEqual(const std::string& baseline_file, + const std::vector& result, + const std::string& comparison_name, + const std::string& failure_message) { + std::vector baseline_rgba, result_rgba; + LoadPng(GetRunfilesRelativePath(baseline_file), &baseline_rgba); + FloatRGBToUint8RGBA(result, &result_rgba); + + const bool images_match = baseline_rgba == result_rgba; + + if (!images_match) { + const std::string result_output_path = + "/tmp/" + comparison_name + "_result.png"; + SavePng(result_output_path, result_rgba); + } + + EXPECT_TRUE(images_match) << failure_message; +} + +class RasterizeTrianglesImplTest : public ::testing::Test { + protected: + void CallRasterizeTrianglesImpl(const float* vertices, const int32* triangles, + int32 triangle_count) { + const int num_pixels = image_height_ * image_width_; + barycentrics_buffer_.resize(num_pixels * 3); + triangle_ids_buffer_.resize(num_pixels); + + constexpr float kClearDepth = 1.0; + z_buffer_.resize(num_pixels, kClearDepth); + + RasterizeTrianglesImpl(vertices, triangles, triangle_count, image_width_, + image_height_, triangle_ids_buffer_.data(), + barycentrics_buffer_.data(), z_buffer_.data()); + } + + // Expects that the sum of barycentric weights at a pixel is close to a + // given value. + void ExpectBarycentricSumIsNear(int x, int y, float expected) const { + constexpr float kEpsilon = 1e-6f; + auto it = barycentrics_buffer_.begin() + y * image_width_ * 3 + x * 3; + EXPECT_NEAR(*it + *(it + 1) + *(it + 2), expected, kEpsilon); + } + // Expects that a pixel is covered by verifying that its barycentric + // coordinates sum to one. + void ExpectIsCovered(int x, int y) const { + ExpectBarycentricSumIsNear(x, y, 1.0); + } + // Expects that a pixel is not covered by verifying that its barycentric + // coordinates sum to zero. + void ExpectIsNotCovered(int x, int y) const { + ExpectBarycentricSumIsNear(x, y, 0.0); + } + + int image_height_ = 480; + int image_width_ = 640; + std::vector barycentrics_buffer_; + std::vector triangle_ids_buffer_; + std::vector z_buffer_; +}; + +TEST_F(RasterizeTrianglesImplTest, CanRasterizeTriangle) { + const std::vector vertices = {-0.5, -0.5, 0.8, 1.0, 0.0, 0.5, + 0.3, 1.0, 0.5, -0.5, 0.3, 1.0}; + const std::vector triangles = {0, 1, 2}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 1); + ExpectImageFileAndImageAreEqual("Simple_Triangle.png", barycentrics_buffer_, + "triangle", "simple triangle does not match"); +} + +TEST_F(RasterizeTrianglesImplTest, CanRasterizeExternalTriangle) { + const std::vector vertices = {-0.5, -0.5, 0.0, 1.0, 0.0, -0.5, + 0.0, -1.0, 0.5, -0.5, 0.0, 1.0}; + const std::vector triangles = {0, 1, 2}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 1); + + ExpectImageFileAndImageAreEqual("External_Triangle.png", + barycentrics_buffer_, "external triangle", + "external triangle does not match"); +} + +TEST_F(RasterizeTrianglesImplTest, CanRasterizeCameraInsideBox) { + const std::vector vertices = { + -1.0, -1.0, 0.0, 2.0, 1.0, -1.0, 0.0, 2.0, 1.0, 1.0, 0.0, + 2.0, -1.0, 1.0, 0.0, 2.0, -1.0, -1.0, 0.0, -2.0, 1.0, -1.0, + 0.0, -2.0, 1.0, 1.0, 0.0, -2.0, -1.0, 1.0, 0.0, -2.0}; + const std::vector triangles = {0, 1, 2, 0, 2, 3, 4, 5, 6, 4, 6, 7, + 2, 3, 7, 2, 7, 6, 1, 0, 4, 1, 4, 5, + 0, 3, 7, 0, 7, 4, 1, 2, 6, 1, 6, 5}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 12); + + ExpectImageFileAndImageAreEqual("Inside_Box.png", + barycentrics_buffer_, "camera inside box", + "camera inside box does not match"); +} + +TEST_F(RasterizeTrianglesImplTest, CanRasterizeTetrahedron) { + const std::vector vertices = {-0.5, -0.5, 0.8, 1.0, 0.0, 0.5, + 0.3, 1.0, 0.5, -0.5, 0.3, 1.0, + 0.0, 0.0, 0.0, 1.0}; + const std::vector triangles = {0, 2, 1, 0, 1, 3, 1, 2, 3, 2, 0, 3}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 4); + + ExpectImageFileAndImageAreEqual("Simple_Tetrahedron.png", + barycentrics_buffer_, "tetrahedron", + "simple tetrahedron does not match"); +} + +TEST_F(RasterizeTrianglesImplTest, CanRasterizeCube) { + // Vertex values were obtained by dumping the clip-space vertex values from + // the renderSimpleCube test in ../rasterize_triangles_test.py. + const std::vector vertices = { + -2.60648608, -3.22707772, 6.85085106, 6.85714293, + -1.30324292, -0.992946863, 8.56856918, 8.5714283, + -1.30324292, 3.97178817, 7.70971, 7.71428585, + -2.60648608, 1.73765731, 5.991992, 6, + 1.30324292, -3.97178817, 6.27827835, 6.28571415, + 2.60648608, -1.73765731, 7.99599648, 8, + 2.60648608, 3.22707772, 7.13713741, 7.14285707, + 1.30324292, 0.992946863, 5.41941929, 5.4285717}; + + const std::vector triangles = {0, 1, 2, 2, 3, 0, 3, 2, 6, 6, 7, 3, + 7, 6, 5, 5, 4, 7, 4, 5, 1, 1, 0, 4, + 5, 6, 2, 2, 1, 5, 7, 4, 0, 0, 3, 7}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 12); + + ExpectImageFileAndImageAreEqual("Barycentrics_Cube.png", + barycentrics_buffer_, "cube", "cube does not match"); +} + +TEST_F(RasterizeTrianglesImplTest, WorksWhenPixelIsOnTriangleEdge) { + // Verifies that a pixel that lies exactly on a triangle edge is considered + // inside the triangle. + image_width_ = 641; + const int x_pixel = image_width_ / 2; + const float x_ndc = 0.0; + constexpr int yPixel = 5; + + const std::vector vertices = {x_ndc, -1.0, 0.5, 1.0, x_ndc, 1.0, + 0.5, 1.0, 0.5, -1.0, 0.5, 1.0}; + { + const std::vector triangles = {0, 1, 2}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 1); + + ExpectIsCovered(x_pixel, yPixel); + } + { + // Test the triangle with the same vertices in reverse order. + const std::vector triangles = {2, 1, 0}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 1); + + ExpectIsCovered(x_pixel, yPixel); + } +} + +TEST_F(RasterizeTrianglesImplTest, CoversEdgePixelsOfImage) { + // Verifies that the pixels along image edges are correct covered. + + const std::vector vertices = {-1.0, -1.0, 0.0, 1.0, 1.0, -1.0, + 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, + -1.0, 1.0, 0.0, 1.0}; + const std::vector triangles = {0, 1, 2, 0, 2, 3}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 2); + + ExpectIsCovered(0, 0); + ExpectIsCovered(image_width_ - 1, 0); + ExpectIsCovered(image_width_ - 1, image_height_ - 1); + ExpectIsCovered(0, image_height_ - 1); +} + +TEST_F(RasterizeTrianglesImplTest, PixelOnDegenerateTriangleIsNotInside) { + // Verifies that a pixel lying exactly on a triangle with zero area is + // counted as lying outside the triangle. + image_width_ = 1; + image_height_ = 1; + const std::vector vertices = {-1.0, -1.0, 0.0, 1.0, 1.0, 1.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 1.0}; + const std::vector triangles = {0, 1, 2}; + + CallRasterizeTrianglesImpl(vertices.data(), triangles.data(), 1); + + ExpectIsNotCovered(0, 0); +} + +} // namespace +} // namespace tf_mesh_renderer diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_kernel.so b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_kernel.so new file mode 100755 index 0000000..673a21b Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_kernel.so differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_op.cc b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_op.cc new file mode 100644 index 0000000..dbf3ab4 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_op.cc @@ -0,0 +1,149 @@ +// Copyright 2017 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "rasterize_triangles_impl.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tf_mesh_renderer { + +using ::tensorflow::DEVICE_CPU; +using ::tensorflow::int32; +using ::tensorflow::OpKernel; +using ::tensorflow::OpKernelConstruction; +using ::tensorflow::OpKernelContext; +using ::tensorflow::PartialTensorShape; +using ::tensorflow::Status; +using ::tensorflow::Tensor; +using ::tensorflow::TensorShape; +using ::tensorflow::TensorShapeUtils; +using ::tensorflow::errors::Internal; +using ::tensorflow::errors::InvalidArgument; + +REGISTER_OP("RasterizeTriangles") + .Input("vertices: float32") + .Input("triangles: int32") + .Attr("image_width: int") + .Attr("image_height: int") + .Output("barycentric_coordinates: float32") + .Output("triangle_ids: int32") + .Output("z_buffer: float32") + .Doc(R"doc( +Implements a rasterization kernel for rendering mesh geometry. + +vertices: 2-D tensor with shape [vertex_count, 4]. The 3-D positions of the mesh + vertices in clip-space (XYZW). +triangles: 2-D tensor with shape [triangle_count, 3]. Each row is a tuple of + indices into vertices specifying a triangle to be drawn. The triangle has an + outward facing normal when the given indices appear in a clockwise winding to + the viewer. +image_width: positive int attribute specifying the width of the output image. +image_height: positive int attribute specifying the height of the output image. +barycentric_coordinates: 3-D tensor with shape [image_height, image_width, 3] + containing the rendered barycentric coordinate triplet per pixel, before + perspective correction. The triplet is the zero vector if the pixel is outside + the mesh boundary. For valid pixels, the ordering of the coordinates + corresponds to the ordering in triangles. +triangle_ids: 2-D tensor with shape [image_height, image_width]. Contains the + triangle id value for each pixel in the output image. For pixels within the + mesh, this is the integer value in the range [0, num_vertices] from triangles. + For vertices outside the mesh this is 0; 0 can either indicate belonging to + triangle 0, or being outside the mesh. This ensures all returned triangle ids + will validly index into the vertex array, enabling the use of tf.gather with + indices from this tensor. The barycentric coordinates can be used to determine + pixel validity instead. +z_buffer: 2-D tensor with shape [image_height, image_width]. Contains the Z + coordinate in Normalized Device Coordinates for each pixel occupied by a + triangle. +)doc"); + +class RasterizeTrianglesOp : public OpKernel { + public: + explicit RasterizeTrianglesOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("image_width", &image_width_)); + OP_REQUIRES(context, image_width_ > 0, + InvalidArgument("Image width must be > 0, got ", image_width_)); + + OP_REQUIRES_OK(context, context->GetAttr("image_height", &image_height_)); + OP_REQUIRES( + context, image_height_ > 0, + InvalidArgument("Image height must be > 0, got ", image_height_)); + } + + ~RasterizeTrianglesOp() override {} + + void Compute(OpKernelContext* context) override { + const Tensor& vertices_tensor = context->input(0); + OP_REQUIRES( + context, + PartialTensorShape({-1, 4}).IsCompatibleWith(vertices_tensor.shape()), + InvalidArgument( + "RasterizeTriangles expects vertices to have shape (-1, 4).")); + auto vertices_flat = vertices_tensor.flat(); + const float* vertices = vertices_flat.data(); + + const Tensor& triangles_tensor = context->input(1); + OP_REQUIRES( + context, + PartialTensorShape({-1, 3}).IsCompatibleWith(triangles_tensor.shape()), + InvalidArgument( + "RasterizeTriangles expects triangles to be a matrix.")); + auto triangles_flat = triangles_tensor.flat(); + const int32* triangles = triangles_flat.data(); + const int triangle_count = triangles_flat.size() / 3; + + Tensor* barycentric_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({image_height_, image_width_, 3}), + &barycentric_tensor)); + + Tensor* triangle_ids_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({image_height_, image_width_}), + &triangle_ids_tensor)); + + Tensor* z_buffer_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 2, TensorShape({image_height_, image_width_}), + &z_buffer_tensor)); + + // Clear barycentric and triangle id buffers to 0. + // Clear z-buffer to 1 (the farthest NDC z value). + barycentric_tensor->flat().setZero(); + triangle_ids_tensor->flat().setZero(); + z_buffer_tensor->flat().setConstant(1); + + RasterizeTrianglesImpl(vertices, triangles, triangle_count, image_width_, + image_height_, + triangle_ids_tensor->flat().data(), + barycentric_tensor->flat().data(), + z_buffer_tensor->flat().data()); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RasterizeTrianglesOp); + + int image_width_; + int image_height_; +}; + +REGISTER_KERNEL_BUILDER(Name("RasterizeTriangles").Device(DEVICE_CPU), + RasterizeTrianglesOp); + +} // namespace tf_mesh_renderer diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/mesh_renderer.py b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/mesh_renderer.py new file mode 100644 index 0000000..618ea76 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/mesh_renderer.py @@ -0,0 +1,462 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Differentiable 3-D rendering of a triangle mesh.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from . import camera_utils +from . import rasterize_triangles + + +def phong_shader(normals, + alphas, + pixel_positions, + light_positions, + light_intensities, + diffuse_colors=None, + camera_position=None, + specular_colors=None, + shininess_coefficients=None, + ambient_color=None): + """Computes pixelwise lighting from rasterized buffers with the Phong model. + + Args: + normals: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the world space XYZ normal for + the corresponding pixel. Should be already normalized. + alphas: a 3D float32 tensor with shape [batch_size, image_height, + image_width]. The inner dimension is the alpha value (transparency) + for the corresponding pixel. + pixel_positions: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the world space XYZ position for + the corresponding pixel. + light_positions: a 3D tensor with shape [batch_size, light_count, 3]. The + XYZ position of each light in the scene. In the same coordinate space as + pixel_positions. + light_intensities: a 3D tensor with shape [batch_size, light_count, 3]. The + RGB intensity values for each light. Intensities may be above one. + diffuse_colors: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the diffuse RGB coefficients at + a pixel in the range [0, 1]. + camera_position: a 1D tensor with shape [batch_size, 3]. The XYZ camera + position in the scene. If supplied, specular reflections will be + computed. If not supplied, specular_colors and shininess_coefficients + are expected to be None. In the same coordinate space as + pixel_positions. + specular_colors: a 4D float32 tensor with shape [batch_size, image_height, + image_width, 3]. The inner dimension is the specular RGB coefficients at + a pixel in the range [0, 1]. If None, assumed to be tf.zeros() + shininess_coefficients: A 3D float32 tensor that is broadcasted to shape + [batch_size, image_height, image_width]. The inner dimension is the + shininess coefficient for the object at a pixel. Dimensions that are + constant can be given length 1, so [batch_size, 1, 1] and [1, 1, 1] are + also valid input shapes. + ambient_color: a 2D tensor with shape [batch_size, 3]. The RGB ambient + color, which is added to each pixel before tone mapping. If None, it is + assumed to be tf.zeros(). + Returns: + A 4D float32 tensor of shape [batch_size, image_height, image_width, 4] + containing the lit RGBA color values for each image at each pixel. Colors + are in the range [0,1]. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + batch_size, image_height, image_width = [s.value for s in normals.shape[:-1]] + light_count = light_positions.shape[1].value + pixel_count = image_height * image_width + # Reshape all values to easily do pixelwise computations: + normals = tf.reshape(normals, [batch_size, -1, 3]) + alphas = tf.reshape(alphas, [batch_size, -1, 1]) + diffuse_colors = tf.reshape(diffuse_colors, [batch_size, -1, 3]) + if camera_position is not None: + specular_colors = tf.reshape(specular_colors, [batch_size, -1, 3]) + + # Ambient component + output_colors = tf.zeros([batch_size, image_height * image_width, 3]) + if ambient_color is not None: + ambient_reshaped = tf.expand_dims(ambient_color, axis=1) + output_colors = tf.add(output_colors, ambient_reshaped * diffuse_colors) + + # Diffuse component + pixel_positions = tf.reshape(pixel_positions, [batch_size, -1, 3]) + per_light_pixel_positions = tf.stack( + [pixel_positions] * light_count, + axis=1) # [batch_size, light_count, pixel_count, 3] + directions_to_lights = tf.nn.l2_normalize( + tf.expand_dims(light_positions, axis=2) - per_light_pixel_positions, + axis=3) # [batch_size, light_count, pixel_count, 3] + # The specular component should only contribute when the light and normal + # face one another (i.e. the dot product is nonnegative): + normals_dot_lights = tf.clip_by_value( + tf.reduce_sum( + tf.expand_dims(normals, axis=1) * directions_to_lights, axis=3), 0.0, + 1.0) # [batch_size, light_count, pixel_count] + diffuse_output = tf.expand_dims( + diffuse_colors, axis=1) * tf.expand_dims( + normals_dot_lights, axis=3) * tf.expand_dims( + light_intensities, axis=2) + diffuse_output = tf.reduce_sum( + diffuse_output, axis=1) # [batch_size, pixel_count, 3] + output_colors = tf.add(output_colors, diffuse_output) + + # Specular component + if camera_position is not None: + camera_position = tf.reshape(camera_position, [batch_size, 1, 3]) + mirror_reflection_direction = tf.nn.l2_normalize( + 2.0 * tf.expand_dims(normals_dot_lights, axis=3) * tf.expand_dims( + normals, axis=1) - directions_to_lights, + dim=3) + direction_to_camera = tf.nn.l2_normalize( + camera_position - pixel_positions, dim=2) + reflection_direction_dot_camera_direction = tf.reduce_sum( + tf.expand_dims(direction_to_camera, axis=1) * + mirror_reflection_direction, + axis=3) + # The specular component should only contribute when the reflection is + # external: + reflection_direction_dot_camera_direction = tf.clip_by_value( + tf.nn.l2_normalize(reflection_direction_dot_camera_direction, dim=2), + 0.0, 1.0) + # The specular component should also only contribute when the diffuse + # component contributes: + reflection_direction_dot_camera_direction = tf.where( + normals_dot_lights != 0.0, reflection_direction_dot_camera_direction, + tf.zeros_like( + reflection_direction_dot_camera_direction, dtype=tf.float32)) + # Reshape to support broadcasting the shininess coefficient, which rarely + # varies per-vertex: + reflection_direction_dot_camera_direction = tf.reshape( + reflection_direction_dot_camera_direction, + [batch_size, light_count, image_height, image_width]) + shininess_coefficients = tf.expand_dims(shininess_coefficients, axis=1) + specularity = tf.reshape( + tf.pow(reflection_direction_dot_camera_direction, + shininess_coefficients), + [batch_size, light_count, pixel_count, 1]) + specular_output = tf.expand_dims( + specular_colors, axis=1) * specularity * tf.expand_dims( + light_intensities, axis=2) + specular_output = tf.reduce_sum(specular_output, axis=1) + output_colors = tf.add(output_colors, specular_output) + rgb_images = tf.reshape(output_colors, + [batch_size, image_height, image_width, 3]) + alpha_images = tf.reshape(alphas, [batch_size, image_height, image_width, 1]) + valid_rgb_values = tf.concat(3 * [alpha_images > 0.5], axis=3) + rgb_images = tf.where(valid_rgb_values, rgb_images, + tf.zeros_like(rgb_images, dtype=tf.float32)) + return tf.reverse(tf.concat([rgb_images, alpha_images], axis=3), axis=[1]) + + +def tone_mapper(image, gamma): + """Applies gamma correction to the input image. + + Tone maps the input image batch in order to make scenes with a high dynamic + range viewable. The gamma correction factor is computed separately per image, + but is shared between all provided channels. The exact function computed is: + + image_out = A*image_in^gamma, where A is an image-wide constant computed so + that the maximum image value is approximately 1. The correction is applied + to all channels. + + Args: + image: 4-D float32 tensor with shape [batch_size, image_height, + image_width, channel_count]. The batch of images to tone map. + gamma: 0-D float32 nonnegative tensor. Values of gamma below one compress + relative contrast in the image, and values above one increase it. A + value of 1 is equivalent to scaling the image to have a maximum value + of 1. + Returns: + 4-D float32 tensor with shape [batch_size, image_height, image_width, + channel_count]. Contains the gamma-corrected images, clipped to the range + [0, 1]. + """ + batch_size = image.shape[0].value + corrected_image = tf.pow(image, gamma) + image_max = tf.reduce_max( + tf.reshape(corrected_image, [batch_size, -1]), axis=1) + scaled_image = tf.divide(corrected_image, + tf.reshape(image_max, [batch_size, 1, 1, 1])) + return tf.clip_by_value(scaled_image, 0.0, 1.0) + + +def mesh_renderer(vertices, + triangles, + normals, + diffuse_colors, + camera_position, + camera_lookat, + camera_up, + light_positions, + light_intensities, + image_width, + image_height, + specular_colors=None, + shininess_coefficients=None, + ambient_color=None, + fov_y=40.0, + near_clip=0.01, + far_clip=10.0): + """Renders an input scene using phong shading, and returns an output image. + + Args: + vertices: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each + triplet is an xyz position in world space. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + normals: 3-D float32 tensor with shape [batch_size, vertex_count, 3]. Each + triplet is the xyz vertex normal for its corresponding vertex. Each + vector is assumed to be already normalized. + diffuse_colors: 3-D float32 tensor with shape [batch_size, + vertex_count, 3]. The RGB diffuse reflection in the range [0,1] for + each vertex. + camera_position: 2-D tensor with shape [batch_size, 3] or 1-D tensor with + shape [3] specifying the XYZ world space camera position. + camera_lookat: 2-D tensor with shape [batch_size, 3] or 1-D tensor with + shape [3] containing an XYZ point along the center of the camera's gaze. + camera_up: 2-D tensor with shape [batch_size, 3] or 1-D tensor with shape + [3] containing the up direction for the camera. The camera will have no + tilt with respect to this direction. + light_positions: a 3-D tensor with shape [batch_size, light_count, 3]. The + XYZ position of each light in the scene. In the same coordinate space as + pixel_positions. + light_intensities: a 3-D tensor with shape [batch_size, light_count, 3]. The + RGB intensity values for each light. Intensities may be above one. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + specular_colors: 3-D float32 tensor with shape [batch_size, + vertex_count, 3]. The RGB specular reflection in the range [0, 1] for + each vertex. If supplied, specular reflections will be computed, and + both specular_colors and shininess_coefficients are expected. + shininess_coefficients: a 0D-2D float32 tensor with maximum shape + [batch_size, vertex_count]. The phong shininess coefficient of each + vertex. A 0D tensor or float gives a constant shininess coefficient + across all batches and images. A 1D tensor must have shape [batch_size], + and a single shininess coefficient per image is used. + ambient_color: a 2D tensor with shape [batch_size, 3]. The RGB ambient + color, which is added to each pixel in the scene. If None, it is + assumed to be black. + fov_y: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + desired output image y field of view in degrees. + near_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + near clipping plane distance. + far_clip: float, 0D tensor, or 1D tensor with shape [batch_size] specifying + far clipping plane distance. + + Returns: + A 4-D float32 tensor of shape [batch_size, image_height, image_width, 4] + containing the lit RGBA color values for each image at each pixel. RGB + colors are the intensity values before tonemapping and can be in the range + [0, infinity]. Clipping to the range [0,1] with tf.clip_by_value is likely + reasonable for both viewing and training most scenes. More complex scenes + with multiple lights should tone map color values for display only. One + simple tonemapping approach is to rescale color values as x/(1+x); gamma + compression is another common techinque. Alpha values are zero for + background pixels and near one for mesh pixels. + Raises: + ValueError: An invalid argument to the method is detected. + """ + if len(vertices.shape) != 3: + raise ValueError('Vertices must have shape [batch_size, vertex_count, 3].') + batch_size = vertices.shape[0].value + if len(normals.shape) != 3: + raise ValueError('Normals must have shape [batch_size, vertex_count, 3].') + if len(light_positions.shape) != 3: + raise ValueError( + 'Light_positions must have shape [batch_size, light_count, 3].') + if len(light_intensities.shape) != 3: + raise ValueError( + 'Light_intensities must have shape [batch_size, light_count, 3].') + if len(diffuse_colors.shape) != 3: + raise ValueError( + 'vertex_diffuse_colors must have shape [batch_size, vertex_count, 3].') + if (ambient_color is not None and + ambient_color.get_shape().as_list() != [batch_size, 3]): + raise ValueError('Ambient_color must have shape [batch_size, 3].') + if camera_position.get_shape().as_list() == [3]: + camera_position = tf.tile( + tf.expand_dims(camera_position, axis=0), [batch_size, 1]) + elif camera_position.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_position must have shape [batch_size, 3]') + if camera_lookat.get_shape().as_list() == [3]: + camera_lookat = tf.tile( + tf.expand_dims(camera_lookat, axis=0), [batch_size, 1]) + elif camera_lookat.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_lookat must have shape [batch_size, 3]') + if camera_up.get_shape().as_list() == [3]: + camera_up = tf.tile(tf.expand_dims(camera_up, axis=0), [batch_size, 1]) + elif camera_up.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_up must have shape [batch_size, 3]') + if isinstance(fov_y, float): + fov_y = tf.constant(batch_size * [fov_y], dtype=tf.float32) + elif not fov_y.get_shape().as_list(): + fov_y = tf.tile(tf.expand_dims(fov_y, 0), [batch_size]) + elif fov_y.get_shape().as_list() != [batch_size]: + raise ValueError('Fov_y must be a float, a 0D tensor, or a 1D tensor with' + 'shape [batch_size]') + if isinstance(near_clip, float): + near_clip = tf.constant(batch_size * [near_clip], dtype=tf.float32) + elif not near_clip.get_shape().as_list(): + near_clip = tf.tile(tf.expand_dims(near_clip, 0), [batch_size]) + elif near_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Near_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + if isinstance(far_clip, float): + far_clip = tf.constant(batch_size * [far_clip], dtype=tf.float32) + elif not far_clip.get_shape().as_list(): + far_clip = tf.tile(tf.expand_dims(far_clip, 0), [batch_size]) + elif far_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Far_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + if specular_colors is not None and shininess_coefficients is None: + raise ValueError( + 'Specular colors were supplied without shininess coefficients.') + if shininess_coefficients is not None and specular_colors is None: + raise ValueError( + 'Shininess coefficients were supplied without specular colors.') + if specular_colors is not None: + # Since a 0-D float32 tensor is accepted, also accept a float. + if isinstance(shininess_coefficients, float): + shininess_coefficients = tf.constant( + shininess_coefficients, dtype=tf.float32) + if len(specular_colors.shape) != 3: + raise ValueError('The specular colors must have shape [batch_size, ' + 'vertex_count, 3].') + if len(shininess_coefficients.shape) > 2: + raise ValueError('The shininess coefficients must have shape at most' + '[batch_size, vertex_count].') + # If we don't have per-vertex coefficients, we can just reshape the + # input shininess to broadcast later, rather than interpolating an + # additional vertex attribute: + if len(shininess_coefficients.shape) < 2: + vertex_attributes = tf.concat( + [normals, vertices, diffuse_colors, specular_colors], axis=2) + else: + vertex_attributes = tf.concat( + [ + normals, vertices, diffuse_colors, specular_colors, + tf.expand_dims(shininess_coefficients, axis=2) + ], + axis=2) + else: + vertex_attributes = tf.concat([normals, vertices, diffuse_colors], axis=2) + + camera_matrices = camera_utils.look_at(camera_position, camera_lookat, + camera_up) + + perspective_transforms = camera_utils.perspective(image_width / image_height, + fov_y, near_clip, far_clip) + + clip_space_transforms = tf.matmul(perspective_transforms, camera_matrices) + + pixel_attributes = rasterize_triangles.rasterize( + vertices, vertex_attributes, triangles, clip_space_transforms, + image_width, image_height, [-1] * vertex_attributes.shape[2].value) + + # Extract the interpolated vertex attributes from the pixel buffer and + # supply them to the shader: + pixel_normals = tf.nn.l2_normalize(pixel_attributes[:, :, :, 0:3], axis=3) + pixel_positions = pixel_attributes[:, :, :, 3:6] + diffuse_colors = pixel_attributes[:, :, :, 6:9] + if specular_colors is not None: + specular_colors = pixel_attributes[:, :, :, 9:12] + # Retrieve the interpolated shininess coefficients if necessary, or just + # reshape our input for broadcasting: + if len(shininess_coefficients.shape) == 2: + shininess_coefficients = pixel_attributes[:, :, :, 12] + else: + shininess_coefficients = tf.reshape(shininess_coefficients, [-1, 1, 1]) + + pixel_mask = tf.cast(tf.reduce_any(diffuse_colors >= 0, axis=3), tf.float32) + + renders = phong_shader( + normals=pixel_normals, + alphas=pixel_mask, + pixel_positions=pixel_positions, + light_positions=light_positions, + light_intensities=light_intensities, + diffuse_colors=diffuse_colors, + camera_position=camera_position if specular_colors is not None else None, + specular_colors=specular_colors, + shininess_coefficients=shininess_coefficients, + ambient_color=ambient_color) + return renders + +def clip_vertices(vertices, + camera_position, + camera_lookat, + camera_up, + image_width, + image_height, + fov_y=40.0, + near_clip=0.01, + far_clip=10.0): + if len(vertices.shape) != 3: + raise ValueError('Vertices must have shape [batch_size, vertex_count, 3].') + batch_size = vertices.shape[0].value + if camera_position.get_shape().as_list() == [3]: + camera_position = tf.tile( + tf.expand_dims(camera_position, axis=0), [batch_size, 1]) + elif camera_position.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_position must have shape [batch_size, 3]') + if camera_lookat.get_shape().as_list() == [3]: + camera_lookat = tf.tile( + tf.expand_dims(camera_lookat, axis=0), [batch_size, 1]) + elif camera_lookat.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_lookat must have shape [batch_size, 3]') + if camera_up.get_shape().as_list() == [3]: + camera_up = tf.tile(tf.expand_dims(camera_up, axis=0), [batch_size, 1]) + elif camera_up.get_shape().as_list() != [batch_size, 3]: + raise ValueError('Camera_up must have shape [batch_size, 3]') + if isinstance(fov_y, float): + fov_y = tf.constant(batch_size * [fov_y], dtype=tf.float32) + elif not fov_y.get_shape().as_list(): + fov_y = tf.tile(tf.expand_dims(fov_y, 0), [batch_size]) + elif fov_y.get_shape().as_list() != [batch_size]: + raise ValueError('Fov_y must be a float, a 0D tensor, or a 1D tensor with' + 'shape [batch_size]') + if isinstance(near_clip, float): + near_clip = tf.constant(batch_size * [near_clip], dtype=tf.float32) + elif not near_clip.get_shape().as_list(): + near_clip = tf.tile(tf.expand_dims(near_clip, 0), [batch_size]) + elif near_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Near_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + if isinstance(far_clip, float): + far_clip = tf.constant(batch_size * [far_clip], dtype=tf.float32) + elif not far_clip.get_shape().as_list(): + far_clip = tf.tile(tf.expand_dims(far_clip, 0), [batch_size]) + elif far_clip.get_shape().as_list() != [batch_size]: + raise ValueError('Far_clip must be a float, a 0D tensor, or a 1D tensor' + 'with shape [batch_size]') + + camera_matrices = camera_utils.look_at(camera_position, camera_lookat, + camera_up) + + perspective_transforms = camera_utils.perspective(image_width / image_height, + fov_y, near_clip, far_clip) + + clip_space_transforms = tf.matmul(perspective_transforms, camera_matrices) + + clip_space_vertices = camera_utils.transform_homogeneous( + clip_space_transforms, vertices) + return clip_space_vertices diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/mesh_renderer_test.py b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/mesh_renderer_test.py new file mode 100644 index 0000000..930305a --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/mesh_renderer_test.py @@ -0,0 +1,317 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import os + +import numpy as np +import tensorflow as tf + +import camera_utils +import mesh_renderer +import test_utils + + +class RenderTest(tf.test.TestCase): + + def setUp(self): + self.test_data_directory = ( + 'mesh_renderer/test_data/') + + tf.reset_default_graph() + # Set up a basic cube centered at the origin, with vertex normals pointing + # outwards along the line from the origin to the cube vertices: + self.cube_vertices = tf.constant( + [[-1, -1, 1], [-1, -1, -1], [-1, 1, -1], [-1, 1, 1], [1, -1, 1], + [1, -1, -1], [1, 1, -1], [1, 1, 1]], + dtype=tf.float32) + self.cube_normals = tf.nn.l2_normalize(self.cube_vertices, dim=1) + self.cube_triangles = tf.constant( + [[0, 1, 2], [2, 3, 0], [3, 2, 6], [6, 7, 3], [7, 6, 5], [5, 4, 7], + [4, 5, 1], [1, 0, 4], [5, 6, 2], [2, 1, 5], [7, 4, 0], [0, 3, 7]], + dtype=tf.int32) + + def testRendersSimpleCube(self): + """Renders a simple cube to test the full forward pass. + + Verifies the functionality of both the custom kernel and the python wrapper. + """ + + model_transforms = camera_utils.euler_matrices( + [[-20.0, 0.0, 60.0], [45.0, 60.0, 0.0]])[:, :3, :3] + + vertices_world_space = tf.matmul( + tf.stack([self.cube_vertices, self.cube_vertices]), + model_transforms, + transpose_b=True) + + normals_world_space = tf.matmul( + tf.stack([self.cube_normals, self.cube_normals]), + model_transforms, + transpose_b=True) + + # camera position: + eye = tf.constant(2 * [[0.0, 0.0, 6.0]], dtype=tf.float32) + center = tf.constant(2 * [[0.0, 0.0, 0.0]], dtype=tf.float32) + world_up = tf.constant(2 * [[0.0, 1.0, 0.0]], dtype=tf.float32) + image_width = 640 + image_height = 480 + light_positions = tf.constant([[[0.0, 0.0, 6.0]], [[0.0, 0.0, 6.0]]]) + light_intensities = tf.ones([2, 1, 3], dtype=tf.float32) + vertex_diffuse_colors = tf.ones_like(vertices_world_space, dtype=tf.float32) + + rendered = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + + with self.test_session() as sess: + images = sess.run(rendered, feed_dict={}) + for image_id in range(images.shape[0]): + target_image_name = 'Gray_Cube_%i.png' % image_id + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, images[image_id, :, :, :]) + + def testComplexShading(self): + """Tests specular highlights, colors, and multiple lights per image.""" + # rotate the cube for the test: + model_transforms = camera_utils.euler_matrices( + [[-20.0, 0.0, 60.0], [45.0, 60.0, 0.0]])[:, :3, :3] + + vertices_world_space = tf.matmul( + tf.stack([self.cube_vertices, self.cube_vertices]), + model_transforms, + transpose_b=True) + + normals_world_space = tf.matmul( + tf.stack([self.cube_normals, self.cube_normals]), + model_transforms, + transpose_b=True) + + # camera position: + eye = tf.constant([[0.0, 0.0, 6.0], [0., 0.2, 18.0]], dtype=tf.float32) + center = tf.constant([[0.0, 0.0, 0.0], [0.1, -0.1, 0.1]], dtype=tf.float32) + world_up = tf.constant( + [[0.0, 1.0, 0.0], [0.1, 1.0, 0.15]], dtype=tf.float32) + fov_y = tf.constant([40., 13.3], dtype=tf.float32) + near_clip = tf.constant(0.1, dtype=tf.float32) + far_clip = tf.constant(25.0, dtype=tf.float32) + image_width = 640 + image_height = 480 + light_positions = tf.constant([[[0.0, 0.0, 6.0], [1.0, 2.0, 6.0]], + [[0.0, -2.0, 4.0], [1.0, 3.0, 4.0]]]) + light_intensities = tf.constant( + [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[2.0, 0.0, 1.0], [0.0, 2.0, + 1.0]]], + dtype=tf.float32) + # pyformat: disable + vertex_diffuse_colors = tf.constant(2*[[[1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + [0.5, 0.5, 0.5]]], + dtype=tf.float32) + vertex_specular_colors = tf.constant(2*[[[0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + [0.5, 0.5, 0.5], + [1.0, 0.0, 0.0]]], + dtype=tf.float32) + # pyformat: enable + shininess_coefficients = 6.0 * tf.ones([2, 8], dtype=tf.float32) + ambient_color = tf.constant( + [[0., 0., 0.], [0.1, 0.1, 0.2]], dtype=tf.float32) + renders = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height, vertex_specular_colors, + shininess_coefficients, ambient_color, fov_y, near_clip, far_clip) + tonemapped_renders = tf.concat( + [ + mesh_renderer.tone_mapper(renders[:, :, :, 0:3], 0.7), + renders[:, :, :, 3:4] + ], + axis=3) + + # Check that shininess coefficient broadcasting works by also rendering + # with a scalar shininess coefficient, and ensuring the result is identical: + broadcasted_renders = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height, vertex_specular_colors, + 6.0, ambient_color, fov_y, near_clip, far_clip) + tonemapped_broadcasted_renders = tf.concat( + [ + mesh_renderer.tone_mapper(broadcasted_renders[:, :, :, 0:3], 0.7), + broadcasted_renders[:, :, :, 3:4] + ], + axis=3) + + with self.test_session() as sess: + images, broadcasted_images = sess.run( + [tonemapped_renders, tonemapped_broadcasted_renders], feed_dict={}) + + for image_id in range(images.shape[0]): + target_image_name = 'Colored_Cube_%i.png' % image_id + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, images[image_id, :, :, :]) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, + broadcasted_images[image_id, :, :, :]) + + def testFullRenderGradientComputation(self): + """Verifies the Jacobian matrix for the entire renderer. + + This ensures correct gradients are propagated backwards through the entire + process, not just through the rasterization kernel. Uses the simple cube + forward pass. + """ + image_height = 21 + image_width = 28 + + # rotate the cube for the test: + model_transforms = camera_utils.euler_matrices( + [[-20.0, 0.0, 60.0], [45.0, 60.0, 0.0]])[:, :3, :3] + + vertices_world_space = tf.matmul( + tf.stack([self.cube_vertices, self.cube_vertices]), + model_transforms, + transpose_b=True) + + normals_world_space = tf.matmul( + tf.stack([self.cube_normals, self.cube_normals]), + model_transforms, + transpose_b=True) + + # camera position: + eye = tf.constant([0.0, 0.0, 6.0], dtype=tf.float32) + center = tf.constant([0.0, 0.0, 0.0], dtype=tf.float32) + world_up = tf.constant([0.0, 1.0, 0.0], dtype=tf.float32) + + # Scene has a single light from the viewer's eye. + light_positions = tf.expand_dims(tf.stack([eye, eye], axis=0), axis=1) + light_intensities = tf.ones([2, 1, 3], dtype=tf.float32) + + vertex_diffuse_colors = tf.ones_like(vertices_world_space, dtype=tf.float32) + + rendered = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + + with self.test_session(): + theoretical, numerical = tf.test.compute_gradient( + self.cube_vertices, (8, 3), + rendered, (2, image_height, image_width, 4), + x_init_value=self.cube_vertices.eval(), + delta=1e-3) + jacobians_match, message = ( + test_utils.check_jacobians_are_nearly_equal( + theoretical, numerical, 0.01, 0.01)) + self.assertTrue(jacobians_match, message) + + def testThatCubeRotates(self): + """Optimize a simple cube's rotation using pixel loss. + + The rotation is represented as static-basis euler angles. This test checks + that the computed gradients are useful. + """ + image_height = 480 + image_width = 640 + initial_euler_angles = [[0.0, 0.0, 0.0]] + + euler_angles = tf.Variable(initial_euler_angles) + model_rotation = camera_utils.euler_matrices(euler_angles)[0, :3, :3] + + vertices_world_space = tf.reshape( + tf.matmul(self.cube_vertices, model_rotation, transpose_b=True), + [1, 8, 3]) + + normals_world_space = tf.reshape( + tf.matmul(self.cube_normals, model_rotation, transpose_b=True), + [1, 8, 3]) + + # camera position: + eye = tf.constant([[0.0, 0.0, 6.0]], dtype=tf.float32) + center = tf.constant([[0.0, 0.0, 0.0]], dtype=tf.float32) + world_up = tf.constant([[0.0, 1.0, 0.0]], dtype=tf.float32) + + vertex_diffuse_colors = tf.ones_like(vertices_world_space, dtype=tf.float32) + light_positions = tf.reshape(eye, [1, 1, 3]) + light_intensities = tf.ones([1, 1, 3], dtype=tf.float32) + + render = mesh_renderer.mesh_renderer( + vertices_world_space, self.cube_triangles, normals_world_space, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + render = tf.reshape(render, [image_height, image_width, 4]) + + # Pick the desired cube rotation for the test: + test_model_rotation = camera_utils.euler_matrices([[-20.0, 0.0, + 60.0]])[0, :3, :3] + + desired_vertex_positions = tf.reshape( + tf.matmul(self.cube_vertices, test_model_rotation, transpose_b=True), + [1, 8, 3]) + desired_normals = tf.reshape( + tf.matmul(self.cube_normals, test_model_rotation, transpose_b=True), + [1, 8, 3]) + desired_render = mesh_renderer.mesh_renderer( + desired_vertex_positions, self.cube_triangles, desired_normals, + vertex_diffuse_colors, eye, center, world_up, light_positions, + light_intensities, image_width, image_height) + desired_render = tf.reshape(desired_render, [image_height, image_width, 4]) + + loss = tf.reduce_mean(tf.abs(render - desired_render)) + optimizer = tf.train.MomentumOptimizer(0.7, 0.1) + grad = tf.gradients(loss, [euler_angles]) + grad, _ = tf.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients([(grad[0], euler_angles)]) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(35): + sess.run([loss, opt_func]) + final_image, desired_image = sess.run([render, desired_render]) + + target_image_name = 'Gray_Cube_0.png' + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, desired_image) + test_utils.expect_image_file_and_render_are_near( + self, + sess, + baseline_image_path, + final_image, + max_outlier_fraction=0.01, + pixel_error_threshold=0.04) + + +if __name__ == '__main__': + tf.test.main() diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/rasterize_triangles.py b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/rasterize_triangles.py new file mode 100644 index 0000000..ac8d106 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/rasterize_triangles.py @@ -0,0 +1,178 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Differentiable triangle rasterizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tensorflow as tf + +from . import camera_utils + +rasterize_triangles_module = tf.load_op_library( + #os.path.join(os.environ['TEST_SRCDIR'], + os.path.join('/home4/yiran/TalkingFace/Pipeline/Deep3DFaceReconstruction', + 'tf_mesh_renderer/mesh_renderer/kernels/rasterize_triangles_kernel.so')) + + +def rasterize(world_space_vertices, attributes, triangles, camera_matrices, + image_width, image_height, background_value): + """Rasterizes a mesh and computes interpolated vertex attributes. + + Applies projection matrices and then calls rasterize_clip_space(). + + Args: + world_space_vertices: 3-D float32 tensor of xyz positions with shape + [batch_size, vertex_count, 3]. + attributes: 3-D float32 tensor with shape [batch_size, vertex_count, + attribute_count]. Each vertex attribute is interpolated across the + triangle using barycentric interpolation. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + camera_matrices: 3-D float tensor with shape [batch_size, 4, 4] containing + model-view-perspective projection matrices. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + background_value: a 1-D float32 tensor with shape [attribute_count]. Pixels + that lie outside all triangles take this value. + + Returns: + A 4-D float32 tensor with shape [batch_size, image_height, image_width, + attribute_count], containing the interpolated vertex attributes at + each pixel. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + clip_space_vertices = camera_utils.transform_homogeneous( + camera_matrices, world_space_vertices) + return rasterize_clip_space(clip_space_vertices, attributes, triangles, + image_width, image_height, background_value) + + +def rasterize_clip_space(clip_space_vertices, attributes, triangles, + image_width, image_height, background_value): + """Rasterizes the input mesh expressed in clip-space (xyzw) coordinates. + + Interpolates vertex attributes using perspective-correct interpolation and + clips triangles that lie outside the viewing frustum. + + Args: + clip_space_vertices: 3-D float32 tensor of homogenous vertices (xyzw) with + shape [batch_size, vertex_count, 4]. + attributes: 3-D float32 tensor with shape [batch_size, vertex_count, + attribute_count]. Each vertex attribute is interpolated across the + triangle using barycentric interpolation. + triangles: 2-D int32 tensor with shape [triangle_count, 3]. Each triplet + should contain vertex indices describing a triangle such that the + triangle's normal points toward the viewer if the forward order of the + triplet defines a clockwise winding of the vertices. Gradients with + respect to this tensor are not available. + image_width: int specifying desired output image width in pixels. + image_height: int specifying desired output image height in pixels. + background_value: a 1-D float32 tensor with shape [attribute_count]. Pixels + that lie outside all triangles take this value. + + Returns: + A 4-D float32 tensor with shape [batch_size, image_height, image_width, + attribute_count], containing the interpolated vertex attributes at + each pixel. + + Raises: + ValueError: An invalid argument to the method is detected. + """ + if not image_width > 0: + raise ValueError('Image width must be > 0.') + if not image_height > 0: + raise ValueError('Image height must be > 0.') + if len(clip_space_vertices.shape) != 3: + raise ValueError('The vertex buffer must be 3D.') + + vertex_count = clip_space_vertices.shape[1].value + + batch_size = tf.shape(clip_space_vertices)[0] + + per_image_barycentric_coordinates = tf.TensorArray(dtype=tf.float32, + size=batch_size) + per_image_vertex_ids = tf.TensorArray(dtype=tf.int32, size=batch_size) + + def batch_loop_condition(b, *args): + return b < batch_size + + def batch_loop_iteration(b, per_image_barycentric_coordinates, + per_image_vertex_ids): + barycentric_coords, triangle_ids, _ = ( + rasterize_triangles_module.rasterize_triangles( + clip_space_vertices[b, :, :], triangles, image_width, + image_height)) + per_image_barycentric_coordinates = \ + per_image_barycentric_coordinates.write( + b, tf.reshape(barycentric_coords, [-1, 3])) + + vertex_ids = tf.gather(triangles, tf.reshape(triangle_ids, [-1])) + reindexed_ids = tf.add(vertex_ids, b * clip_space_vertices.shape[1].value) + per_image_vertex_ids = per_image_vertex_ids.write(b, reindexed_ids) + + return b+1, per_image_barycentric_coordinates, per_image_vertex_ids + + b = tf.constant(0) + _, per_image_barycentric_coordinates, per_image_vertex_ids = tf.while_loop( + batch_loop_condition, batch_loop_iteration, + [b, per_image_barycentric_coordinates, per_image_vertex_ids]) + + barycentric_coordinates = tf.reshape( + per_image_barycentric_coordinates.stack(), [-1, 3]) + vertex_ids = tf.reshape(per_image_vertex_ids.stack(), [-1, 3]) + + # Indexes with each pixel's clip-space triangle's extrema (the pixel's + # 'corner points') ids to get the relevant properties for deferred shading. + flattened_vertex_attributes = tf.reshape(attributes, + [batch_size * vertex_count, -1]) + corner_attributes = tf.gather(flattened_vertex_attributes, vertex_ids) + + # Computes the pixel attributes by interpolating the known attributes at the + # corner points of the triangle interpolated with the barycentric coordinates. + weighted_vertex_attributes = tf.multiply( + corner_attributes, tf.expand_dims(barycentric_coordinates, axis=2)) + summed_attributes = tf.reduce_sum(weighted_vertex_attributes, axis=1) + attribute_images = tf.reshape(summed_attributes, + [batch_size, image_height, image_width, -1]) + + # Barycentric coordinates should approximately sum to one where there is + # rendered geometry, but be exactly zero where there is not. + alphas = tf.clip_by_value( + tf.reduce_sum(2.0 * barycentric_coordinates, axis=1), 0.0, 1.0) + alphas = tf.reshape(alphas, [batch_size, image_height, image_width, 1]) + + attributes_with_background = ( + alphas * attribute_images + (1.0 - alphas) * background_value) + + return attributes_with_background + + +@tf.RegisterGradient('RasterizeTriangles') +def _rasterize_triangles_grad(op, df_dbarys, df_dids, df_dz): + # Gradients are only supported for barycentric coordinates. Gradients for the + # z-buffer are not currently implemented. If you need gradients w.r.t. z, + # include z as a vertex attribute when calling rasterize_triangles. + del df_dids, df_dz + return rasterize_triangles_module.rasterize_triangles_grad( + op.inputs[0], op.inputs[1], op.outputs[0], op.outputs[1], df_dbarys, + op.get_attr('image_width'), op.get_attr('image_height')), None diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/rasterize_triangles_test.py b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/rasterize_triangles_test.py new file mode 100644 index 0000000..ccd7e7c --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/rasterize_triangles_test.py @@ -0,0 +1,196 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +import tensorflow as tf + +import test_utils +import camera_utils +import rasterize_triangles + + +class RenderTest(tf.test.TestCase): + + def setUp(self): + self.test_data_directory = 'mesh_renderer/test_data/' + + tf.reset_default_graph() + self.cube_vertex_positions = tf.constant( + [[-1, -1, 1], [-1, -1, -1], [-1, 1, -1], [-1, 1, 1], [1, -1, 1], + [1, -1, -1], [1, 1, -1], [1, 1, 1]], + dtype=tf.float32) + self.cube_triangles = tf.constant( + [[0, 1, 2], [2, 3, 0], [3, 2, 6], [6, 7, 3], [7, 6, 5], [5, 4, 7], + [4, 5, 1], [1, 0, 4], [5, 6, 2], [2, 1, 5], [7, 4, 0], [0, 3, 7]], + dtype=tf.int32) + + self.tf_float = lambda x: tf.constant(x, dtype=tf.float32) + + self.image_width = 640 + self.image_height = 480 + + self.perspective = camera_utils.perspective( + self.image_width / self.image_height, + self.tf_float([40.0]), self.tf_float([0.01]), + self.tf_float([10.0])) + + def runTriangleTest(self, w_vector, target_image_name): + """Directly renders a rasterized triangle's barycentric coordinates. + + Tests only the kernel (rasterize_triangles_module). + + Args: + w_vector: 3 element vector of w components to scale triangle vertices. + target_image_name: image file name to compare result against. + """ + clip_init = np.array( + [[-0.5, -0.5, 0.8, 1.0], [0.0, 0.5, 0.3, 1.0], [0.5, -0.5, 0.3, 1.0]], + dtype=np.float32) + clip_init = clip_init * np.reshape( + np.array(w_vector, dtype=np.float32), [3, 1]) + + clip_coordinates = tf.constant(clip_init) + triangles = tf.constant([[0, 1, 2]], dtype=tf.int32) + + rendered_coordinates, _, _ = ( + rasterize_triangles.rasterize_triangles_module.rasterize_triangles( + clip_coordinates, triangles, self.image_width, self.image_height)) + rendered_coordinates = tf.concat( + [rendered_coordinates, + tf.ones([self.image_height, self.image_width, 1])], axis=2) + with self.test_session() as sess: + image = rendered_coordinates.eval() + baseline_image_path = os.path.join(self.test_data_directory, + target_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, image) + + def testRendersSimpleTriangle(self): + self.runTriangleTest((1.0, 1.0, 1.0), 'Simple_Triangle.png') + + def testRendersPerspectiveCorrectTriangle(self): + self.runTriangleTest((0.2, 0.5, 2.0), 'Perspective_Corrected_Triangle.png') + + def testRendersTwoCubesInBatch(self): + """Renders a simple cube in two viewpoints to test the python wrapper.""" + + vertex_rgb = (self.cube_vertex_positions * 0.5 + 0.5) + vertex_rgba = tf.concat([vertex_rgb, tf.ones([8, 1])], axis=1) + + center = self.tf_float([[0.0, 0.0, 0.0]]) + world_up = self.tf_float([[0.0, 1.0, 0.0]]) + look_at_1 = camera_utils.look_at(self.tf_float([[2.0, 3.0, 6.0]]), + center, world_up) + look_at_2 = camera_utils.look_at(self.tf_float([[-3.0, 1.0, 6.0]]), + center, world_up) + projection_1 = tf.matmul(self.perspective, look_at_1) + projection_2 = tf.matmul(self.perspective, look_at_2) + projection = tf.concat([projection_1, projection_2], axis=0) + background_value = [0.0, 0.0, 0.0, 0.0] + + rendered = rasterize_triangles.rasterize( + tf.stack([self.cube_vertex_positions, self.cube_vertex_positions]), + tf.stack([vertex_rgba, vertex_rgba]), self.cube_triangles, projection, + self.image_width, self.image_height, background_value) + + with self.test_session() as sess: + images = sess.run(rendered, feed_dict={}) + for i in (0, 1): + image = images[i, :, :, :] + baseline_image_name = 'Unlit_Cube_{}.png'.format(i) + baseline_image_path = os.path.join(self.test_data_directory, + baseline_image_name) + test_utils.expect_image_file_and_render_are_near( + self, sess, baseline_image_path, image) + + def testSimpleTriangleGradientComputation(self): + """Verifies the Jacobian matrix for a single pixel. + + The pixel is in the center of a triangle facing the camera. This makes it + easy to check which entries of the Jacobian might not make sense without + worrying about corner cases. + """ + test_pixel_x = 325 + test_pixel_y = 245 + + clip_coordinates = tf.placeholder(tf.float32, shape=[3, 4]) + + triangles = tf.constant([[0, 1, 2]], dtype=tf.int32) + + barycentric_coordinates, _, _ = ( + rasterize_triangles.rasterize_triangles_module.rasterize_triangles( + clip_coordinates, triangles, self.image_width, self.image_height)) + + pixels_to_compare = barycentric_coordinates[ + test_pixel_y:test_pixel_y + 1, test_pixel_x:test_pixel_x + 1, :] + + with self.test_session(): + ndc_init = np.array( + [[-0.5, -0.5, 0.8, 1.0], [0.0, 0.5, 0.3, 1.0], [0.5, -0.5, 0.3, 1.0]], + dtype=np.float32) + theoretical, numerical = tf.test.compute_gradient( + clip_coordinates, (3, 4), + pixels_to_compare, (1, 1, 3), + x_init_value=ndc_init, + delta=4e-2) + jacobians_match, message = ( + test_utils.check_jacobians_are_nearly_equal( + theoretical, numerical, 0.01, 0.0, True)) + self.assertTrue(jacobians_match, message) + + def testInternalRenderGradientComputation(self): + """Isolates and verifies the Jacobian matrix for the custom kernel.""" + image_height = 21 + image_width = 28 + + clip_coordinates = tf.placeholder(tf.float32, shape=[8, 4]) + + barycentric_coordinates, _, _ = ( + rasterize_triangles.rasterize_triangles_module.rasterize_triangles( + clip_coordinates, self.cube_triangles, image_width, image_height)) + + with self.test_session(): + # Precomputed transformation of the simple cube to normalized device + # coordinates, in order to isolate the rasterization gradient. + # pyformat: disable + ndc_init = np.array( + [[-0.43889722, -0.53184521, 0.85293502, 1.0], + [-0.37635487, 0.22206162, 0.90555805, 1.0], + [-0.22849123, 0.76811147, 0.80993629, 1.0], + [-0.2805393, -0.14092168, 0.71602166, 1.0], + [0.18631913, -0.62634289, 0.88603103, 1.0], + [0.16183566, 0.08129397, 0.93020856, 1.0], + [0.44147962, 0.53497446, 0.85076219, 1.0], + [0.53008741, -0.31276882, 0.77620775, 1.0]], + dtype=np.float32) + # pyformat: enable + theoretical, numerical = tf.test.compute_gradient( + clip_coordinates, (8, 4), + barycentric_coordinates, (image_height, image_width, 3), + x_init_value=ndc_init, + delta=4e-2) + jacobians_match, message = ( + test_utils.check_jacobians_are_nearly_equal( + theoretical, numerical, 0.01, 0.01)) + self.assertTrue(jacobians_match, message) + + +if __name__ == '__main__': + tf.test.main() diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/BUILD b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/BUILD new file mode 100644 index 0000000..6bf68b3 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/BUILD @@ -0,0 +1,6 @@ +package(default_visibility = ["//visibility:public"]) + +filegroup( + name = "images", + srcs = glob(["*.png"]), +) diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Barycentrics_Cube.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Barycentrics_Cube.png new file mode 100644 index 0000000..172142f Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Barycentrics_Cube.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Colored_Cube_0.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Colored_Cube_0.png new file mode 100644 index 0000000..e7682f5 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Colored_Cube_0.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Colored_Cube_1.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Colored_Cube_1.png new file mode 100644 index 0000000..ac455f9 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Colored_Cube_1.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/External_Triangle.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/External_Triangle.png new file mode 100644 index 0000000..44208e6 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/External_Triangle.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Gray_Cube_0.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Gray_Cube_0.png new file mode 100644 index 0000000..4b0c82f Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Gray_Cube_0.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Gray_Cube_1.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Gray_Cube_1.png new file mode 100644 index 0000000..e84d75a Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Gray_Cube_1.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Inside_Box.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Inside_Box.png new file mode 100644 index 0000000..e9dbf71 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Inside_Box.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Perspective_Corrected_Triangle.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Perspective_Corrected_Triangle.png new file mode 100644 index 0000000..49187e1 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Perspective_Corrected_Triangle.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Simple_Tetrahedron.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Simple_Tetrahedron.png new file mode 100644 index 0000000..abdddb3 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Simple_Tetrahedron.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Simple_Triangle.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Simple_Triangle.png new file mode 100644 index 0000000..25715b4 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Simple_Triangle.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Unlit_Cube_0.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Unlit_Cube_0.png new file mode 100644 index 0000000..8f58fe9 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Unlit_Cube_0.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Unlit_Cube_1.png b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Unlit_Cube_1.png new file mode 100644 index 0000000..7ff5af6 Binary files /dev/null and b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_data/Unlit_Cube_1.png differ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_utils.py b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_utils.py new file mode 100644 index 0000000..6c0b46e --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/test_utils.py @@ -0,0 +1,124 @@ +# Copyright 2017 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common functions for the rasterizer and mesh renderer tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import tensorflow as tf + + +def check_jacobians_are_nearly_equal(theoretical, + numerical, + outlier_relative_error_threshold, + max_outlier_fraction, + include_jacobians_in_error_message=False): + """Compares two Jacobian matrices, allowing for some fraction of outliers. + + Args: + theoretical: 2D numpy array containing a Jacobian matrix with entries + computed via gradient functions. The layout should be as in the output + of gradient_checker. + numerical: 2D numpy array of the same shape as theoretical containing a + Jacobian matrix with entries computed via finite difference + approximations. The layout should be as in the output + of gradient_checker. + outlier_relative_error_threshold: float prescribing the maximum relative + error (from the finite difference approximation) is tolerated before + and entry is considered an outlier. + max_outlier_fraction: float defining the maximum fraction of entries in + theoretical that may be outliers before the check returns False. + include_jacobians_in_error_message: bool defining whether the jacobian + matrices should be included in the return message should the test fail. + + Returns: + A tuple where the first entry is a boolean describing whether + max_outlier_fraction was exceeded, and where the second entry is a string + containing an error message if one is relevant. + """ + outlier_gradients = np.abs( + numerical - theoretical) / numerical > outlier_relative_error_threshold + outlier_fraction = np.count_nonzero(outlier_gradients) / np.prod( + numerical.shape[:2]) + jacobians_match = outlier_fraction <= max_outlier_fraction + + message = ( + ' %f of theoretical gradients are relative outliers, but the maximum' + ' allowable fraction is %f ' % (outlier_fraction, max_outlier_fraction)) + if include_jacobians_in_error_message: + # the gradient_checker convention is the typical Jacobian transposed: + message += ('\nNumerical Jacobian:\n%s\nTheoretical Jacobian:\n%s' % + (repr(numerical.T), repr(theoretical.T))) + return jacobians_match, message + + +def expect_image_file_and_render_are_near(test_instance, + sess, + baseline_path, + result_image, + max_outlier_fraction=0.001, + pixel_error_threshold=0.01): + """Compares the output of mesh_renderer with an image on disk. + + The comparison is soft: the images are considered identical if at most + max_outlier_fraction of the pixels differ by more than a relative error of + pixel_error_threshold of the full color value. Note that before comparison, + mesh renderer values are clipped to the range [0,1]. + + Uses _images_are_near for the actual comparison. + + Args: + test_instance: a python unit test instance. + sess: a TensorFlow session for decoding the png. + baseline_path: path to the reference image on disk. + result_image: the result image, as a numpy array. + max_outlier_fraction: the maximum fraction of outlier pixels allowed. + pixel_error_threshold: pixel values are considered to differ if their + difference exceeds this amount. Range is 0.0 - 1.0. + """ + baseline_bytes = open(baseline_path, 'rb').read() + baseline_image = sess.run(tf.image.decode_png(baseline_bytes)) + + test_instance.assertEqual(baseline_image.shape, result_image.shape, + 'Image shapes %s and %s do not match.' % + (baseline_image.shape, result_image.shape)) + + result_image = np.clip(result_image, 0., 1.).copy(order='C') + baseline_image = baseline_image.astype(float) / 255.0 + + outlier_channels = (np.abs(baseline_image - result_image) > + pixel_error_threshold) + outlier_pixels = np.any(outlier_channels, axis=2) + outlier_count = np.count_nonzero(outlier_pixels) + outlier_fraction = outlier_count / np.prod(baseline_image.shape[:2]) + images_match = outlier_fraction <= max_outlier_fraction + + outputs_dir = "/tmp" #os.environ["TEST_TMPDIR"] + base_prefix = os.path.splitext(os.path.basename(baseline_path))[0] + result_output_path = os.path.join(outputs_dir, base_prefix + "_result.png") + + message = ('%s does not match. (%f of pixels are outliers, %f is allowed.). ' + 'Result image written to %s' % + (baseline_path, outlier_fraction, max_outlier_fraction, result_output_path)) + + if not images_match: + result_bytes = sess.run(tf.image.encode_png(result_image*255.0)) + with open(result_output_path, 'wb') as output_file: + output_file.write(result_bytes) + + test_instance.assertTrue(images_match, msg=message) diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/runtests.sh b/Deep3DFaceReconstruction/tf_mesh_renderer/runtests.sh new file mode 100755 index 0000000..82b2b42 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/runtests.sh @@ -0,0 +1,2 @@ +#!/bin/bash +bazel test --python_path=$VIRTUAL_ENV/bin/python --define=PYTHON=$VIRTUAL_ENV/bin/python ... diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/BUILD b/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/BUILD new file mode 100644 index 0000000..2e302bc --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/BUILD @@ -0,0 +1,7 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "lodepng", + srcs = ["lodepng.cpp"], + hdrs = ["lodepng.h"], +) diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/lodepng.cpp b/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/lodepng.cpp new file mode 100644 index 0000000..37b0562 --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/lodepng.cpp @@ -0,0 +1,6232 @@ +/* +LodePNG version 20170917 + +Copyright (c) 2005-2017 Lode Vandevenne + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + + 3. This notice may not be removed or altered from any source + distribution. +*/ + +/* +The manual and changelog are in the header file "lodepng.h" +Rename this file to lodepng.cpp to use it for C++, or to lodepng.c to use it for C. +*/ + +#include "lodepng.h" + +#include +#include +#include + +#if defined(_MSC_VER) && (_MSC_VER >= 1310) /*Visual Studio: A few warning types are not desired here.*/ +#pragma warning( disable : 4244 ) /*implicit conversions: not warned by gcc -Wall -Wextra and requires too much casts*/ +#pragma warning( disable : 4996 ) /*VS does not like fopen, but fopen_s is not standard C so unusable here*/ +#endif /*_MSC_VER */ + +const char* LODEPNG_VERSION_STRING = "20170917"; + +/* +This source file is built up in the following large parts. The code sections +with the "LODEPNG_COMPILE_" #defines divide this up further in an intermixed way. +-Tools for C and common code for PNG and Zlib +-C Code for Zlib (huffman, deflate, ...) +-C Code for PNG (file format chunks, adam7, PNG filters, color conversions, ...) +-The C++ wrapper around all of the above +*/ + +/*The malloc, realloc and free functions defined here with "lodepng_" in front +of the name, so that you can easily change them to others related to your +platform if needed. Everything else in the code calls these. Pass +-DLODEPNG_NO_COMPILE_ALLOCATORS to the compiler, or comment out +#define LODEPNG_COMPILE_ALLOCATORS in the header, to disable the ones here and +define them in your own project's source files without needing to change +lodepng source code. Don't forget to remove "static" if you copypaste them +from here.*/ + +#ifdef LODEPNG_COMPILE_ALLOCATORS +static void* lodepng_malloc(size_t size) +{ + return malloc(size); +} + +static void* lodepng_realloc(void* ptr, size_t new_size) +{ + return realloc(ptr, new_size); +} + +static void lodepng_free(void* ptr) +{ + free(ptr); +} +#else /*LODEPNG_COMPILE_ALLOCATORS*/ +void* lodepng_malloc(size_t size); +void* lodepng_realloc(void* ptr, size_t new_size); +void lodepng_free(void* ptr); +#endif /*LODEPNG_COMPILE_ALLOCATORS*/ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* // Tools for C, and common code for PNG and Zlib. // */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ + +/* +Often in case of an error a value is assigned to a variable and then it breaks +out of a loop (to go to the cleanup phase of a function). This macro does that. +It makes the error handling code shorter and more readable. + +Example: if(!uivector_resizev(&frequencies_ll, 286, 0)) ERROR_BREAK(83); +*/ +#define CERROR_BREAK(errorvar, code)\ +{\ + errorvar = code;\ + break;\ +} + +/*version of CERROR_BREAK that assumes the common case where the error variable is named "error"*/ +#define ERROR_BREAK(code) CERROR_BREAK(error, code) + +/*Set error var to the error code, and return it.*/ +#define CERROR_RETURN_ERROR(errorvar, code)\ +{\ + errorvar = code;\ + return code;\ +} + +/*Try the code, if it returns error, also return the error.*/ +#define CERROR_TRY_RETURN(call)\ +{\ + unsigned error = call;\ + if(error) return error;\ +} + +/*Set error var to the error code, and return from the void function.*/ +#define CERROR_RETURN(errorvar, code)\ +{\ + errorvar = code;\ + return;\ +} + +/* +About uivector, ucvector and string: +-All of them wrap dynamic arrays or text strings in a similar way. +-LodePNG was originally written in C++. The vectors replace the std::vectors that were used in the C++ version. +-The string tools are made to avoid problems with compilers that declare things like strncat as deprecated. +-They're not used in the interface, only internally in this file as static functions. +-As with many other structs in this file, the init and cleanup functions serve as ctor and dtor. +*/ + +#ifdef LODEPNG_COMPILE_ZLIB +/*dynamic vector of unsigned ints*/ +typedef struct uivector +{ + unsigned* data; + size_t size; /*size in number of unsigned longs*/ + size_t allocsize; /*allocated size in bytes*/ +} uivector; + +static void uivector_cleanup(void* p) +{ + ((uivector*)p)->size = ((uivector*)p)->allocsize = 0; + lodepng_free(((uivector*)p)->data); + ((uivector*)p)->data = NULL; +} + +/*returns 1 if success, 0 if failure ==> nothing done*/ +static unsigned uivector_reserve(uivector* p, size_t allocsize) +{ + if(allocsize > p->allocsize) + { + size_t newsize = (allocsize > p->allocsize * 2) ? allocsize : (allocsize * 3 / 2); + void* data = lodepng_realloc(p->data, newsize); + if(data) + { + p->allocsize = newsize; + p->data = (unsigned*)data; + } + else return 0; /*error: not enough memory*/ + } + return 1; +} + +/*returns 1 if success, 0 if failure ==> nothing done*/ +static unsigned uivector_resize(uivector* p, size_t size) +{ + if(!uivector_reserve(p, size * sizeof(unsigned))) return 0; + p->size = size; + return 1; /*success*/ +} + +/*resize and give all new elements the value*/ +static unsigned uivector_resizev(uivector* p, size_t size, unsigned value) +{ + size_t oldsize = p->size, i; + if(!uivector_resize(p, size)) return 0; + for(i = oldsize; i < size; ++i) p->data[i] = value; + return 1; +} + +static void uivector_init(uivector* p) +{ + p->data = NULL; + p->size = p->allocsize = 0; +} + +#ifdef LODEPNG_COMPILE_ENCODER +/*returns 1 if success, 0 if failure ==> nothing done*/ +static unsigned uivector_push_back(uivector* p, unsigned c) +{ + if(!uivector_resize(p, p->size + 1)) return 0; + p->data[p->size - 1] = c; + return 1; +} +#endif /*LODEPNG_COMPILE_ENCODER*/ +#endif /*LODEPNG_COMPILE_ZLIB*/ + +/* /////////////////////////////////////////////////////////////////////////// */ + +/*dynamic vector of unsigned chars*/ +typedef struct ucvector +{ + unsigned char* data; + size_t size; /*used size*/ + size_t allocsize; /*allocated size*/ +} ucvector; + +/*returns 1 if success, 0 if failure ==> nothing done*/ +static unsigned ucvector_reserve(ucvector* p, size_t allocsize) +{ + if(allocsize > p->allocsize) + { + size_t newsize = (allocsize > p->allocsize * 2) ? allocsize : (allocsize * 3 / 2); + void* data = lodepng_realloc(p->data, newsize); + if(data) + { + p->allocsize = newsize; + p->data = (unsigned char*)data; + } + else return 0; /*error: not enough memory*/ + } + return 1; +} + +/*returns 1 if success, 0 if failure ==> nothing done*/ +static unsigned ucvector_resize(ucvector* p, size_t size) +{ + if(!ucvector_reserve(p, size * sizeof(unsigned char))) return 0; + p->size = size; + return 1; /*success*/ +} + +#ifdef LODEPNG_COMPILE_PNG + +static void ucvector_cleanup(void* p) +{ + ((ucvector*)p)->size = ((ucvector*)p)->allocsize = 0; + lodepng_free(((ucvector*)p)->data); + ((ucvector*)p)->data = NULL; +} + +static void ucvector_init(ucvector* p) +{ + p->data = NULL; + p->size = p->allocsize = 0; +} +#endif /*LODEPNG_COMPILE_PNG*/ + +#ifdef LODEPNG_COMPILE_ZLIB +/*you can both convert from vector to buffer&size and vica versa. If you use +init_buffer to take over a buffer and size, it is not needed to use cleanup*/ +static void ucvector_init_buffer(ucvector* p, unsigned char* buffer, size_t size) +{ + p->data = buffer; + p->allocsize = p->size = size; +} +#endif /*LODEPNG_COMPILE_ZLIB*/ + +#if (defined(LODEPNG_COMPILE_PNG) && defined(LODEPNG_COMPILE_ANCILLARY_CHUNKS)) || defined(LODEPNG_COMPILE_ENCODER) +/*returns 1 if success, 0 if failure ==> nothing done*/ +static unsigned ucvector_push_back(ucvector* p, unsigned char c) +{ + if(!ucvector_resize(p, p->size + 1)) return 0; + p->data[p->size - 1] = c; + return 1; +} +#endif /*defined(LODEPNG_COMPILE_PNG) || defined(LODEPNG_COMPILE_ENCODER)*/ + + +/* ////////////////////////////////////////////////////////////////////////// */ + +#ifdef LODEPNG_COMPILE_PNG +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS +/*returns 1 if success, 0 if failure ==> nothing done*/ +static unsigned string_resize(char** out, size_t size) +{ + char* data = (char*)lodepng_realloc(*out, size + 1); + if(data) + { + data[size] = 0; /*null termination char*/ + *out = data; + } + return data != 0; +} + +/*init a {char*, size_t} pair for use as string*/ +static void string_init(char** out) +{ + *out = NULL; + string_resize(out, 0); +} + +/*free the above pair again*/ +static void string_cleanup(char** out) +{ + lodepng_free(*out); + *out = NULL; +} + +static void string_set(char** out, const char* in) +{ + size_t insize = strlen(in), i; + if(string_resize(out, insize)) + { + for(i = 0; i != insize; ++i) + { + (*out)[i] = in[i]; + } + } +} +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ +#endif /*LODEPNG_COMPILE_PNG*/ + +/* ////////////////////////////////////////////////////////////////////////// */ + +unsigned lodepng_read32bitInt(const unsigned char* buffer) +{ + return (unsigned)((buffer[0] << 24) | (buffer[1] << 16) | (buffer[2] << 8) | buffer[3]); +} + +#if defined(LODEPNG_COMPILE_PNG) || defined(LODEPNG_COMPILE_ENCODER) +/*buffer must have at least 4 allocated bytes available*/ +static void lodepng_set32bitInt(unsigned char* buffer, unsigned value) +{ + buffer[0] = (unsigned char)((value >> 24) & 0xff); + buffer[1] = (unsigned char)((value >> 16) & 0xff); + buffer[2] = (unsigned char)((value >> 8) & 0xff); + buffer[3] = (unsigned char)((value ) & 0xff); +} +#endif /*defined(LODEPNG_COMPILE_PNG) || defined(LODEPNG_COMPILE_ENCODER)*/ + +#ifdef LODEPNG_COMPILE_ENCODER +static void lodepng_add32bitInt(ucvector* buffer, unsigned value) +{ + ucvector_resize(buffer, buffer->size + 4); /*todo: give error if resize failed*/ + lodepng_set32bitInt(&buffer->data[buffer->size - 4], value); +} +#endif /*LODEPNG_COMPILE_ENCODER*/ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / File IO / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +#ifdef LODEPNG_COMPILE_DISK + +/* returns negative value on error. This should be pure C compatible, so no fstat. */ +static long lodepng_filesize(const char* filename) +{ + FILE* file; + long size; + file = fopen(filename, "rb"); + if(!file) return -1; + + if(fseek(file, 0, SEEK_END) != 0) + { + fclose(file); + return -1; + } + + size = ftell(file); + /* It may give LONG_MAX as directory size, this is invalid for us. */ + if(size == LONG_MAX) size = -1; + + fclose(file); + return size; +} + +/* load file into buffer that already has the correct allocated size. Returns error code.*/ +static unsigned lodepng_buffer_file(unsigned char* out, size_t size, const char* filename) +{ + FILE* file; + size_t readsize; + file = fopen(filename, "rb"); + if(!file) return 78; + + readsize = fread(out, 1, size, file); + fclose(file); + + if (readsize != size) return 78; + return 0; +} + +unsigned lodepng_load_file(unsigned char** out, size_t* outsize, const char* filename) +{ + long size = lodepng_filesize(filename); + if (size < 0) return 78; + *outsize = (size_t)size; + + *out = (unsigned char*)lodepng_malloc((size_t)size); + if(!(*out) && size > 0) return 83; /*the above malloc failed*/ + + return lodepng_buffer_file(*out, (size_t)size, filename); +} + +/*write given buffer to the file, overwriting the file, it doesn't append to it.*/ +unsigned lodepng_save_file(const unsigned char* buffer, size_t buffersize, const char* filename) +{ + FILE* file; + file = fopen(filename, "wb" ); + if(!file) return 79; + fwrite((char*)buffer , 1 , buffersize, file); + fclose(file); + return 0; +} + +#endif /*LODEPNG_COMPILE_DISK*/ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* // End of common code and tools. Begin of Zlib related code. // */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ + +#ifdef LODEPNG_COMPILE_ZLIB +#ifdef LODEPNG_COMPILE_ENCODER +/*TODO: this ignores potential out of memory errors*/ +#define addBitToStream(/*size_t**/ bitpointer, /*ucvector**/ bitstream, /*unsigned char*/ bit)\ +{\ + /*add a new byte at the end*/\ + if(((*bitpointer) & 7) == 0) ucvector_push_back(bitstream, (unsigned char)0);\ + /*earlier bit of huffman code is in a lesser significant bit of an earlier byte*/\ + (bitstream->data[bitstream->size - 1]) |= (bit << ((*bitpointer) & 0x7));\ + ++(*bitpointer);\ +} + +static void addBitsToStream(size_t* bitpointer, ucvector* bitstream, unsigned value, size_t nbits) +{ + size_t i; + for(i = 0; i != nbits; ++i) addBitToStream(bitpointer, bitstream, (unsigned char)((value >> i) & 1)); +} + +static void addBitsToStreamReversed(size_t* bitpointer, ucvector* bitstream, unsigned value, size_t nbits) +{ + size_t i; + for(i = 0; i != nbits; ++i) addBitToStream(bitpointer, bitstream, (unsigned char)((value >> (nbits - 1 - i)) & 1)); +} +#endif /*LODEPNG_COMPILE_ENCODER*/ + +#ifdef LODEPNG_COMPILE_DECODER + +#define READBIT(bitpointer, bitstream) ((bitstream[bitpointer >> 3] >> (bitpointer & 0x7)) & (unsigned char)1) + +static unsigned char readBitFromStream(size_t* bitpointer, const unsigned char* bitstream) +{ + unsigned char result = (unsigned char)(READBIT(*bitpointer, bitstream)); + ++(*bitpointer); + return result; +} + +static unsigned readBitsFromStream(size_t* bitpointer, const unsigned char* bitstream, size_t nbits) +{ + unsigned result = 0, i; + for(i = 0; i != nbits; ++i) + { + result += ((unsigned)READBIT(*bitpointer, bitstream)) << i; + ++(*bitpointer); + } + return result; +} +#endif /*LODEPNG_COMPILE_DECODER*/ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / Deflate - Huffman / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +#define FIRST_LENGTH_CODE_INDEX 257 +#define LAST_LENGTH_CODE_INDEX 285 +/*256 literals, the end code, some length codes, and 2 unused codes*/ +#define NUM_DEFLATE_CODE_SYMBOLS 288 +/*the distance codes have their own symbols, 30 used, 2 unused*/ +#define NUM_DISTANCE_SYMBOLS 32 +/*the code length codes. 0-15: code lengths, 16: copy previous 3-6 times, 17: 3-10 zeros, 18: 11-138 zeros*/ +#define NUM_CODE_LENGTH_CODES 19 + +/*the base lengths represented by codes 257-285*/ +static const unsigned LENGTHBASE[29] + = {3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, + 67, 83, 99, 115, 131, 163, 195, 227, 258}; + +/*the extra bits used by codes 257-285 (added to base length)*/ +static const unsigned LENGTHEXTRA[29] + = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 5, 5, 5, 5, 0}; + +/*the base backwards distances (the bits of distance codes appear after length codes and use their own huffman tree)*/ +static const unsigned DISTANCEBASE[30] + = {1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, + 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577}; + +/*the extra bits of backwards distances (added to base)*/ +static const unsigned DISTANCEEXTRA[30] + = {0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, + 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13}; + +/*the order in which "code length alphabet code lengths" are stored, out of this +the huffman tree of the dynamic huffman tree lengths is generated*/ +static const unsigned CLCL_ORDER[NUM_CODE_LENGTH_CODES] + = {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + +/* ////////////////////////////////////////////////////////////////////////// */ + +/* +Huffman tree struct, containing multiple representations of the tree +*/ +typedef struct HuffmanTree +{ + unsigned* tree2d; + unsigned* tree1d; + unsigned* lengths; /*the lengths of the codes of the 1d-tree*/ + unsigned maxbitlen; /*maximum number of bits a single code can get*/ + unsigned numcodes; /*number of symbols in the alphabet = number of codes*/ +} HuffmanTree; + +/*function used for debug purposes to draw the tree in ascii art with C++*/ +/* +static void HuffmanTree_draw(HuffmanTree* tree) +{ + std::cout << "tree. length: " << tree->numcodes << " maxbitlen: " << tree->maxbitlen << std::endl; + for(size_t i = 0; i != tree->tree1d.size; ++i) + { + if(tree->lengths.data[i]) + std::cout << i << " " << tree->tree1d.data[i] << " " << tree->lengths.data[i] << std::endl; + } + std::cout << std::endl; +}*/ + +static void HuffmanTree_init(HuffmanTree* tree) +{ + tree->tree2d = 0; + tree->tree1d = 0; + tree->lengths = 0; +} + +static void HuffmanTree_cleanup(HuffmanTree* tree) +{ + lodepng_free(tree->tree2d); + lodepng_free(tree->tree1d); + lodepng_free(tree->lengths); +} + +/*the tree representation used by the decoder. return value is error*/ +static unsigned HuffmanTree_make2DTree(HuffmanTree* tree) +{ + unsigned nodefilled = 0; /*up to which node it is filled*/ + unsigned treepos = 0; /*position in the tree (1 of the numcodes columns)*/ + unsigned n, i; + + tree->tree2d = (unsigned*)lodepng_malloc(tree->numcodes * 2 * sizeof(unsigned)); + if(!tree->tree2d) return 83; /*alloc fail*/ + + /* + convert tree1d[] to tree2d[][]. In the 2D array, a value of 32767 means + uninited, a value >= numcodes is an address to another bit, a value < numcodes + is a code. The 2 rows are the 2 possible bit values (0 or 1), there are as + many columns as codes - 1. + A good huffman tree has N * 2 - 1 nodes, of which N - 1 are internal nodes. + Here, the internal nodes are stored (what their 0 and 1 option point to). + There is only memory for such good tree currently, if there are more nodes + (due to too long length codes), error 55 will happen + */ + for(n = 0; n < tree->numcodes * 2; ++n) + { + tree->tree2d[n] = 32767; /*32767 here means the tree2d isn't filled there yet*/ + } + + for(n = 0; n < tree->numcodes; ++n) /*the codes*/ + { + for(i = 0; i != tree->lengths[n]; ++i) /*the bits for this code*/ + { + unsigned char bit = (unsigned char)((tree->tree1d[n] >> (tree->lengths[n] - i - 1)) & 1); + /*oversubscribed, see comment in lodepng_error_text*/ + if(treepos > 2147483647 || treepos + 2 > tree->numcodes) return 55; + if(tree->tree2d[2 * treepos + bit] == 32767) /*not yet filled in*/ + { + if(i + 1 == tree->lengths[n]) /*last bit*/ + { + tree->tree2d[2 * treepos + bit] = n; /*put the current code in it*/ + treepos = 0; + } + else + { + /*put address of the next step in here, first that address has to be found of course + (it's just nodefilled + 1)...*/ + ++nodefilled; + /*addresses encoded with numcodes added to it*/ + tree->tree2d[2 * treepos + bit] = nodefilled + tree->numcodes; + treepos = nodefilled; + } + } + else treepos = tree->tree2d[2 * treepos + bit] - tree->numcodes; + } + } + + for(n = 0; n < tree->numcodes * 2; ++n) + { + if(tree->tree2d[n] == 32767) tree->tree2d[n] = 0; /*remove possible remaining 32767's*/ + } + + return 0; +} + +/* +Second step for the ...makeFromLengths and ...makeFromFrequencies functions. +numcodes, lengths and maxbitlen must already be filled in correctly. return +value is error. +*/ +static unsigned HuffmanTree_makeFromLengths2(HuffmanTree* tree) +{ + uivector blcount; + uivector nextcode; + unsigned error = 0; + unsigned bits, n; + + uivector_init(&blcount); + uivector_init(&nextcode); + + tree->tree1d = (unsigned*)lodepng_malloc(tree->numcodes * sizeof(unsigned)); + if(!tree->tree1d) error = 83; /*alloc fail*/ + + if(!uivector_resizev(&blcount, tree->maxbitlen + 1, 0) + || !uivector_resizev(&nextcode, tree->maxbitlen + 1, 0)) + error = 83; /*alloc fail*/ + + if(!error) + { + /*step 1: count number of instances of each code length*/ + for(bits = 0; bits != tree->numcodes; ++bits) ++blcount.data[tree->lengths[bits]]; + /*step 2: generate the nextcode values*/ + for(bits = 1; bits <= tree->maxbitlen; ++bits) + { + nextcode.data[bits] = (nextcode.data[bits - 1] + blcount.data[bits - 1]) << 1; + } + /*step 3: generate all the codes*/ + for(n = 0; n != tree->numcodes; ++n) + { + if(tree->lengths[n] != 0) tree->tree1d[n] = nextcode.data[tree->lengths[n]]++; + } + } + + uivector_cleanup(&blcount); + uivector_cleanup(&nextcode); + + if(!error) return HuffmanTree_make2DTree(tree); + else return error; +} + +/* +given the code lengths (as stored in the PNG file), generate the tree as defined +by Deflate. maxbitlen is the maximum bits that a code in the tree can have. +return value is error. +*/ +static unsigned HuffmanTree_makeFromLengths(HuffmanTree* tree, const unsigned* bitlen, + size_t numcodes, unsigned maxbitlen) +{ + unsigned i; + tree->lengths = (unsigned*)lodepng_malloc(numcodes * sizeof(unsigned)); + if(!tree->lengths) return 83; /*alloc fail*/ + for(i = 0; i != numcodes; ++i) tree->lengths[i] = bitlen[i]; + tree->numcodes = (unsigned)numcodes; /*number of symbols*/ + tree->maxbitlen = maxbitlen; + return HuffmanTree_makeFromLengths2(tree); +} + +#ifdef LODEPNG_COMPILE_ENCODER + +/*BPM: Boundary Package Merge, see "A Fast and Space-Economical Algorithm for Length-Limited Coding", +Jyrki Katajainen, Alistair Moffat, Andrew Turpin, 1995.*/ + +/*chain node for boundary package merge*/ +typedef struct BPMNode +{ + int weight; /*the sum of all weights in this chain*/ + unsigned index; /*index of this leaf node (called "count" in the paper)*/ + struct BPMNode* tail; /*the next nodes in this chain (null if last)*/ + int in_use; +} BPMNode; + +/*lists of chains*/ +typedef struct BPMLists +{ + /*memory pool*/ + unsigned memsize; + BPMNode* memory; + unsigned numfree; + unsigned nextfree; + BPMNode** freelist; + /*two heads of lookahead chains per list*/ + unsigned listsize; + BPMNode** chains0; + BPMNode** chains1; +} BPMLists; + +/*creates a new chain node with the given parameters, from the memory in the lists */ +static BPMNode* bpmnode_create(BPMLists* lists, int weight, unsigned index, BPMNode* tail) +{ + unsigned i; + BPMNode* result; + + /*memory full, so garbage collect*/ + if(lists->nextfree >= lists->numfree) + { + /*mark only those that are in use*/ + for(i = 0; i != lists->memsize; ++i) lists->memory[i].in_use = 0; + for(i = 0; i != lists->listsize; ++i) + { + BPMNode* node; + for(node = lists->chains0[i]; node != 0; node = node->tail) node->in_use = 1; + for(node = lists->chains1[i]; node != 0; node = node->tail) node->in_use = 1; + } + /*collect those that are free*/ + lists->numfree = 0; + for(i = 0; i != lists->memsize; ++i) + { + if(!lists->memory[i].in_use) lists->freelist[lists->numfree++] = &lists->memory[i]; + } + lists->nextfree = 0; + } + + result = lists->freelist[lists->nextfree++]; + result->weight = weight; + result->index = index; + result->tail = tail; + return result; +} + +/*sort the leaves with stable mergesort*/ +static void bpmnode_sort(BPMNode* leaves, size_t num) +{ + BPMNode* mem = (BPMNode*)lodepng_malloc(sizeof(*leaves) * num); + size_t width, counter = 0; + for(width = 1; width < num; width *= 2) + { + BPMNode* a = (counter & 1) ? mem : leaves; + BPMNode* b = (counter & 1) ? leaves : mem; + size_t p; + for(p = 0; p < num; p += 2 * width) + { + size_t q = (p + width > num) ? num : (p + width); + size_t r = (p + 2 * width > num) ? num : (p + 2 * width); + size_t i = p, j = q, k; + for(k = p; k < r; k++) + { + if(i < q && (j >= r || a[i].weight <= a[j].weight)) b[k] = a[i++]; + else b[k] = a[j++]; + } + } + counter++; + } + if(counter & 1) memcpy(leaves, mem, sizeof(*leaves) * num); + lodepng_free(mem); +} + +/*Boundary Package Merge step, numpresent is the amount of leaves, and c is the current chain.*/ +static void boundaryPM(BPMLists* lists, BPMNode* leaves, size_t numpresent, int c, int num) +{ + unsigned lastindex = lists->chains1[c]->index; + + if(c == 0) + { + if(lastindex >= numpresent) return; + lists->chains0[c] = lists->chains1[c]; + lists->chains1[c] = bpmnode_create(lists, leaves[lastindex].weight, lastindex + 1, 0); + } + else + { + /*sum of the weights of the head nodes of the previous lookahead chains.*/ + int sum = lists->chains0[c - 1]->weight + lists->chains1[c - 1]->weight; + lists->chains0[c] = lists->chains1[c]; + if(lastindex < numpresent && sum > leaves[lastindex].weight) + { + lists->chains1[c] = bpmnode_create(lists, leaves[lastindex].weight, lastindex + 1, lists->chains1[c]->tail); + return; + } + lists->chains1[c] = bpmnode_create(lists, sum, lastindex, lists->chains1[c - 1]); + /*in the end we are only interested in the chain of the last list, so no + need to recurse if we're at the last one (this gives measurable speedup)*/ + if(num + 1 < (int)(2 * numpresent - 2)) + { + boundaryPM(lists, leaves, numpresent, c - 1, num); + boundaryPM(lists, leaves, numpresent, c - 1, num); + } + } +} + +unsigned lodepng_huffman_code_lengths(unsigned* lengths, const unsigned* frequencies, + size_t numcodes, unsigned maxbitlen) +{ + unsigned error = 0; + unsigned i; + size_t numpresent = 0; /*number of symbols with non-zero frequency*/ + BPMNode* leaves; /*the symbols, only those with > 0 frequency*/ + + if(numcodes == 0) return 80; /*error: a tree of 0 symbols is not supposed to be made*/ + if((1u << maxbitlen) < numcodes) return 80; /*error: represent all symbols*/ + + leaves = (BPMNode*)lodepng_malloc(numcodes * sizeof(*leaves)); + if(!leaves) return 83; /*alloc fail*/ + + for(i = 0; i != numcodes; ++i) + { + if(frequencies[i] > 0) + { + leaves[numpresent].weight = (int)frequencies[i]; + leaves[numpresent].index = i; + ++numpresent; + } + } + + for(i = 0; i != numcodes; ++i) lengths[i] = 0; + + /*ensure at least two present symbols. There should be at least one symbol + according to RFC 1951 section 3.2.7. Some decoders incorrectly require two. To + make these work as well ensure there are at least two symbols. The + Package-Merge code below also doesn't work correctly if there's only one + symbol, it'd give it the theoritical 0 bits but in practice zlib wants 1 bit*/ + if(numpresent == 0) + { + lengths[0] = lengths[1] = 1; /*note that for RFC 1951 section 3.2.7, only lengths[0] = 1 is needed*/ + } + else if(numpresent == 1) + { + lengths[leaves[0].index] = 1; + lengths[leaves[0].index == 0 ? 1 : 0] = 1; + } + else + { + BPMLists lists; + BPMNode* node; + + bpmnode_sort(leaves, numpresent); + + lists.listsize = maxbitlen; + lists.memsize = 2 * maxbitlen * (maxbitlen + 1); + lists.nextfree = 0; + lists.numfree = lists.memsize; + lists.memory = (BPMNode*)lodepng_malloc(lists.memsize * sizeof(*lists.memory)); + lists.freelist = (BPMNode**)lodepng_malloc(lists.memsize * sizeof(BPMNode*)); + lists.chains0 = (BPMNode**)lodepng_malloc(lists.listsize * sizeof(BPMNode*)); + lists.chains1 = (BPMNode**)lodepng_malloc(lists.listsize * sizeof(BPMNode*)); + if(!lists.memory || !lists.freelist || !lists.chains0 || !lists.chains1) error = 83; /*alloc fail*/ + + if(!error) + { + for(i = 0; i != lists.memsize; ++i) lists.freelist[i] = &lists.memory[i]; + + bpmnode_create(&lists, leaves[0].weight, 1, 0); + bpmnode_create(&lists, leaves[1].weight, 2, 0); + + for(i = 0; i != lists.listsize; ++i) + { + lists.chains0[i] = &lists.memory[0]; + lists.chains1[i] = &lists.memory[1]; + } + + /*each boundaryPM call adds one chain to the last list, and we need 2 * numpresent - 2 chains.*/ + for(i = 2; i != 2 * numpresent - 2; ++i) boundaryPM(&lists, leaves, numpresent, (int)maxbitlen - 1, (int)i); + + for(node = lists.chains1[maxbitlen - 1]; node; node = node->tail) + { + for(i = 0; i != node->index; ++i) ++lengths[leaves[i].index]; + } + } + + lodepng_free(lists.memory); + lodepng_free(lists.freelist); + lodepng_free(lists.chains0); + lodepng_free(lists.chains1); + } + + lodepng_free(leaves); + return error; +} + +/*Create the Huffman tree given the symbol frequencies*/ +static unsigned HuffmanTree_makeFromFrequencies(HuffmanTree* tree, const unsigned* frequencies, + size_t mincodes, size_t numcodes, unsigned maxbitlen) +{ + unsigned error = 0; + while(!frequencies[numcodes - 1] && numcodes > mincodes) --numcodes; /*trim zeroes*/ + tree->maxbitlen = maxbitlen; + tree->numcodes = (unsigned)numcodes; /*number of symbols*/ + tree->lengths = (unsigned*)lodepng_realloc(tree->lengths, numcodes * sizeof(unsigned)); + if(!tree->lengths) return 83; /*alloc fail*/ + /*initialize all lengths to 0*/ + memset(tree->lengths, 0, numcodes * sizeof(unsigned)); + + error = lodepng_huffman_code_lengths(tree->lengths, frequencies, numcodes, maxbitlen); + if(!error) error = HuffmanTree_makeFromLengths2(tree); + return error; +} + +static unsigned HuffmanTree_getCode(const HuffmanTree* tree, unsigned index) +{ + return tree->tree1d[index]; +} + +static unsigned HuffmanTree_getLength(const HuffmanTree* tree, unsigned index) +{ + return tree->lengths[index]; +} +#endif /*LODEPNG_COMPILE_ENCODER*/ + +/*get the literal and length code tree of a deflated block with fixed tree, as per the deflate specification*/ +static unsigned generateFixedLitLenTree(HuffmanTree* tree) +{ + unsigned i, error = 0; + unsigned* bitlen = (unsigned*)lodepng_malloc(NUM_DEFLATE_CODE_SYMBOLS * sizeof(unsigned)); + if(!bitlen) return 83; /*alloc fail*/ + + /*288 possible codes: 0-255=literals, 256=endcode, 257-285=lengthcodes, 286-287=unused*/ + for(i = 0; i <= 143; ++i) bitlen[i] = 8; + for(i = 144; i <= 255; ++i) bitlen[i] = 9; + for(i = 256; i <= 279; ++i) bitlen[i] = 7; + for(i = 280; i <= 287; ++i) bitlen[i] = 8; + + error = HuffmanTree_makeFromLengths(tree, bitlen, NUM_DEFLATE_CODE_SYMBOLS, 15); + + lodepng_free(bitlen); + return error; +} + +/*get the distance code tree of a deflated block with fixed tree, as specified in the deflate specification*/ +static unsigned generateFixedDistanceTree(HuffmanTree* tree) +{ + unsigned i, error = 0; + unsigned* bitlen = (unsigned*)lodepng_malloc(NUM_DISTANCE_SYMBOLS * sizeof(unsigned)); + if(!bitlen) return 83; /*alloc fail*/ + + /*there are 32 distance codes, but 30-31 are unused*/ + for(i = 0; i != NUM_DISTANCE_SYMBOLS; ++i) bitlen[i] = 5; + error = HuffmanTree_makeFromLengths(tree, bitlen, NUM_DISTANCE_SYMBOLS, 15); + + lodepng_free(bitlen); + return error; +} + +#ifdef LODEPNG_COMPILE_DECODER + +/* +returns the code, or (unsigned)(-1) if error happened +inbitlength is the length of the complete buffer, in bits (so its byte length times 8) +*/ +static unsigned huffmanDecodeSymbol(const unsigned char* in, size_t* bp, + const HuffmanTree* codetree, size_t inbitlength) +{ + unsigned treepos = 0, ct; + for(;;) + { + if(*bp >= inbitlength) return (unsigned)(-1); /*error: end of input memory reached without endcode*/ + /* + decode the symbol from the tree. The "readBitFromStream" code is inlined in + the expression below because this is the biggest bottleneck while decoding + */ + ct = codetree->tree2d[(treepos << 1) + READBIT(*bp, in)]; + ++(*bp); + if(ct < codetree->numcodes) return ct; /*the symbol is decoded, return it*/ + else treepos = ct - codetree->numcodes; /*symbol not yet decoded, instead move tree position*/ + + if(treepos >= codetree->numcodes) return (unsigned)(-1); /*error: it appeared outside the codetree*/ + } +} +#endif /*LODEPNG_COMPILE_DECODER*/ + +#ifdef LODEPNG_COMPILE_DECODER + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / Inflator (Decompressor) / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +/*get the tree of a deflated block with fixed tree, as specified in the deflate specification*/ +static void getTreeInflateFixed(HuffmanTree* tree_ll, HuffmanTree* tree_d) +{ + /*TODO: check for out of memory errors*/ + generateFixedLitLenTree(tree_ll); + generateFixedDistanceTree(tree_d); +} + +/*get the tree of a deflated block with dynamic tree, the tree itself is also Huffman compressed with a known tree*/ +static unsigned getTreeInflateDynamic(HuffmanTree* tree_ll, HuffmanTree* tree_d, + const unsigned char* in, size_t* bp, size_t inlength) +{ + /*make sure that length values that aren't filled in will be 0, or a wrong tree will be generated*/ + unsigned error = 0; + unsigned n, HLIT, HDIST, HCLEN, i; + size_t inbitlength = inlength * 8; + + /*see comments in deflateDynamic for explanation of the context and these variables, it is analogous*/ + unsigned* bitlen_ll = 0; /*lit,len code lengths*/ + unsigned* bitlen_d = 0; /*dist code lengths*/ + /*code length code lengths ("clcl"), the bit lengths of the huffman tree used to compress bitlen_ll and bitlen_d*/ + unsigned* bitlen_cl = 0; + HuffmanTree tree_cl; /*the code tree for code length codes (the huffman tree for compressed huffman trees)*/ + + if((*bp) + 14 > (inlength << 3)) return 49; /*error: the bit pointer is or will go past the memory*/ + + /*number of literal/length codes + 257. Unlike the spec, the value 257 is added to it here already*/ + HLIT = readBitsFromStream(bp, in, 5) + 257; + /*number of distance codes. Unlike the spec, the value 1 is added to it here already*/ + HDIST = readBitsFromStream(bp, in, 5) + 1; + /*number of code length codes. Unlike the spec, the value 4 is added to it here already*/ + HCLEN = readBitsFromStream(bp, in, 4) + 4; + + if((*bp) + HCLEN * 3 > (inlength << 3)) return 50; /*error: the bit pointer is or will go past the memory*/ + + HuffmanTree_init(&tree_cl); + + while(!error) + { + /*read the code length codes out of 3 * (amount of code length codes) bits*/ + + bitlen_cl = (unsigned*)lodepng_malloc(NUM_CODE_LENGTH_CODES * sizeof(unsigned)); + if(!bitlen_cl) ERROR_BREAK(83 /*alloc fail*/); + + for(i = 0; i != NUM_CODE_LENGTH_CODES; ++i) + { + if(i < HCLEN) bitlen_cl[CLCL_ORDER[i]] = readBitsFromStream(bp, in, 3); + else bitlen_cl[CLCL_ORDER[i]] = 0; /*if not, it must stay 0*/ + } + + error = HuffmanTree_makeFromLengths(&tree_cl, bitlen_cl, NUM_CODE_LENGTH_CODES, 7); + if(error) break; + + /*now we can use this tree to read the lengths for the tree that this function will return*/ + bitlen_ll = (unsigned*)lodepng_malloc(NUM_DEFLATE_CODE_SYMBOLS * sizeof(unsigned)); + bitlen_d = (unsigned*)lodepng_malloc(NUM_DISTANCE_SYMBOLS * sizeof(unsigned)); + if(!bitlen_ll || !bitlen_d) ERROR_BREAK(83 /*alloc fail*/); + for(i = 0; i != NUM_DEFLATE_CODE_SYMBOLS; ++i) bitlen_ll[i] = 0; + for(i = 0; i != NUM_DISTANCE_SYMBOLS; ++i) bitlen_d[i] = 0; + + /*i is the current symbol we're reading in the part that contains the code lengths of lit/len and dist codes*/ + i = 0; + while(i < HLIT + HDIST) + { + unsigned code = huffmanDecodeSymbol(in, bp, &tree_cl, inbitlength); + if(code <= 15) /*a length code*/ + { + if(i < HLIT) bitlen_ll[i] = code; + else bitlen_d[i - HLIT] = code; + ++i; + } + else if(code == 16) /*repeat previous*/ + { + unsigned replength = 3; /*read in the 2 bits that indicate repeat length (3-6)*/ + unsigned value; /*set value to the previous code*/ + + if(i == 0) ERROR_BREAK(54); /*can't repeat previous if i is 0*/ + + if((*bp + 2) > inbitlength) ERROR_BREAK(50); /*error, bit pointer jumps past memory*/ + replength += readBitsFromStream(bp, in, 2); + + if(i < HLIT + 1) value = bitlen_ll[i - 1]; + else value = bitlen_d[i - HLIT - 1]; + /*repeat this value in the next lengths*/ + for(n = 0; n < replength; ++n) + { + if(i >= HLIT + HDIST) ERROR_BREAK(13); /*error: i is larger than the amount of codes*/ + if(i < HLIT) bitlen_ll[i] = value; + else bitlen_d[i - HLIT] = value; + ++i; + } + } + else if(code == 17) /*repeat "0" 3-10 times*/ + { + unsigned replength = 3; /*read in the bits that indicate repeat length*/ + if((*bp + 3) > inbitlength) ERROR_BREAK(50); /*error, bit pointer jumps past memory*/ + replength += readBitsFromStream(bp, in, 3); + + /*repeat this value in the next lengths*/ + for(n = 0; n < replength; ++n) + { + if(i >= HLIT + HDIST) ERROR_BREAK(14); /*error: i is larger than the amount of codes*/ + + if(i < HLIT) bitlen_ll[i] = 0; + else bitlen_d[i - HLIT] = 0; + ++i; + } + } + else if(code == 18) /*repeat "0" 11-138 times*/ + { + unsigned replength = 11; /*read in the bits that indicate repeat length*/ + if((*bp + 7) > inbitlength) ERROR_BREAK(50); /*error, bit pointer jumps past memory*/ + replength += readBitsFromStream(bp, in, 7); + + /*repeat this value in the next lengths*/ + for(n = 0; n < replength; ++n) + { + if(i >= HLIT + HDIST) ERROR_BREAK(15); /*error: i is larger than the amount of codes*/ + + if(i < HLIT) bitlen_ll[i] = 0; + else bitlen_d[i - HLIT] = 0; + ++i; + } + } + else /*if(code == (unsigned)(-1))*/ /*huffmanDecodeSymbol returns (unsigned)(-1) in case of error*/ + { + if(code == (unsigned)(-1)) + { + /*return error code 10 or 11 depending on the situation that happened in huffmanDecodeSymbol + (10=no endcode, 11=wrong jump outside of tree)*/ + error = (*bp) > inbitlength ? 10 : 11; + } + else error = 16; /*unexisting code, this can never happen*/ + break; + } + } + if(error) break; + + if(bitlen_ll[256] == 0) ERROR_BREAK(64); /*the length of the end code 256 must be larger than 0*/ + + /*now we've finally got HLIT and HDIST, so generate the code trees, and the function is done*/ + error = HuffmanTree_makeFromLengths(tree_ll, bitlen_ll, NUM_DEFLATE_CODE_SYMBOLS, 15); + if(error) break; + error = HuffmanTree_makeFromLengths(tree_d, bitlen_d, NUM_DISTANCE_SYMBOLS, 15); + + break; /*end of error-while*/ + } + + lodepng_free(bitlen_cl); + lodepng_free(bitlen_ll); + lodepng_free(bitlen_d); + HuffmanTree_cleanup(&tree_cl); + + return error; +} + +/*inflate a block with dynamic of fixed Huffman tree*/ +static unsigned inflateHuffmanBlock(ucvector* out, const unsigned char* in, size_t* bp, + size_t* pos, size_t inlength, unsigned btype) +{ + unsigned error = 0; + HuffmanTree tree_ll; /*the huffman tree for literal and length codes*/ + HuffmanTree tree_d; /*the huffman tree for distance codes*/ + size_t inbitlength = inlength * 8; + + HuffmanTree_init(&tree_ll); + HuffmanTree_init(&tree_d); + + if(btype == 1) getTreeInflateFixed(&tree_ll, &tree_d); + else if(btype == 2) error = getTreeInflateDynamic(&tree_ll, &tree_d, in, bp, inlength); + + while(!error) /*decode all symbols until end reached, breaks at end code*/ + { + /*code_ll is literal, length or end code*/ + unsigned code_ll = huffmanDecodeSymbol(in, bp, &tree_ll, inbitlength); + if(code_ll <= 255) /*literal symbol*/ + { + /*ucvector_push_back would do the same, but for some reason the two lines below run 10% faster*/ + if(!ucvector_resize(out, (*pos) + 1)) ERROR_BREAK(83 /*alloc fail*/); + out->data[*pos] = (unsigned char)code_ll; + ++(*pos); + } + else if(code_ll >= FIRST_LENGTH_CODE_INDEX && code_ll <= LAST_LENGTH_CODE_INDEX) /*length code*/ + { + unsigned code_d, distance; + unsigned numextrabits_l, numextrabits_d; /*extra bits for length and distance*/ + size_t start, forward, backward, length; + + /*part 1: get length base*/ + length = LENGTHBASE[code_ll - FIRST_LENGTH_CODE_INDEX]; + + /*part 2: get extra bits and add the value of that to length*/ + numextrabits_l = LENGTHEXTRA[code_ll - FIRST_LENGTH_CODE_INDEX]; + if((*bp + numextrabits_l) > inbitlength) ERROR_BREAK(51); /*error, bit pointer will jump past memory*/ + length += readBitsFromStream(bp, in, numextrabits_l); + + /*part 3: get distance code*/ + code_d = huffmanDecodeSymbol(in, bp, &tree_d, inbitlength); + if(code_d > 29) + { + if(code_d == (unsigned)(-1)) /*huffmanDecodeSymbol returns (unsigned)(-1) in case of error*/ + { + /*return error code 10 or 11 depending on the situation that happened in huffmanDecodeSymbol + (10=no endcode, 11=wrong jump outside of tree)*/ + error = (*bp) > inlength * 8 ? 10 : 11; + } + else error = 18; /*error: invalid distance code (30-31 are never used)*/ + break; + } + distance = DISTANCEBASE[code_d]; + + /*part 4: get extra bits from distance*/ + numextrabits_d = DISTANCEEXTRA[code_d]; + if((*bp + numextrabits_d) > inbitlength) ERROR_BREAK(51); /*error, bit pointer will jump past memory*/ + distance += readBitsFromStream(bp, in, numextrabits_d); + + /*part 5: fill in all the out[n] values based on the length and dist*/ + start = (*pos); + if(distance > start) ERROR_BREAK(52); /*too long backward distance*/ + backward = start - distance; + + if(!ucvector_resize(out, (*pos) + length)) ERROR_BREAK(83 /*alloc fail*/); + if (distance < length) { + for(forward = 0; forward < length; ++forward) + { + out->data[(*pos)++] = out->data[backward++]; + } + } else { + memcpy(out->data + *pos, out->data + backward, length); + *pos += length; + } + } + else if(code_ll == 256) + { + break; /*end code, break the loop*/ + } + else /*if(code == (unsigned)(-1))*/ /*huffmanDecodeSymbol returns (unsigned)(-1) in case of error*/ + { + /*return error code 10 or 11 depending on the situation that happened in huffmanDecodeSymbol + (10=no endcode, 11=wrong jump outside of tree)*/ + error = ((*bp) > inlength * 8) ? 10 : 11; + break; + } + } + + HuffmanTree_cleanup(&tree_ll); + HuffmanTree_cleanup(&tree_d); + + return error; +} + +static unsigned inflateNoCompression(ucvector* out, const unsigned char* in, size_t* bp, size_t* pos, size_t inlength) +{ + size_t p; + unsigned LEN, NLEN, n, error = 0; + + /*go to first boundary of byte*/ + while(((*bp) & 0x7) != 0) ++(*bp); + p = (*bp) / 8; /*byte position*/ + + /*read LEN (2 bytes) and NLEN (2 bytes)*/ + if(p + 4 >= inlength) return 52; /*error, bit pointer will jump past memory*/ + LEN = in[p] + 256u * in[p + 1]; p += 2; + NLEN = in[p] + 256u * in[p + 1]; p += 2; + + /*check if 16-bit NLEN is really the one's complement of LEN*/ + if(LEN + NLEN != 65535) return 21; /*error: NLEN is not one's complement of LEN*/ + + if(!ucvector_resize(out, (*pos) + LEN)) return 83; /*alloc fail*/ + + /*read the literal data: LEN bytes are now stored in the out buffer*/ + if(p + LEN > inlength) return 23; /*error: reading outside of in buffer*/ + for(n = 0; n < LEN; ++n) out->data[(*pos)++] = in[p++]; + + (*bp) = p * 8; + + return error; +} + +static unsigned lodepng_inflatev(ucvector* out, + const unsigned char* in, size_t insize, + const LodePNGDecompressSettings* settings) +{ + /*bit pointer in the "in" data, current byte is bp >> 3, current bit is bp & 0x7 (from lsb to msb of the byte)*/ + size_t bp = 0; + unsigned BFINAL = 0; + size_t pos = 0; /*byte position in the out buffer*/ + unsigned error = 0; + + (void)settings; + + while(!BFINAL) + { + unsigned BTYPE; + if(bp + 2 >= insize * 8) return 52; /*error, bit pointer will jump past memory*/ + BFINAL = readBitFromStream(&bp, in); + BTYPE = 1u * readBitFromStream(&bp, in); + BTYPE += 2u * readBitFromStream(&bp, in); + + if(BTYPE == 3) return 20; /*error: invalid BTYPE*/ + else if(BTYPE == 0) error = inflateNoCompression(out, in, &bp, &pos, insize); /*no compression*/ + else error = inflateHuffmanBlock(out, in, &bp, &pos, insize, BTYPE); /*compression, BTYPE 01 or 10*/ + + if(error) return error; + } + + return error; +} + +unsigned lodepng_inflate(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGDecompressSettings* settings) +{ + unsigned error; + ucvector v; + ucvector_init_buffer(&v, *out, *outsize); + error = lodepng_inflatev(&v, in, insize, settings); + *out = v.data; + *outsize = v.size; + return error; +} + +static unsigned inflate(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGDecompressSettings* settings) +{ + if(settings->custom_inflate) + { + return settings->custom_inflate(out, outsize, in, insize, settings); + } + else + { + return lodepng_inflate(out, outsize, in, insize, settings); + } +} + +#endif /*LODEPNG_COMPILE_DECODER*/ + +#ifdef LODEPNG_COMPILE_ENCODER + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / Deflator (Compressor) / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +static const size_t MAX_SUPPORTED_DEFLATE_LENGTH = 258; + +/*bitlen is the size in bits of the code*/ +static void addHuffmanSymbol(size_t* bp, ucvector* compressed, unsigned code, unsigned bitlen) +{ + addBitsToStreamReversed(bp, compressed, code, bitlen); +} + +/*search the index in the array, that has the largest value smaller than or equal to the given value, +given array must be sorted (if no value is smaller, it returns the size of the given array)*/ +static size_t searchCodeIndex(const unsigned* array, size_t array_size, size_t value) +{ + /*binary search (only small gain over linear). TODO: use CPU log2 instruction for getting symbols instead*/ + size_t left = 1; + size_t right = array_size - 1; + + while(left <= right) { + size_t mid = (left + right) >> 1; + if (array[mid] >= value) right = mid - 1; + else left = mid + 1; + } + if(left >= array_size || array[left] > value) left--; + return left; +} + +static void addLengthDistance(uivector* values, size_t length, size_t distance) +{ + /*values in encoded vector are those used by deflate: + 0-255: literal bytes + 256: end + 257-285: length/distance pair (length code, followed by extra length bits, distance code, extra distance bits) + 286-287: invalid*/ + + unsigned length_code = (unsigned)searchCodeIndex(LENGTHBASE, 29, length); + unsigned extra_length = (unsigned)(length - LENGTHBASE[length_code]); + unsigned dist_code = (unsigned)searchCodeIndex(DISTANCEBASE, 30, distance); + unsigned extra_distance = (unsigned)(distance - DISTANCEBASE[dist_code]); + + uivector_push_back(values, length_code + FIRST_LENGTH_CODE_INDEX); + uivector_push_back(values, extra_length); + uivector_push_back(values, dist_code); + uivector_push_back(values, extra_distance); +} + +/*3 bytes of data get encoded into two bytes. The hash cannot use more than 3 +bytes as input because 3 is the minimum match length for deflate*/ +static const unsigned HASH_NUM_VALUES = 65536; +static const unsigned HASH_BIT_MASK = 65535; /*HASH_NUM_VALUES - 1, but C90 does not like that as initializer*/ + +typedef struct Hash +{ + int* head; /*hash value to head circular pos - can be outdated if went around window*/ + /*circular pos to prev circular pos*/ + unsigned short* chain; + int* val; /*circular pos to hash value*/ + + /*TODO: do this not only for zeros but for any repeated byte. However for PNG + it's always going to be the zeros that dominate, so not important for PNG*/ + int* headz; /*similar to head, but for chainz*/ + unsigned short* chainz; /*those with same amount of zeros*/ + unsigned short* zeros; /*length of zeros streak, used as a second hash chain*/ +} Hash; + +static unsigned hash_init(Hash* hash, unsigned windowsize) +{ + unsigned i; + hash->head = (int*)lodepng_malloc(sizeof(int) * HASH_NUM_VALUES); + hash->val = (int*)lodepng_malloc(sizeof(int) * windowsize); + hash->chain = (unsigned short*)lodepng_malloc(sizeof(unsigned short) * windowsize); + + hash->zeros = (unsigned short*)lodepng_malloc(sizeof(unsigned short) * windowsize); + hash->headz = (int*)lodepng_malloc(sizeof(int) * (MAX_SUPPORTED_DEFLATE_LENGTH + 1)); + hash->chainz = (unsigned short*)lodepng_malloc(sizeof(unsigned short) * windowsize); + + if(!hash->head || !hash->chain || !hash->val || !hash->headz|| !hash->chainz || !hash->zeros) + { + return 83; /*alloc fail*/ + } + + /*initialize hash table*/ + for(i = 0; i != HASH_NUM_VALUES; ++i) hash->head[i] = -1; + for(i = 0; i != windowsize; ++i) hash->val[i] = -1; + for(i = 0; i != windowsize; ++i) hash->chain[i] = i; /*same value as index indicates uninitialized*/ + + for(i = 0; i <= MAX_SUPPORTED_DEFLATE_LENGTH; ++i) hash->headz[i] = -1; + for(i = 0; i != windowsize; ++i) hash->chainz[i] = i; /*same value as index indicates uninitialized*/ + + return 0; +} + +static void hash_cleanup(Hash* hash) +{ + lodepng_free(hash->head); + lodepng_free(hash->val); + lodepng_free(hash->chain); + + lodepng_free(hash->zeros); + lodepng_free(hash->headz); + lodepng_free(hash->chainz); +} + + + +static unsigned getHash(const unsigned char* data, size_t size, size_t pos) +{ + unsigned result = 0; + if(pos + 2 < size) + { + /*A simple shift and xor hash is used. Since the data of PNGs is dominated + by zeroes due to the filters, a better hash does not have a significant + effect on speed in traversing the chain, and causes more time spend on + calculating the hash.*/ + result ^= (unsigned)(data[pos + 0] << 0u); + result ^= (unsigned)(data[pos + 1] << 4u); + result ^= (unsigned)(data[pos + 2] << 8u); + } else { + size_t amount, i; + if(pos >= size) return 0; + amount = size - pos; + for(i = 0; i != amount; ++i) result ^= (unsigned)(data[pos + i] << (i * 8u)); + } + return result & HASH_BIT_MASK; +} + +static unsigned countZeros(const unsigned char* data, size_t size, size_t pos) +{ + const unsigned char* start = data + pos; + const unsigned char* end = start + MAX_SUPPORTED_DEFLATE_LENGTH; + if(end > data + size) end = data + size; + data = start; + while(data != end && *data == 0) ++data; + /*subtracting two addresses returned as 32-bit number (max value is MAX_SUPPORTED_DEFLATE_LENGTH)*/ + return (unsigned)(data - start); +} + +/*wpos = pos & (windowsize - 1)*/ +static void updateHashChain(Hash* hash, size_t wpos, unsigned hashval, unsigned short numzeros) +{ + hash->val[wpos] = (int)hashval; + if(hash->head[hashval] != -1) hash->chain[wpos] = hash->head[hashval]; + hash->head[hashval] = wpos; + + hash->zeros[wpos] = numzeros; + if(hash->headz[numzeros] != -1) hash->chainz[wpos] = hash->headz[numzeros]; + hash->headz[numzeros] = wpos; +} + +/* +LZ77-encode the data. Return value is error code. The input are raw bytes, the output +is in the form of unsigned integers with codes representing for example literal bytes, or +length/distance pairs. +It uses a hash table technique to let it encode faster. When doing LZ77 encoding, a +sliding window (of windowsize) is used, and all past bytes in that window can be used as +the "dictionary". A brute force search through all possible distances would be slow, and +this hash technique is one out of several ways to speed this up. +*/ +static unsigned encodeLZ77(uivector* out, Hash* hash, + const unsigned char* in, size_t inpos, size_t insize, unsigned windowsize, + unsigned minmatch, unsigned nicematch, unsigned lazymatching) +{ + size_t pos; + unsigned i, error = 0; + /*for large window lengths, assume the user wants no compression loss. Otherwise, max hash chain length speedup.*/ + unsigned maxchainlength = windowsize >= 8192 ? windowsize : windowsize / 8; + unsigned maxlazymatch = windowsize >= 8192 ? MAX_SUPPORTED_DEFLATE_LENGTH : 64; + + unsigned usezeros = 1; /*not sure if setting it to false for windowsize < 8192 is better or worse*/ + unsigned numzeros = 0; + + unsigned offset; /*the offset represents the distance in LZ77 terminology*/ + unsigned length; + unsigned lazy = 0; + unsigned lazylength = 0, lazyoffset = 0; + unsigned hashval; + unsigned current_offset, current_length; + unsigned prev_offset; + const unsigned char *lastptr, *foreptr, *backptr; + unsigned hashpos; + + if(windowsize == 0 || windowsize > 32768) return 60; /*error: windowsize smaller/larger than allowed*/ + if((windowsize & (windowsize - 1)) != 0) return 90; /*error: must be power of two*/ + + if(nicematch > MAX_SUPPORTED_DEFLATE_LENGTH) nicematch = MAX_SUPPORTED_DEFLATE_LENGTH; + + for(pos = inpos; pos < insize; ++pos) + { + size_t wpos = pos & (windowsize - 1); /*position for in 'circular' hash buffers*/ + unsigned chainlength = 0; + + hashval = getHash(in, insize, pos); + + if(usezeros && hashval == 0) + { + if(numzeros == 0) numzeros = countZeros(in, insize, pos); + else if(pos + numzeros > insize || in[pos + numzeros - 1] != 0) --numzeros; + } + else + { + numzeros = 0; + } + + updateHashChain(hash, wpos, hashval, numzeros); + + /*the length and offset found for the current position*/ + length = 0; + offset = 0; + + hashpos = hash->chain[wpos]; + + lastptr = &in[insize < pos + MAX_SUPPORTED_DEFLATE_LENGTH ? insize : pos + MAX_SUPPORTED_DEFLATE_LENGTH]; + + /*search for the longest string*/ + prev_offset = 0; + for(;;) + { + if(chainlength++ >= maxchainlength) break; + current_offset = hashpos <= wpos ? wpos - hashpos : wpos - hashpos + windowsize; + + if(current_offset < prev_offset) break; /*stop when went completely around the circular buffer*/ + prev_offset = current_offset; + if(current_offset > 0) + { + /*test the next characters*/ + foreptr = &in[pos]; + backptr = &in[pos - current_offset]; + + /*common case in PNGs is lots of zeros. Quickly skip over them as a speedup*/ + if(numzeros >= 3) + { + unsigned skip = hash->zeros[hashpos]; + if(skip > numzeros) skip = numzeros; + backptr += skip; + foreptr += skip; + } + + while(foreptr != lastptr && *backptr == *foreptr) /*maximum supported length by deflate is max length*/ + { + ++backptr; + ++foreptr; + } + current_length = (unsigned)(foreptr - &in[pos]); + + if(current_length > length) + { + length = current_length; /*the longest length*/ + offset = current_offset; /*the offset that is related to this longest length*/ + /*jump out once a length of max length is found (speed gain). This also jumps + out if length is MAX_SUPPORTED_DEFLATE_LENGTH*/ + if(current_length >= nicematch) break; + } + } + + if(hashpos == hash->chain[hashpos]) break; + + if(numzeros >= 3 && length > numzeros) + { + hashpos = hash->chainz[hashpos]; + if(hash->zeros[hashpos] != numzeros) break; + } + else + { + hashpos = hash->chain[hashpos]; + /*outdated hash value, happens if particular value was not encountered in whole last window*/ + if(hash->val[hashpos] != (int)hashval) break; + } + } + + if(lazymatching) + { + if(!lazy && length >= 3 && length <= maxlazymatch && length < MAX_SUPPORTED_DEFLATE_LENGTH) + { + lazy = 1; + lazylength = length; + lazyoffset = offset; + continue; /*try the next byte*/ + } + if(lazy) + { + lazy = 0; + if(pos == 0) ERROR_BREAK(81); + if(length > lazylength + 1) + { + /*push the previous character as literal*/ + if(!uivector_push_back(out, in[pos - 1])) ERROR_BREAK(83 /*alloc fail*/); + } + else + { + length = lazylength; + offset = lazyoffset; + hash->head[hashval] = -1; /*the same hashchain update will be done, this ensures no wrong alteration*/ + hash->headz[numzeros] = -1; /*idem*/ + --pos; + } + } + } + if(length >= 3 && offset > windowsize) ERROR_BREAK(86 /*too big (or overflown negative) offset*/); + + /*encode it as length/distance pair or literal value*/ + if(length < 3) /*only lengths of 3 or higher are supported as length/distance pair*/ + { + if(!uivector_push_back(out, in[pos])) ERROR_BREAK(83 /*alloc fail*/); + } + else if(length < minmatch || (length == 3 && offset > 4096)) + { + /*compensate for the fact that longer offsets have more extra bits, a + length of only 3 may be not worth it then*/ + if(!uivector_push_back(out, in[pos])) ERROR_BREAK(83 /*alloc fail*/); + } + else + { + addLengthDistance(out, length, offset); + for(i = 1; i < length; ++i) + { + ++pos; + wpos = pos & (windowsize - 1); + hashval = getHash(in, insize, pos); + if(usezeros && hashval == 0) + { + if(numzeros == 0) numzeros = countZeros(in, insize, pos); + else if(pos + numzeros > insize || in[pos + numzeros - 1] != 0) --numzeros; + } + else + { + numzeros = 0; + } + updateHashChain(hash, wpos, hashval, numzeros); + } + } + } /*end of the loop through each character of input*/ + + return error; +} + +/* /////////////////////////////////////////////////////////////////////////// */ + +static unsigned deflateNoCompression(ucvector* out, const unsigned char* data, size_t datasize) +{ + /*non compressed deflate block data: 1 bit BFINAL,2 bits BTYPE,(5 bits): it jumps to start of next byte, + 2 bytes LEN, 2 bytes NLEN, LEN bytes literal DATA*/ + + size_t i, j, numdeflateblocks = (datasize + 65534) / 65535; + unsigned datapos = 0; + for(i = 0; i != numdeflateblocks; ++i) + { + unsigned BFINAL, BTYPE, LEN, NLEN; + unsigned char firstbyte; + + BFINAL = (i == numdeflateblocks - 1); + BTYPE = 0; + + firstbyte = (unsigned char)(BFINAL + ((BTYPE & 1) << 1) + ((BTYPE & 2) << 1)); + ucvector_push_back(out, firstbyte); + + LEN = 65535; + if(datasize - datapos < 65535) LEN = (unsigned)datasize - datapos; + NLEN = 65535 - LEN; + + ucvector_push_back(out, (unsigned char)(LEN & 255)); + ucvector_push_back(out, (unsigned char)(LEN >> 8)); + ucvector_push_back(out, (unsigned char)(NLEN & 255)); + ucvector_push_back(out, (unsigned char)(NLEN >> 8)); + + /*Decompressed data*/ + for(j = 0; j < 65535 && datapos < datasize; ++j) + { + ucvector_push_back(out, data[datapos++]); + } + } + + return 0; +} + +/* +write the lz77-encoded data, which has lit, len and dist codes, to compressed stream using huffman trees. +tree_ll: the tree for lit and len codes. +tree_d: the tree for distance codes. +*/ +static void writeLZ77data(size_t* bp, ucvector* out, const uivector* lz77_encoded, + const HuffmanTree* tree_ll, const HuffmanTree* tree_d) +{ + size_t i = 0; + for(i = 0; i != lz77_encoded->size; ++i) + { + unsigned val = lz77_encoded->data[i]; + addHuffmanSymbol(bp, out, HuffmanTree_getCode(tree_ll, val), HuffmanTree_getLength(tree_ll, val)); + if(val > 256) /*for a length code, 3 more things have to be added*/ + { + unsigned length_index = val - FIRST_LENGTH_CODE_INDEX; + unsigned n_length_extra_bits = LENGTHEXTRA[length_index]; + unsigned length_extra_bits = lz77_encoded->data[++i]; + + unsigned distance_code = lz77_encoded->data[++i]; + + unsigned distance_index = distance_code; + unsigned n_distance_extra_bits = DISTANCEEXTRA[distance_index]; + unsigned distance_extra_bits = lz77_encoded->data[++i]; + + addBitsToStream(bp, out, length_extra_bits, n_length_extra_bits); + addHuffmanSymbol(bp, out, HuffmanTree_getCode(tree_d, distance_code), + HuffmanTree_getLength(tree_d, distance_code)); + addBitsToStream(bp, out, distance_extra_bits, n_distance_extra_bits); + } + } +} + +/*Deflate for a block of type "dynamic", that is, with freely, optimally, created huffman trees*/ +static unsigned deflateDynamic(ucvector* out, size_t* bp, Hash* hash, + const unsigned char* data, size_t datapos, size_t dataend, + const LodePNGCompressSettings* settings, unsigned final) +{ + unsigned error = 0; + + /* + A block is compressed as follows: The PNG data is lz77 encoded, resulting in + literal bytes and length/distance pairs. This is then huffman compressed with + two huffman trees. One huffman tree is used for the lit and len values ("ll"), + another huffman tree is used for the dist values ("d"). These two trees are + stored using their code lengths, and to compress even more these code lengths + are also run-length encoded and huffman compressed. This gives a huffman tree + of code lengths "cl". The code lenghts used to describe this third tree are + the code length code lengths ("clcl"). + */ + + /*The lz77 encoded data, represented with integers since there will also be length and distance codes in it*/ + uivector lz77_encoded; + HuffmanTree tree_ll; /*tree for lit,len values*/ + HuffmanTree tree_d; /*tree for distance codes*/ + HuffmanTree tree_cl; /*tree for encoding the code lengths representing tree_ll and tree_d*/ + uivector frequencies_ll; /*frequency of lit,len codes*/ + uivector frequencies_d; /*frequency of dist codes*/ + uivector frequencies_cl; /*frequency of code length codes*/ + uivector bitlen_lld; /*lit,len,dist code lenghts (int bits), literally (without repeat codes).*/ + uivector bitlen_lld_e; /*bitlen_lld encoded with repeat codes (this is a rudemtary run length compression)*/ + /*bitlen_cl is the code length code lengths ("clcl"). The bit lengths of codes to represent tree_cl + (these are written as is in the file, it would be crazy to compress these using yet another huffman + tree that needs to be represented by yet another set of code lengths)*/ + uivector bitlen_cl; + size_t datasize = dataend - datapos; + + /* + Due to the huffman compression of huffman tree representations ("two levels"), there are some anologies: + bitlen_lld is to tree_cl what data is to tree_ll and tree_d. + bitlen_lld_e is to bitlen_lld what lz77_encoded is to data. + bitlen_cl is to bitlen_lld_e what bitlen_lld is to lz77_encoded. + */ + + unsigned BFINAL = final; + size_t numcodes_ll, numcodes_d, i; + unsigned HLIT, HDIST, HCLEN; + + uivector_init(&lz77_encoded); + HuffmanTree_init(&tree_ll); + HuffmanTree_init(&tree_d); + HuffmanTree_init(&tree_cl); + uivector_init(&frequencies_ll); + uivector_init(&frequencies_d); + uivector_init(&frequencies_cl); + uivector_init(&bitlen_lld); + uivector_init(&bitlen_lld_e); + uivector_init(&bitlen_cl); + + /*This while loop never loops due to a break at the end, it is here to + allow breaking out of it to the cleanup phase on error conditions.*/ + while(!error) + { + if(settings->use_lz77) + { + error = encodeLZ77(&lz77_encoded, hash, data, datapos, dataend, settings->windowsize, + settings->minmatch, settings->nicematch, settings->lazymatching); + if(error) break; + } + else + { + if(!uivector_resize(&lz77_encoded, datasize)) ERROR_BREAK(83 /*alloc fail*/); + for(i = datapos; i < dataend; ++i) lz77_encoded.data[i - datapos] = data[i]; /*no LZ77, but still will be Huffman compressed*/ + } + + if(!uivector_resizev(&frequencies_ll, 286, 0)) ERROR_BREAK(83 /*alloc fail*/); + if(!uivector_resizev(&frequencies_d, 30, 0)) ERROR_BREAK(83 /*alloc fail*/); + + /*Count the frequencies of lit, len and dist codes*/ + for(i = 0; i != lz77_encoded.size; ++i) + { + unsigned symbol = lz77_encoded.data[i]; + ++frequencies_ll.data[symbol]; + if(symbol > 256) + { + unsigned dist = lz77_encoded.data[i + 2]; + ++frequencies_d.data[dist]; + i += 3; + } + } + frequencies_ll.data[256] = 1; /*there will be exactly 1 end code, at the end of the block*/ + + /*Make both huffman trees, one for the lit and len codes, one for the dist codes*/ + error = HuffmanTree_makeFromFrequencies(&tree_ll, frequencies_ll.data, 257, frequencies_ll.size, 15); + if(error) break; + /*2, not 1, is chosen for mincodes: some buggy PNG decoders require at least 2 symbols in the dist tree*/ + error = HuffmanTree_makeFromFrequencies(&tree_d, frequencies_d.data, 2, frequencies_d.size, 15); + if(error) break; + + numcodes_ll = tree_ll.numcodes; if(numcodes_ll > 286) numcodes_ll = 286; + numcodes_d = tree_d.numcodes; if(numcodes_d > 30) numcodes_d = 30; + /*store the code lengths of both generated trees in bitlen_lld*/ + for(i = 0; i != numcodes_ll; ++i) uivector_push_back(&bitlen_lld, HuffmanTree_getLength(&tree_ll, (unsigned)i)); + for(i = 0; i != numcodes_d; ++i) uivector_push_back(&bitlen_lld, HuffmanTree_getLength(&tree_d, (unsigned)i)); + + /*run-length compress bitlen_ldd into bitlen_lld_e by using repeat codes 16 (copy length 3-6 times), + 17 (3-10 zeroes), 18 (11-138 zeroes)*/ + for(i = 0; i != (unsigned)bitlen_lld.size; ++i) + { + unsigned j = 0; /*amount of repititions*/ + while(i + j + 1 < (unsigned)bitlen_lld.size && bitlen_lld.data[i + j + 1] == bitlen_lld.data[i]) ++j; + + if(bitlen_lld.data[i] == 0 && j >= 2) /*repeat code for zeroes*/ + { + ++j; /*include the first zero*/ + if(j <= 10) /*repeat code 17 supports max 10 zeroes*/ + { + uivector_push_back(&bitlen_lld_e, 17); + uivector_push_back(&bitlen_lld_e, j - 3); + } + else /*repeat code 18 supports max 138 zeroes*/ + { + if(j > 138) j = 138; + uivector_push_back(&bitlen_lld_e, 18); + uivector_push_back(&bitlen_lld_e, j - 11); + } + i += (j - 1); + } + else if(j >= 3) /*repeat code for value other than zero*/ + { + size_t k; + unsigned num = j / 6, rest = j % 6; + uivector_push_back(&bitlen_lld_e, bitlen_lld.data[i]); + for(k = 0; k < num; ++k) + { + uivector_push_back(&bitlen_lld_e, 16); + uivector_push_back(&bitlen_lld_e, 6 - 3); + } + if(rest >= 3) + { + uivector_push_back(&bitlen_lld_e, 16); + uivector_push_back(&bitlen_lld_e, rest - 3); + } + else j -= rest; + i += j; + } + else /*too short to benefit from repeat code*/ + { + uivector_push_back(&bitlen_lld_e, bitlen_lld.data[i]); + } + } + + /*generate tree_cl, the huffmantree of huffmantrees*/ + + if(!uivector_resizev(&frequencies_cl, NUM_CODE_LENGTH_CODES, 0)) ERROR_BREAK(83 /*alloc fail*/); + for(i = 0; i != bitlen_lld_e.size; ++i) + { + ++frequencies_cl.data[bitlen_lld_e.data[i]]; + /*after a repeat code come the bits that specify the number of repetitions, + those don't need to be in the frequencies_cl calculation*/ + if(bitlen_lld_e.data[i] >= 16) ++i; + } + + error = HuffmanTree_makeFromFrequencies(&tree_cl, frequencies_cl.data, + frequencies_cl.size, frequencies_cl.size, 7); + if(error) break; + + if(!uivector_resize(&bitlen_cl, tree_cl.numcodes)) ERROR_BREAK(83 /*alloc fail*/); + for(i = 0; i != tree_cl.numcodes; ++i) + { + /*lenghts of code length tree is in the order as specified by deflate*/ + bitlen_cl.data[i] = HuffmanTree_getLength(&tree_cl, CLCL_ORDER[i]); + } + while(bitlen_cl.data[bitlen_cl.size - 1] == 0 && bitlen_cl.size > 4) + { + /*remove zeros at the end, but minimum size must be 4*/ + if(!uivector_resize(&bitlen_cl, bitlen_cl.size - 1)) ERROR_BREAK(83 /*alloc fail*/); + } + if(error) break; + + /* + Write everything into the output + + After the BFINAL and BTYPE, the dynamic block consists out of the following: + - 5 bits HLIT, 5 bits HDIST, 4 bits HCLEN + - (HCLEN+4)*3 bits code lengths of code length alphabet + - HLIT + 257 code lenghts of lit/length alphabet (encoded using the code length + alphabet, + possible repetition codes 16, 17, 18) + - HDIST + 1 code lengths of distance alphabet (encoded using the code length + alphabet, + possible repetition codes 16, 17, 18) + - compressed data + - 256 (end code) + */ + + /*Write block type*/ + addBitToStream(bp, out, BFINAL); + addBitToStream(bp, out, 0); /*first bit of BTYPE "dynamic"*/ + addBitToStream(bp, out, 1); /*second bit of BTYPE "dynamic"*/ + + /*write the HLIT, HDIST and HCLEN values*/ + HLIT = (unsigned)(numcodes_ll - 257); + HDIST = (unsigned)(numcodes_d - 1); + HCLEN = (unsigned)bitlen_cl.size - 4; + /*trim zeroes for HCLEN. HLIT and HDIST were already trimmed at tree creation*/ + while(!bitlen_cl.data[HCLEN + 4 - 1] && HCLEN > 0) --HCLEN; + addBitsToStream(bp, out, HLIT, 5); + addBitsToStream(bp, out, HDIST, 5); + addBitsToStream(bp, out, HCLEN, 4); + + /*write the code lenghts of the code length alphabet*/ + for(i = 0; i != HCLEN + 4; ++i) addBitsToStream(bp, out, bitlen_cl.data[i], 3); + + /*write the lenghts of the lit/len AND the dist alphabet*/ + for(i = 0; i != bitlen_lld_e.size; ++i) + { + addHuffmanSymbol(bp, out, HuffmanTree_getCode(&tree_cl, bitlen_lld_e.data[i]), + HuffmanTree_getLength(&tree_cl, bitlen_lld_e.data[i])); + /*extra bits of repeat codes*/ + if(bitlen_lld_e.data[i] == 16) addBitsToStream(bp, out, bitlen_lld_e.data[++i], 2); + else if(bitlen_lld_e.data[i] == 17) addBitsToStream(bp, out, bitlen_lld_e.data[++i], 3); + else if(bitlen_lld_e.data[i] == 18) addBitsToStream(bp, out, bitlen_lld_e.data[++i], 7); + } + + /*write the compressed data symbols*/ + writeLZ77data(bp, out, &lz77_encoded, &tree_ll, &tree_d); + /*error: the length of the end code 256 must be larger than 0*/ + if(HuffmanTree_getLength(&tree_ll, 256) == 0) ERROR_BREAK(64); + + /*write the end code*/ + addHuffmanSymbol(bp, out, HuffmanTree_getCode(&tree_ll, 256), HuffmanTree_getLength(&tree_ll, 256)); + + break; /*end of error-while*/ + } + + /*cleanup*/ + uivector_cleanup(&lz77_encoded); + HuffmanTree_cleanup(&tree_ll); + HuffmanTree_cleanup(&tree_d); + HuffmanTree_cleanup(&tree_cl); + uivector_cleanup(&frequencies_ll); + uivector_cleanup(&frequencies_d); + uivector_cleanup(&frequencies_cl); + uivector_cleanup(&bitlen_lld_e); + uivector_cleanup(&bitlen_lld); + uivector_cleanup(&bitlen_cl); + + return error; +} + +static unsigned deflateFixed(ucvector* out, size_t* bp, Hash* hash, + const unsigned char* data, + size_t datapos, size_t dataend, + const LodePNGCompressSettings* settings, unsigned final) +{ + HuffmanTree tree_ll; /*tree for literal values and length codes*/ + HuffmanTree tree_d; /*tree for distance codes*/ + + unsigned BFINAL = final; + unsigned error = 0; + size_t i; + + HuffmanTree_init(&tree_ll); + HuffmanTree_init(&tree_d); + + generateFixedLitLenTree(&tree_ll); + generateFixedDistanceTree(&tree_d); + + addBitToStream(bp, out, BFINAL); + addBitToStream(bp, out, 1); /*first bit of BTYPE*/ + addBitToStream(bp, out, 0); /*second bit of BTYPE*/ + + if(settings->use_lz77) /*LZ77 encoded*/ + { + uivector lz77_encoded; + uivector_init(&lz77_encoded); + error = encodeLZ77(&lz77_encoded, hash, data, datapos, dataend, settings->windowsize, + settings->minmatch, settings->nicematch, settings->lazymatching); + if(!error) writeLZ77data(bp, out, &lz77_encoded, &tree_ll, &tree_d); + uivector_cleanup(&lz77_encoded); + } + else /*no LZ77, but still will be Huffman compressed*/ + { + for(i = datapos; i < dataend; ++i) + { + addHuffmanSymbol(bp, out, HuffmanTree_getCode(&tree_ll, data[i]), HuffmanTree_getLength(&tree_ll, data[i])); + } + } + /*add END code*/ + if(!error) addHuffmanSymbol(bp, out, HuffmanTree_getCode(&tree_ll, 256), HuffmanTree_getLength(&tree_ll, 256)); + + /*cleanup*/ + HuffmanTree_cleanup(&tree_ll); + HuffmanTree_cleanup(&tree_d); + + return error; +} + +static unsigned lodepng_deflatev(ucvector* out, const unsigned char* in, size_t insize, + const LodePNGCompressSettings* settings) +{ + unsigned error = 0; + size_t i, blocksize, numdeflateblocks; + size_t bp = 0; /*the bit pointer*/ + Hash hash; + + if(settings->btype > 2) return 61; + else if(settings->btype == 0) return deflateNoCompression(out, in, insize); + else if(settings->btype == 1) blocksize = insize; + else /*if(settings->btype == 2)*/ + { + /*on PNGs, deflate blocks of 65-262k seem to give most dense encoding*/ + blocksize = insize / 8 + 8; + if(blocksize < 65536) blocksize = 65536; + if(blocksize > 262144) blocksize = 262144; + } + + numdeflateblocks = (insize + blocksize - 1) / blocksize; + if(numdeflateblocks == 0) numdeflateblocks = 1; + + error = hash_init(&hash, settings->windowsize); + if(error) return error; + + for(i = 0; i != numdeflateblocks && !error; ++i) + { + unsigned final = (i == numdeflateblocks - 1); + size_t start = i * blocksize; + size_t end = start + blocksize; + if(end > insize) end = insize; + + if(settings->btype == 1) error = deflateFixed(out, &bp, &hash, in, start, end, settings, final); + else if(settings->btype == 2) error = deflateDynamic(out, &bp, &hash, in, start, end, settings, final); + } + + hash_cleanup(&hash); + + return error; +} + +unsigned lodepng_deflate(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGCompressSettings* settings) +{ + unsigned error; + ucvector v; + ucvector_init_buffer(&v, *out, *outsize); + error = lodepng_deflatev(&v, in, insize, settings); + *out = v.data; + *outsize = v.size; + return error; +} + +static unsigned deflate(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGCompressSettings* settings) +{ + if(settings->custom_deflate) + { + return settings->custom_deflate(out, outsize, in, insize, settings); + } + else + { + return lodepng_deflate(out, outsize, in, insize, settings); + } +} + +#endif /*LODEPNG_COMPILE_DECODER*/ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / Adler32 */ +/* ////////////////////////////////////////////////////////////////////////// */ + +static unsigned update_adler32(unsigned adler, const unsigned char* data, unsigned len) +{ + unsigned s1 = adler & 0xffff; + unsigned s2 = (adler >> 16) & 0xffff; + + while(len > 0) + { + /*at least 5550 sums can be done before the sums overflow, saving a lot of module divisions*/ + unsigned amount = len > 5550 ? 5550 : len; + len -= amount; + while(amount > 0) + { + s1 += (*data++); + s2 += s1; + --amount; + } + s1 %= 65521; + s2 %= 65521; + } + + return (s2 << 16) | s1; +} + +/*Return the adler32 of the bytes data[0..len-1]*/ +static unsigned adler32(const unsigned char* data, unsigned len) +{ + return update_adler32(1L, data, len); +} + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / Zlib / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +#ifdef LODEPNG_COMPILE_DECODER + +unsigned lodepng_zlib_decompress(unsigned char** out, size_t* outsize, const unsigned char* in, + size_t insize, const LodePNGDecompressSettings* settings) +{ + unsigned error = 0; + unsigned CM, CINFO, FDICT; + + if(insize < 2) return 53; /*error, size of zlib data too small*/ + /*read information from zlib header*/ + if((in[0] * 256 + in[1]) % 31 != 0) + { + /*error: 256 * in[0] + in[1] must be a multiple of 31, the FCHECK value is supposed to be made that way*/ + return 24; + } + + CM = in[0] & 15; + CINFO = (in[0] >> 4) & 15; + /*FCHECK = in[1] & 31;*/ /*FCHECK is already tested above*/ + FDICT = (in[1] >> 5) & 1; + /*FLEVEL = (in[1] >> 6) & 3;*/ /*FLEVEL is not used here*/ + + if(CM != 8 || CINFO > 7) + { + /*error: only compression method 8: inflate with sliding window of 32k is supported by the PNG spec*/ + return 25; + } + if(FDICT != 0) + { + /*error: the specification of PNG says about the zlib stream: + "The additional flags shall not specify a preset dictionary."*/ + return 26; + } + + error = inflate(out, outsize, in + 2, insize - 2, settings); + if(error) return error; + + if(!settings->ignore_adler32) + { + unsigned ADLER32 = lodepng_read32bitInt(&in[insize - 4]); + unsigned checksum = adler32(*out, (unsigned)(*outsize)); + if(checksum != ADLER32) return 58; /*error, adler checksum not correct, data must be corrupted*/ + } + + return 0; /*no error*/ +} + +static unsigned zlib_decompress(unsigned char** out, size_t* outsize, const unsigned char* in, + size_t insize, const LodePNGDecompressSettings* settings) +{ + if(settings->custom_zlib) + { + return settings->custom_zlib(out, outsize, in, insize, settings); + } + else + { + return lodepng_zlib_decompress(out, outsize, in, insize, settings); + } +} + +#endif /*LODEPNG_COMPILE_DECODER*/ + +#ifdef LODEPNG_COMPILE_ENCODER + +unsigned lodepng_zlib_compress(unsigned char** out, size_t* outsize, const unsigned char* in, + size_t insize, const LodePNGCompressSettings* settings) +{ + /*initially, *out must be NULL and outsize 0, if you just give some random *out + that's pointing to a non allocated buffer, this'll crash*/ + ucvector outv; + size_t i; + unsigned error; + unsigned char* deflatedata = 0; + size_t deflatesize = 0; + + /*zlib data: 1 byte CMF (CM+CINFO), 1 byte FLG, deflate data, 4 byte ADLER32 checksum of the Decompressed data*/ + unsigned CMF = 120; /*0b01111000: CM 8, CINFO 7. With CINFO 7, any window size up to 32768 can be used.*/ + unsigned FLEVEL = 0; + unsigned FDICT = 0; + unsigned CMFFLG = 256 * CMF + FDICT * 32 + FLEVEL * 64; + unsigned FCHECK = 31 - CMFFLG % 31; + CMFFLG += FCHECK; + + /*ucvector-controlled version of the output buffer, for dynamic array*/ + ucvector_init_buffer(&outv, *out, *outsize); + + ucvector_push_back(&outv, (unsigned char)(CMFFLG >> 8)); + ucvector_push_back(&outv, (unsigned char)(CMFFLG & 255)); + + error = deflate(&deflatedata, &deflatesize, in, insize, settings); + + if(!error) + { + unsigned ADLER32 = adler32(in, (unsigned)insize); + for(i = 0; i != deflatesize; ++i) ucvector_push_back(&outv, deflatedata[i]); + lodepng_free(deflatedata); + lodepng_add32bitInt(&outv, ADLER32); + } + + *out = outv.data; + *outsize = outv.size; + + return error; +} + +/* compress using the default or custom zlib function */ +static unsigned zlib_compress(unsigned char** out, size_t* outsize, const unsigned char* in, + size_t insize, const LodePNGCompressSettings* settings) +{ + if(settings->custom_zlib) + { + return settings->custom_zlib(out, outsize, in, insize, settings); + } + else + { + return lodepng_zlib_compress(out, outsize, in, insize, settings); + } +} + +#endif /*LODEPNG_COMPILE_ENCODER*/ + +#else /*no LODEPNG_COMPILE_ZLIB*/ + +#ifdef LODEPNG_COMPILE_DECODER +static unsigned zlib_decompress(unsigned char** out, size_t* outsize, const unsigned char* in, + size_t insize, const LodePNGDecompressSettings* settings) +{ + if(!settings->custom_zlib) return 87; /*no custom zlib function provided */ + return settings->custom_zlib(out, outsize, in, insize, settings); +} +#endif /*LODEPNG_COMPILE_DECODER*/ +#ifdef LODEPNG_COMPILE_ENCODER +static unsigned zlib_compress(unsigned char** out, size_t* outsize, const unsigned char* in, + size_t insize, const LodePNGCompressSettings* settings) +{ + if(!settings->custom_zlib) return 87; /*no custom zlib function provided */ + return settings->custom_zlib(out, outsize, in, insize, settings); +} +#endif /*LODEPNG_COMPILE_ENCODER*/ + +#endif /*LODEPNG_COMPILE_ZLIB*/ + +/* ////////////////////////////////////////////////////////////////////////// */ + +#ifdef LODEPNG_COMPILE_ENCODER + +/*this is a good tradeoff between speed and compression ratio*/ +#define DEFAULT_WINDOWSIZE 2048 + +void lodepng_compress_settings_init(LodePNGCompressSettings* settings) +{ + /*compress with dynamic huffman tree (not in the mathematical sense, just not the predefined one)*/ + settings->btype = 2; + settings->use_lz77 = 1; + settings->windowsize = DEFAULT_WINDOWSIZE; + settings->minmatch = 3; + settings->nicematch = 128; + settings->lazymatching = 1; + + settings->custom_zlib = 0; + settings->custom_deflate = 0; + settings->custom_context = 0; +} + +const LodePNGCompressSettings lodepng_default_compress_settings = {2, 1, DEFAULT_WINDOWSIZE, 3, 128, 1, 0, 0, 0}; + + +#endif /*LODEPNG_COMPILE_ENCODER*/ + +#ifdef LODEPNG_COMPILE_DECODER + +void lodepng_decompress_settings_init(LodePNGDecompressSettings* settings) +{ + settings->ignore_adler32 = 0; + + settings->custom_zlib = 0; + settings->custom_inflate = 0; + settings->custom_context = 0; +} + +const LodePNGDecompressSettings lodepng_default_decompress_settings = {0, 0, 0, 0}; + +#endif /*LODEPNG_COMPILE_DECODER*/ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* // End of Zlib related code. Begin of PNG related code. // */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ + +#ifdef LODEPNG_COMPILE_PNG + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / CRC32 / */ +/* ////////////////////////////////////////////////////////////////////////// */ + + +#ifndef LODEPNG_NO_COMPILE_CRC +/* CRC polynomial: 0xedb88320 */ +static unsigned lodepng_crc32_table[256] = { + 0u, 1996959894u, 3993919788u, 2567524794u, 124634137u, 1886057615u, 3915621685u, 2657392035u, + 249268274u, 2044508324u, 3772115230u, 2547177864u, 162941995u, 2125561021u, 3887607047u, 2428444049u, + 498536548u, 1789927666u, 4089016648u, 2227061214u, 450548861u, 1843258603u, 4107580753u, 2211677639u, + 325883990u, 1684777152u, 4251122042u, 2321926636u, 335633487u, 1661365465u, 4195302755u, 2366115317u, + 997073096u, 1281953886u, 3579855332u, 2724688242u, 1006888145u, 1258607687u, 3524101629u, 2768942443u, + 901097722u, 1119000684u, 3686517206u, 2898065728u, 853044451u, 1172266101u, 3705015759u, 2882616665u, + 651767980u, 1373503546u, 3369554304u, 3218104598u, 565507253u, 1454621731u, 3485111705u, 3099436303u, + 671266974u, 1594198024u, 3322730930u, 2970347812u, 795835527u, 1483230225u, 3244367275u, 3060149565u, + 1994146192u, 31158534u, 2563907772u, 4023717930u, 1907459465u, 112637215u, 2680153253u, 3904427059u, + 2013776290u, 251722036u, 2517215374u, 3775830040u, 2137656763u, 141376813u, 2439277719u, 3865271297u, + 1802195444u, 476864866u, 2238001368u, 4066508878u, 1812370925u, 453092731u, 2181625025u, 4111451223u, + 1706088902u, 314042704u, 2344532202u, 4240017532u, 1658658271u, 366619977u, 2362670323u, 4224994405u, + 1303535960u, 984961486u, 2747007092u, 3569037538u, 1256170817u, 1037604311u, 2765210733u, 3554079995u, + 1131014506u, 879679996u, 2909243462u, 3663771856u, 1141124467u, 855842277u, 2852801631u, 3708648649u, + 1342533948u, 654459306u, 3188396048u, 3373015174u, 1466479909u, 544179635u, 3110523913u, 3462522015u, + 1591671054u, 702138776u, 2966460450u, 3352799412u, 1504918807u, 783551873u, 3082640443u, 3233442989u, + 3988292384u, 2596254646u, 62317068u, 1957810842u, 3939845945u, 2647816111u, 81470997u, 1943803523u, + 3814918930u, 2489596804u, 225274430u, 2053790376u, 3826175755u, 2466906013u, 167816743u, 2097651377u, + 4027552580u, 2265490386u, 503444072u, 1762050814u, 4150417245u, 2154129355u, 426522225u, 1852507879u, + 4275313526u, 2312317920u, 282753626u, 1742555852u, 4189708143u, 2394877945u, 397917763u, 1622183637u, + 3604390888u, 2714866558u, 953729732u, 1340076626u, 3518719985u, 2797360999u, 1068828381u, 1219638859u, + 3624741850u, 2936675148u, 906185462u, 1090812512u, 3747672003u, 2825379669u, 829329135u, 1181335161u, + 3412177804u, 3160834842u, 628085408u, 1382605366u, 3423369109u, 3138078467u, 570562233u, 1426400815u, + 3317316542u, 2998733608u, 733239954u, 1555261956u, 3268935591u, 3050360625u, 752459403u, 1541320221u, + 2607071920u, 3965973030u, 1969922972u, 40735498u, 2617837225u, 3943577151u, 1913087877u, 83908371u, + 2512341634u, 3803740692u, 2075208622u, 213261112u, 2463272603u, 3855990285u, 2094854071u, 198958881u, + 2262029012u, 4057260610u, 1759359992u, 534414190u, 2176718541u, 4139329115u, 1873836001u, 414664567u, + 2282248934u, 4279200368u, 1711684554u, 285281116u, 2405801727u, 4167216745u, 1634467795u, 376229701u, + 2685067896u, 3608007406u, 1308918612u, 956543938u, 2808555105u, 3495958263u, 1231636301u, 1047427035u, + 2932959818u, 3654703836u, 1088359270u, 936918000u, 2847714899u, 3736837829u, 1202900863u, 817233897u, + 3183342108u, 3401237130u, 1404277552u, 615818150u, 3134207493u, 3453421203u, 1423857449u, 601450431u, + 3009837614u, 3294710456u, 1567103746u, 711928724u, 3020668471u, 3272380065u, 1510334235u, 755167117u +}; + +/*Return the CRC of the bytes buf[0..len-1].*/ +unsigned lodepng_crc32(const unsigned char* data, size_t length) +{ + unsigned r = 0xffffffffu; + size_t i; + for(i = 0; i < length; ++i) + { + r = lodepng_crc32_table[(r ^ data[i]) & 0xff] ^ (r >> 8); + } + return r ^ 0xffffffffu; +} +#else /* !LODEPNG_NO_COMPILE_CRC */ +unsigned lodepng_crc32(const unsigned char* data, size_t length); +#endif /* !LODEPNG_NO_COMPILE_CRC */ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / Reading and writing single bits and bytes from/to stream for LodePNG / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +static unsigned char readBitFromReversedStream(size_t* bitpointer, const unsigned char* bitstream) +{ + unsigned char result = (unsigned char)((bitstream[(*bitpointer) >> 3] >> (7 - ((*bitpointer) & 0x7))) & 1); + ++(*bitpointer); + return result; +} + +static unsigned readBitsFromReversedStream(size_t* bitpointer, const unsigned char* bitstream, size_t nbits) +{ + unsigned result = 0; + size_t i; + for(i = 0 ; i < nbits; ++i) + { + result <<= 1; + result |= (unsigned)readBitFromReversedStream(bitpointer, bitstream); + } + return result; +} + +#ifdef LODEPNG_COMPILE_DECODER +static void setBitOfReversedStream0(size_t* bitpointer, unsigned char* bitstream, unsigned char bit) +{ + /*the current bit in bitstream must be 0 for this to work*/ + if(bit) + { + /*earlier bit of huffman code is in a lesser significant bit of an earlier byte*/ + bitstream[(*bitpointer) >> 3] |= (bit << (7 - ((*bitpointer) & 0x7))); + } + ++(*bitpointer); +} +#endif /*LODEPNG_COMPILE_DECODER*/ + +static void setBitOfReversedStream(size_t* bitpointer, unsigned char* bitstream, unsigned char bit) +{ + /*the current bit in bitstream may be 0 or 1 for this to work*/ + if(bit == 0) bitstream[(*bitpointer) >> 3] &= (unsigned char)(~(1 << (7 - ((*bitpointer) & 0x7)))); + else bitstream[(*bitpointer) >> 3] |= (1 << (7 - ((*bitpointer) & 0x7))); + ++(*bitpointer); +} + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / PNG chunks / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +unsigned lodepng_chunk_length(const unsigned char* chunk) +{ + return lodepng_read32bitInt(&chunk[0]); +} + +void lodepng_chunk_type(char type[5], const unsigned char* chunk) +{ + unsigned i; + for(i = 0; i != 4; ++i) type[i] = (char)chunk[4 + i]; + type[4] = 0; /*null termination char*/ +} + +unsigned char lodepng_chunk_type_equals(const unsigned char* chunk, const char* type) +{ + if(strlen(type) != 4) return 0; + return (chunk[4] == type[0] && chunk[5] == type[1] && chunk[6] == type[2] && chunk[7] == type[3]); +} + +unsigned char lodepng_chunk_ancillary(const unsigned char* chunk) +{ + return((chunk[4] & 32) != 0); +} + +unsigned char lodepng_chunk_private(const unsigned char* chunk) +{ + return((chunk[6] & 32) != 0); +} + +unsigned char lodepng_chunk_safetocopy(const unsigned char* chunk) +{ + return((chunk[7] & 32) != 0); +} + +unsigned char* lodepng_chunk_data(unsigned char* chunk) +{ + return &chunk[8]; +} + +const unsigned char* lodepng_chunk_data_const(const unsigned char* chunk) +{ + return &chunk[8]; +} + +unsigned lodepng_chunk_check_crc(const unsigned char* chunk) +{ + unsigned length = lodepng_chunk_length(chunk); + unsigned CRC = lodepng_read32bitInt(&chunk[length + 8]); + /*the CRC is taken of the data and the 4 chunk type letters, not the length*/ + unsigned checksum = lodepng_crc32(&chunk[4], length + 4); + if(CRC != checksum) return 1; + else return 0; +} + +void lodepng_chunk_generate_crc(unsigned char* chunk) +{ + unsigned length = lodepng_chunk_length(chunk); + unsigned CRC = lodepng_crc32(&chunk[4], length + 4); + lodepng_set32bitInt(chunk + 8 + length, CRC); +} + +unsigned char* lodepng_chunk_next(unsigned char* chunk) +{ + unsigned total_chunk_length = lodepng_chunk_length(chunk) + 12; + return &chunk[total_chunk_length]; +} + +const unsigned char* lodepng_chunk_next_const(const unsigned char* chunk) +{ + unsigned total_chunk_length = lodepng_chunk_length(chunk) + 12; + return &chunk[total_chunk_length]; +} + +unsigned lodepng_chunk_append(unsigned char** out, size_t* outlength, const unsigned char* chunk) +{ + unsigned i; + unsigned total_chunk_length = lodepng_chunk_length(chunk) + 12; + unsigned char *chunk_start, *new_buffer; + size_t new_length = (*outlength) + total_chunk_length; + if(new_length < total_chunk_length || new_length < (*outlength)) return 77; /*integer overflow happened*/ + + new_buffer = (unsigned char*)lodepng_realloc(*out, new_length); + if(!new_buffer) return 83; /*alloc fail*/ + (*out) = new_buffer; + (*outlength) = new_length; + chunk_start = &(*out)[new_length - total_chunk_length]; + + for(i = 0; i != total_chunk_length; ++i) chunk_start[i] = chunk[i]; + + return 0; +} + +unsigned lodepng_chunk_create(unsigned char** out, size_t* outlength, unsigned length, + const char* type, const unsigned char* data) +{ + unsigned i; + unsigned char *chunk, *new_buffer; + size_t new_length = (*outlength) + length + 12; + if(new_length < length + 12 || new_length < (*outlength)) return 77; /*integer overflow happened*/ + new_buffer = (unsigned char*)lodepng_realloc(*out, new_length); + if(!new_buffer) return 83; /*alloc fail*/ + (*out) = new_buffer; + (*outlength) = new_length; + chunk = &(*out)[(*outlength) - length - 12]; + + /*1: length*/ + lodepng_set32bitInt(chunk, (unsigned)length); + + /*2: chunk name (4 letters)*/ + chunk[4] = (unsigned char)type[0]; + chunk[5] = (unsigned char)type[1]; + chunk[6] = (unsigned char)type[2]; + chunk[7] = (unsigned char)type[3]; + + /*3: the data*/ + for(i = 0; i != length; ++i) chunk[8 + i] = data[i]; + + /*4: CRC (of the chunkname characters and the data)*/ + lodepng_chunk_generate_crc(chunk); + + return 0; +} + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / Color types and such / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +/*return type is a LodePNG error code*/ +static unsigned checkColorValidity(LodePNGColorType colortype, unsigned bd) /*bd = bitdepth*/ +{ + switch(colortype) + { + case 0: if(!(bd == 1 || bd == 2 || bd == 4 || bd == 8 || bd == 16)) return 37; break; /*grey*/ + case 2: if(!( bd == 8 || bd == 16)) return 37; break; /*RGB*/ + case 3: if(!(bd == 1 || bd == 2 || bd == 4 || bd == 8 )) return 37; break; /*palette*/ + case 4: if(!( bd == 8 || bd == 16)) return 37; break; /*grey + alpha*/ + case 6: if(!( bd == 8 || bd == 16)) return 37; break; /*RGBA*/ + default: return 31; + } + return 0; /*allowed color type / bits combination*/ +} + +static unsigned getNumColorChannels(LodePNGColorType colortype) +{ + switch(colortype) + { + case 0: return 1; /*grey*/ + case 2: return 3; /*RGB*/ + case 3: return 1; /*palette*/ + case 4: return 2; /*grey + alpha*/ + case 6: return 4; /*RGBA*/ + } + return 0; /*unexisting color type*/ +} + +static unsigned lodepng_get_bpp_lct(LodePNGColorType colortype, unsigned bitdepth) +{ + /*bits per pixel is amount of channels * bits per channel*/ + return getNumColorChannels(colortype) * bitdepth; +} + +/* ////////////////////////////////////////////////////////////////////////// */ + +void lodepng_color_mode_init(LodePNGColorMode* info) +{ + info->key_defined = 0; + info->key_r = info->key_g = info->key_b = 0; + info->colortype = LCT_RGBA; + info->bitdepth = 8; + info->palette = 0; + info->palettesize = 0; +} + +void lodepng_color_mode_cleanup(LodePNGColorMode* info) +{ + lodepng_palette_clear(info); +} + +unsigned lodepng_color_mode_copy(LodePNGColorMode* dest, const LodePNGColorMode* source) +{ + size_t i; + lodepng_color_mode_cleanup(dest); + *dest = *source; + if(source->palette) + { + dest->palette = (unsigned char*)lodepng_malloc(1024); + if(!dest->palette && source->palettesize) return 83; /*alloc fail*/ + for(i = 0; i != source->palettesize * 4; ++i) dest->palette[i] = source->palette[i]; + } + return 0; +} + +static int lodepng_color_mode_equal(const LodePNGColorMode* a, const LodePNGColorMode* b) +{ + size_t i; + if(a->colortype != b->colortype) return 0; + if(a->bitdepth != b->bitdepth) return 0; + if(a->key_defined != b->key_defined) return 0; + if(a->key_defined) + { + if(a->key_r != b->key_r) return 0; + if(a->key_g != b->key_g) return 0; + if(a->key_b != b->key_b) return 0; + } + /*if one of the palette sizes is 0, then we consider it to be the same as the + other: it means that e.g. the palette was not given by the user and should be + considered the same as the palette inside the PNG.*/ + if(1/*a->palettesize != 0 && b->palettesize != 0*/) { + if(a->palettesize != b->palettesize) return 0; + for(i = 0; i != a->palettesize * 4; ++i) + { + if(a->palette[i] != b->palette[i]) return 0; + } + } + return 1; +} + +void lodepng_palette_clear(LodePNGColorMode* info) +{ + if(info->palette) lodepng_free(info->palette); + info->palette = 0; + info->palettesize = 0; +} + +unsigned lodepng_palette_add(LodePNGColorMode* info, + unsigned char r, unsigned char g, unsigned char b, unsigned char a) +{ + unsigned char* data; + /*the same resize technique as C++ std::vectors is used, and here it's made so that for a palette with + the max of 256 colors, it'll have the exact alloc size*/ + if(!info->palette) /*allocate palette if empty*/ + { + /*room for 256 colors with 4 bytes each*/ + data = (unsigned char*)lodepng_realloc(info->palette, 1024); + if(!data) return 83; /*alloc fail*/ + else info->palette = data; + } + info->palette[4 * info->palettesize + 0] = r; + info->palette[4 * info->palettesize + 1] = g; + info->palette[4 * info->palettesize + 2] = b; + info->palette[4 * info->palettesize + 3] = a; + ++info->palettesize; + return 0; +} + +unsigned lodepng_get_bpp(const LodePNGColorMode* info) +{ + /*calculate bits per pixel out of colortype and bitdepth*/ + return lodepng_get_bpp_lct(info->colortype, info->bitdepth); +} + +unsigned lodepng_get_channels(const LodePNGColorMode* info) +{ + return getNumColorChannels(info->colortype); +} + +unsigned lodepng_is_greyscale_type(const LodePNGColorMode* info) +{ + return info->colortype == LCT_GREY || info->colortype == LCT_GREY_ALPHA; +} + +unsigned lodepng_is_alpha_type(const LodePNGColorMode* info) +{ + return (info->colortype & 4) != 0; /*4 or 6*/ +} + +unsigned lodepng_is_palette_type(const LodePNGColorMode* info) +{ + return info->colortype == LCT_PALETTE; +} + +unsigned lodepng_has_palette_alpha(const LodePNGColorMode* info) +{ + size_t i; + for(i = 0; i != info->palettesize; ++i) + { + if(info->palette[i * 4 + 3] < 255) return 1; + } + return 0; +} + +unsigned lodepng_can_have_alpha(const LodePNGColorMode* info) +{ + return info->key_defined + || lodepng_is_alpha_type(info) + || lodepng_has_palette_alpha(info); +} + +size_t lodepng_get_raw_size(unsigned w, unsigned h, const LodePNGColorMode* color) +{ + /*will not overflow for any color type if roughly w * h < 268435455*/ + size_t bpp = lodepng_get_bpp(color); + size_t n = w * h; + return ((n / 8) * bpp) + ((n & 7) * bpp + 7) / 8; +} + +size_t lodepng_get_raw_size_lct(unsigned w, unsigned h, LodePNGColorType colortype, unsigned bitdepth) +{ + /*will not overflow for any color type if roughly w * h < 268435455*/ + size_t bpp = lodepng_get_bpp_lct(colortype, bitdepth); + size_t n = w * h; + return ((n / 8) * bpp) + ((n & 7) * bpp + 7) / 8; +} + + +#ifdef LODEPNG_COMPILE_PNG +#ifdef LODEPNG_COMPILE_DECODER +/*in an idat chunk, each scanline is a multiple of 8 bits, unlike the lodepng output buffer*/ +static size_t lodepng_get_raw_size_idat(unsigned w, unsigned h, const LodePNGColorMode* color) +{ + /*will not overflow for any color type if roughly w * h < 268435455*/ + size_t bpp = lodepng_get_bpp(color); + size_t line = ((w / 8) * bpp) + ((w & 7) * bpp + 7) / 8; + return h * line; +} +#endif /*LODEPNG_COMPILE_DECODER*/ +#endif /*LODEPNG_COMPILE_PNG*/ + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + +static void LodePNGUnknownChunks_init(LodePNGInfo* info) +{ + unsigned i; + for(i = 0; i != 3; ++i) info->unknown_chunks_data[i] = 0; + for(i = 0; i != 3; ++i) info->unknown_chunks_size[i] = 0; +} + +static void LodePNGUnknownChunks_cleanup(LodePNGInfo* info) +{ + unsigned i; + for(i = 0; i != 3; ++i) lodepng_free(info->unknown_chunks_data[i]); +} + +static unsigned LodePNGUnknownChunks_copy(LodePNGInfo* dest, const LodePNGInfo* src) +{ + unsigned i; + + LodePNGUnknownChunks_cleanup(dest); + + for(i = 0; i != 3; ++i) + { + size_t j; + dest->unknown_chunks_size[i] = src->unknown_chunks_size[i]; + dest->unknown_chunks_data[i] = (unsigned char*)lodepng_malloc(src->unknown_chunks_size[i]); + if(!dest->unknown_chunks_data[i] && dest->unknown_chunks_size[i]) return 83; /*alloc fail*/ + for(j = 0; j < src->unknown_chunks_size[i]; ++j) + { + dest->unknown_chunks_data[i][j] = src->unknown_chunks_data[i][j]; + } + } + + return 0; +} + +/******************************************************************************/ + +static void LodePNGText_init(LodePNGInfo* info) +{ + info->text_num = 0; + info->text_keys = NULL; + info->text_strings = NULL; +} + +static void LodePNGText_cleanup(LodePNGInfo* info) +{ + size_t i; + for(i = 0; i != info->text_num; ++i) + { + string_cleanup(&info->text_keys[i]); + string_cleanup(&info->text_strings[i]); + } + lodepng_free(info->text_keys); + lodepng_free(info->text_strings); +} + +static unsigned LodePNGText_copy(LodePNGInfo* dest, const LodePNGInfo* source) +{ + size_t i = 0; + dest->text_keys = 0; + dest->text_strings = 0; + dest->text_num = 0; + for(i = 0; i != source->text_num; ++i) + { + CERROR_TRY_RETURN(lodepng_add_text(dest, source->text_keys[i], source->text_strings[i])); + } + return 0; +} + +void lodepng_clear_text(LodePNGInfo* info) +{ + LodePNGText_cleanup(info); +} + +unsigned lodepng_add_text(LodePNGInfo* info, const char* key, const char* str) +{ + char** new_keys = (char**)(lodepng_realloc(info->text_keys, sizeof(char*) * (info->text_num + 1))); + char** new_strings = (char**)(lodepng_realloc(info->text_strings, sizeof(char*) * (info->text_num + 1))); + if(!new_keys || !new_strings) + { + lodepng_free(new_keys); + lodepng_free(new_strings); + return 83; /*alloc fail*/ + } + + ++info->text_num; + info->text_keys = new_keys; + info->text_strings = new_strings; + + string_init(&info->text_keys[info->text_num - 1]); + string_set(&info->text_keys[info->text_num - 1], key); + + string_init(&info->text_strings[info->text_num - 1]); + string_set(&info->text_strings[info->text_num - 1], str); + + return 0; +} + +/******************************************************************************/ + +static void LodePNGIText_init(LodePNGInfo* info) +{ + info->itext_num = 0; + info->itext_keys = NULL; + info->itext_langtags = NULL; + info->itext_transkeys = NULL; + info->itext_strings = NULL; +} + +static void LodePNGIText_cleanup(LodePNGInfo* info) +{ + size_t i; + for(i = 0; i != info->itext_num; ++i) + { + string_cleanup(&info->itext_keys[i]); + string_cleanup(&info->itext_langtags[i]); + string_cleanup(&info->itext_transkeys[i]); + string_cleanup(&info->itext_strings[i]); + } + lodepng_free(info->itext_keys); + lodepng_free(info->itext_langtags); + lodepng_free(info->itext_transkeys); + lodepng_free(info->itext_strings); +} + +static unsigned LodePNGIText_copy(LodePNGInfo* dest, const LodePNGInfo* source) +{ + size_t i = 0; + dest->itext_keys = 0; + dest->itext_langtags = 0; + dest->itext_transkeys = 0; + dest->itext_strings = 0; + dest->itext_num = 0; + for(i = 0; i != source->itext_num; ++i) + { + CERROR_TRY_RETURN(lodepng_add_itext(dest, source->itext_keys[i], source->itext_langtags[i], + source->itext_transkeys[i], source->itext_strings[i])); + } + return 0; +} + +void lodepng_clear_itext(LodePNGInfo* info) +{ + LodePNGIText_cleanup(info); +} + +unsigned lodepng_add_itext(LodePNGInfo* info, const char* key, const char* langtag, + const char* transkey, const char* str) +{ + char** new_keys = (char**)(lodepng_realloc(info->itext_keys, sizeof(char*) * (info->itext_num + 1))); + char** new_langtags = (char**)(lodepng_realloc(info->itext_langtags, sizeof(char*) * (info->itext_num + 1))); + char** new_transkeys = (char**)(lodepng_realloc(info->itext_transkeys, sizeof(char*) * (info->itext_num + 1))); + char** new_strings = (char**)(lodepng_realloc(info->itext_strings, sizeof(char*) * (info->itext_num + 1))); + if(!new_keys || !new_langtags || !new_transkeys || !new_strings) + { + lodepng_free(new_keys); + lodepng_free(new_langtags); + lodepng_free(new_transkeys); + lodepng_free(new_strings); + return 83; /*alloc fail*/ + } + + ++info->itext_num; + info->itext_keys = new_keys; + info->itext_langtags = new_langtags; + info->itext_transkeys = new_transkeys; + info->itext_strings = new_strings; + + string_init(&info->itext_keys[info->itext_num - 1]); + string_set(&info->itext_keys[info->itext_num - 1], key); + + string_init(&info->itext_langtags[info->itext_num - 1]); + string_set(&info->itext_langtags[info->itext_num - 1], langtag); + + string_init(&info->itext_transkeys[info->itext_num - 1]); + string_set(&info->itext_transkeys[info->itext_num - 1], transkey); + + string_init(&info->itext_strings[info->itext_num - 1]); + string_set(&info->itext_strings[info->itext_num - 1], str); + + return 0; +} +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + +void lodepng_info_init(LodePNGInfo* info) +{ + lodepng_color_mode_init(&info->color); + info->interlace_method = 0; + info->compression_method = 0; + info->filter_method = 0; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + info->background_defined = 0; + info->background_r = info->background_g = info->background_b = 0; + + LodePNGText_init(info); + LodePNGIText_init(info); + + info->time_defined = 0; + info->phys_defined = 0; + + LodePNGUnknownChunks_init(info); +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ +} + +void lodepng_info_cleanup(LodePNGInfo* info) +{ + lodepng_color_mode_cleanup(&info->color); +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + LodePNGText_cleanup(info); + LodePNGIText_cleanup(info); + + LodePNGUnknownChunks_cleanup(info); +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ +} + +unsigned lodepng_info_copy(LodePNGInfo* dest, const LodePNGInfo* source) +{ + lodepng_info_cleanup(dest); + *dest = *source; + lodepng_color_mode_init(&dest->color); + CERROR_TRY_RETURN(lodepng_color_mode_copy(&dest->color, &source->color)); + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + CERROR_TRY_RETURN(LodePNGText_copy(dest, source)); + CERROR_TRY_RETURN(LodePNGIText_copy(dest, source)); + + LodePNGUnknownChunks_init(dest); + CERROR_TRY_RETURN(LodePNGUnknownChunks_copy(dest, source)); +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + return 0; +} + +void lodepng_info_swap(LodePNGInfo* a, LodePNGInfo* b) +{ + LodePNGInfo temp = *a; + *a = *b; + *b = temp; +} + +/* ////////////////////////////////////////////////////////////////////////// */ + +/*index: bitgroup index, bits: bitgroup size(1, 2 or 4), in: bitgroup value, out: octet array to add bits to*/ +static void addColorBits(unsigned char* out, size_t index, unsigned bits, unsigned in) +{ + unsigned m = bits == 1 ? 7 : bits == 2 ? 3 : 1; /*8 / bits - 1*/ + /*p = the partial index in the byte, e.g. with 4 palettebits it is 0 for first half or 1 for second half*/ + unsigned p = index & m; + in &= (1u << bits) - 1u; /*filter out any other bits of the input value*/ + in = in << (bits * (m - p)); + if(p == 0) out[index * bits / 8] = in; + else out[index * bits / 8] |= in; +} + +typedef struct ColorTree ColorTree; + +/* +One node of a color tree +This is the data structure used to count the number of unique colors and to get a palette +index for a color. It's like an octree, but because the alpha channel is used too, each +node has 16 instead of 8 children. +*/ +struct ColorTree +{ + ColorTree* children[16]; /*up to 16 pointers to ColorTree of next level*/ + int index; /*the payload. Only has a meaningful value if this is in the last level*/ +}; + +static void color_tree_init(ColorTree* tree) +{ + int i; + for(i = 0; i != 16; ++i) tree->children[i] = 0; + tree->index = -1; +} + +static void color_tree_cleanup(ColorTree* tree) +{ + int i; + for(i = 0; i != 16; ++i) + { + if(tree->children[i]) + { + color_tree_cleanup(tree->children[i]); + lodepng_free(tree->children[i]); + } + } +} + +/*returns -1 if color not present, its index otherwise*/ +static int color_tree_get(ColorTree* tree, unsigned char r, unsigned char g, unsigned char b, unsigned char a) +{ + int bit = 0; + for(bit = 0; bit < 8; ++bit) + { + int i = 8 * ((r >> bit) & 1) + 4 * ((g >> bit) & 1) + 2 * ((b >> bit) & 1) + 1 * ((a >> bit) & 1); + if(!tree->children[i]) return -1; + else tree = tree->children[i]; + } + return tree ? tree->index : -1; +} + +#ifdef LODEPNG_COMPILE_ENCODER +static int color_tree_has(ColorTree* tree, unsigned char r, unsigned char g, unsigned char b, unsigned char a) +{ + return color_tree_get(tree, r, g, b, a) >= 0; +} +#endif /*LODEPNG_COMPILE_ENCODER*/ + +/*color is not allowed to already exist. +Index should be >= 0 (it's signed to be compatible with using -1 for "doesn't exist")*/ +static void color_tree_add(ColorTree* tree, + unsigned char r, unsigned char g, unsigned char b, unsigned char a, unsigned index) +{ + int bit; + for(bit = 0; bit < 8; ++bit) + { + int i = 8 * ((r >> bit) & 1) + 4 * ((g >> bit) & 1) + 2 * ((b >> bit) & 1) + 1 * ((a >> bit) & 1); + if(!tree->children[i]) + { + tree->children[i] = (ColorTree*)lodepng_malloc(sizeof(ColorTree)); + color_tree_init(tree->children[i]); + } + tree = tree->children[i]; + } + tree->index = (int)index; +} + +/*put a pixel, given its RGBA color, into image of any color type*/ +static unsigned rgba8ToPixel(unsigned char* out, size_t i, + const LodePNGColorMode* mode, ColorTree* tree /*for palette*/, + unsigned char r, unsigned char g, unsigned char b, unsigned char a) +{ + if(mode->colortype == LCT_GREY) + { + unsigned char grey = r; /*((unsigned short)r + g + b) / 3*/; + if(mode->bitdepth == 8) out[i] = grey; + else if(mode->bitdepth == 16) out[i * 2 + 0] = out[i * 2 + 1] = grey; + else + { + /*take the most significant bits of grey*/ + grey = (grey >> (8 - mode->bitdepth)) & ((1 << mode->bitdepth) - 1); + addColorBits(out, i, mode->bitdepth, grey); + } + } + else if(mode->colortype == LCT_RGB) + { + if(mode->bitdepth == 8) + { + out[i * 3 + 0] = r; + out[i * 3 + 1] = g; + out[i * 3 + 2] = b; + } + else + { + out[i * 6 + 0] = out[i * 6 + 1] = r; + out[i * 6 + 2] = out[i * 6 + 3] = g; + out[i * 6 + 4] = out[i * 6 + 5] = b; + } + } + else if(mode->colortype == LCT_PALETTE) + { + int index = color_tree_get(tree, r, g, b, a); + if(index < 0) return 82; /*color not in palette*/ + if(mode->bitdepth == 8) out[i] = index; + else addColorBits(out, i, mode->bitdepth, (unsigned)index); + } + else if(mode->colortype == LCT_GREY_ALPHA) + { + unsigned char grey = r; /*((unsigned short)r + g + b) / 3*/; + if(mode->bitdepth == 8) + { + out[i * 2 + 0] = grey; + out[i * 2 + 1] = a; + } + else if(mode->bitdepth == 16) + { + out[i * 4 + 0] = out[i * 4 + 1] = grey; + out[i * 4 + 2] = out[i * 4 + 3] = a; + } + } + else if(mode->colortype == LCT_RGBA) + { + if(mode->bitdepth == 8) + { + out[i * 4 + 0] = r; + out[i * 4 + 1] = g; + out[i * 4 + 2] = b; + out[i * 4 + 3] = a; + } + else + { + out[i * 8 + 0] = out[i * 8 + 1] = r; + out[i * 8 + 2] = out[i * 8 + 3] = g; + out[i * 8 + 4] = out[i * 8 + 5] = b; + out[i * 8 + 6] = out[i * 8 + 7] = a; + } + } + + return 0; /*no error*/ +} + +/*put a pixel, given its RGBA16 color, into image of any color 16-bitdepth type*/ +static void rgba16ToPixel(unsigned char* out, size_t i, + const LodePNGColorMode* mode, + unsigned short r, unsigned short g, unsigned short b, unsigned short a) +{ + if(mode->colortype == LCT_GREY) + { + unsigned short grey = r; /*((unsigned)r + g + b) / 3*/; + out[i * 2 + 0] = (grey >> 8) & 255; + out[i * 2 + 1] = grey & 255; + } + else if(mode->colortype == LCT_RGB) + { + out[i * 6 + 0] = (r >> 8) & 255; + out[i * 6 + 1] = r & 255; + out[i * 6 + 2] = (g >> 8) & 255; + out[i * 6 + 3] = g & 255; + out[i * 6 + 4] = (b >> 8) & 255; + out[i * 6 + 5] = b & 255; + } + else if(mode->colortype == LCT_GREY_ALPHA) + { + unsigned short grey = r; /*((unsigned)r + g + b) / 3*/; + out[i * 4 + 0] = (grey >> 8) & 255; + out[i * 4 + 1] = grey & 255; + out[i * 4 + 2] = (a >> 8) & 255; + out[i * 4 + 3] = a & 255; + } + else if(mode->colortype == LCT_RGBA) + { + out[i * 8 + 0] = (r >> 8) & 255; + out[i * 8 + 1] = r & 255; + out[i * 8 + 2] = (g >> 8) & 255; + out[i * 8 + 3] = g & 255; + out[i * 8 + 4] = (b >> 8) & 255; + out[i * 8 + 5] = b & 255; + out[i * 8 + 6] = (a >> 8) & 255; + out[i * 8 + 7] = a & 255; + } +} + +/*Get RGBA8 color of pixel with index i (y * width + x) from the raw image with given color type.*/ +static void getPixelColorRGBA8(unsigned char* r, unsigned char* g, + unsigned char* b, unsigned char* a, + const unsigned char* in, size_t i, + const LodePNGColorMode* mode) +{ + if(mode->colortype == LCT_GREY) + { + if(mode->bitdepth == 8) + { + *r = *g = *b = in[i]; + if(mode->key_defined && *r == mode->key_r) *a = 0; + else *a = 255; + } + else if(mode->bitdepth == 16) + { + *r = *g = *b = in[i * 2 + 0]; + if(mode->key_defined && 256U * in[i * 2 + 0] + in[i * 2 + 1] == mode->key_r) *a = 0; + else *a = 255; + } + else + { + unsigned highest = ((1U << mode->bitdepth) - 1U); /*highest possible value for this bit depth*/ + size_t j = i * mode->bitdepth; + unsigned value = readBitsFromReversedStream(&j, in, mode->bitdepth); + *r = *g = *b = (value * 255) / highest; + if(mode->key_defined && value == mode->key_r) *a = 0; + else *a = 255; + } + } + else if(mode->colortype == LCT_RGB) + { + if(mode->bitdepth == 8) + { + *r = in[i * 3 + 0]; *g = in[i * 3 + 1]; *b = in[i * 3 + 2]; + if(mode->key_defined && *r == mode->key_r && *g == mode->key_g && *b == mode->key_b) *a = 0; + else *a = 255; + } + else + { + *r = in[i * 6 + 0]; + *g = in[i * 6 + 2]; + *b = in[i * 6 + 4]; + if(mode->key_defined && 256U * in[i * 6 + 0] + in[i * 6 + 1] == mode->key_r + && 256U * in[i * 6 + 2] + in[i * 6 + 3] == mode->key_g + && 256U * in[i * 6 + 4] + in[i * 6 + 5] == mode->key_b) *a = 0; + else *a = 255; + } + } + else if(mode->colortype == LCT_PALETTE) + { + unsigned index; + if(mode->bitdepth == 8) index = in[i]; + else + { + size_t j = i * mode->bitdepth; + index = readBitsFromReversedStream(&j, in, mode->bitdepth); + } + + if(index >= mode->palettesize) + { + /*This is an error according to the PNG spec, but common PNG decoders make it black instead. + Done here too, slightly faster due to no error handling needed.*/ + *r = *g = *b = 0; + *a = 255; + } + else + { + *r = mode->palette[index * 4 + 0]; + *g = mode->palette[index * 4 + 1]; + *b = mode->palette[index * 4 + 2]; + *a = mode->palette[index * 4 + 3]; + } + } + else if(mode->colortype == LCT_GREY_ALPHA) + { + if(mode->bitdepth == 8) + { + *r = *g = *b = in[i * 2 + 0]; + *a = in[i * 2 + 1]; + } + else + { + *r = *g = *b = in[i * 4 + 0]; + *a = in[i * 4 + 2]; + } + } + else if(mode->colortype == LCT_RGBA) + { + if(mode->bitdepth == 8) + { + *r = in[i * 4 + 0]; + *g = in[i * 4 + 1]; + *b = in[i * 4 + 2]; + *a = in[i * 4 + 3]; + } + else + { + *r = in[i * 8 + 0]; + *g = in[i * 8 + 2]; + *b = in[i * 8 + 4]; + *a = in[i * 8 + 6]; + } + } +} + +/*Similar to getPixelColorRGBA8, but with all the for loops inside of the color +mode test cases, optimized to convert the colors much faster, when converting +to RGBA or RGB with 8 bit per cannel. buffer must be RGBA or RGB output with +enough memory, if has_alpha is true the output is RGBA. mode has the color mode +of the input buffer.*/ +static void getPixelColorsRGBA8(unsigned char* buffer, size_t numpixels, + unsigned has_alpha, const unsigned char* in, + const LodePNGColorMode* mode) +{ + unsigned num_channels = has_alpha ? 4 : 3; + size_t i; + if(mode->colortype == LCT_GREY) + { + if(mode->bitdepth == 8) + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = buffer[1] = buffer[2] = in[i]; + if(has_alpha) buffer[3] = mode->key_defined && in[i] == mode->key_r ? 0 : 255; + } + } + else if(mode->bitdepth == 16) + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = buffer[1] = buffer[2] = in[i * 2]; + if(has_alpha) buffer[3] = mode->key_defined && 256U * in[i * 2 + 0] + in[i * 2 + 1] == mode->key_r ? 0 : 255; + } + } + else + { + unsigned highest = ((1U << mode->bitdepth) - 1U); /*highest possible value for this bit depth*/ + size_t j = 0; + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + unsigned value = readBitsFromReversedStream(&j, in, mode->bitdepth); + buffer[0] = buffer[1] = buffer[2] = (value * 255) / highest; + if(has_alpha) buffer[3] = mode->key_defined && value == mode->key_r ? 0 : 255; + } + } + } + else if(mode->colortype == LCT_RGB) + { + if(mode->bitdepth == 8) + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = in[i * 3 + 0]; + buffer[1] = in[i * 3 + 1]; + buffer[2] = in[i * 3 + 2]; + if(has_alpha) buffer[3] = mode->key_defined && buffer[0] == mode->key_r + && buffer[1]== mode->key_g && buffer[2] == mode->key_b ? 0 : 255; + } + } + else + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = in[i * 6 + 0]; + buffer[1] = in[i * 6 + 2]; + buffer[2] = in[i * 6 + 4]; + if(has_alpha) buffer[3] = mode->key_defined + && 256U * in[i * 6 + 0] + in[i * 6 + 1] == mode->key_r + && 256U * in[i * 6 + 2] + in[i * 6 + 3] == mode->key_g + && 256U * in[i * 6 + 4] + in[i * 6 + 5] == mode->key_b ? 0 : 255; + } + } + } + else if(mode->colortype == LCT_PALETTE) + { + unsigned index; + size_t j = 0; + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + if(mode->bitdepth == 8) index = in[i]; + else index = readBitsFromReversedStream(&j, in, mode->bitdepth); + + if(index >= mode->palettesize) + { + /*This is an error according to the PNG spec, but most PNG decoders make it black instead. + Done here too, slightly faster due to no error handling needed.*/ + buffer[0] = buffer[1] = buffer[2] = 0; + if(has_alpha) buffer[3] = 255; + } + else + { + buffer[0] = mode->palette[index * 4 + 0]; + buffer[1] = mode->palette[index * 4 + 1]; + buffer[2] = mode->palette[index * 4 + 2]; + if(has_alpha) buffer[3] = mode->palette[index * 4 + 3]; + } + } + } + else if(mode->colortype == LCT_GREY_ALPHA) + { + if(mode->bitdepth == 8) + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = buffer[1] = buffer[2] = in[i * 2 + 0]; + if(has_alpha) buffer[3] = in[i * 2 + 1]; + } + } + else + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = buffer[1] = buffer[2] = in[i * 4 + 0]; + if(has_alpha) buffer[3] = in[i * 4 + 2]; + } + } + } + else if(mode->colortype == LCT_RGBA) + { + if(mode->bitdepth == 8) + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = in[i * 4 + 0]; + buffer[1] = in[i * 4 + 1]; + buffer[2] = in[i * 4 + 2]; + if(has_alpha) buffer[3] = in[i * 4 + 3]; + } + } + else + { + for(i = 0; i != numpixels; ++i, buffer += num_channels) + { + buffer[0] = in[i * 8 + 0]; + buffer[1] = in[i * 8 + 2]; + buffer[2] = in[i * 8 + 4]; + if(has_alpha) buffer[3] = in[i * 8 + 6]; + } + } + } +} + +/*Get RGBA16 color of pixel with index i (y * width + x) from the raw image with +given color type, but the given color type must be 16-bit itself.*/ +static void getPixelColorRGBA16(unsigned short* r, unsigned short* g, unsigned short* b, unsigned short* a, + const unsigned char* in, size_t i, const LodePNGColorMode* mode) +{ + if(mode->colortype == LCT_GREY) + { + *r = *g = *b = 256 * in[i * 2 + 0] + in[i * 2 + 1]; + if(mode->key_defined && 256U * in[i * 2 + 0] + in[i * 2 + 1] == mode->key_r) *a = 0; + else *a = 65535; + } + else if(mode->colortype == LCT_RGB) + { + *r = 256u * in[i * 6 + 0] + in[i * 6 + 1]; + *g = 256u * in[i * 6 + 2] + in[i * 6 + 3]; + *b = 256u * in[i * 6 + 4] + in[i * 6 + 5]; + if(mode->key_defined + && 256u * in[i * 6 + 0] + in[i * 6 + 1] == mode->key_r + && 256u * in[i * 6 + 2] + in[i * 6 + 3] == mode->key_g + && 256u * in[i * 6 + 4] + in[i * 6 + 5] == mode->key_b) *a = 0; + else *a = 65535; + } + else if(mode->colortype == LCT_GREY_ALPHA) + { + *r = *g = *b = 256u * in[i * 4 + 0] + in[i * 4 + 1]; + *a = 256u * in[i * 4 + 2] + in[i * 4 + 3]; + } + else if(mode->colortype == LCT_RGBA) + { + *r = 256u * in[i * 8 + 0] + in[i * 8 + 1]; + *g = 256u * in[i * 8 + 2] + in[i * 8 + 3]; + *b = 256u * in[i * 8 + 4] + in[i * 8 + 5]; + *a = 256u * in[i * 8 + 6] + in[i * 8 + 7]; + } +} + +unsigned lodepng_convert(unsigned char* out, const unsigned char* in, + const LodePNGColorMode* mode_out, const LodePNGColorMode* mode_in, + unsigned w, unsigned h) +{ + size_t i; + ColorTree tree; + size_t numpixels = w * h; + unsigned error = 0; + + if(lodepng_color_mode_equal(mode_out, mode_in)) + { + size_t numbytes = lodepng_get_raw_size(w, h, mode_in); + for(i = 0; i != numbytes; ++i) out[i] = in[i]; + return 0; + } + + if(mode_out->colortype == LCT_PALETTE) + { + size_t palettesize = mode_out->palettesize; + const unsigned char* palette = mode_out->palette; + size_t palsize = 1u << mode_out->bitdepth; + /*if the user specified output palette but did not give the values, assume + they want the values of the input color type (assuming that one is palette). + Note that we never create a new palette ourselves.*/ + if(palettesize == 0) + { + palettesize = mode_in->palettesize; + palette = mode_in->palette; + } + if(palettesize < palsize) palsize = palettesize; + color_tree_init(&tree); + for(i = 0; i != palsize; ++i) + { + const unsigned char* p = &palette[i * 4]; + color_tree_add(&tree, p[0], p[1], p[2], p[3], i); + } + } + + if(mode_in->bitdepth == 16 && mode_out->bitdepth == 16) + { + for(i = 0; i != numpixels; ++i) + { + unsigned short r = 0, g = 0, b = 0, a = 0; + getPixelColorRGBA16(&r, &g, &b, &a, in, i, mode_in); + rgba16ToPixel(out, i, mode_out, r, g, b, a); + } + } + else if(mode_out->bitdepth == 8 && mode_out->colortype == LCT_RGBA) + { + getPixelColorsRGBA8(out, numpixels, 1, in, mode_in); + } + else if(mode_out->bitdepth == 8 && mode_out->colortype == LCT_RGB) + { + getPixelColorsRGBA8(out, numpixels, 0, in, mode_in); + } + else + { + unsigned char r = 0, g = 0, b = 0, a = 0; + for(i = 0; i != numpixels; ++i) + { + getPixelColorRGBA8(&r, &g, &b, &a, in, i, mode_in); + error = rgba8ToPixel(out, i, mode_out, &tree, r, g, b, a); + if (error) break; + } + } + + if(mode_out->colortype == LCT_PALETTE) + { + color_tree_cleanup(&tree); + } + + return error; +} + +#ifdef LODEPNG_COMPILE_ENCODER + +void lodepng_color_profile_init(LodePNGColorProfile* profile) +{ + profile->colored = 0; + profile->key = 0; + profile->key_r = profile->key_g = profile->key_b = 0; + profile->alpha = 0; + profile->numcolors = 0; + profile->bits = 1; +} + +/*function used for debug purposes with C++*/ +/*void printColorProfile(LodePNGColorProfile* p) +{ + std::cout << "colored: " << (int)p->colored << ", "; + std::cout << "key: " << (int)p->key << ", "; + std::cout << "key_r: " << (int)p->key_r << ", "; + std::cout << "key_g: " << (int)p->key_g << ", "; + std::cout << "key_b: " << (int)p->key_b << ", "; + std::cout << "alpha: " << (int)p->alpha << ", "; + std::cout << "numcolors: " << (int)p->numcolors << ", "; + std::cout << "bits: " << (int)p->bits << std::endl; +}*/ + +/*Returns how many bits needed to represent given value (max 8 bit)*/ +static unsigned getValueRequiredBits(unsigned char value) +{ + if(value == 0 || value == 255) return 1; + /*The scaling of 2-bit and 4-bit values uses multiples of 85 and 17*/ + if(value % 17 == 0) return value % 85 == 0 ? 2 : 4; + return 8; +} + +/*profile must already have been inited with mode. +It's ok to set some parameters of profile to done already.*/ +unsigned lodepng_get_color_profile(LodePNGColorProfile* profile, + const unsigned char* in, unsigned w, unsigned h, + const LodePNGColorMode* mode) +{ + unsigned error = 0; + size_t i; + ColorTree tree; + size_t numpixels = w * h; + + unsigned colored_done = lodepng_is_greyscale_type(mode) ? 1 : 0; + unsigned alpha_done = lodepng_can_have_alpha(mode) ? 0 : 1; + unsigned numcolors_done = 0; + unsigned bpp = lodepng_get_bpp(mode); + unsigned bits_done = bpp == 1 ? 1 : 0; + unsigned maxnumcolors = 257; + unsigned sixteen = 0; + if(bpp <= 8) maxnumcolors = bpp == 1 ? 2 : (bpp == 2 ? 4 : (bpp == 4 ? 16 : 256)); + + color_tree_init(&tree); + + /*Check if the 16-bit input is truly 16-bit*/ + if(mode->bitdepth == 16) + { + unsigned short r, g, b, a; + for(i = 0; i != numpixels; ++i) + { + getPixelColorRGBA16(&r, &g, &b, &a, in, i, mode); + if((r & 255) != ((r >> 8) & 255) || (g & 255) != ((g >> 8) & 255) || + (b & 255) != ((b >> 8) & 255) || (a & 255) != ((a >> 8) & 255)) /*first and second byte differ*/ + { + sixteen = 1; + break; + } + } + } + + if(sixteen) + { + unsigned short r = 0, g = 0, b = 0, a = 0; + profile->bits = 16; + bits_done = numcolors_done = 1; /*counting colors no longer useful, palette doesn't support 16-bit*/ + + for(i = 0; i != numpixels; ++i) + { + getPixelColorRGBA16(&r, &g, &b, &a, in, i, mode); + + if(!colored_done && (r != g || r != b)) + { + profile->colored = 1; + colored_done = 1; + } + + if(!alpha_done) + { + unsigned matchkey = (r == profile->key_r && g == profile->key_g && b == profile->key_b); + if(a != 65535 && (a != 0 || (profile->key && !matchkey))) + { + profile->alpha = 1; + profile->key = 0; + alpha_done = 1; + } + else if(a == 0 && !profile->alpha && !profile->key) + { + profile->key = 1; + profile->key_r = r; + profile->key_g = g; + profile->key_b = b; + } + else if(a == 65535 && profile->key && matchkey) + { + /* Color key cannot be used if an opaque pixel also has that RGB color. */ + profile->alpha = 1; + profile->key = 0; + alpha_done = 1; + } + } + if(alpha_done && numcolors_done && colored_done && bits_done) break; + } + + if(profile->key && !profile->alpha) + { + for(i = 0; i != numpixels; ++i) + { + getPixelColorRGBA16(&r, &g, &b, &a, in, i, mode); + if(a != 0 && r == profile->key_r && g == profile->key_g && b == profile->key_b) + { + /* Color key cannot be used if an opaque pixel also has that RGB color. */ + profile->alpha = 1; + profile->key = 0; + alpha_done = 1; + } + } + } + } + else /* < 16-bit */ + { + unsigned char r = 0, g = 0, b = 0, a = 0; + for(i = 0; i != numpixels; ++i) + { + getPixelColorRGBA8(&r, &g, &b, &a, in, i, mode); + + if(!bits_done && profile->bits < 8) + { + /*only r is checked, < 8 bits is only relevant for greyscale*/ + unsigned bits = getValueRequiredBits(r); + if(bits > profile->bits) profile->bits = bits; + } + bits_done = (profile->bits >= bpp); + + if(!colored_done && (r != g || r != b)) + { + profile->colored = 1; + colored_done = 1; + if(profile->bits < 8) profile->bits = 8; /*PNG has no colored modes with less than 8-bit per channel*/ + } + + if(!alpha_done) + { + unsigned matchkey = (r == profile->key_r && g == profile->key_g && b == profile->key_b); + if(a != 255 && (a != 0 || (profile->key && !matchkey))) + { + profile->alpha = 1; + profile->key = 0; + alpha_done = 1; + if(profile->bits < 8) profile->bits = 8; /*PNG has no alphachannel modes with less than 8-bit per channel*/ + } + else if(a == 0 && !profile->alpha && !profile->key) + { + profile->key = 1; + profile->key_r = r; + profile->key_g = g; + profile->key_b = b; + } + else if(a == 255 && profile->key && matchkey) + { + /* Color key cannot be used if an opaque pixel also has that RGB color. */ + profile->alpha = 1; + profile->key = 0; + alpha_done = 1; + if(profile->bits < 8) profile->bits = 8; /*PNG has no alphachannel modes with less than 8-bit per channel*/ + } + } + + if(!numcolors_done) + { + if(!color_tree_has(&tree, r, g, b, a)) + { + color_tree_add(&tree, r, g, b, a, profile->numcolors); + if(profile->numcolors < 256) + { + unsigned char* p = profile->palette; + unsigned n = profile->numcolors; + p[n * 4 + 0] = r; + p[n * 4 + 1] = g; + p[n * 4 + 2] = b; + p[n * 4 + 3] = a; + } + ++profile->numcolors; + numcolors_done = profile->numcolors >= maxnumcolors; + } + } + + if(alpha_done && numcolors_done && colored_done && bits_done) break; + } + + if(profile->key && !profile->alpha) + { + for(i = 0; i != numpixels; ++i) + { + getPixelColorRGBA8(&r, &g, &b, &a, in, i, mode); + if(a != 0 && r == profile->key_r && g == profile->key_g && b == profile->key_b) + { + /* Color key cannot be used if an opaque pixel also has that RGB color. */ + profile->alpha = 1; + profile->key = 0; + alpha_done = 1; + if(profile->bits < 8) profile->bits = 8; /*PNG has no alphachannel modes with less than 8-bit per channel*/ + } + } + } + + /*make the profile's key always 16-bit for consistency - repeat each byte twice*/ + profile->key_r += (profile->key_r << 8); + profile->key_g += (profile->key_g << 8); + profile->key_b += (profile->key_b << 8); + } + + color_tree_cleanup(&tree); + return error; +} + +/*Automatically chooses color type that gives smallest amount of bits in the +output image, e.g. grey if there are only greyscale pixels, palette if there +are less than 256 colors, ... +Updates values of mode with a potentially smaller color model. mode_out should +contain the user chosen color model, but will be overwritten with the new chosen one.*/ +unsigned lodepng_auto_choose_color(LodePNGColorMode* mode_out, + const unsigned char* image, unsigned w, unsigned h, + const LodePNGColorMode* mode_in) +{ + LodePNGColorProfile prof; + unsigned error = 0; + unsigned i, n, palettebits, palette_ok; + + lodepng_color_profile_init(&prof); + error = lodepng_get_color_profile(&prof, image, w, h, mode_in); + if(error) return error; + mode_out->key_defined = 0; + + if(prof.key && w * h <= 16) + { + prof.alpha = 1; /*too few pixels to justify tRNS chunk overhead*/ + prof.key = 0; + if(prof.bits < 8) prof.bits = 8; /*PNG has no alphachannel modes with less than 8-bit per channel*/ + } + n = prof.numcolors; + palettebits = n <= 2 ? 1 : (n <= 4 ? 2 : (n <= 16 ? 4 : 8)); + palette_ok = n <= 256 && prof.bits <= 8; + if(w * h < n * 2) palette_ok = 0; /*don't add palette overhead if image has only a few pixels*/ + if(!prof.colored && prof.bits <= palettebits) palette_ok = 0; /*grey is less overhead*/ + + if(palette_ok) + { + unsigned char* p = prof.palette; + lodepng_palette_clear(mode_out); /*remove potential earlier palette*/ + for(i = 0; i != prof.numcolors; ++i) + { + error = lodepng_palette_add(mode_out, p[i * 4 + 0], p[i * 4 + 1], p[i * 4 + 2], p[i * 4 + 3]); + if(error) break; + } + + mode_out->colortype = LCT_PALETTE; + mode_out->bitdepth = palettebits; + + if(mode_in->colortype == LCT_PALETTE && mode_in->palettesize >= mode_out->palettesize + && mode_in->bitdepth == mode_out->bitdepth) + { + /*If input should have same palette colors, keep original to preserve its order and prevent conversion*/ + lodepng_color_mode_cleanup(mode_out); + lodepng_color_mode_copy(mode_out, mode_in); + } + } + else /*8-bit or 16-bit per channel*/ + { + mode_out->bitdepth = prof.bits; + mode_out->colortype = prof.alpha ? (prof.colored ? LCT_RGBA : LCT_GREY_ALPHA) + : (prof.colored ? LCT_RGB : LCT_GREY); + + if(prof.key) + { + unsigned mask = (1u << mode_out->bitdepth) - 1u; /*profile always uses 16-bit, mask converts it*/ + mode_out->key_r = prof.key_r & mask; + mode_out->key_g = prof.key_g & mask; + mode_out->key_b = prof.key_b & mask; + mode_out->key_defined = 1; + } + } + + return error; +} + +#endif /* #ifdef LODEPNG_COMPILE_ENCODER */ + +/* +Paeth predicter, used by PNG filter type 4 +The parameters are of type short, but should come from unsigned chars, the shorts +are only needed to make the paeth calculation correct. +*/ +static unsigned char paethPredictor(short a, short b, short c) +{ + short pa = abs(b - c); + short pb = abs(a - c); + short pc = abs(a + b - c - c); + + if(pc < pa && pc < pb) return (unsigned char)c; + else if(pb < pa) return (unsigned char)b; + else return (unsigned char)a; +} + +/*shared values used by multiple Adam7 related functions*/ + +static const unsigned ADAM7_IX[7] = { 0, 4, 0, 2, 0, 1, 0 }; /*x start values*/ +static const unsigned ADAM7_IY[7] = { 0, 0, 4, 0, 2, 0, 1 }; /*y start values*/ +static const unsigned ADAM7_DX[7] = { 8, 8, 4, 4, 2, 2, 1 }; /*x delta values*/ +static const unsigned ADAM7_DY[7] = { 8, 8, 8, 4, 4, 2, 2 }; /*y delta values*/ + +/* +Outputs various dimensions and positions in the image related to the Adam7 reduced images. +passw: output containing the width of the 7 passes +passh: output containing the height of the 7 passes +filter_passstart: output containing the index of the start and end of each + reduced image with filter bytes +padded_passstart output containing the index of the start and end of each + reduced image when without filter bytes but with padded scanlines +passstart: output containing the index of the start and end of each reduced + image without padding between scanlines, but still padding between the images +w, h: width and height of non-interlaced image +bpp: bits per pixel +"padded" is only relevant if bpp is less than 8 and a scanline or image does not + end at a full byte +*/ +static void Adam7_getpassvalues(unsigned passw[7], unsigned passh[7], size_t filter_passstart[8], + size_t padded_passstart[8], size_t passstart[8], unsigned w, unsigned h, unsigned bpp) +{ + /*the passstart values have 8 values: the 8th one indicates the byte after the end of the 7th (= last) pass*/ + unsigned i; + + /*calculate width and height in pixels of each pass*/ + for(i = 0; i != 7; ++i) + { + passw[i] = (w + ADAM7_DX[i] - ADAM7_IX[i] - 1) / ADAM7_DX[i]; + passh[i] = (h + ADAM7_DY[i] - ADAM7_IY[i] - 1) / ADAM7_DY[i]; + if(passw[i] == 0) passh[i] = 0; + if(passh[i] == 0) passw[i] = 0; + } + + filter_passstart[0] = padded_passstart[0] = passstart[0] = 0; + for(i = 0; i != 7; ++i) + { + /*if passw[i] is 0, it's 0 bytes, not 1 (no filtertype-byte)*/ + filter_passstart[i + 1] = filter_passstart[i] + + ((passw[i] && passh[i]) ? passh[i] * (1 + (passw[i] * bpp + 7) / 8) : 0); + /*bits padded if needed to fill full byte at end of each scanline*/ + padded_passstart[i + 1] = padded_passstart[i] + passh[i] * ((passw[i] * bpp + 7) / 8); + /*only padded at end of reduced image*/ + passstart[i + 1] = passstart[i] + (passh[i] * passw[i] * bpp + 7) / 8; + } +} + +#ifdef LODEPNG_COMPILE_DECODER + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / PNG Decoder / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +/*read the information from the header and store it in the LodePNGInfo. return value is error*/ +unsigned lodepng_inspect(unsigned* w, unsigned* h, LodePNGState* state, + const unsigned char* in, size_t insize) +{ + LodePNGInfo* info = &state->info_png; + if(insize == 0 || in == 0) + { + CERROR_RETURN_ERROR(state->error, 48); /*error: the given data is empty*/ + } + if(insize < 33) + { + CERROR_RETURN_ERROR(state->error, 27); /*error: the data length is smaller than the length of a PNG header*/ + } + + /*when decoding a new PNG image, make sure all parameters created after previous decoding are reset*/ + lodepng_info_cleanup(info); + lodepng_info_init(info); + + if(in[0] != 137 || in[1] != 80 || in[2] != 78 || in[3] != 71 + || in[4] != 13 || in[5] != 10 || in[6] != 26 || in[7] != 10) + { + CERROR_RETURN_ERROR(state->error, 28); /*error: the first 8 bytes are not the correct PNG signature*/ + } + if(lodepng_chunk_length(in + 8) != 13) + { + CERROR_RETURN_ERROR(state->error, 94); /*error: header size must be 13 bytes*/ + } + if(!lodepng_chunk_type_equals(in + 8, "IHDR")) + { + CERROR_RETURN_ERROR(state->error, 29); /*error: it doesn't start with a IHDR chunk!*/ + } + + /*read the values given in the header*/ + *w = lodepng_read32bitInt(&in[16]); + *h = lodepng_read32bitInt(&in[20]); + info->color.bitdepth = in[24]; + info->color.colortype = (LodePNGColorType)in[25]; + info->compression_method = in[26]; + info->filter_method = in[27]; + info->interlace_method = in[28]; + + if(*w == 0 || *h == 0) + { + CERROR_RETURN_ERROR(state->error, 93); + } + + if(!state->decoder.ignore_crc) + { + unsigned CRC = lodepng_read32bitInt(&in[29]); + unsigned checksum = lodepng_crc32(&in[12], 17); + if(CRC != checksum) + { + CERROR_RETURN_ERROR(state->error, 57); /*invalid CRC*/ + } + } + + /*error: only compression method 0 is allowed in the specification*/ + if(info->compression_method != 0) CERROR_RETURN_ERROR(state->error, 32); + /*error: only filter method 0 is allowed in the specification*/ + if(info->filter_method != 0) CERROR_RETURN_ERROR(state->error, 33); + /*error: only interlace methods 0 and 1 exist in the specification*/ + if(info->interlace_method > 1) CERROR_RETURN_ERROR(state->error, 34); + + state->error = checkColorValidity(info->color.colortype, info->color.bitdepth); + return state->error; +} + +static unsigned unfilterScanline(unsigned char* recon, const unsigned char* scanline, const unsigned char* precon, + size_t bytewidth, unsigned char filterType, size_t length) +{ + /* + For PNG filter method 0 + unfilter a PNG image scanline by scanline. when the pixels are smaller than 1 byte, + the filter works byte per byte (bytewidth = 1) + precon is the previous unfiltered scanline, recon the result, scanline the current one + the incoming scanlines do NOT include the filtertype byte, that one is given in the parameter filterType instead + recon and scanline MAY be the same memory address! precon must be disjoint. + */ + + size_t i; + switch(filterType) + { + case 0: + for(i = 0; i != length; ++i) recon[i] = scanline[i]; + break; + case 1: + for(i = 0; i != bytewidth; ++i) recon[i] = scanline[i]; + for(i = bytewidth; i < length; ++i) recon[i] = scanline[i] + recon[i - bytewidth]; + break; + case 2: + if(precon) + { + for(i = 0; i != length; ++i) recon[i] = scanline[i] + precon[i]; + } + else + { + for(i = 0; i != length; ++i) recon[i] = scanline[i]; + } + break; + case 3: + if(precon) + { + for(i = 0; i != bytewidth; ++i) recon[i] = scanline[i] + (precon[i] >> 1); + for(i = bytewidth; i < length; ++i) recon[i] = scanline[i] + ((recon[i - bytewidth] + precon[i]) >> 1); + } + else + { + for(i = 0; i != bytewidth; ++i) recon[i] = scanline[i]; + for(i = bytewidth; i < length; ++i) recon[i] = scanline[i] + (recon[i - bytewidth] >> 1); + } + break; + case 4: + if(precon) + { + for(i = 0; i != bytewidth; ++i) + { + recon[i] = (scanline[i] + precon[i]); /*paethPredictor(0, precon[i], 0) is always precon[i]*/ + } + for(i = bytewidth; i < length; ++i) + { + recon[i] = (scanline[i] + paethPredictor(recon[i - bytewidth], precon[i], precon[i - bytewidth])); + } + } + else + { + for(i = 0; i != bytewidth; ++i) + { + recon[i] = scanline[i]; + } + for(i = bytewidth; i < length; ++i) + { + /*paethPredictor(recon[i - bytewidth], 0, 0) is always recon[i - bytewidth]*/ + recon[i] = (scanline[i] + recon[i - bytewidth]); + } + } + break; + default: return 36; /*error: unexisting filter type given*/ + } + return 0; +} + +static unsigned unfilter(unsigned char* out, const unsigned char* in, unsigned w, unsigned h, unsigned bpp) +{ + /* + For PNG filter method 0 + this function unfilters a single image (e.g. without interlacing this is called once, with Adam7 seven times) + out must have enough bytes allocated already, in must have the scanlines + 1 filtertype byte per scanline + w and h are image dimensions or dimensions of reduced image, bpp is bits per pixel + in and out are allowed to be the same memory address (but aren't the same size since in has the extra filter bytes) + */ + + unsigned y; + unsigned char* prevline = 0; + + /*bytewidth is used for filtering, is 1 when bpp < 8, number of bytes per pixel otherwise*/ + size_t bytewidth = (bpp + 7) / 8; + size_t linebytes = (w * bpp + 7) / 8; + + for(y = 0; y < h; ++y) + { + size_t outindex = linebytes * y; + size_t inindex = (1 + linebytes) * y; /*the extra filterbyte added to each row*/ + unsigned char filterType = in[inindex]; + + CERROR_TRY_RETURN(unfilterScanline(&out[outindex], &in[inindex + 1], prevline, bytewidth, filterType, linebytes)); + + prevline = &out[outindex]; + } + + return 0; +} + +/* +in: Adam7 interlaced image, with no padding bits between scanlines, but between + reduced images so that each reduced image starts at a byte. +out: the same pixels, but re-ordered so that they're now a non-interlaced image with size w*h +bpp: bits per pixel +out has the following size in bits: w * h * bpp. +in is possibly bigger due to padding bits between reduced images. +out must be big enough AND must be 0 everywhere if bpp < 8 in the current implementation +(because that's likely a little bit faster) +NOTE: comments about padding bits are only relevant if bpp < 8 +*/ +static void Adam7_deinterlace(unsigned char* out, const unsigned char* in, unsigned w, unsigned h, unsigned bpp) +{ + unsigned passw[7], passh[7]; + size_t filter_passstart[8], padded_passstart[8], passstart[8]; + unsigned i; + + Adam7_getpassvalues(passw, passh, filter_passstart, padded_passstart, passstart, w, h, bpp); + + if(bpp >= 8) + { + for(i = 0; i != 7; ++i) + { + unsigned x, y, b; + size_t bytewidth = bpp / 8; + for(y = 0; y < passh[i]; ++y) + for(x = 0; x < passw[i]; ++x) + { + size_t pixelinstart = passstart[i] + (y * passw[i] + x) * bytewidth; + size_t pixeloutstart = ((ADAM7_IY[i] + y * ADAM7_DY[i]) * w + ADAM7_IX[i] + x * ADAM7_DX[i]) * bytewidth; + for(b = 0; b < bytewidth; ++b) + { + out[pixeloutstart + b] = in[pixelinstart + b]; + } + } + } + } + else /*bpp < 8: Adam7 with pixels < 8 bit is a bit trickier: with bit pointers*/ + { + for(i = 0; i != 7; ++i) + { + unsigned x, y, b; + unsigned ilinebits = bpp * passw[i]; + unsigned olinebits = bpp * w; + size_t obp, ibp; /*bit pointers (for out and in buffer)*/ + for(y = 0; y < passh[i]; ++y) + for(x = 0; x < passw[i]; ++x) + { + ibp = (8 * passstart[i]) + (y * ilinebits + x * bpp); + obp = (ADAM7_IY[i] + y * ADAM7_DY[i]) * olinebits + (ADAM7_IX[i] + x * ADAM7_DX[i]) * bpp; + for(b = 0; b < bpp; ++b) + { + unsigned char bit = readBitFromReversedStream(&ibp, in); + /*note that this function assumes the out buffer is completely 0, use setBitOfReversedStream otherwise*/ + setBitOfReversedStream0(&obp, out, bit); + } + } + } + } +} + +static void removePaddingBits(unsigned char* out, const unsigned char* in, + size_t olinebits, size_t ilinebits, unsigned h) +{ + /* + After filtering there are still padding bits if scanlines have non multiple of 8 bit amounts. They need + to be removed (except at last scanline of (Adam7-reduced) image) before working with pure image buffers + for the Adam7 code, the color convert code and the output to the user. + in and out are allowed to be the same buffer, in may also be higher but still overlapping; in must + have >= ilinebits*h bits, out must have >= olinebits*h bits, olinebits must be <= ilinebits + also used to move bits after earlier such operations happened, e.g. in a sequence of reduced images from Adam7 + only useful if (ilinebits - olinebits) is a value in the range 1..7 + */ + unsigned y; + size_t diff = ilinebits - olinebits; + size_t ibp = 0, obp = 0; /*input and output bit pointers*/ + for(y = 0; y < h; ++y) + { + size_t x; + for(x = 0; x < olinebits; ++x) + { + unsigned char bit = readBitFromReversedStream(&ibp, in); + setBitOfReversedStream(&obp, out, bit); + } + ibp += diff; + } +} + +/*out must be buffer big enough to contain full image, and in must contain the full decompressed data from +the IDAT chunks (with filter index bytes and possible padding bits) +return value is error*/ +static unsigned postProcessScanlines(unsigned char* out, unsigned char* in, + unsigned w, unsigned h, const LodePNGInfo* info_png) +{ + /* + This function converts the filtered-padded-interlaced data into pure 2D image buffer with the PNG's colortype. + Steps: + *) if no Adam7: 1) unfilter 2) remove padding bits (= posible extra bits per scanline if bpp < 8) + *) if adam7: 1) 7x unfilter 2) 7x remove padding bits 3) Adam7_deinterlace + NOTE: the in buffer will be overwritten with intermediate data! + */ + unsigned bpp = lodepng_get_bpp(&info_png->color); + if(bpp == 0) return 31; /*error: invalid colortype*/ + + if(info_png->interlace_method == 0) + { + if(bpp < 8 && w * bpp != ((w * bpp + 7) / 8) * 8) + { + CERROR_TRY_RETURN(unfilter(in, in, w, h, bpp)); + removePaddingBits(out, in, w * bpp, ((w * bpp + 7) / 8) * 8, h); + } + /*we can immediately filter into the out buffer, no other steps needed*/ + else CERROR_TRY_RETURN(unfilter(out, in, w, h, bpp)); + } + else /*interlace_method is 1 (Adam7)*/ + { + unsigned passw[7], passh[7]; size_t filter_passstart[8], padded_passstart[8], passstart[8]; + unsigned i; + + Adam7_getpassvalues(passw, passh, filter_passstart, padded_passstart, passstart, w, h, bpp); + + for(i = 0; i != 7; ++i) + { + CERROR_TRY_RETURN(unfilter(&in[padded_passstart[i]], &in[filter_passstart[i]], passw[i], passh[i], bpp)); + /*TODO: possible efficiency improvement: if in this reduced image the bits fit nicely in 1 scanline, + move bytes instead of bits or move not at all*/ + if(bpp < 8) + { + /*remove padding bits in scanlines; after this there still may be padding + bits between the different reduced images: each reduced image still starts nicely at a byte*/ + removePaddingBits(&in[passstart[i]], &in[padded_passstart[i]], passw[i] * bpp, + ((passw[i] * bpp + 7) / 8) * 8, passh[i]); + } + } + + Adam7_deinterlace(out, in, w, h, bpp); + } + + return 0; +} + +static unsigned readChunk_PLTE(LodePNGColorMode* color, const unsigned char* data, size_t chunkLength) +{ + unsigned pos = 0, i; + if(color->palette) lodepng_free(color->palette); + color->palettesize = chunkLength / 3; + color->palette = (unsigned char*)lodepng_malloc(4 * color->palettesize); + if(!color->palette && color->palettesize) + { + color->palettesize = 0; + return 83; /*alloc fail*/ + } + if(color->palettesize > 256) return 38; /*error: palette too big*/ + + for(i = 0; i != color->palettesize; ++i) + { + color->palette[4 * i + 0] = data[pos++]; /*R*/ + color->palette[4 * i + 1] = data[pos++]; /*G*/ + color->palette[4 * i + 2] = data[pos++]; /*B*/ + color->palette[4 * i + 3] = 255; /*alpha*/ + } + + return 0; /* OK */ +} + +static unsigned readChunk_tRNS(LodePNGColorMode* color, const unsigned char* data, size_t chunkLength) +{ + unsigned i; + if(color->colortype == LCT_PALETTE) + { + /*error: more alpha values given than there are palette entries*/ + if(chunkLength > color->palettesize) return 38; + + for(i = 0; i != chunkLength; ++i) color->palette[4 * i + 3] = data[i]; + } + else if(color->colortype == LCT_GREY) + { + /*error: this chunk must be 2 bytes for greyscale image*/ + if(chunkLength != 2) return 30; + + color->key_defined = 1; + color->key_r = color->key_g = color->key_b = 256u * data[0] + data[1]; + } + else if(color->colortype == LCT_RGB) + { + /*error: this chunk must be 6 bytes for RGB image*/ + if(chunkLength != 6) return 41; + + color->key_defined = 1; + color->key_r = 256u * data[0] + data[1]; + color->key_g = 256u * data[2] + data[3]; + color->key_b = 256u * data[4] + data[5]; + } + else return 42; /*error: tRNS chunk not allowed for other color models*/ + + return 0; /* OK */ +} + + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS +/*background color chunk (bKGD)*/ +static unsigned readChunk_bKGD(LodePNGInfo* info, const unsigned char* data, size_t chunkLength) +{ + if(info->color.colortype == LCT_PALETTE) + { + /*error: this chunk must be 1 byte for indexed color image*/ + if(chunkLength != 1) return 43; + + info->background_defined = 1; + info->background_r = info->background_g = info->background_b = data[0]; + } + else if(info->color.colortype == LCT_GREY || info->color.colortype == LCT_GREY_ALPHA) + { + /*error: this chunk must be 2 bytes for greyscale image*/ + if(chunkLength != 2) return 44; + + info->background_defined = 1; + info->background_r = info->background_g = info->background_b = 256u * data[0] + data[1]; + } + else if(info->color.colortype == LCT_RGB || info->color.colortype == LCT_RGBA) + { + /*error: this chunk must be 6 bytes for greyscale image*/ + if(chunkLength != 6) return 45; + + info->background_defined = 1; + info->background_r = 256u * data[0] + data[1]; + info->background_g = 256u * data[2] + data[3]; + info->background_b = 256u * data[4] + data[5]; + } + + return 0; /* OK */ +} + +/*text chunk (tEXt)*/ +static unsigned readChunk_tEXt(LodePNGInfo* info, const unsigned char* data, size_t chunkLength) +{ + unsigned error = 0; + char *key = 0, *str = 0; + unsigned i; + + while(!error) /*not really a while loop, only used to break on error*/ + { + unsigned length, string2_begin; + + length = 0; + while(length < chunkLength && data[length] != 0) ++length; + /*even though it's not allowed by the standard, no error is thrown if + there's no null termination char, if the text is empty*/ + if(length < 1 || length > 79) CERROR_BREAK(error, 89); /*keyword too short or long*/ + + key = (char*)lodepng_malloc(length + 1); + if(!key) CERROR_BREAK(error, 83); /*alloc fail*/ + + key[length] = 0; + for(i = 0; i != length; ++i) key[i] = (char)data[i]; + + string2_begin = length + 1; /*skip keyword null terminator*/ + + length = chunkLength < string2_begin ? 0 : chunkLength - string2_begin; + str = (char*)lodepng_malloc(length + 1); + if(!str) CERROR_BREAK(error, 83); /*alloc fail*/ + + str[length] = 0; + for(i = 0; i != length; ++i) str[i] = (char)data[string2_begin + i]; + + error = lodepng_add_text(info, key, str); + + break; + } + + lodepng_free(key); + lodepng_free(str); + + return error; +} + +/*compressed text chunk (zTXt)*/ +static unsigned readChunk_zTXt(LodePNGInfo* info, const LodePNGDecompressSettings* zlibsettings, + const unsigned char* data, size_t chunkLength) +{ + unsigned error = 0; + unsigned i; + + unsigned length, string2_begin; + char *key = 0; + ucvector decoded; + + ucvector_init(&decoded); + + while(!error) /*not really a while loop, only used to break on error*/ + { + for(length = 0; length < chunkLength && data[length] != 0; ++length) ; + if(length + 2 >= chunkLength) CERROR_BREAK(error, 75); /*no null termination, corrupt?*/ + if(length < 1 || length > 79) CERROR_BREAK(error, 89); /*keyword too short or long*/ + + key = (char*)lodepng_malloc(length + 1); + if(!key) CERROR_BREAK(error, 83); /*alloc fail*/ + + key[length] = 0; + for(i = 0; i != length; ++i) key[i] = (char)data[i]; + + if(data[length + 1] != 0) CERROR_BREAK(error, 72); /*the 0 byte indicating compression must be 0*/ + + string2_begin = length + 2; + if(string2_begin > chunkLength) CERROR_BREAK(error, 75); /*no null termination, corrupt?*/ + + length = chunkLength - string2_begin; + /*will fail if zlib error, e.g. if length is too small*/ + error = zlib_decompress(&decoded.data, &decoded.size, + (unsigned char*)(&data[string2_begin]), + length, zlibsettings); + if(error) break; + ucvector_push_back(&decoded, 0); + + error = lodepng_add_text(info, key, (char*)decoded.data); + + break; + } + + lodepng_free(key); + ucvector_cleanup(&decoded); + + return error; +} + +/*international text chunk (iTXt)*/ +static unsigned readChunk_iTXt(LodePNGInfo* info, const LodePNGDecompressSettings* zlibsettings, + const unsigned char* data, size_t chunkLength) +{ + unsigned error = 0; + unsigned i; + + unsigned length, begin, compressed; + char *key = 0, *langtag = 0, *transkey = 0; + ucvector decoded; + ucvector_init(&decoded); + + while(!error) /*not really a while loop, only used to break on error*/ + { + /*Quick check if the chunk length isn't too small. Even without check + it'd still fail with other error checks below if it's too short. This just gives a different error code.*/ + if(chunkLength < 5) CERROR_BREAK(error, 30); /*iTXt chunk too short*/ + + /*read the key*/ + for(length = 0; length < chunkLength && data[length] != 0; ++length) ; + if(length + 3 >= chunkLength) CERROR_BREAK(error, 75); /*no null termination char, corrupt?*/ + if(length < 1 || length > 79) CERROR_BREAK(error, 89); /*keyword too short or long*/ + + key = (char*)lodepng_malloc(length + 1); + if(!key) CERROR_BREAK(error, 83); /*alloc fail*/ + + key[length] = 0; + for(i = 0; i != length; ++i) key[i] = (char)data[i]; + + /*read the compression method*/ + compressed = data[length + 1]; + if(data[length + 2] != 0) CERROR_BREAK(error, 72); /*the 0 byte indicating compression must be 0*/ + + /*even though it's not allowed by the standard, no error is thrown if + there's no null termination char, if the text is empty for the next 3 texts*/ + + /*read the langtag*/ + begin = length + 3; + length = 0; + for(i = begin; i < chunkLength && data[i] != 0; ++i) ++length; + + langtag = (char*)lodepng_malloc(length + 1); + if(!langtag) CERROR_BREAK(error, 83); /*alloc fail*/ + + langtag[length] = 0; + for(i = 0; i != length; ++i) langtag[i] = (char)data[begin + i]; + + /*read the transkey*/ + begin += length + 1; + length = 0; + for(i = begin; i < chunkLength && data[i] != 0; ++i) ++length; + + transkey = (char*)lodepng_malloc(length + 1); + if(!transkey) CERROR_BREAK(error, 83); /*alloc fail*/ + + transkey[length] = 0; + for(i = 0; i != length; ++i) transkey[i] = (char)data[begin + i]; + + /*read the actual text*/ + begin += length + 1; + + length = chunkLength < begin ? 0 : chunkLength - begin; + + if(compressed) + { + /*will fail if zlib error, e.g. if length is too small*/ + error = zlib_decompress(&decoded.data, &decoded.size, + (unsigned char*)(&data[begin]), + length, zlibsettings); + if(error) break; + if(decoded.allocsize < decoded.size) decoded.allocsize = decoded.size; + ucvector_push_back(&decoded, 0); + } + else + { + if(!ucvector_resize(&decoded, length + 1)) CERROR_BREAK(error, 83 /*alloc fail*/); + + decoded.data[length] = 0; + for(i = 0; i != length; ++i) decoded.data[i] = data[begin + i]; + } + + error = lodepng_add_itext(info, key, langtag, transkey, (char*)decoded.data); + + break; + } + + lodepng_free(key); + lodepng_free(langtag); + lodepng_free(transkey); + ucvector_cleanup(&decoded); + + return error; +} + +static unsigned readChunk_tIME(LodePNGInfo* info, const unsigned char* data, size_t chunkLength) +{ + if(chunkLength != 7) return 73; /*invalid tIME chunk size*/ + + info->time_defined = 1; + info->time.year = 256u * data[0] + data[1]; + info->time.month = data[2]; + info->time.day = data[3]; + info->time.hour = data[4]; + info->time.minute = data[5]; + info->time.second = data[6]; + + return 0; /* OK */ +} + +static unsigned readChunk_pHYs(LodePNGInfo* info, const unsigned char* data, size_t chunkLength) +{ + if(chunkLength != 9) return 74; /*invalid pHYs chunk size*/ + + info->phys_defined = 1; + info->phys_x = 16777216u * data[0] + 65536u * data[1] + 256u * data[2] + data[3]; + info->phys_y = 16777216u * data[4] + 65536u * data[5] + 256u * data[6] + data[7]; + info->phys_unit = data[8]; + + return 0; /* OK */ +} +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + +/*read a PNG, the result will be in the same color type as the PNG (hence "generic")*/ +static void decodeGeneric(unsigned char** out, unsigned* w, unsigned* h, + LodePNGState* state, + const unsigned char* in, size_t insize) +{ + unsigned char IEND = 0; + const unsigned char* chunk; + size_t i; + ucvector idat; /*the data from idat chunks*/ + ucvector scanlines; + size_t predict; + size_t numpixels; + size_t outsize = 0; + + /*for unknown chunk order*/ + unsigned unknown = 0; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + unsigned critical_pos = 1; /*1 = after IHDR, 2 = after PLTE, 3 = after IDAT*/ +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + + /*provide some proper output values if error will happen*/ + *out = 0; + + state->error = lodepng_inspect(w, h, state, in, insize); /*reads header and resets other parameters in state->info_png*/ + if(state->error) return; + + numpixels = *w * *h; + + /*multiplication overflow*/ + if(*h != 0 && numpixels / *h != *w) CERROR_RETURN(state->error, 92); + /*multiplication overflow possible further below. Allows up to 2^31-1 pixel + bytes with 16-bit RGBA, the rest is room for filter bytes.*/ + if(numpixels > 268435455) CERROR_RETURN(state->error, 92); + + ucvector_init(&idat); + chunk = &in[33]; /*first byte of the first chunk after the header*/ + + /*loop through the chunks, ignoring unknown chunks and stopping at IEND chunk. + IDAT data is put at the start of the in buffer*/ + while(!IEND && !state->error) + { + unsigned chunkLength; + const unsigned char* data; /*the data in the chunk*/ + + /*error: size of the in buffer too small to contain next chunk*/ + if((size_t)((chunk - in) + 12) > insize || chunk < in) CERROR_BREAK(state->error, 30); + + /*length of the data of the chunk, excluding the length bytes, chunk type and CRC bytes*/ + chunkLength = lodepng_chunk_length(chunk); + /*error: chunk length larger than the max PNG chunk size*/ + if(chunkLength > 2147483647) CERROR_BREAK(state->error, 63); + + if((size_t)((chunk - in) + chunkLength + 12) > insize || (chunk + chunkLength + 12) < in) + { + CERROR_BREAK(state->error, 64); /*error: size of the in buffer too small to contain next chunk*/ + } + + data = lodepng_chunk_data_const(chunk); + + /*IDAT chunk, containing compressed image data*/ + if(lodepng_chunk_type_equals(chunk, "IDAT")) + { + size_t oldsize = idat.size; + if(!ucvector_resize(&idat, oldsize + chunkLength)) CERROR_BREAK(state->error, 83 /*alloc fail*/); + for(i = 0; i != chunkLength; ++i) idat.data[oldsize + i] = data[i]; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + critical_pos = 3; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + } + /*IEND chunk*/ + else if(lodepng_chunk_type_equals(chunk, "IEND")) + { + IEND = 1; + } + /*palette chunk (PLTE)*/ + else if(lodepng_chunk_type_equals(chunk, "PLTE")) + { + state->error = readChunk_PLTE(&state->info_png.color, data, chunkLength); + if(state->error) break; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + critical_pos = 2; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + } + /*palette transparency chunk (tRNS)*/ + else if(lodepng_chunk_type_equals(chunk, "tRNS")) + { + state->error = readChunk_tRNS(&state->info_png.color, data, chunkLength); + if(state->error) break; + } +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + /*background color chunk (bKGD)*/ + else if(lodepng_chunk_type_equals(chunk, "bKGD")) + { + state->error = readChunk_bKGD(&state->info_png, data, chunkLength); + if(state->error) break; + } + /*text chunk (tEXt)*/ + else if(lodepng_chunk_type_equals(chunk, "tEXt")) + { + if(state->decoder.read_text_chunks) + { + state->error = readChunk_tEXt(&state->info_png, data, chunkLength); + if(state->error) break; + } + } + /*compressed text chunk (zTXt)*/ + else if(lodepng_chunk_type_equals(chunk, "zTXt")) + { + if(state->decoder.read_text_chunks) + { + state->error = readChunk_zTXt(&state->info_png, &state->decoder.zlibsettings, data, chunkLength); + if(state->error) break; + } + } + /*international text chunk (iTXt)*/ + else if(lodepng_chunk_type_equals(chunk, "iTXt")) + { + if(state->decoder.read_text_chunks) + { + state->error = readChunk_iTXt(&state->info_png, &state->decoder.zlibsettings, data, chunkLength); + if(state->error) break; + } + } + else if(lodepng_chunk_type_equals(chunk, "tIME")) + { + state->error = readChunk_tIME(&state->info_png, data, chunkLength); + if(state->error) break; + } + else if(lodepng_chunk_type_equals(chunk, "pHYs")) + { + state->error = readChunk_pHYs(&state->info_png, data, chunkLength); + if(state->error) break; + } +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + else /*it's not an implemented chunk type, so ignore it: skip over the data*/ + { + /*error: unknown critical chunk (5th bit of first byte of chunk type is 0)*/ + if(!lodepng_chunk_ancillary(chunk)) CERROR_BREAK(state->error, 69); + + unknown = 1; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + if(state->decoder.remember_unknown_chunks) + { + state->error = lodepng_chunk_append(&state->info_png.unknown_chunks_data[critical_pos - 1], + &state->info_png.unknown_chunks_size[critical_pos - 1], chunk); + if(state->error) break; + } +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + } + + if(!state->decoder.ignore_crc && !unknown) /*check CRC if wanted, only on known chunk types*/ + { + if(lodepng_chunk_check_crc(chunk)) CERROR_BREAK(state->error, 57); /*invalid CRC*/ + } + + if(!IEND) chunk = lodepng_chunk_next_const(chunk); + } + + ucvector_init(&scanlines); + /*predict output size, to allocate exact size for output buffer to avoid more dynamic allocation. + If the decompressed size does not match the prediction, the image must be corrupt.*/ + if(state->info_png.interlace_method == 0) + { + /*The extra *h is added because this are the filter bytes every scanline starts with*/ + predict = lodepng_get_raw_size_idat(*w, *h, &state->info_png.color) + *h; + } + else + { + /*Adam-7 interlaced: predicted size is the sum of the 7 sub-images sizes*/ + const LodePNGColorMode* color = &state->info_png.color; + predict = 0; + predict += lodepng_get_raw_size_idat((*w + 7) >> 3, (*h + 7) >> 3, color) + ((*h + 7) >> 3); + if(*w > 4) predict += lodepng_get_raw_size_idat((*w + 3) >> 3, (*h + 7) >> 3, color) + ((*h + 7) >> 3); + predict += lodepng_get_raw_size_idat((*w + 3) >> 2, (*h + 3) >> 3, color) + ((*h + 3) >> 3); + if(*w > 2) predict += lodepng_get_raw_size_idat((*w + 1) >> 2, (*h + 3) >> 2, color) + ((*h + 3) >> 2); + predict += lodepng_get_raw_size_idat((*w + 1) >> 1, (*h + 1) >> 2, color) + ((*h + 1) >> 2); + if(*w > 1) predict += lodepng_get_raw_size_idat((*w + 0) >> 1, (*h + 1) >> 1, color) + ((*h + 1) >> 1); + predict += lodepng_get_raw_size_idat((*w + 0), (*h + 0) >> 1, color) + ((*h + 0) >> 1); + } + if(!state->error && !ucvector_reserve(&scanlines, predict)) state->error = 83; /*alloc fail*/ + if(!state->error) + { + state->error = zlib_decompress(&scanlines.data, &scanlines.size, idat.data, + idat.size, &state->decoder.zlibsettings); + if(!state->error && scanlines.size != predict) state->error = 91; /*decompressed size doesn't match prediction*/ + } + ucvector_cleanup(&idat); + + if(!state->error) + { + outsize = lodepng_get_raw_size(*w, *h, &state->info_png.color); + *out = (unsigned char*)lodepng_malloc(outsize); + if(!*out) state->error = 83; /*alloc fail*/ + } + if(!state->error) + { + for(i = 0; i < outsize; i++) (*out)[i] = 0; + state->error = postProcessScanlines(*out, scanlines.data, *w, *h, &state->info_png); + } + ucvector_cleanup(&scanlines); +} + +unsigned lodepng_decode(unsigned char** out, unsigned* w, unsigned* h, + LodePNGState* state, + const unsigned char* in, size_t insize) +{ + *out = 0; + decodeGeneric(out, w, h, state, in, insize); + if(state->error) return state->error; + if(!state->decoder.color_convert || lodepng_color_mode_equal(&state->info_raw, &state->info_png.color)) + { + /*same color type, no copying or converting of data needed*/ + /*store the info_png color settings on the info_raw so that the info_raw still reflects what colortype + the raw image has to the end user*/ + if(!state->decoder.color_convert) + { + state->error = lodepng_color_mode_copy(&state->info_raw, &state->info_png.color); + if(state->error) return state->error; + } + } + else + { + /*color conversion needed; sort of copy of the data*/ + unsigned char* data = *out; + size_t outsize; + + /*TODO: check if this works according to the statement in the documentation: "The converter can convert + from greyscale input color type, to 8-bit greyscale or greyscale with alpha"*/ + if(!(state->info_raw.colortype == LCT_RGB || state->info_raw.colortype == LCT_RGBA) + && !(state->info_raw.bitdepth == 8)) + { + return 56; /*unsupported color mode conversion*/ + } + + outsize = lodepng_get_raw_size(*w, *h, &state->info_raw); + *out = (unsigned char*)lodepng_malloc(outsize); + if(!(*out)) + { + state->error = 83; /*alloc fail*/ + } + else state->error = lodepng_convert(*out, data, &state->info_raw, + &state->info_png.color, *w, *h); + lodepng_free(data); + } + return state->error; +} + +unsigned lodepng_decode_memory(unsigned char** out, unsigned* w, unsigned* h, const unsigned char* in, + size_t insize, LodePNGColorType colortype, unsigned bitdepth) +{ + unsigned error; + LodePNGState state; + lodepng_state_init(&state); + state.info_raw.colortype = colortype; + state.info_raw.bitdepth = bitdepth; + error = lodepng_decode(out, w, h, &state, in, insize); + lodepng_state_cleanup(&state); + return error; +} + +unsigned lodepng_decode32(unsigned char** out, unsigned* w, unsigned* h, const unsigned char* in, size_t insize) +{ + return lodepng_decode_memory(out, w, h, in, insize, LCT_RGBA, 8); +} + +unsigned lodepng_decode24(unsigned char** out, unsigned* w, unsigned* h, const unsigned char* in, size_t insize) +{ + return lodepng_decode_memory(out, w, h, in, insize, LCT_RGB, 8); +} + +#ifdef LODEPNG_COMPILE_DISK +unsigned lodepng_decode_file(unsigned char** out, unsigned* w, unsigned* h, const char* filename, + LodePNGColorType colortype, unsigned bitdepth) +{ + unsigned char* buffer = 0; + size_t buffersize; + unsigned error; + error = lodepng_load_file(&buffer, &buffersize, filename); + if(!error) error = lodepng_decode_memory(out, w, h, buffer, buffersize, colortype, bitdepth); + lodepng_free(buffer); + return error; +} + +unsigned lodepng_decode32_file(unsigned char** out, unsigned* w, unsigned* h, const char* filename) +{ + return lodepng_decode_file(out, w, h, filename, LCT_RGBA, 8); +} + +unsigned lodepng_decode24_file(unsigned char** out, unsigned* w, unsigned* h, const char* filename) +{ + return lodepng_decode_file(out, w, h, filename, LCT_RGB, 8); +} +#endif /*LODEPNG_COMPILE_DISK*/ + +void lodepng_decoder_settings_init(LodePNGDecoderSettings* settings) +{ + settings->color_convert = 1; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + settings->read_text_chunks = 1; + settings->remember_unknown_chunks = 0; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + settings->ignore_crc = 0; + lodepng_decompress_settings_init(&settings->zlibsettings); +} + +#endif /*LODEPNG_COMPILE_DECODER*/ + +#if defined(LODEPNG_COMPILE_DECODER) || defined(LODEPNG_COMPILE_ENCODER) + +void lodepng_state_init(LodePNGState* state) +{ +#ifdef LODEPNG_COMPILE_DECODER + lodepng_decoder_settings_init(&state->decoder); +#endif /*LODEPNG_COMPILE_DECODER*/ +#ifdef LODEPNG_COMPILE_ENCODER + lodepng_encoder_settings_init(&state->encoder); +#endif /*LODEPNG_COMPILE_ENCODER*/ + lodepng_color_mode_init(&state->info_raw); + lodepng_info_init(&state->info_png); + state->error = 1; +} + +void lodepng_state_cleanup(LodePNGState* state) +{ + lodepng_color_mode_cleanup(&state->info_raw); + lodepng_info_cleanup(&state->info_png); +} + +void lodepng_state_copy(LodePNGState* dest, const LodePNGState* source) +{ + lodepng_state_cleanup(dest); + *dest = *source; + lodepng_color_mode_init(&dest->info_raw); + lodepng_info_init(&dest->info_png); + dest->error = lodepng_color_mode_copy(&dest->info_raw, &source->info_raw); if(dest->error) return; + dest->error = lodepng_info_copy(&dest->info_png, &source->info_png); if(dest->error) return; +} + +#endif /* defined(LODEPNG_COMPILE_DECODER) || defined(LODEPNG_COMPILE_ENCODER) */ + +#ifdef LODEPNG_COMPILE_ENCODER + +/* ////////////////////////////////////////////////////////////////////////// */ +/* / PNG Encoder / */ +/* ////////////////////////////////////////////////////////////////////////// */ + +/*chunkName must be string of 4 characters*/ +static unsigned addChunk(ucvector* out, const char* chunkName, const unsigned char* data, size_t length) +{ + CERROR_TRY_RETURN(lodepng_chunk_create(&out->data, &out->size, (unsigned)length, chunkName, data)); + out->allocsize = out->size; /*fix the allocsize again*/ + return 0; +} + +static void writeSignature(ucvector* out) +{ + /*8 bytes PNG signature, aka the magic bytes*/ + ucvector_push_back(out, 137); + ucvector_push_back(out, 80); + ucvector_push_back(out, 78); + ucvector_push_back(out, 71); + ucvector_push_back(out, 13); + ucvector_push_back(out, 10); + ucvector_push_back(out, 26); + ucvector_push_back(out, 10); +} + +static unsigned addChunk_IHDR(ucvector* out, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth, unsigned interlace_method) +{ + unsigned error = 0; + ucvector header; + ucvector_init(&header); + + lodepng_add32bitInt(&header, w); /*width*/ + lodepng_add32bitInt(&header, h); /*height*/ + ucvector_push_back(&header, (unsigned char)bitdepth); /*bit depth*/ + ucvector_push_back(&header, (unsigned char)colortype); /*color type*/ + ucvector_push_back(&header, 0); /*compression method*/ + ucvector_push_back(&header, 0); /*filter method*/ + ucvector_push_back(&header, interlace_method); /*interlace method*/ + + error = addChunk(out, "IHDR", header.data, header.size); + ucvector_cleanup(&header); + + return error; +} + +static unsigned addChunk_PLTE(ucvector* out, const LodePNGColorMode* info) +{ + unsigned error = 0; + size_t i; + ucvector PLTE; + ucvector_init(&PLTE); + for(i = 0; i != info->palettesize * 4; ++i) + { + /*add all channels except alpha channel*/ + if(i % 4 != 3) ucvector_push_back(&PLTE, info->palette[i]); + } + error = addChunk(out, "PLTE", PLTE.data, PLTE.size); + ucvector_cleanup(&PLTE); + + return error; +} + +static unsigned addChunk_tRNS(ucvector* out, const LodePNGColorMode* info) +{ + unsigned error = 0; + size_t i; + ucvector tRNS; + ucvector_init(&tRNS); + if(info->colortype == LCT_PALETTE) + { + size_t amount = info->palettesize; + /*the tail of palette values that all have 255 as alpha, does not have to be encoded*/ + for(i = info->palettesize; i != 0; --i) + { + if(info->palette[4 * (i - 1) + 3] == 255) --amount; + else break; + } + /*add only alpha channel*/ + for(i = 0; i != amount; ++i) ucvector_push_back(&tRNS, info->palette[4 * i + 3]); + } + else if(info->colortype == LCT_GREY) + { + if(info->key_defined) + { + ucvector_push_back(&tRNS, (unsigned char)(info->key_r >> 8)); + ucvector_push_back(&tRNS, (unsigned char)(info->key_r & 255)); + } + } + else if(info->colortype == LCT_RGB) + { + if(info->key_defined) + { + ucvector_push_back(&tRNS, (unsigned char)(info->key_r >> 8)); + ucvector_push_back(&tRNS, (unsigned char)(info->key_r & 255)); + ucvector_push_back(&tRNS, (unsigned char)(info->key_g >> 8)); + ucvector_push_back(&tRNS, (unsigned char)(info->key_g & 255)); + ucvector_push_back(&tRNS, (unsigned char)(info->key_b >> 8)); + ucvector_push_back(&tRNS, (unsigned char)(info->key_b & 255)); + } + } + + error = addChunk(out, "tRNS", tRNS.data, tRNS.size); + ucvector_cleanup(&tRNS); + + return error; +} + +static unsigned addChunk_IDAT(ucvector* out, const unsigned char* data, size_t datasize, + LodePNGCompressSettings* zlibsettings) +{ + ucvector zlibdata; + unsigned error = 0; + + /*compress with the Zlib compressor*/ + ucvector_init(&zlibdata); + error = zlib_compress(&zlibdata.data, &zlibdata.size, data, datasize, zlibsettings); + if(!error) error = addChunk(out, "IDAT", zlibdata.data, zlibdata.size); + ucvector_cleanup(&zlibdata); + + return error; +} + +static unsigned addChunk_IEND(ucvector* out) +{ + unsigned error = 0; + error = addChunk(out, "IEND", 0, 0); + return error; +} + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + +static unsigned addChunk_tEXt(ucvector* out, const char* keyword, const char* textstring) +{ + unsigned error = 0; + size_t i; + ucvector text; + ucvector_init(&text); + for(i = 0; keyword[i] != 0; ++i) ucvector_push_back(&text, (unsigned char)keyword[i]); + if(i < 1 || i > 79) return 89; /*error: invalid keyword size*/ + ucvector_push_back(&text, 0); /*0 termination char*/ + for(i = 0; textstring[i] != 0; ++i) ucvector_push_back(&text, (unsigned char)textstring[i]); + error = addChunk(out, "tEXt", text.data, text.size); + ucvector_cleanup(&text); + + return error; +} + +static unsigned addChunk_zTXt(ucvector* out, const char* keyword, const char* textstring, + LodePNGCompressSettings* zlibsettings) +{ + unsigned error = 0; + ucvector data, compressed; + size_t i, textsize = strlen(textstring); + + ucvector_init(&data); + ucvector_init(&compressed); + for(i = 0; keyword[i] != 0; ++i) ucvector_push_back(&data, (unsigned char)keyword[i]); + if(i < 1 || i > 79) return 89; /*error: invalid keyword size*/ + ucvector_push_back(&data, 0); /*0 termination char*/ + ucvector_push_back(&data, 0); /*compression method: 0*/ + + error = zlib_compress(&compressed.data, &compressed.size, + (unsigned char*)textstring, textsize, zlibsettings); + if(!error) + { + for(i = 0; i != compressed.size; ++i) ucvector_push_back(&data, compressed.data[i]); + error = addChunk(out, "zTXt", data.data, data.size); + } + + ucvector_cleanup(&compressed); + ucvector_cleanup(&data); + return error; +} + +static unsigned addChunk_iTXt(ucvector* out, unsigned compressed, const char* keyword, const char* langtag, + const char* transkey, const char* textstring, LodePNGCompressSettings* zlibsettings) +{ + unsigned error = 0; + ucvector data; + size_t i, textsize = strlen(textstring); + + ucvector_init(&data); + + for(i = 0; keyword[i] != 0; ++i) ucvector_push_back(&data, (unsigned char)keyword[i]); + if(i < 1 || i > 79) return 89; /*error: invalid keyword size*/ + ucvector_push_back(&data, 0); /*null termination char*/ + ucvector_push_back(&data, compressed ? 1 : 0); /*compression flag*/ + ucvector_push_back(&data, 0); /*compression method*/ + for(i = 0; langtag[i] != 0; ++i) ucvector_push_back(&data, (unsigned char)langtag[i]); + ucvector_push_back(&data, 0); /*null termination char*/ + for(i = 0; transkey[i] != 0; ++i) ucvector_push_back(&data, (unsigned char)transkey[i]); + ucvector_push_back(&data, 0); /*null termination char*/ + + if(compressed) + { + ucvector compressed_data; + ucvector_init(&compressed_data); + error = zlib_compress(&compressed_data.data, &compressed_data.size, + (unsigned char*)textstring, textsize, zlibsettings); + if(!error) + { + for(i = 0; i != compressed_data.size; ++i) ucvector_push_back(&data, compressed_data.data[i]); + } + ucvector_cleanup(&compressed_data); + } + else /*not compressed*/ + { + for(i = 0; textstring[i] != 0; ++i) ucvector_push_back(&data, (unsigned char)textstring[i]); + } + + if(!error) error = addChunk(out, "iTXt", data.data, data.size); + ucvector_cleanup(&data); + return error; +} + +static unsigned addChunk_bKGD(ucvector* out, const LodePNGInfo* info) +{ + unsigned error = 0; + ucvector bKGD; + ucvector_init(&bKGD); + if(info->color.colortype == LCT_GREY || info->color.colortype == LCT_GREY_ALPHA) + { + ucvector_push_back(&bKGD, (unsigned char)(info->background_r >> 8)); + ucvector_push_back(&bKGD, (unsigned char)(info->background_r & 255)); + } + else if(info->color.colortype == LCT_RGB || info->color.colortype == LCT_RGBA) + { + ucvector_push_back(&bKGD, (unsigned char)(info->background_r >> 8)); + ucvector_push_back(&bKGD, (unsigned char)(info->background_r & 255)); + ucvector_push_back(&bKGD, (unsigned char)(info->background_g >> 8)); + ucvector_push_back(&bKGD, (unsigned char)(info->background_g & 255)); + ucvector_push_back(&bKGD, (unsigned char)(info->background_b >> 8)); + ucvector_push_back(&bKGD, (unsigned char)(info->background_b & 255)); + } + else if(info->color.colortype == LCT_PALETTE) + { + ucvector_push_back(&bKGD, (unsigned char)(info->background_r & 255)); /*palette index*/ + } + + error = addChunk(out, "bKGD", bKGD.data, bKGD.size); + ucvector_cleanup(&bKGD); + + return error; +} + +static unsigned addChunk_tIME(ucvector* out, const LodePNGTime* time) +{ + unsigned error = 0; + unsigned char* data = (unsigned char*)lodepng_malloc(7); + if(!data) return 83; /*alloc fail*/ + data[0] = (unsigned char)(time->year >> 8); + data[1] = (unsigned char)(time->year & 255); + data[2] = (unsigned char)time->month; + data[3] = (unsigned char)time->day; + data[4] = (unsigned char)time->hour; + data[5] = (unsigned char)time->minute; + data[6] = (unsigned char)time->second; + error = addChunk(out, "tIME", data, 7); + lodepng_free(data); + return error; +} + +static unsigned addChunk_pHYs(ucvector* out, const LodePNGInfo* info) +{ + unsigned error = 0; + ucvector data; + ucvector_init(&data); + + lodepng_add32bitInt(&data, info->phys_x); + lodepng_add32bitInt(&data, info->phys_y); + ucvector_push_back(&data, info->phys_unit); + + error = addChunk(out, "pHYs", data.data, data.size); + ucvector_cleanup(&data); + + return error; +} + +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + +static void filterScanline(unsigned char* out, const unsigned char* scanline, const unsigned char* prevline, + size_t length, size_t bytewidth, unsigned char filterType) +{ + size_t i; + switch(filterType) + { + case 0: /*None*/ + for(i = 0; i != length; ++i) out[i] = scanline[i]; + break; + case 1: /*Sub*/ + for(i = 0; i != bytewidth; ++i) out[i] = scanline[i]; + for(i = bytewidth; i < length; ++i) out[i] = scanline[i] - scanline[i - bytewidth]; + break; + case 2: /*Up*/ + if(prevline) + { + for(i = 0; i != length; ++i) out[i] = scanline[i] - prevline[i]; + } + else + { + for(i = 0; i != length; ++i) out[i] = scanline[i]; + } + break; + case 3: /*Average*/ + if(prevline) + { + for(i = 0; i != bytewidth; ++i) out[i] = scanline[i] - (prevline[i] >> 1); + for(i = bytewidth; i < length; ++i) out[i] = scanline[i] - ((scanline[i - bytewidth] + prevline[i]) >> 1); + } + else + { + for(i = 0; i != bytewidth; ++i) out[i] = scanline[i]; + for(i = bytewidth; i < length; ++i) out[i] = scanline[i] - (scanline[i - bytewidth] >> 1); + } + break; + case 4: /*Paeth*/ + if(prevline) + { + /*paethPredictor(0, prevline[i], 0) is always prevline[i]*/ + for(i = 0; i != bytewidth; ++i) out[i] = (scanline[i] - prevline[i]); + for(i = bytewidth; i < length; ++i) + { + out[i] = (scanline[i] - paethPredictor(scanline[i - bytewidth], prevline[i], prevline[i - bytewidth])); + } + } + else + { + for(i = 0; i != bytewidth; ++i) out[i] = scanline[i]; + /*paethPredictor(scanline[i - bytewidth], 0, 0) is always scanline[i - bytewidth]*/ + for(i = bytewidth; i < length; ++i) out[i] = (scanline[i] - scanline[i - bytewidth]); + } + break; + default: return; /*unexisting filter type given*/ + } +} + +/* log2 approximation. A slight bit faster than std::log. */ +static float flog2(float f) +{ + float result = 0; + while(f > 32) { result += 4; f /= 16; } + while(f > 2) { ++result; f /= 2; } + return result + 1.442695f * (f * f * f / 3 - 3 * f * f / 2 + 3 * f - 1.83333f); +} + +static unsigned filter(unsigned char* out, const unsigned char* in, unsigned w, unsigned h, + const LodePNGColorMode* info, const LodePNGEncoderSettings* settings) +{ + /* + For PNG filter method 0 + out must be a buffer with as size: h + (w * h * bpp + 7) / 8, because there are + the scanlines with 1 extra byte per scanline + */ + + unsigned bpp = lodepng_get_bpp(info); + /*the width of a scanline in bytes, not including the filter type*/ + size_t linebytes = (w * bpp + 7) / 8; + /*bytewidth is used for filtering, is 1 when bpp < 8, number of bytes per pixel otherwise*/ + size_t bytewidth = (bpp + 7) / 8; + const unsigned char* prevline = 0; + unsigned x, y; + unsigned error = 0; + LodePNGFilterStrategy strategy = settings->filter_strategy; + + /* + There is a heuristic called the minimum sum of absolute differences heuristic, suggested by the PNG standard: + * If the image type is Palette, or the bit depth is smaller than 8, then do not filter the image (i.e. + use fixed filtering, with the filter None). + * (The other case) If the image type is Grayscale or RGB (with or without Alpha), and the bit depth is + not smaller than 8, then use adaptive filtering heuristic as follows: independently for each row, apply + all five filters and select the filter that produces the smallest sum of absolute values per row. + This heuristic is used if filter strategy is LFS_MINSUM and filter_palette_zero is true. + + If filter_palette_zero is true and filter_strategy is not LFS_MINSUM, the above heuristic is followed, + but for "the other case", whatever strategy filter_strategy is set to instead of the minimum sum + heuristic is used. + */ + if(settings->filter_palette_zero && + (info->colortype == LCT_PALETTE || info->bitdepth < 8)) strategy = LFS_ZERO; + + if(bpp == 0) return 31; /*error: invalid color type*/ + + if(strategy == LFS_ZERO) + { + for(y = 0; y != h; ++y) + { + size_t outindex = (1 + linebytes) * y; /*the extra filterbyte added to each row*/ + size_t inindex = linebytes * y; + out[outindex] = 0; /*filter type byte*/ + filterScanline(&out[outindex + 1], &in[inindex], prevline, linebytes, bytewidth, 0); + prevline = &in[inindex]; + } + } + else if(strategy == LFS_MINSUM) + { + /*adaptive filtering*/ + size_t sum[5]; + unsigned char* attempt[5]; /*five filtering attempts, one for each filter type*/ + size_t smallest = 0; + unsigned char type, bestType = 0; + + for(type = 0; type != 5; ++type) + { + attempt[type] = (unsigned char*)lodepng_malloc(linebytes); + if(!attempt[type]) return 83; /*alloc fail*/ + } + + if(!error) + { + for(y = 0; y != h; ++y) + { + /*try the 5 filter types*/ + for(type = 0; type != 5; ++type) + { + filterScanline(attempt[type], &in[y * linebytes], prevline, linebytes, bytewidth, type); + + /*calculate the sum of the result*/ + sum[type] = 0; + if(type == 0) + { + for(x = 0; x != linebytes; ++x) sum[type] += (unsigned char)(attempt[type][x]); + } + else + { + for(x = 0; x != linebytes; ++x) + { + /*For differences, each byte should be treated as signed, values above 127 are negative + (converted to signed char). Filtertype 0 isn't a difference though, so use unsigned there. + This means filtertype 0 is almost never chosen, but that is justified.*/ + unsigned char s = attempt[type][x]; + sum[type] += s < 128 ? s : (255U - s); + } + } + + /*check if this is smallest sum (or if type == 0 it's the first case so always store the values)*/ + if(type == 0 || sum[type] < smallest) + { + bestType = type; + smallest = sum[type]; + } + } + + prevline = &in[y * linebytes]; + + /*now fill the out values*/ + out[y * (linebytes + 1)] = bestType; /*the first byte of a scanline will be the filter type*/ + for(x = 0; x != linebytes; ++x) out[y * (linebytes + 1) + 1 + x] = attempt[bestType][x]; + } + } + + for(type = 0; type != 5; ++type) lodepng_free(attempt[type]); + } + else if(strategy == LFS_ENTROPY) + { + float sum[5]; + unsigned char* attempt[5]; /*five filtering attempts, one for each filter type*/ + float smallest = 0; + unsigned type, bestType = 0; + unsigned count[256]; + + for(type = 0; type != 5; ++type) + { + attempt[type] = (unsigned char*)lodepng_malloc(linebytes); + if(!attempt[type]) return 83; /*alloc fail*/ + } + + for(y = 0; y != h; ++y) + { + /*try the 5 filter types*/ + for(type = 0; type != 5; ++type) + { + filterScanline(attempt[type], &in[y * linebytes], prevline, linebytes, bytewidth, type); + for(x = 0; x != 256; ++x) count[x] = 0; + for(x = 0; x != linebytes; ++x) ++count[attempt[type][x]]; + ++count[type]; /*the filter type itself is part of the scanline*/ + sum[type] = 0; + for(x = 0; x != 256; ++x) + { + float p = count[x] / (float)(linebytes + 1); + sum[type] += count[x] == 0 ? 0 : flog2(1 / p) * p; + } + /*check if this is smallest sum (or if type == 0 it's the first case so always store the values)*/ + if(type == 0 || sum[type] < smallest) + { + bestType = type; + smallest = sum[type]; + } + } + + prevline = &in[y * linebytes]; + + /*now fill the out values*/ + out[y * (linebytes + 1)] = bestType; /*the first byte of a scanline will be the filter type*/ + for(x = 0; x != linebytes; ++x) out[y * (linebytes + 1) + 1 + x] = attempt[bestType][x]; + } + + for(type = 0; type != 5; ++type) lodepng_free(attempt[type]); + } + else if(strategy == LFS_PREDEFINED) + { + for(y = 0; y != h; ++y) + { + size_t outindex = (1 + linebytes) * y; /*the extra filterbyte added to each row*/ + size_t inindex = linebytes * y; + unsigned char type = settings->predefined_filters[y]; + out[outindex] = type; /*filter type byte*/ + filterScanline(&out[outindex + 1], &in[inindex], prevline, linebytes, bytewidth, type); + prevline = &in[inindex]; + } + } + else if(strategy == LFS_BRUTE_FORCE) + { + /*brute force filter chooser. + deflate the scanline after every filter attempt to see which one deflates best. + This is very slow and gives only slightly smaller, sometimes even larger, result*/ + size_t size[5]; + unsigned char* attempt[5]; /*five filtering attempts, one for each filter type*/ + size_t smallest = 0; + unsigned type = 0, bestType = 0; + unsigned char* dummy; + LodePNGCompressSettings zlibsettings = settings->zlibsettings; + /*use fixed tree on the attempts so that the tree is not adapted to the filtertype on purpose, + to simulate the true case where the tree is the same for the whole image. Sometimes it gives + better result with dynamic tree anyway. Using the fixed tree sometimes gives worse, but in rare + cases better compression. It does make this a bit less slow, so it's worth doing this.*/ + zlibsettings.btype = 1; + /*a custom encoder likely doesn't read the btype setting and is optimized for complete PNG + images only, so disable it*/ + zlibsettings.custom_zlib = 0; + zlibsettings.custom_deflate = 0; + for(type = 0; type != 5; ++type) + { + attempt[type] = (unsigned char*)lodepng_malloc(linebytes); + if(!attempt[type]) return 83; /*alloc fail*/ + } + for(y = 0; y != h; ++y) /*try the 5 filter types*/ + { + for(type = 0; type != 5; ++type) + { + unsigned testsize = linebytes; + /*if(testsize > 8) testsize /= 8;*/ /*it already works good enough by testing a part of the row*/ + + filterScanline(attempt[type], &in[y * linebytes], prevline, linebytes, bytewidth, type); + size[type] = 0; + dummy = 0; + zlib_compress(&dummy, &size[type], attempt[type], testsize, &zlibsettings); + lodepng_free(dummy); + /*check if this is smallest size (or if type == 0 it's the first case so always store the values)*/ + if(type == 0 || size[type] < smallest) + { + bestType = type; + smallest = size[type]; + } + } + prevline = &in[y * linebytes]; + out[y * (linebytes + 1)] = bestType; /*the first byte of a scanline will be the filter type*/ + for(x = 0; x != linebytes; ++x) out[y * (linebytes + 1) + 1 + x] = attempt[bestType][x]; + } + for(type = 0; type != 5; ++type) lodepng_free(attempt[type]); + } + else return 88; /* unknown filter strategy */ + + return error; +} + +static void addPaddingBits(unsigned char* out, const unsigned char* in, + size_t olinebits, size_t ilinebits, unsigned h) +{ + /*The opposite of the removePaddingBits function + olinebits must be >= ilinebits*/ + unsigned y; + size_t diff = olinebits - ilinebits; + size_t obp = 0, ibp = 0; /*bit pointers*/ + for(y = 0; y != h; ++y) + { + size_t x; + for(x = 0; x < ilinebits; ++x) + { + unsigned char bit = readBitFromReversedStream(&ibp, in); + setBitOfReversedStream(&obp, out, bit); + } + /*obp += diff; --> no, fill in some value in the padding bits too, to avoid + "Use of uninitialised value of size ###" warning from valgrind*/ + for(x = 0; x != diff; ++x) setBitOfReversedStream(&obp, out, 0); + } +} + +/* +in: non-interlaced image with size w*h +out: the same pixels, but re-ordered according to PNG's Adam7 interlacing, with + no padding bits between scanlines, but between reduced images so that each + reduced image starts at a byte. +bpp: bits per pixel +there are no padding bits, not between scanlines, not between reduced images +in has the following size in bits: w * h * bpp. +out is possibly bigger due to padding bits between reduced images +NOTE: comments about padding bits are only relevant if bpp < 8 +*/ +static void Adam7_interlace(unsigned char* out, const unsigned char* in, unsigned w, unsigned h, unsigned bpp) +{ + unsigned passw[7], passh[7]; + size_t filter_passstart[8], padded_passstart[8], passstart[8]; + unsigned i; + + Adam7_getpassvalues(passw, passh, filter_passstart, padded_passstart, passstart, w, h, bpp); + + if(bpp >= 8) + { + for(i = 0; i != 7; ++i) + { + unsigned x, y, b; + size_t bytewidth = bpp / 8; + for(y = 0; y < passh[i]; ++y) + for(x = 0; x < passw[i]; ++x) + { + size_t pixelinstart = ((ADAM7_IY[i] + y * ADAM7_DY[i]) * w + ADAM7_IX[i] + x * ADAM7_DX[i]) * bytewidth; + size_t pixeloutstart = passstart[i] + (y * passw[i] + x) * bytewidth; + for(b = 0; b < bytewidth; ++b) + { + out[pixeloutstart + b] = in[pixelinstart + b]; + } + } + } + } + else /*bpp < 8: Adam7 with pixels < 8 bit is a bit trickier: with bit pointers*/ + { + for(i = 0; i != 7; ++i) + { + unsigned x, y, b; + unsigned ilinebits = bpp * passw[i]; + unsigned olinebits = bpp * w; + size_t obp, ibp; /*bit pointers (for out and in buffer)*/ + for(y = 0; y < passh[i]; ++y) + for(x = 0; x < passw[i]; ++x) + { + ibp = (ADAM7_IY[i] + y * ADAM7_DY[i]) * olinebits + (ADAM7_IX[i] + x * ADAM7_DX[i]) * bpp; + obp = (8 * passstart[i]) + (y * ilinebits + x * bpp); + for(b = 0; b < bpp; ++b) + { + unsigned char bit = readBitFromReversedStream(&ibp, in); + setBitOfReversedStream(&obp, out, bit); + } + } + } + } +} + +/*out must be buffer big enough to contain uncompressed IDAT chunk data, and in must contain the full image. +return value is error**/ +static unsigned preProcessScanlines(unsigned char** out, size_t* outsize, const unsigned char* in, + unsigned w, unsigned h, + const LodePNGInfo* info_png, const LodePNGEncoderSettings* settings) +{ + /* + This function converts the pure 2D image with the PNG's colortype, into filtered-padded-interlaced data. Steps: + *) if no Adam7: 1) add padding bits (= posible extra bits per scanline if bpp < 8) 2) filter + *) if adam7: 1) Adam7_interlace 2) 7x add padding bits 3) 7x filter + */ + unsigned bpp = lodepng_get_bpp(&info_png->color); + unsigned error = 0; + + if(info_png->interlace_method == 0) + { + *outsize = h + (h * ((w * bpp + 7) / 8)); /*image size plus an extra byte per scanline + possible padding bits*/ + *out = (unsigned char*)lodepng_malloc(*outsize); + if(!(*out) && (*outsize)) error = 83; /*alloc fail*/ + + if(!error) + { + /*non multiple of 8 bits per scanline, padding bits needed per scanline*/ + if(bpp < 8 && w * bpp != ((w * bpp + 7) / 8) * 8) + { + unsigned char* padded = (unsigned char*)lodepng_malloc(h * ((w * bpp + 7) / 8)); + if(!padded) error = 83; /*alloc fail*/ + if(!error) + { + addPaddingBits(padded, in, ((w * bpp + 7) / 8) * 8, w * bpp, h); + error = filter(*out, padded, w, h, &info_png->color, settings); + } + lodepng_free(padded); + } + else + { + /*we can immediately filter into the out buffer, no other steps needed*/ + error = filter(*out, in, w, h, &info_png->color, settings); + } + } + } + else /*interlace_method is 1 (Adam7)*/ + { + unsigned passw[7], passh[7]; + size_t filter_passstart[8], padded_passstart[8], passstart[8]; + unsigned char* adam7; + + Adam7_getpassvalues(passw, passh, filter_passstart, padded_passstart, passstart, w, h, bpp); + + *outsize = filter_passstart[7]; /*image size plus an extra byte per scanline + possible padding bits*/ + *out = (unsigned char*)lodepng_malloc(*outsize); + if(!(*out)) error = 83; /*alloc fail*/ + + adam7 = (unsigned char*)lodepng_malloc(passstart[7]); + if(!adam7 && passstart[7]) error = 83; /*alloc fail*/ + + if(!error) + { + unsigned i; + + Adam7_interlace(adam7, in, w, h, bpp); + for(i = 0; i != 7; ++i) + { + if(bpp < 8) + { + unsigned char* padded = (unsigned char*)lodepng_malloc(padded_passstart[i + 1] - padded_passstart[i]); + if(!padded) ERROR_BREAK(83); /*alloc fail*/ + addPaddingBits(padded, &adam7[passstart[i]], + ((passw[i] * bpp + 7) / 8) * 8, passw[i] * bpp, passh[i]); + error = filter(&(*out)[filter_passstart[i]], padded, + passw[i], passh[i], &info_png->color, settings); + lodepng_free(padded); + } + else + { + error = filter(&(*out)[filter_passstart[i]], &adam7[padded_passstart[i]], + passw[i], passh[i], &info_png->color, settings); + } + + if(error) break; + } + } + + lodepng_free(adam7); + } + + return error; +} + +/* +palette must have 4 * palettesize bytes allocated, and given in format RGBARGBARGBARGBA... +returns 0 if the palette is opaque, +returns 1 if the palette has a single color with alpha 0 ==> color key +returns 2 if the palette is semi-translucent. +*/ +static unsigned getPaletteTranslucency(const unsigned char* palette, size_t palettesize) +{ + size_t i; + unsigned key = 0; + unsigned r = 0, g = 0, b = 0; /*the value of the color with alpha 0, so long as color keying is possible*/ + for(i = 0; i != palettesize; ++i) + { + if(!key && palette[4 * i + 3] == 0) + { + r = palette[4 * i + 0]; g = palette[4 * i + 1]; b = palette[4 * i + 2]; + key = 1; + i = (size_t)(-1); /*restart from beginning, to detect earlier opaque colors with key's value*/ + } + else if(palette[4 * i + 3] != 255) return 2; + /*when key, no opaque RGB may have key's RGB*/ + else if(key && r == palette[i * 4 + 0] && g == palette[i * 4 + 1] && b == palette[i * 4 + 2]) return 2; + } + return key; +} + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS +static unsigned addUnknownChunks(ucvector* out, unsigned char* data, size_t datasize) +{ + unsigned char* inchunk = data; + while((size_t)(inchunk - data) < datasize) + { + CERROR_TRY_RETURN(lodepng_chunk_append(&out->data, &out->size, inchunk)); + out->allocsize = out->size; /*fix the allocsize again*/ + inchunk = lodepng_chunk_next(inchunk); + } + return 0; +} +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + +unsigned lodepng_encode(unsigned char** out, size_t* outsize, + const unsigned char* image, unsigned w, unsigned h, + LodePNGState* state) +{ + LodePNGInfo info; + ucvector outv; + unsigned char* data = 0; /*uncompressed version of the IDAT chunk data*/ + size_t datasize = 0; + + /*provide some proper output values if error will happen*/ + *out = 0; + *outsize = 0; + state->error = 0; + + /*check input values validity*/ + if((state->info_png.color.colortype == LCT_PALETTE || state->encoder.force_palette) + && (state->info_png.color.palettesize == 0 || state->info_png.color.palettesize > 256)) + { + CERROR_RETURN_ERROR(state->error, 68); /*invalid palette size, it is only allowed to be 1-256*/ + } + if(state->encoder.zlibsettings.btype > 2) + { + CERROR_RETURN_ERROR(state->error, 61); /*error: unexisting btype*/ + } + if(state->info_png.interlace_method > 1) + { + CERROR_RETURN_ERROR(state->error, 71); /*error: unexisting interlace mode*/ + } + state->error = checkColorValidity(state->info_png.color.colortype, state->info_png.color.bitdepth); + if(state->error) return state->error; /*error: unexisting color type given*/ + state->error = checkColorValidity(state->info_raw.colortype, state->info_raw.bitdepth); + if(state->error) return state->error; /*error: unexisting color type given*/ + + /* color convert and compute scanline filter types */ + lodepng_info_init(&info); + lodepng_info_copy(&info, &state->info_png); + if(state->encoder.auto_convert) + { + state->error = lodepng_auto_choose_color(&info.color, image, w, h, &state->info_raw); + } + if (!state->error) + { + if(!lodepng_color_mode_equal(&state->info_raw, &info.color)) + { + unsigned char* converted; + size_t size = (w * h * (size_t)lodepng_get_bpp(&info.color) + 7) / 8; + + converted = (unsigned char*)lodepng_malloc(size); + if(!converted && size) state->error = 83; /*alloc fail*/ + if(!state->error) + { + state->error = lodepng_convert(converted, image, &info.color, &state->info_raw, w, h); + } + if(!state->error) preProcessScanlines(&data, &datasize, converted, w, h, &info, &state->encoder); + lodepng_free(converted); + } + else preProcessScanlines(&data, &datasize, image, w, h, &info, &state->encoder); + } + + /* output all PNG chunks */ + ucvector_init(&outv); + while(!state->error) /*while only executed once, to break on error*/ + { +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + size_t i; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + /*write signature and chunks*/ + writeSignature(&outv); + /*IHDR*/ + addChunk_IHDR(&outv, w, h, info.color.colortype, info.color.bitdepth, info.interlace_method); +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + /*unknown chunks between IHDR and PLTE*/ + if(info.unknown_chunks_data[0]) + { + state->error = addUnknownChunks(&outv, info.unknown_chunks_data[0], info.unknown_chunks_size[0]); + if(state->error) break; + } +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + /*PLTE*/ + if(info.color.colortype == LCT_PALETTE) + { + addChunk_PLTE(&outv, &info.color); + } + if(state->encoder.force_palette && (info.color.colortype == LCT_RGB || info.color.colortype == LCT_RGBA)) + { + addChunk_PLTE(&outv, &info.color); + } + /*tRNS*/ + if(info.color.colortype == LCT_PALETTE && getPaletteTranslucency(info.color.palette, info.color.palettesize) != 0) + { + addChunk_tRNS(&outv, &info.color); + } + if((info.color.colortype == LCT_GREY || info.color.colortype == LCT_RGB) && info.color.key_defined) + { + addChunk_tRNS(&outv, &info.color); + } +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + /*bKGD (must come between PLTE and the IDAt chunks*/ + if(info.background_defined) addChunk_bKGD(&outv, &info); + /*pHYs (must come before the IDAT chunks)*/ + if(info.phys_defined) addChunk_pHYs(&outv, &info); + + /*unknown chunks between PLTE and IDAT*/ + if(info.unknown_chunks_data[1]) + { + state->error = addUnknownChunks(&outv, info.unknown_chunks_data[1], info.unknown_chunks_size[1]); + if(state->error) break; + } +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + /*IDAT (multiple IDAT chunks must be consecutive)*/ + state->error = addChunk_IDAT(&outv, data, datasize, &state->encoder.zlibsettings); + if(state->error) break; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + /*tIME*/ + if(info.time_defined) addChunk_tIME(&outv, &info.time); + /*tEXt and/or zTXt*/ + for(i = 0; i != info.text_num; ++i) + { + if(strlen(info.text_keys[i]) > 79) + { + state->error = 66; /*text chunk too large*/ + break; + } + if(strlen(info.text_keys[i]) < 1) + { + state->error = 67; /*text chunk too small*/ + break; + } + if(state->encoder.text_compression) + { + addChunk_zTXt(&outv, info.text_keys[i], info.text_strings[i], &state->encoder.zlibsettings); + } + else + { + addChunk_tEXt(&outv, info.text_keys[i], info.text_strings[i]); + } + } + /*LodePNG version id in text chunk*/ + if(state->encoder.add_id) + { + unsigned alread_added_id_text = 0; + for(i = 0; i != info.text_num; ++i) + { + if(!strcmp(info.text_keys[i], "LodePNG")) + { + alread_added_id_text = 1; + break; + } + } + if(alread_added_id_text == 0) + { + addChunk_tEXt(&outv, "LodePNG", LODEPNG_VERSION_STRING); /*it's shorter as tEXt than as zTXt chunk*/ + } + } + /*iTXt*/ + for(i = 0; i != info.itext_num; ++i) + { + if(strlen(info.itext_keys[i]) > 79) + { + state->error = 66; /*text chunk too large*/ + break; + } + if(strlen(info.itext_keys[i]) < 1) + { + state->error = 67; /*text chunk too small*/ + break; + } + addChunk_iTXt(&outv, state->encoder.text_compression, + info.itext_keys[i], info.itext_langtags[i], info.itext_transkeys[i], info.itext_strings[i], + &state->encoder.zlibsettings); + } + + /*unknown chunks between IDAT and IEND*/ + if(info.unknown_chunks_data[2]) + { + state->error = addUnknownChunks(&outv, info.unknown_chunks_data[2], info.unknown_chunks_size[2]); + if(state->error) break; + } +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + addChunk_IEND(&outv); + + break; /*this isn't really a while loop; no error happened so break out now!*/ + } + + lodepng_info_cleanup(&info); + lodepng_free(data); + /*instead of cleaning the vector up, give it to the output*/ + *out = outv.data; + *outsize = outv.size; + + return state->error; +} + +unsigned lodepng_encode_memory(unsigned char** out, size_t* outsize, const unsigned char* image, + unsigned w, unsigned h, LodePNGColorType colortype, unsigned bitdepth) +{ + unsigned error; + LodePNGState state; + lodepng_state_init(&state); + state.info_raw.colortype = colortype; + state.info_raw.bitdepth = bitdepth; + state.info_png.color.colortype = colortype; + state.info_png.color.bitdepth = bitdepth; + lodepng_encode(out, outsize, image, w, h, &state); + error = state.error; + lodepng_state_cleanup(&state); + return error; +} + +unsigned lodepng_encode32(unsigned char** out, size_t* outsize, const unsigned char* image, unsigned w, unsigned h) +{ + return lodepng_encode_memory(out, outsize, image, w, h, LCT_RGBA, 8); +} + +unsigned lodepng_encode24(unsigned char** out, size_t* outsize, const unsigned char* image, unsigned w, unsigned h) +{ + return lodepng_encode_memory(out, outsize, image, w, h, LCT_RGB, 8); +} + +#ifdef LODEPNG_COMPILE_DISK +unsigned lodepng_encode_file(const char* filename, const unsigned char* image, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth) +{ + unsigned char* buffer; + size_t buffersize; + unsigned error = lodepng_encode_memory(&buffer, &buffersize, image, w, h, colortype, bitdepth); + if(!error) error = lodepng_save_file(buffer, buffersize, filename); + lodepng_free(buffer); + return error; +} + +unsigned lodepng_encode32_file(const char* filename, const unsigned char* image, unsigned w, unsigned h) +{ + return lodepng_encode_file(filename, image, w, h, LCT_RGBA, 8); +} + +unsigned lodepng_encode24_file(const char* filename, const unsigned char* image, unsigned w, unsigned h) +{ + return lodepng_encode_file(filename, image, w, h, LCT_RGB, 8); +} +#endif /*LODEPNG_COMPILE_DISK*/ + +void lodepng_encoder_settings_init(LodePNGEncoderSettings* settings) +{ + lodepng_compress_settings_init(&settings->zlibsettings); + settings->filter_palette_zero = 1; + settings->filter_strategy = LFS_MINSUM; + settings->auto_convert = 1; + settings->force_palette = 0; + settings->predefined_filters = 0; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + settings->add_id = 0; + settings->text_compression = 1; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ +} + +#endif /*LODEPNG_COMPILE_ENCODER*/ +#endif /*LODEPNG_COMPILE_PNG*/ + +#ifdef LODEPNG_COMPILE_ERROR_TEXT +/* +This returns the description of a numerical error code in English. This is also +the documentation of all the error codes. +*/ +const char* lodepng_error_text(unsigned code) +{ + switch(code) + { + case 0: return "no error, everything went ok"; + case 1: return "nothing done yet"; /*the Encoder/Decoder has done nothing yet, error checking makes no sense yet*/ + case 10: return "end of input memory reached without huffman end code"; /*while huffman decoding*/ + case 11: return "error in code tree made it jump outside of huffman tree"; /*while huffman decoding*/ + case 13: return "problem while processing dynamic deflate block"; + case 14: return "problem while processing dynamic deflate block"; + case 15: return "problem while processing dynamic deflate block"; + case 16: return "unexisting code while processing dynamic deflate block"; + case 17: return "end of out buffer memory reached while inflating"; + case 18: return "invalid distance code while inflating"; + case 19: return "end of out buffer memory reached while inflating"; + case 20: return "invalid deflate block BTYPE encountered while decoding"; + case 21: return "NLEN is not ones complement of LEN in a deflate block"; + /*end of out buffer memory reached while inflating: + This can happen if the inflated deflate data is longer than the amount of bytes required to fill up + all the pixels of the image, given the color depth and image dimensions. Something that doesn't + happen in a normal, well encoded, PNG image.*/ + case 22: return "end of out buffer memory reached while inflating"; + case 23: return "end of in buffer memory reached while inflating"; + case 24: return "invalid FCHECK in zlib header"; + case 25: return "invalid compression method in zlib header"; + case 26: return "FDICT encountered in zlib header while it's not used for PNG"; + case 27: return "PNG file is smaller than a PNG header"; + /*Checks the magic file header, the first 8 bytes of the PNG file*/ + case 28: return "incorrect PNG signature, it's no PNG or corrupted"; + case 29: return "first chunk is not the header chunk"; + case 30: return "chunk length too large, chunk broken off at end of file"; + case 31: return "illegal PNG color type or bpp"; + case 32: return "illegal PNG compression method"; + case 33: return "illegal PNG filter method"; + case 34: return "illegal PNG interlace method"; + case 35: return "chunk length of a chunk is too large or the chunk too small"; + case 36: return "illegal PNG filter type encountered"; + case 37: return "illegal bit depth for this color type given"; + case 38: return "the palette is too big"; /*more than 256 colors*/ + case 39: return "more palette alpha values given in tRNS chunk than there are colors in the palette"; + case 40: return "tRNS chunk has wrong size for greyscale image"; + case 41: return "tRNS chunk has wrong size for RGB image"; + case 42: return "tRNS chunk appeared while it was not allowed for this color type"; + case 43: return "bKGD chunk has wrong size for palette image"; + case 44: return "bKGD chunk has wrong size for greyscale image"; + case 45: return "bKGD chunk has wrong size for RGB image"; + case 48: return "empty input buffer given to decoder. Maybe caused by non-existing file?"; + case 49: return "jumped past memory while generating dynamic huffman tree"; + case 50: return "jumped past memory while generating dynamic huffman tree"; + case 51: return "jumped past memory while inflating huffman block"; + case 52: return "jumped past memory while inflating"; + case 53: return "size of zlib data too small"; + case 54: return "repeat symbol in tree while there was no value symbol yet"; + /*jumped past tree while generating huffman tree, this could be when the + tree will have more leaves than symbols after generating it out of the + given lenghts. They call this an oversubscribed dynamic bit lengths tree in zlib.*/ + case 55: return "jumped past tree while generating huffman tree"; + case 56: return "given output image colortype or bitdepth not supported for color conversion"; + case 57: return "invalid CRC encountered (checking CRC can be disabled)"; + case 58: return "invalid ADLER32 encountered (checking ADLER32 can be disabled)"; + case 59: return "requested color conversion not supported"; + case 60: return "invalid window size given in the settings of the encoder (must be 0-32768)"; + case 61: return "invalid BTYPE given in the settings of the encoder (only 0, 1 and 2 are allowed)"; + /*LodePNG leaves the choice of RGB to greyscale conversion formula to the user.*/ + case 62: return "conversion from color to greyscale not supported"; + case 63: return "length of a chunk too long, max allowed for PNG is 2147483647 bytes per chunk"; /*(2^31-1)*/ + /*this would result in the inability of a deflated block to ever contain an end code. It must be at least 1.*/ + case 64: return "the length of the END symbol 256 in the Huffman tree is 0"; + case 66: return "the length of a text chunk keyword given to the encoder is longer than the maximum of 79 bytes"; + case 67: return "the length of a text chunk keyword given to the encoder is smaller than the minimum of 1 byte"; + case 68: return "tried to encode a PLTE chunk with a palette that has less than 1 or more than 256 colors"; + case 69: return "unknown chunk type with 'critical' flag encountered by the decoder"; + case 71: return "unexisting interlace mode given to encoder (must be 0 or 1)"; + case 72: return "while decoding, unexisting compression method encountering in zTXt or iTXt chunk (it must be 0)"; + case 73: return "invalid tIME chunk size"; + case 74: return "invalid pHYs chunk size"; + /*length could be wrong, or data chopped off*/ + case 75: return "no null termination char found while decoding text chunk"; + case 76: return "iTXt chunk too short to contain required bytes"; + case 77: return "integer overflow in buffer size"; + case 78: return "failed to open file for reading"; /*file doesn't exist or couldn't be opened for reading*/ + case 79: return "failed to open file for writing"; + case 80: return "tried creating a tree of 0 symbols"; + case 81: return "lazy matching at pos 0 is impossible"; + case 82: return "color conversion to palette requested while a color isn't in palette"; + case 83: return "memory allocation failed"; + case 84: return "given image too small to contain all pixels to be encoded"; + case 86: return "impossible offset in lz77 encoding (internal bug)"; + case 87: return "must provide custom zlib function pointer if LODEPNG_COMPILE_ZLIB is not defined"; + case 88: return "invalid filter strategy given for LodePNGEncoderSettings.filter_strategy"; + case 89: return "text chunk keyword too short or long: must have size 1-79"; + /*the windowsize in the LodePNGCompressSettings. Requiring POT(==> & instead of %) makes encoding 12% faster.*/ + case 90: return "windowsize must be a power of two"; + case 91: return "invalid decompressed idat size"; + case 92: return "too many pixels, not supported"; + case 93: return "zero width or height is invalid"; + case 94: return "header chunk must have a size of 13 bytes"; + } + return "unknown error code"; +} +#endif /*LODEPNG_COMPILE_ERROR_TEXT*/ + +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* // C++ Wrapper // */ +/* ////////////////////////////////////////////////////////////////////////// */ +/* ////////////////////////////////////////////////////////////////////////// */ + +#ifdef LODEPNG_COMPILE_CPP +namespace lodepng +{ + +#ifdef LODEPNG_COMPILE_DISK +unsigned load_file(std::vector& buffer, const std::string& filename) +{ + long size = lodepng_filesize(filename.c_str()); + if(size < 0) return 78; + buffer.resize((size_t)size); + return size == 0 ? 0 : lodepng_buffer_file(&buffer[0], (size_t)size, filename.c_str()); +} + +/*write given buffer to the file, overwriting the file, it doesn't append to it.*/ +unsigned save_file(const std::vector& buffer, const std::string& filename) +{ + return lodepng_save_file(buffer.empty() ? 0 : &buffer[0], buffer.size(), filename.c_str()); +} +#endif /* LODEPNG_COMPILE_DISK */ + +#ifdef LODEPNG_COMPILE_ZLIB +#ifdef LODEPNG_COMPILE_DECODER +unsigned decompress(std::vector& out, const unsigned char* in, size_t insize, + const LodePNGDecompressSettings& settings) +{ + unsigned char* buffer = 0; + size_t buffersize = 0; + unsigned error = zlib_decompress(&buffer, &buffersize, in, insize, &settings); + if(buffer) + { + out.insert(out.end(), &buffer[0], &buffer[buffersize]); + lodepng_free(buffer); + } + return error; +} + +unsigned decompress(std::vector& out, const std::vector& in, + const LodePNGDecompressSettings& settings) +{ + return decompress(out, in.empty() ? 0 : &in[0], in.size(), settings); +} +#endif /* LODEPNG_COMPILE_DECODER */ + +#ifdef LODEPNG_COMPILE_ENCODER +unsigned compress(std::vector& out, const unsigned char* in, size_t insize, + const LodePNGCompressSettings& settings) +{ + unsigned char* buffer = 0; + size_t buffersize = 0; + unsigned error = zlib_compress(&buffer, &buffersize, in, insize, &settings); + if(buffer) + { + out.insert(out.end(), &buffer[0], &buffer[buffersize]); + lodepng_free(buffer); + } + return error; +} + +unsigned compress(std::vector& out, const std::vector& in, + const LodePNGCompressSettings& settings) +{ + return compress(out, in.empty() ? 0 : &in[0], in.size(), settings); +} +#endif /* LODEPNG_COMPILE_ENCODER */ +#endif /* LODEPNG_COMPILE_ZLIB */ + + +#ifdef LODEPNG_COMPILE_PNG + +State::State() +{ + lodepng_state_init(this); +} + +State::State(const State& other) +{ + lodepng_state_init(this); + lodepng_state_copy(this, &other); +} + +State::~State() +{ + lodepng_state_cleanup(this); +} + +State& State::operator=(const State& other) +{ + lodepng_state_copy(this, &other); + return *this; +} + +#ifdef LODEPNG_COMPILE_DECODER + +unsigned decode(std::vector& out, unsigned& w, unsigned& h, const unsigned char* in, + size_t insize, LodePNGColorType colortype, unsigned bitdepth) +{ + unsigned char* buffer; + unsigned error = lodepng_decode_memory(&buffer, &w, &h, in, insize, colortype, bitdepth); + if(buffer && !error) + { + State state; + state.info_raw.colortype = colortype; + state.info_raw.bitdepth = bitdepth; + size_t buffersize = lodepng_get_raw_size(w, h, &state.info_raw); + out.insert(out.end(), &buffer[0], &buffer[buffersize]); + lodepng_free(buffer); + } + return error; +} + +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + const std::vector& in, LodePNGColorType colortype, unsigned bitdepth) +{ + return decode(out, w, h, in.empty() ? 0 : &in[0], (unsigned)in.size(), colortype, bitdepth); +} + +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + State& state, + const unsigned char* in, size_t insize) +{ + unsigned char* buffer = NULL; + unsigned error = lodepng_decode(&buffer, &w, &h, &state, in, insize); + if(buffer && !error) + { + size_t buffersize = lodepng_get_raw_size(w, h, &state.info_raw); + out.insert(out.end(), &buffer[0], &buffer[buffersize]); + } + lodepng_free(buffer); + return error; +} + +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + State& state, + const std::vector& in) +{ + return decode(out, w, h, state, in.empty() ? 0 : &in[0], in.size()); +} + +#ifdef LODEPNG_COMPILE_DISK +unsigned decode(std::vector& out, unsigned& w, unsigned& h, const std::string& filename, + LodePNGColorType colortype, unsigned bitdepth) +{ + std::vector buffer; + unsigned error = load_file(buffer, filename); + if(error) return error; + return decode(out, w, h, buffer, colortype, bitdepth); +} +#endif /* LODEPNG_COMPILE_DECODER */ +#endif /* LODEPNG_COMPILE_DISK */ + +#ifdef LODEPNG_COMPILE_ENCODER +unsigned encode(std::vector& out, const unsigned char* in, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth) +{ + unsigned char* buffer; + size_t buffersize; + unsigned error = lodepng_encode_memory(&buffer, &buffersize, in, w, h, colortype, bitdepth); + if(buffer) + { + out.insert(out.end(), &buffer[0], &buffer[buffersize]); + lodepng_free(buffer); + } + return error; +} + +unsigned encode(std::vector& out, + const std::vector& in, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth) +{ + if(lodepng_get_raw_size_lct(w, h, colortype, bitdepth) > in.size()) return 84; + return encode(out, in.empty() ? 0 : &in[0], w, h, colortype, bitdepth); +} + +unsigned encode(std::vector& out, + const unsigned char* in, unsigned w, unsigned h, + State& state) +{ + unsigned char* buffer; + size_t buffersize; + unsigned error = lodepng_encode(&buffer, &buffersize, in, w, h, &state); + if(buffer) + { + out.insert(out.end(), &buffer[0], &buffer[buffersize]); + lodepng_free(buffer); + } + return error; +} + +unsigned encode(std::vector& out, + const std::vector& in, unsigned w, unsigned h, + State& state) +{ + if(lodepng_get_raw_size(w, h, &state.info_raw) > in.size()) return 84; + return encode(out, in.empty() ? 0 : &in[0], w, h, state); +} + +#ifdef LODEPNG_COMPILE_DISK +unsigned encode(const std::string& filename, + const unsigned char* in, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth) +{ + std::vector buffer; + unsigned error = encode(buffer, in, w, h, colortype, bitdepth); + if(!error) error = save_file(buffer, filename); + return error; +} + +unsigned encode(const std::string& filename, + const std::vector& in, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth) +{ + if(lodepng_get_raw_size_lct(w, h, colortype, bitdepth) > in.size()) return 84; + return encode(filename, in.empty() ? 0 : &in[0], w, h, colortype, bitdepth); +} +#endif /* LODEPNG_COMPILE_DISK */ +#endif /* LODEPNG_COMPILE_ENCODER */ +#endif /* LODEPNG_COMPILE_PNG */ +} /* namespace lodepng */ +#endif /*LODEPNG_COMPILE_CPP*/ diff --git a/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/lodepng.h b/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/lodepng.h new file mode 100644 index 0000000..d633bfa --- /dev/null +++ b/Deep3DFaceReconstruction/tf_mesh_renderer/third_party/lodepng.h @@ -0,0 +1,1762 @@ +/* +LodePNG version 20170917 + +Copyright (c) 2005-2017 Lode Vandevenne + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + + 3. This notice may not be removed or altered from any source + distribution. +*/ + +#ifndef LODEPNG_H +#define LODEPNG_H + +#include /*for size_t*/ + +extern const char* LODEPNG_VERSION_STRING; + +/* +The following #defines are used to create code sections. They can be disabled +to disable code sections, which can give faster compile time and smaller binary. +The "NO_COMPILE" defines are designed to be used to pass as defines to the +compiler command to disable them without modifying this header, e.g. +-DLODEPNG_NO_COMPILE_ZLIB for gcc. +In addition to those below, you can also define LODEPNG_NO_COMPILE_CRC to +allow implementing a custom lodepng_crc32. +*/ +/*deflate & zlib. If disabled, you must specify alternative zlib functions in +the custom_zlib field of the compress and decompress settings*/ +#ifndef LODEPNG_NO_COMPILE_ZLIB +#define LODEPNG_COMPILE_ZLIB +#endif +/*png encoder and png decoder*/ +#ifndef LODEPNG_NO_COMPILE_PNG +#define LODEPNG_COMPILE_PNG +#endif +/*deflate&zlib decoder and png decoder*/ +#ifndef LODEPNG_NO_COMPILE_DECODER +#define LODEPNG_COMPILE_DECODER +#endif +/*deflate&zlib encoder and png encoder*/ +#ifndef LODEPNG_NO_COMPILE_ENCODER +#define LODEPNG_COMPILE_ENCODER +#endif +/*the optional built in harddisk file loading and saving functions*/ +#ifndef LODEPNG_NO_COMPILE_DISK +#define LODEPNG_COMPILE_DISK +#endif +/*support for chunks other than IHDR, IDAT, PLTE, tRNS, IEND: ancillary and unknown chunks*/ +#ifndef LODEPNG_NO_COMPILE_ANCILLARY_CHUNKS +#define LODEPNG_COMPILE_ANCILLARY_CHUNKS +#endif +/*ability to convert error numerical codes to English text string*/ +#ifndef LODEPNG_NO_COMPILE_ERROR_TEXT +#define LODEPNG_COMPILE_ERROR_TEXT +#endif +/*Compile the default allocators (C's free, malloc and realloc). If you disable this, +you can define the functions lodepng_free, lodepng_malloc and lodepng_realloc in your +source files with custom allocators.*/ +#ifndef LODEPNG_NO_COMPILE_ALLOCATORS +#define LODEPNG_COMPILE_ALLOCATORS +#endif +/*compile the C++ version (you can disable the C++ wrapper here even when compiling for C++)*/ +#ifdef __cplusplus +#ifndef LODEPNG_NO_COMPILE_CPP +#define LODEPNG_COMPILE_CPP +#endif +#endif + +#ifdef LODEPNG_COMPILE_CPP +#include +#include +#endif /*LODEPNG_COMPILE_CPP*/ + +#ifdef LODEPNG_COMPILE_PNG +/*The PNG color types (also used for raw).*/ +typedef enum LodePNGColorType +{ + LCT_GREY = 0, /*greyscale: 1,2,4,8,16 bit*/ + LCT_RGB = 2, /*RGB: 8,16 bit*/ + LCT_PALETTE = 3, /*palette: 1,2,4,8 bit*/ + LCT_GREY_ALPHA = 4, /*greyscale with alpha: 8,16 bit*/ + LCT_RGBA = 6 /*RGB with alpha: 8,16 bit*/ +} LodePNGColorType; + +#ifdef LODEPNG_COMPILE_DECODER +/* +Converts PNG data in memory to raw pixel data. +out: Output parameter. Pointer to buffer that will contain the raw pixel data. + After decoding, its size is w * h * (bytes per pixel) bytes larger than + initially. Bytes per pixel depends on colortype and bitdepth. + Must be freed after usage with free(*out). + Note: for 16-bit per channel colors, uses big endian format like PNG does. +w: Output parameter. Pointer to width of pixel data. +h: Output parameter. Pointer to height of pixel data. +in: Memory buffer with the PNG file. +insize: size of the in buffer. +colortype: the desired color type for the raw output image. See explanation on PNG color types. +bitdepth: the desired bit depth for the raw output image. See explanation on PNG color types. +Return value: LodePNG error code (0 means no error). +*/ +unsigned lodepng_decode_memory(unsigned char** out, unsigned* w, unsigned* h, + const unsigned char* in, size_t insize, + LodePNGColorType colortype, unsigned bitdepth); + +/*Same as lodepng_decode_memory, but always decodes to 32-bit RGBA raw image*/ +unsigned lodepng_decode32(unsigned char** out, unsigned* w, unsigned* h, + const unsigned char* in, size_t insize); + +/*Same as lodepng_decode_memory, but always decodes to 24-bit RGB raw image*/ +unsigned lodepng_decode24(unsigned char** out, unsigned* w, unsigned* h, + const unsigned char* in, size_t insize); + +#ifdef LODEPNG_COMPILE_DISK +/* +Load PNG from disk, from file with given name. +Same as the other decode functions, but instead takes a filename as input. +*/ +unsigned lodepng_decode_file(unsigned char** out, unsigned* w, unsigned* h, + const char* filename, + LodePNGColorType colortype, unsigned bitdepth); + +/*Same as lodepng_decode_file, but always decodes to 32-bit RGBA raw image.*/ +unsigned lodepng_decode32_file(unsigned char** out, unsigned* w, unsigned* h, + const char* filename); + +/*Same as lodepng_decode_file, but always decodes to 24-bit RGB raw image.*/ +unsigned lodepng_decode24_file(unsigned char** out, unsigned* w, unsigned* h, + const char* filename); +#endif /*LODEPNG_COMPILE_DISK*/ +#endif /*LODEPNG_COMPILE_DECODER*/ + + +#ifdef LODEPNG_COMPILE_ENCODER +/* +Converts raw pixel data into a PNG image in memory. The colortype and bitdepth + of the output PNG image cannot be chosen, they are automatically determined + by the colortype, bitdepth and content of the input pixel data. + Note: for 16-bit per channel colors, needs big endian format like PNG does. +out: Output parameter. Pointer to buffer that will contain the PNG image data. + Must be freed after usage with free(*out). +outsize: Output parameter. Pointer to the size in bytes of the out buffer. +image: The raw pixel data to encode. The size of this buffer should be + w * h * (bytes per pixel), bytes per pixel depends on colortype and bitdepth. +w: width of the raw pixel data in pixels. +h: height of the raw pixel data in pixels. +colortype: the color type of the raw input image. See explanation on PNG color types. +bitdepth: the bit depth of the raw input image. See explanation on PNG color types. +Return value: LodePNG error code (0 means no error). +*/ +unsigned lodepng_encode_memory(unsigned char** out, size_t* outsize, + const unsigned char* image, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth); + +/*Same as lodepng_encode_memory, but always encodes from 32-bit RGBA raw image.*/ +unsigned lodepng_encode32(unsigned char** out, size_t* outsize, + const unsigned char* image, unsigned w, unsigned h); + +/*Same as lodepng_encode_memory, but always encodes from 24-bit RGB raw image.*/ +unsigned lodepng_encode24(unsigned char** out, size_t* outsize, + const unsigned char* image, unsigned w, unsigned h); + +#ifdef LODEPNG_COMPILE_DISK +/* +Converts raw pixel data into a PNG file on disk. +Same as the other encode functions, but instead takes a filename as output. +NOTE: This overwrites existing files without warning! +*/ +unsigned lodepng_encode_file(const char* filename, + const unsigned char* image, unsigned w, unsigned h, + LodePNGColorType colortype, unsigned bitdepth); + +/*Same as lodepng_encode_file, but always encodes from 32-bit RGBA raw image.*/ +unsigned lodepng_encode32_file(const char* filename, + const unsigned char* image, unsigned w, unsigned h); + +/*Same as lodepng_encode_file, but always encodes from 24-bit RGB raw image.*/ +unsigned lodepng_encode24_file(const char* filename, + const unsigned char* image, unsigned w, unsigned h); +#endif /*LODEPNG_COMPILE_DISK*/ +#endif /*LODEPNG_COMPILE_ENCODER*/ + + +#ifdef LODEPNG_COMPILE_CPP +namespace lodepng +{ +#ifdef LODEPNG_COMPILE_DECODER +/*Same as lodepng_decode_memory, but decodes to an std::vector. The colortype +is the format to output the pixels to. Default is RGBA 8-bit per channel.*/ +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + const unsigned char* in, size_t insize, + LodePNGColorType colortype = LCT_RGBA, unsigned bitdepth = 8); +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + const std::vector& in, + LodePNGColorType colortype = LCT_RGBA, unsigned bitdepth = 8); +#ifdef LODEPNG_COMPILE_DISK +/* +Converts PNG file from disk to raw pixel data in memory. +Same as the other decode functions, but instead takes a filename as input. +*/ +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + const std::string& filename, + LodePNGColorType colortype = LCT_RGBA, unsigned bitdepth = 8); +#endif /* LODEPNG_COMPILE_DISK */ +#endif /* LODEPNG_COMPILE_DECODER */ + +#ifdef LODEPNG_COMPILE_ENCODER +/*Same as lodepng_encode_memory, but encodes to an std::vector. colortype +is that of the raw input data. The output PNG color type will be auto chosen.*/ +unsigned encode(std::vector& out, + const unsigned char* in, unsigned w, unsigned h, + LodePNGColorType colortype = LCT_RGBA, unsigned bitdepth = 8); +unsigned encode(std::vector& out, + const std::vector& in, unsigned w, unsigned h, + LodePNGColorType colortype = LCT_RGBA, unsigned bitdepth = 8); +#ifdef LODEPNG_COMPILE_DISK +/* +Converts 32-bit RGBA raw pixel data into a PNG file on disk. +Same as the other encode functions, but instead takes a filename as output. +NOTE: This overwrites existing files without warning! +*/ +unsigned encode(const std::string& filename, + const unsigned char* in, unsigned w, unsigned h, + LodePNGColorType colortype = LCT_RGBA, unsigned bitdepth = 8); +unsigned encode(const std::string& filename, + const std::vector& in, unsigned w, unsigned h, + LodePNGColorType colortype = LCT_RGBA, unsigned bitdepth = 8); +#endif /* LODEPNG_COMPILE_DISK */ +#endif /* LODEPNG_COMPILE_ENCODER */ +} /* namespace lodepng */ +#endif /*LODEPNG_COMPILE_CPP*/ +#endif /*LODEPNG_COMPILE_PNG*/ + +#ifdef LODEPNG_COMPILE_ERROR_TEXT +/*Returns an English description of the numerical error code.*/ +const char* lodepng_error_text(unsigned code); +#endif /*LODEPNG_COMPILE_ERROR_TEXT*/ + +#ifdef LODEPNG_COMPILE_DECODER +/*Settings for zlib decompression*/ +typedef struct LodePNGDecompressSettings LodePNGDecompressSettings; +struct LodePNGDecompressSettings +{ + unsigned ignore_adler32; /*if 1, continue and don't give an error message if the Adler32 checksum is corrupted*/ + + /*use custom zlib decoder instead of built in one (default: null)*/ + unsigned (*custom_zlib)(unsigned char**, size_t*, + const unsigned char*, size_t, + const LodePNGDecompressSettings*); + /*use custom deflate decoder instead of built in one (default: null) + if custom_zlib is used, custom_deflate is ignored since only the built in + zlib function will call custom_deflate*/ + unsigned (*custom_inflate)(unsigned char**, size_t*, + const unsigned char*, size_t, + const LodePNGDecompressSettings*); + + const void* custom_context; /*optional custom settings for custom functions*/ +}; + +extern const LodePNGDecompressSettings lodepng_default_decompress_settings; +void lodepng_decompress_settings_init(LodePNGDecompressSettings* settings); +#endif /*LODEPNG_COMPILE_DECODER*/ + +#ifdef LODEPNG_COMPILE_ENCODER +/* +Settings for zlib compression. Tweaking these settings tweaks the balance +between speed and compression ratio. +*/ +typedef struct LodePNGCompressSettings LodePNGCompressSettings; +struct LodePNGCompressSettings /*deflate = compress*/ +{ + /*LZ77 related settings*/ + unsigned btype; /*the block type for LZ (0, 1, 2 or 3, see zlib standard). Should be 2 for proper compression.*/ + unsigned use_lz77; /*whether or not to use LZ77. Should be 1 for proper compression.*/ + unsigned windowsize; /*must be a power of two <= 32768. higher compresses more but is slower. Default value: 2048.*/ + unsigned minmatch; /*mininum lz77 length. 3 is normally best, 6 can be better for some PNGs. Default: 0*/ + unsigned nicematch; /*stop searching if >= this length found. Set to 258 for best compression. Default: 128*/ + unsigned lazymatching; /*use lazy matching: better compression but a bit slower. Default: true*/ + + /*use custom zlib encoder instead of built in one (default: null)*/ + unsigned (*custom_zlib)(unsigned char**, size_t*, + const unsigned char*, size_t, + const LodePNGCompressSettings*); + /*use custom deflate encoder instead of built in one (default: null) + if custom_zlib is used, custom_deflate is ignored since only the built in + zlib function will call custom_deflate*/ + unsigned (*custom_deflate)(unsigned char**, size_t*, + const unsigned char*, size_t, + const LodePNGCompressSettings*); + + const void* custom_context; /*optional custom settings for custom functions*/ +}; + +extern const LodePNGCompressSettings lodepng_default_compress_settings; +void lodepng_compress_settings_init(LodePNGCompressSettings* settings); +#endif /*LODEPNG_COMPILE_ENCODER*/ + +#ifdef LODEPNG_COMPILE_PNG +/* +Color mode of an image. Contains all information required to decode the pixel +bits to RGBA colors. This information is the same as used in the PNG file +format, and is used both for PNG and raw image data in LodePNG. +*/ +typedef struct LodePNGColorMode +{ + /*header (IHDR)*/ + LodePNGColorType colortype; /*color type, see PNG standard or documentation further in this header file*/ + unsigned bitdepth; /*bits per sample, see PNG standard or documentation further in this header file*/ + + /* + palette (PLTE and tRNS) + + Dynamically allocated with the colors of the palette, including alpha. + When encoding a PNG, to store your colors in the palette of the LodePNGColorMode, first use + lodepng_palette_clear, then for each color use lodepng_palette_add. + If you encode an image without alpha with palette, don't forget to put value 255 in each A byte of the palette. + + When decoding, by default you can ignore this palette, since LodePNG already + fills the palette colors in the pixels of the raw RGBA output. + + The palette is only supported for color type 3. + */ + unsigned char* palette; /*palette in RGBARGBA... order. When allocated, must be either 0, or have size 1024*/ + size_t palettesize; /*palette size in number of colors (amount of bytes is 4 * palettesize)*/ + + /* + transparent color key (tRNS) + + This color uses the same bit depth as the bitdepth value in this struct, which can be 1-bit to 16-bit. + For greyscale PNGs, r, g and b will all 3 be set to the same. + + When decoding, by default you can ignore this information, since LodePNG sets + pixels with this key to transparent already in the raw RGBA output. + + The color key is only supported for color types 0 and 2. + */ + unsigned key_defined; /*is a transparent color key given? 0 = false, 1 = true*/ + unsigned key_r; /*red/greyscale component of color key*/ + unsigned key_g; /*green component of color key*/ + unsigned key_b; /*blue component of color key*/ +} LodePNGColorMode; + +/*init, cleanup and copy functions to use with this struct*/ +void lodepng_color_mode_init(LodePNGColorMode* info); +void lodepng_color_mode_cleanup(LodePNGColorMode* info); +/*return value is error code (0 means no error)*/ +unsigned lodepng_color_mode_copy(LodePNGColorMode* dest, const LodePNGColorMode* source); + +void lodepng_palette_clear(LodePNGColorMode* info); +/*add 1 color to the palette*/ +unsigned lodepng_palette_add(LodePNGColorMode* info, + unsigned char r, unsigned char g, unsigned char b, unsigned char a); + +/*get the total amount of bits per pixel, based on colortype and bitdepth in the struct*/ +unsigned lodepng_get_bpp(const LodePNGColorMode* info); +/*get the amount of color channels used, based on colortype in the struct. +If a palette is used, it counts as 1 channel.*/ +unsigned lodepng_get_channels(const LodePNGColorMode* info); +/*is it a greyscale type? (only colortype 0 or 4)*/ +unsigned lodepng_is_greyscale_type(const LodePNGColorMode* info); +/*has it got an alpha channel? (only colortype 2 or 6)*/ +unsigned lodepng_is_alpha_type(const LodePNGColorMode* info); +/*has it got a palette? (only colortype 3)*/ +unsigned lodepng_is_palette_type(const LodePNGColorMode* info); +/*only returns true if there is a palette and there is a value in the palette with alpha < 255. +Loops through the palette to check this.*/ +unsigned lodepng_has_palette_alpha(const LodePNGColorMode* info); +/* +Check if the given color info indicates the possibility of having non-opaque pixels in the PNG image. +Returns true if the image can have translucent or invisible pixels (it still be opaque if it doesn't use such pixels). +Returns false if the image can only have opaque pixels. +In detail, it returns true only if it's a color type with alpha, or has a palette with non-opaque values, +or if "key_defined" is true. +*/ +unsigned lodepng_can_have_alpha(const LodePNGColorMode* info); +/*Returns the byte size of a raw image buffer with given width, height and color mode*/ +size_t lodepng_get_raw_size(unsigned w, unsigned h, const LodePNGColorMode* color); + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS +/*The information of a Time chunk in PNG.*/ +typedef struct LodePNGTime +{ + unsigned year; /*2 bytes used (0-65535)*/ + unsigned month; /*1-12*/ + unsigned day; /*1-31*/ + unsigned hour; /*0-23*/ + unsigned minute; /*0-59*/ + unsigned second; /*0-60 (to allow for leap seconds)*/ +} LodePNGTime; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + +/*Information about the PNG image, except pixels, width and height.*/ +typedef struct LodePNGInfo +{ + /*header (IHDR), palette (PLTE) and transparency (tRNS) chunks*/ + unsigned compression_method;/*compression method of the original file. Always 0.*/ + unsigned filter_method; /*filter method of the original file*/ + unsigned interlace_method; /*interlace method of the original file*/ + LodePNGColorMode color; /*color type and bits, palette and transparency of the PNG file*/ + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + /* + suggested background color chunk (bKGD) + This color uses the same color mode as the PNG (except alpha channel), which can be 1-bit to 16-bit. + + For greyscale PNGs, r, g and b will all 3 be set to the same. When encoding + the encoder writes the red one. For palette PNGs: When decoding, the RGB value + will be stored, not a palette index. But when encoding, specify the index of + the palette in background_r, the other two are then ignored. + + The decoder does not use this background color to edit the color of pixels. + */ + unsigned background_defined; /*is a suggested background color given?*/ + unsigned background_r; /*red component of suggested background color*/ + unsigned background_g; /*green component of suggested background color*/ + unsigned background_b; /*blue component of suggested background color*/ + + /* + non-international text chunks (tEXt and zTXt) + + The char** arrays each contain num strings. The actual messages are in + text_strings, while text_keys are keywords that give a short description what + the actual text represents, e.g. Title, Author, Description, or anything else. + + A keyword is minimum 1 character and maximum 79 characters long. It's + discouraged to use a single line length longer than 79 characters for texts. + + Don't allocate these text buffers yourself. Use the init/cleanup functions + correctly and use lodepng_add_text and lodepng_clear_text. + */ + size_t text_num; /*the amount of texts in these char** buffers (there may be more texts in itext)*/ + char** text_keys; /*the keyword of a text chunk (e.g. "Comment")*/ + char** text_strings; /*the actual text*/ + + /* + international text chunks (iTXt) + Similar to the non-international text chunks, but with additional strings + "langtags" and "transkeys". + */ + size_t itext_num; /*the amount of international texts in this PNG*/ + char** itext_keys; /*the English keyword of the text chunk (e.g. "Comment")*/ + char** itext_langtags; /*language tag for this text's language, ISO/IEC 646 string, e.g. ISO 639 language tag*/ + char** itext_transkeys; /*keyword translated to the international language - UTF-8 string*/ + char** itext_strings; /*the actual international text - UTF-8 string*/ + + /*time chunk (tIME)*/ + unsigned time_defined; /*set to 1 to make the encoder generate a tIME chunk*/ + LodePNGTime time; + + /*phys chunk (pHYs)*/ + unsigned phys_defined; /*if 0, there is no pHYs chunk and the values below are undefined, if 1 else there is one*/ + unsigned phys_x; /*pixels per unit in x direction*/ + unsigned phys_y; /*pixels per unit in y direction*/ + unsigned phys_unit; /*may be 0 (unknown unit) or 1 (metre)*/ + + /* + unknown chunks + There are 3 buffers, one for each position in the PNG where unknown chunks can appear + each buffer contains all unknown chunks for that position consecutively + The 3 buffers are the unknown chunks between certain critical chunks: + 0: IHDR-PLTE, 1: PLTE-IDAT, 2: IDAT-IEND + Do not allocate or traverse this data yourself. Use the chunk traversing functions declared + later, such as lodepng_chunk_next and lodepng_chunk_append, to read/write this struct. + */ + unsigned char* unknown_chunks_data[3]; + size_t unknown_chunks_size[3]; /*size in bytes of the unknown chunks, given for protection*/ +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ +} LodePNGInfo; + +/*init, cleanup and copy functions to use with this struct*/ +void lodepng_info_init(LodePNGInfo* info); +void lodepng_info_cleanup(LodePNGInfo* info); +/*return value is error code (0 means no error)*/ +unsigned lodepng_info_copy(LodePNGInfo* dest, const LodePNGInfo* source); + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS +void lodepng_clear_text(LodePNGInfo* info); /*use this to clear the texts again after you filled them in*/ +unsigned lodepng_add_text(LodePNGInfo* info, const char* key, const char* str); /*push back both texts at once*/ + +void lodepng_clear_itext(LodePNGInfo* info); /*use this to clear the itexts again after you filled them in*/ +unsigned lodepng_add_itext(LodePNGInfo* info, const char* key, const char* langtag, + const char* transkey, const char* str); /*push back the 4 texts of 1 chunk at once*/ +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ + +/* +Converts raw buffer from one color type to another color type, based on +LodePNGColorMode structs to describe the input and output color type. +See the reference manual at the end of this header file to see which color conversions are supported. +return value = LodePNG error code (0 if all went ok, an error if the conversion isn't supported) +The out buffer must have size (w * h * bpp + 7) / 8, where bpp is the bits per pixel +of the output color type (lodepng_get_bpp). +For < 8 bpp images, there should not be padding bits at the end of scanlines. +For 16-bit per channel colors, uses big endian format like PNG does. +Return value is LodePNG error code +*/ +unsigned lodepng_convert(unsigned char* out, const unsigned char* in, + const LodePNGColorMode* mode_out, const LodePNGColorMode* mode_in, + unsigned w, unsigned h); + +#ifdef LODEPNG_COMPILE_DECODER +/* +Settings for the decoder. This contains settings for the PNG and the Zlib +decoder, but not the Info settings from the Info structs. +*/ +typedef struct LodePNGDecoderSettings +{ + LodePNGDecompressSettings zlibsettings; /*in here is the setting to ignore Adler32 checksums*/ + + unsigned ignore_crc; /*ignore CRC checksums*/ + + unsigned color_convert; /*whether to convert the PNG to the color type you want. Default: yes*/ + +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + unsigned read_text_chunks; /*if false but remember_unknown_chunks is true, they're stored in the unknown chunks*/ + /*store all bytes from unknown chunks in the LodePNGInfo (off by default, useful for a png editor)*/ + unsigned remember_unknown_chunks; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ +} LodePNGDecoderSettings; + +void lodepng_decoder_settings_init(LodePNGDecoderSettings* settings); +#endif /*LODEPNG_COMPILE_DECODER*/ + +#ifdef LODEPNG_COMPILE_ENCODER +/*automatically use color type with less bits per pixel if losslessly possible. Default: AUTO*/ +typedef enum LodePNGFilterStrategy +{ + /*every filter at zero*/ + LFS_ZERO, + /*Use filter that gives minimum sum, as described in the official PNG filter heuristic.*/ + LFS_MINSUM, + /*Use the filter type that gives smallest Shannon entropy for this scanline. Depending + on the image, this is better or worse than minsum.*/ + LFS_ENTROPY, + /* + Brute-force-search PNG filters by compressing each filter for each scanline. + Experimental, very slow, and only rarely gives better compression than MINSUM. + */ + LFS_BRUTE_FORCE, + /*use predefined_filters buffer: you specify the filter type for each scanline*/ + LFS_PREDEFINED +} LodePNGFilterStrategy; + +/*Gives characteristics about the colors of the image, which helps decide which color model to use for encoding. +Used internally by default if "auto_convert" is enabled. Public because it's useful for custom algorithms.*/ +typedef struct LodePNGColorProfile +{ + unsigned colored; /*not greyscale*/ + unsigned key; /*image is not opaque and color key is possible instead of full alpha*/ + unsigned short key_r; /*key values, always as 16-bit, in 8-bit case the byte is duplicated, e.g. 65535 means 255*/ + unsigned short key_g; + unsigned short key_b; + unsigned alpha; /*image is not opaque and alpha channel or alpha palette required*/ + unsigned numcolors; /*amount of colors, up to 257. Not valid if bits == 16.*/ + unsigned char palette[1024]; /*Remembers up to the first 256 RGBA colors, in no particular order*/ + unsigned bits; /*bits per channel (not for palette). 1,2 or 4 for greyscale only. 16 if 16-bit per channel required.*/ +} LodePNGColorProfile; + +void lodepng_color_profile_init(LodePNGColorProfile* profile); + +/*Get a LodePNGColorProfile of the image.*/ +unsigned lodepng_get_color_profile(LodePNGColorProfile* profile, + const unsigned char* image, unsigned w, unsigned h, + const LodePNGColorMode* mode_in); +/*The function LodePNG uses internally to decide the PNG color with auto_convert. +Chooses an optimal color model, e.g. grey if only grey pixels, palette if < 256 colors, ...*/ +unsigned lodepng_auto_choose_color(LodePNGColorMode* mode_out, + const unsigned char* image, unsigned w, unsigned h, + const LodePNGColorMode* mode_in); + +/*Settings for the encoder.*/ +typedef struct LodePNGEncoderSettings +{ + LodePNGCompressSettings zlibsettings; /*settings for the zlib encoder, such as window size, ...*/ + + unsigned auto_convert; /*automatically choose output PNG color type. Default: true*/ + + /*If true, follows the official PNG heuristic: if the PNG uses a palette or lower than + 8 bit depth, set all filters to zero. Otherwise use the filter_strategy. Note that to + completely follow the official PNG heuristic, filter_palette_zero must be true and + filter_strategy must be LFS_MINSUM*/ + unsigned filter_palette_zero; + /*Which filter strategy to use when not using zeroes due to filter_palette_zero. + Set filter_palette_zero to 0 to ensure always using your chosen strategy. Default: LFS_MINSUM*/ + LodePNGFilterStrategy filter_strategy; + /*used if filter_strategy is LFS_PREDEFINED. In that case, this must point to a buffer with + the same length as the amount of scanlines in the image, and each value must <= 5. You + have to cleanup this buffer, LodePNG will never free it. Don't forget that filter_palette_zero + must be set to 0 to ensure this is also used on palette or low bitdepth images.*/ + const unsigned char* predefined_filters; + + /*force creating a PLTE chunk if colortype is 2 or 6 (= a suggested palette). + If colortype is 3, PLTE is _always_ created.*/ + unsigned force_palette; +#ifdef LODEPNG_COMPILE_ANCILLARY_CHUNKS + /*add LodePNG identifier and version as a text chunk, for debugging*/ + unsigned add_id; + /*encode text chunks as zTXt chunks instead of tEXt chunks, and use compression in iTXt chunks*/ + unsigned text_compression; +#endif /*LODEPNG_COMPILE_ANCILLARY_CHUNKS*/ +} LodePNGEncoderSettings; + +void lodepng_encoder_settings_init(LodePNGEncoderSettings* settings); +#endif /*LODEPNG_COMPILE_ENCODER*/ + + +#if defined(LODEPNG_COMPILE_DECODER) || defined(LODEPNG_COMPILE_ENCODER) +/*The settings, state and information for extended encoding and decoding.*/ +typedef struct LodePNGState +{ +#ifdef LODEPNG_COMPILE_DECODER + LodePNGDecoderSettings decoder; /*the decoding settings*/ +#endif /*LODEPNG_COMPILE_DECODER*/ +#ifdef LODEPNG_COMPILE_ENCODER + LodePNGEncoderSettings encoder; /*the encoding settings*/ +#endif /*LODEPNG_COMPILE_ENCODER*/ + LodePNGColorMode info_raw; /*specifies the format in which you would like to get the raw pixel buffer*/ + LodePNGInfo info_png; /*info of the PNG image obtained after decoding*/ + unsigned error; +#ifdef LODEPNG_COMPILE_CPP + /* For the lodepng::State subclass. */ + virtual ~LodePNGState(){} +#endif +} LodePNGState; + +/*init, cleanup and copy functions to use with this struct*/ +void lodepng_state_init(LodePNGState* state); +void lodepng_state_cleanup(LodePNGState* state); +void lodepng_state_copy(LodePNGState* dest, const LodePNGState* source); +#endif /* defined(LODEPNG_COMPILE_DECODER) || defined(LODEPNG_COMPILE_ENCODER) */ + +#ifdef LODEPNG_COMPILE_DECODER +/* +Same as lodepng_decode_memory, but uses a LodePNGState to allow custom settings and +getting much more information about the PNG image and color mode. +*/ +unsigned lodepng_decode(unsigned char** out, unsigned* w, unsigned* h, + LodePNGState* state, + const unsigned char* in, size_t insize); + +/* +Read the PNG header, but not the actual data. This returns only the information +that is in the header chunk of the PNG, such as width, height and color type. The +information is placed in the info_png field of the LodePNGState. +*/ +unsigned lodepng_inspect(unsigned* w, unsigned* h, + LodePNGState* state, + const unsigned char* in, size_t insize); +#endif /*LODEPNG_COMPILE_DECODER*/ + + +#ifdef LODEPNG_COMPILE_ENCODER +/*This function allocates the out buffer with standard malloc and stores the size in *outsize.*/ +unsigned lodepng_encode(unsigned char** out, size_t* outsize, + const unsigned char* image, unsigned w, unsigned h, + LodePNGState* state); +#endif /*LODEPNG_COMPILE_ENCODER*/ + +/* +The lodepng_chunk functions are normally not needed, except to traverse the +unknown chunks stored in the LodePNGInfo struct, or add new ones to it. +It also allows traversing the chunks of an encoded PNG file yourself. + +PNG standard chunk naming conventions: +First byte: uppercase = critical, lowercase = ancillary +Second byte: uppercase = public, lowercase = private +Third byte: must be uppercase +Fourth byte: uppercase = unsafe to copy, lowercase = safe to copy +*/ + +/* +Gets the length of the data of the chunk. Total chunk length has 12 bytes more. +There must be at least 4 bytes to read from. If the result value is too large, +it may be corrupt data. +*/ +unsigned lodepng_chunk_length(const unsigned char* chunk); + +/*puts the 4-byte type in null terminated string*/ +void lodepng_chunk_type(char type[5], const unsigned char* chunk); + +/*check if the type is the given type*/ +unsigned char lodepng_chunk_type_equals(const unsigned char* chunk, const char* type); + +/*0: it's one of the critical chunk types, 1: it's an ancillary chunk (see PNG standard)*/ +unsigned char lodepng_chunk_ancillary(const unsigned char* chunk); + +/*0: public, 1: private (see PNG standard)*/ +unsigned char lodepng_chunk_private(const unsigned char* chunk); + +/*0: the chunk is unsafe to copy, 1: the chunk is safe to copy (see PNG standard)*/ +unsigned char lodepng_chunk_safetocopy(const unsigned char* chunk); + +/*get pointer to the data of the chunk, where the input points to the header of the chunk*/ +unsigned char* lodepng_chunk_data(unsigned char* chunk); +const unsigned char* lodepng_chunk_data_const(const unsigned char* chunk); + +/*returns 0 if the crc is correct, 1 if it's incorrect (0 for OK as usual!)*/ +unsigned lodepng_chunk_check_crc(const unsigned char* chunk); + +/*generates the correct CRC from the data and puts it in the last 4 bytes of the chunk*/ +void lodepng_chunk_generate_crc(unsigned char* chunk); + +/*iterate to next chunks. don't use on IEND chunk, as there is no next chunk then*/ +unsigned char* lodepng_chunk_next(unsigned char* chunk); +const unsigned char* lodepng_chunk_next_const(const unsigned char* chunk); + +/* +Appends chunk to the data in out. The given chunk should already have its chunk header. +The out variable and outlength are updated to reflect the new reallocated buffer. +Returns error code (0 if it went ok) +*/ +unsigned lodepng_chunk_append(unsigned char** out, size_t* outlength, const unsigned char* chunk); + +/* +Appends new chunk to out. The chunk to append is given by giving its length, type +and data separately. The type is a 4-letter string. +The out variable and outlength are updated to reflect the new reallocated buffer. +Returne error code (0 if it went ok) +*/ +unsigned lodepng_chunk_create(unsigned char** out, size_t* outlength, unsigned length, + const char* type, const unsigned char* data); + + +/*Calculate CRC32 of buffer*/ +unsigned lodepng_crc32(const unsigned char* buf, size_t len); +#endif /*LODEPNG_COMPILE_PNG*/ + + +#ifdef LODEPNG_COMPILE_ZLIB +/* +This zlib part can be used independently to zlib compress and decompress a +buffer. It cannot be used to create gzip files however, and it only supports the +part of zlib that is required for PNG, it does not support dictionaries. +*/ + +#ifdef LODEPNG_COMPILE_DECODER +/*Inflate a buffer. Inflate is the decompression step of deflate. Out buffer must be freed after use.*/ +unsigned lodepng_inflate(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGDecompressSettings* settings); + +/* +Decompresses Zlib data. Reallocates the out buffer and appends the data. The +data must be according to the zlib specification. +Either, *out must be NULL and *outsize must be 0, or, *out must be a valid +buffer and *outsize its size in bytes. out must be freed by user after usage. +*/ +unsigned lodepng_zlib_decompress(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGDecompressSettings* settings); +#endif /*LODEPNG_COMPILE_DECODER*/ + +#ifdef LODEPNG_COMPILE_ENCODER +/* +Compresses data with Zlib. Reallocates the out buffer and appends the data. +Zlib adds a small header and trailer around the deflate data. +The data is output in the format of the zlib specification. +Either, *out must be NULL and *outsize must be 0, or, *out must be a valid +buffer and *outsize its size in bytes. out must be freed by user after usage. +*/ +unsigned lodepng_zlib_compress(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGCompressSettings* settings); + +/* +Find length-limited Huffman code for given frequencies. This function is in the +public interface only for tests, it's used internally by lodepng_deflate. +*/ +unsigned lodepng_huffman_code_lengths(unsigned* lengths, const unsigned* frequencies, + size_t numcodes, unsigned maxbitlen); + +/*Compress a buffer with deflate. See RFC 1951. Out buffer must be freed after use.*/ +unsigned lodepng_deflate(unsigned char** out, size_t* outsize, + const unsigned char* in, size_t insize, + const LodePNGCompressSettings* settings); + +#endif /*LODEPNG_COMPILE_ENCODER*/ +#endif /*LODEPNG_COMPILE_ZLIB*/ + +#ifdef LODEPNG_COMPILE_DISK +/* +Load a file from disk into buffer. The function allocates the out buffer, and +after usage you should free it. +out: output parameter, contains pointer to loaded buffer. +outsize: output parameter, size of the allocated out buffer +filename: the path to the file to load +return value: error code (0 means ok) +*/ +unsigned lodepng_load_file(unsigned char** out, size_t* outsize, const char* filename); + +/* +Save a file from buffer to disk. Warning, if it exists, this function overwrites +the file without warning! +buffer: the buffer to write +buffersize: size of the buffer to write +filename: the path to the file to save to +return value: error code (0 means ok) +*/ +unsigned lodepng_save_file(const unsigned char* buffer, size_t buffersize, const char* filename); +#endif /*LODEPNG_COMPILE_DISK*/ + +#ifdef LODEPNG_COMPILE_CPP +/* The LodePNG C++ wrapper uses std::vectors instead of manually allocated memory buffers. */ +namespace lodepng +{ +#ifdef LODEPNG_COMPILE_PNG +class State : public LodePNGState +{ + public: + State(); + State(const State& other); + virtual ~State(); + State& operator=(const State& other); +}; + +#ifdef LODEPNG_COMPILE_DECODER +/* Same as other lodepng::decode, but using a State for more settings and information. */ +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + State& state, + const unsigned char* in, size_t insize); +unsigned decode(std::vector& out, unsigned& w, unsigned& h, + State& state, + const std::vector& in); +#endif /*LODEPNG_COMPILE_DECODER*/ + +#ifdef LODEPNG_COMPILE_ENCODER +/* Same as other lodepng::encode, but using a State for more settings and information. */ +unsigned encode(std::vector& out, + const unsigned char* in, unsigned w, unsigned h, + State& state); +unsigned encode(std::vector& out, + const std::vector& in, unsigned w, unsigned h, + State& state); +#endif /*LODEPNG_COMPILE_ENCODER*/ + +#ifdef LODEPNG_COMPILE_DISK +/* +Load a file from disk into an std::vector. +return value: error code (0 means ok) +*/ +unsigned load_file(std::vector& buffer, const std::string& filename); + +/* +Save the binary data in an std::vector to a file on disk. The file is overwritten +without warning. +*/ +unsigned save_file(const std::vector& buffer, const std::string& filename); +#endif /* LODEPNG_COMPILE_DISK */ +#endif /* LODEPNG_COMPILE_PNG */ + +#ifdef LODEPNG_COMPILE_ZLIB +#ifdef LODEPNG_COMPILE_DECODER +/* Zlib-decompress an unsigned char buffer */ +unsigned decompress(std::vector& out, const unsigned char* in, size_t insize, + const LodePNGDecompressSettings& settings = lodepng_default_decompress_settings); + +/* Zlib-decompress an std::vector */ +unsigned decompress(std::vector& out, const std::vector& in, + const LodePNGDecompressSettings& settings = lodepng_default_decompress_settings); +#endif /* LODEPNG_COMPILE_DECODER */ + +#ifdef LODEPNG_COMPILE_ENCODER +/* Zlib-compress an unsigned char buffer */ +unsigned compress(std::vector& out, const unsigned char* in, size_t insize, + const LodePNGCompressSettings& settings = lodepng_default_compress_settings); + +/* Zlib-compress an std::vector */ +unsigned compress(std::vector& out, const std::vector& in, + const LodePNGCompressSettings& settings = lodepng_default_compress_settings); +#endif /* LODEPNG_COMPILE_ENCODER */ +#endif /* LODEPNG_COMPILE_ZLIB */ +} /* namespace lodepng */ +#endif /*LODEPNG_COMPILE_CPP*/ + +/* +TODO: +[.] test if there are no memory leaks or security exploits - done a lot but needs to be checked often +[.] check compatibility with various compilers - done but needs to be redone for every newer version +[X] converting color to 16-bit per channel types +[ ] read all public PNG chunk types (but never let the color profile and gamma ones touch RGB values) +[ ] make sure encoder generates no chunks with size > (2^31)-1 +[ ] partial decoding (stream processing) +[X] let the "isFullyOpaque" function check color keys and transparent palettes too +[X] better name for the variables "codes", "codesD", "codelengthcodes", "clcl" and "lldl" +[ ] don't stop decoding on errors like 69, 57, 58 (make warnings) +[ ] make warnings like: oob palette, checksum fail, data after iend, wrong/unknown crit chunk, no null terminator in text, ... +[ ] let the C++ wrapper catch exceptions coming from the standard library and return LodePNG error codes +[ ] allow user to provide custom color conversion functions, e.g. for premultiplied alpha, padding bits or not, ... +[ ] allow user to give data (void*) to custom allocator +*/ + +#endif /*LODEPNG_H inclusion guard*/ + +/* +LodePNG Documentation +--------------------- + +0. table of contents +-------------------- + + 1. about + 1.1. supported features + 1.2. features not supported + 2. C and C++ version + 3. security + 4. decoding + 5. encoding + 6. color conversions + 6.1. PNG color types + 6.2. color conversions + 6.3. padding bits + 6.4. A note about 16-bits per channel and endianness + 7. error values + 8. chunks and PNG editing + 9. compiler support + 10. examples + 10.1. decoder C++ example + 10.2. decoder C example + 11. state settings reference + 12. changes + 13. contact information + + +1. about +-------- + +PNG is a file format to store raster images losslessly with good compression, +supporting different color types and alpha channel. + +LodePNG is a PNG codec according to the Portable Network Graphics (PNG) +Specification (Second Edition) - W3C Recommendation 10 November 2003. + +The specifications used are: + +*) Portable Network Graphics (PNG) Specification (Second Edition): + http://www.w3.org/TR/2003/REC-PNG-20031110 +*) RFC 1950 ZLIB Compressed Data Format version 3.3: + http://www.gzip.org/zlib/rfc-zlib.html +*) RFC 1951 DEFLATE Compressed Data Format Specification ver 1.3: + http://www.gzip.org/zlib/rfc-deflate.html + +The most recent version of LodePNG can currently be found at +http://lodev.org/lodepng/ + +LodePNG works both in C (ISO C90) and C++, with a C++ wrapper that adds +extra functionality. + +LodePNG exists out of two files: +-lodepng.h: the header file for both C and C++ +-lodepng.c(pp): give it the name lodepng.c or lodepng.cpp (or .cc) depending on your usage + +If you want to start using LodePNG right away without reading this doc, get the +examples from the LodePNG website to see how to use it in code, or check the +smaller examples in chapter 13 here. + +LodePNG is simple but only supports the basic requirements. To achieve +simplicity, the following design choices were made: There are no dependencies +on any external library. There are functions to decode and encode a PNG with +a single function call, and extended versions of these functions taking a +LodePNGState struct allowing to specify or get more information. By default +the colors of the raw image are always RGB or RGBA, no matter what color type +the PNG file uses. To read and write files, there are simple functions to +convert the files to/from buffers in memory. + +This all makes LodePNG suitable for loading textures in games, demos and small +programs, ... It's less suitable for full fledged image editors, loading PNGs +over network (it requires all the image data to be available before decoding can +begin), life-critical systems, ... + +1.1. supported features +----------------------- + +The following features are supported by the decoder: + +*) decoding of PNGs with any color type, bit depth and interlace mode, to a 24- or 32-bit color raw image, + or the same color type as the PNG +*) encoding of PNGs, from any raw image to 24- or 32-bit color, or the same color type as the raw image +*) Adam7 interlace and deinterlace for any color type +*) loading the image from harddisk or decoding it from a buffer from other sources than harddisk +*) support for alpha channels, including RGBA color model, translucent palettes and color keying +*) zlib decompression (inflate) +*) zlib compression (deflate) +*) CRC32 and ADLER32 checksums +*) handling of unknown chunks, allowing making a PNG editor that stores custom and unknown chunks. +*) the following chunks are supported (generated/interpreted) by both encoder and decoder: + IHDR: header information + PLTE: color palette + IDAT: pixel data + IEND: the final chunk + tRNS: transparency for palettized images + tEXt: textual information + zTXt: compressed textual information + iTXt: international textual information + bKGD: suggested background color + pHYs: physical dimensions + tIME: modification time + +1.2. features not supported +--------------------------- + +The following features are _not_ supported: + +*) some features needed to make a conformant PNG-Editor might be still missing. +*) partial loading/stream processing. All data must be available and is processed in one call. +*) The following public chunks are not supported but treated as unknown chunks by LodePNG + cHRM, gAMA, iCCP, sRGB, sBIT, hIST, sPLT + Some of these are not supported on purpose: LodePNG wants to provide the RGB values + stored in the pixels, not values modified by system dependent gamma or color models. + + +2. C and C++ version +-------------------- + +The C version uses buffers allocated with alloc that you need to free() +yourself. You need to use init and cleanup functions for each struct whenever +using a struct from the C version to avoid exploits and memory leaks. + +The C++ version has extra functions with std::vectors in the interface and the +lodepng::State class which is a LodePNGState with constructor and destructor. + +These files work without modification for both C and C++ compilers because all +the additional C++ code is in "#ifdef __cplusplus" blocks that make C-compilers +ignore it, and the C code is made to compile both with strict ISO C90 and C++. + +To use the C++ version, you need to rename the source file to lodepng.cpp +(instead of lodepng.c), and compile it with a C++ compiler. + +To use the C version, you need to rename the source file to lodepng.c (instead +of lodepng.cpp), and compile it with a C compiler. + + +3. Security +----------- + +Even if carefully designed, it's always possible that LodePNG contains possible +exploits. If you discover one, please let me know, and it will be fixed. + +When using LodePNG, care has to be taken with the C version of LodePNG, as well +as the C-style structs when working with C++. The following conventions are used +for all C-style structs: + +-if a struct has a corresponding init function, always call the init function when making a new one +-if a struct has a corresponding cleanup function, call it before the struct disappears to avoid memory leaks +-if a struct has a corresponding copy function, use the copy function instead of "=". + The destination must also be inited already. + + +4. Decoding +----------- + +Decoding converts a PNG compressed image to a raw pixel buffer. + +Most documentation on using the decoder is at its declarations in the header +above. For C, simple decoding can be done with functions such as +lodepng_decode32, and more advanced decoding can be done with the struct +LodePNGState and lodepng_decode. For C++, all decoding can be done with the +various lodepng::decode functions, and lodepng::State can be used for advanced +features. + +When using the LodePNGState, it uses the following fields for decoding: +*) LodePNGInfo info_png: it stores extra information about the PNG (the input) in here +*) LodePNGColorMode info_raw: here you can say what color mode of the raw image (the output) you want to get +*) LodePNGDecoderSettings decoder: you can specify a few extra settings for the decoder to use + +LodePNGInfo info_png +-------------------- + +After decoding, this contains extra information of the PNG image, except the actual +pixels, width and height because these are already gotten directly from the decoder +functions. + +It contains for example the original color type of the PNG image, text comments, +suggested background color, etc... More details about the LodePNGInfo struct are +at its declaration documentation. + +LodePNGColorMode info_raw +------------------------- + +When decoding, here you can specify which color type you want +the resulting raw image to be. If this is different from the colortype of the +PNG, then the decoder will automatically convert the result. This conversion +always works, except if you want it to convert a color PNG to greyscale or to +a palette with missing colors. + +By default, 32-bit color is used for the result. + +LodePNGDecoderSettings decoder +------------------------------ + +The settings can be used to ignore the errors created by invalid CRC and Adler32 +chunks, and to disable the decoding of tEXt chunks. + +There's also a setting color_convert, true by default. If false, no conversion +is done, the resulting data will be as it was in the PNG (after decompression) +and you'll have to puzzle the colors of the pixels together yourself using the +color type information in the LodePNGInfo. + + +5. Encoding +----------- + +Encoding converts a raw pixel buffer to a PNG compressed image. + +Most documentation on using the encoder is at its declarations in the header +above. For C, simple encoding can be done with functions such as +lodepng_encode32, and more advanced decoding can be done with the struct +LodePNGState and lodepng_encode. For C++, all encoding can be done with the +various lodepng::encode functions, and lodepng::State can be used for advanced +features. + +Like the decoder, the encoder can also give errors. However it gives less errors +since the encoder input is trusted, the decoder input (a PNG image that could +be forged by anyone) is not trusted. + +When using the LodePNGState, it uses the following fields for encoding: +*) LodePNGInfo info_png: here you specify how you want the PNG (the output) to be. +*) LodePNGColorMode info_raw: here you say what color type of the raw image (the input) has +*) LodePNGEncoderSettings encoder: you can specify a few settings for the encoder to use + +LodePNGInfo info_png +-------------------- + +When encoding, you use this the opposite way as when decoding: for encoding, +you fill in the values you want the PNG to have before encoding. By default it's +not needed to specify a color type for the PNG since it's automatically chosen, +but it's possible to choose it yourself given the right settings. + +The encoder will not always exactly match the LodePNGInfo struct you give, +it tries as close as possible. Some things are ignored by the encoder. The +encoder uses, for example, the following settings from it when applicable: +colortype and bitdepth, text chunks, time chunk, the color key, the palette, the +background color, the interlace method, unknown chunks, ... + +When encoding to a PNG with colortype 3, the encoder will generate a PLTE chunk. +If the palette contains any colors for which the alpha channel is not 255 (so +there are translucent colors in the palette), it'll add a tRNS chunk. + +LodePNGColorMode info_raw +------------------------- + +You specify the color type of the raw image that you give to the input here, +including a possible transparent color key and palette you happen to be using in +your raw image data. + +By default, 32-bit color is assumed, meaning your input has to be in RGBA +format with 4 bytes (unsigned chars) per pixel. + +LodePNGEncoderSettings encoder +------------------------------ + +The following settings are supported (some are in sub-structs): +*) auto_convert: when this option is enabled, the encoder will +automatically choose the smallest possible color mode (including color key) that +can encode the colors of all pixels without information loss. +*) btype: the block type for LZ77. 0 = uncompressed, 1 = fixed huffman tree, + 2 = dynamic huffman tree (best compression). Should be 2 for proper + compression. +*) use_lz77: whether or not to use LZ77 for compressed block types. Should be + true for proper compression. +*) windowsize: the window size used by the LZ77 encoder (1 - 32768). Has value + 2048 by default, but can be set to 32768 for better, but slow, compression. +*) force_palette: if colortype is 2 or 6, you can make the encoder write a PLTE + chunk if force_palette is true. This can used as suggested palette to convert + to by viewers that don't support more than 256 colors (if those still exist) +*) add_id: add text chunk "Encoder: LodePNG " to the image. +*) text_compression: default 1. If 1, it'll store texts as zTXt instead of tEXt chunks. + zTXt chunks use zlib compression on the text. This gives a smaller result on + large texts but a larger result on small texts (such as a single program name). + It's all tEXt or all zTXt though, there's no separate setting per text yet. + + +6. color conversions +-------------------- + +An important thing to note about LodePNG, is that the color type of the PNG, and +the color type of the raw image, are completely independent. By default, when +you decode a PNG, you get the result as a raw image in the color type you want, +no matter whether the PNG was encoded with a palette, greyscale or RGBA color. +And if you encode an image, by default LodePNG will automatically choose the PNG +color type that gives good compression based on the values of colors and amount +of colors in the image. It can be configured to let you control it instead as +well, though. + +To be able to do this, LodePNG does conversions from one color mode to another. +It can convert from almost any color type to any other color type, except the +following conversions: RGB to greyscale is not supported, and converting to a +palette when the palette doesn't have a required color is not supported. This is +not supported on purpose: this is information loss which requires a color +reduction algorithm that is beyong the scope of a PNG encoder (yes, RGB to grey +is easy, but there are multiple ways if you want to give some channels more +weight). + +By default, when decoding, you get the raw image in 32-bit RGBA or 24-bit RGB +color, no matter what color type the PNG has. And by default when encoding, +LodePNG automatically picks the best color model for the output PNG, and expects +the input image to be 32-bit RGBA or 24-bit RGB. So, unless you want to control +the color format of the images yourself, you can skip this chapter. + +6.1. PNG color types +-------------------- + +A PNG image can have many color types, ranging from 1-bit color to 64-bit color, +as well as palettized color modes. After the zlib decompression and unfiltering +in the PNG image is done, the raw pixel data will have that color type and thus +a certain amount of bits per pixel. If you want the output raw image after +decoding to have another color type, a conversion is done by LodePNG. + +The PNG specification gives the following color types: + +0: greyscale, bit depths 1, 2, 4, 8, 16 +2: RGB, bit depths 8 and 16 +3: palette, bit depths 1, 2, 4 and 8 +4: greyscale with alpha, bit depths 8 and 16 +6: RGBA, bit depths 8 and 16 + +Bit depth is the amount of bits per pixel per color channel. So the total amount +of bits per pixel is: amount of channels * bitdepth. + +6.2. color conversions +---------------------- + +As explained in the sections about the encoder and decoder, you can specify +color types and bit depths in info_png and info_raw to change the default +behaviour. + +If, when decoding, you want the raw image to be something else than the default, +you need to set the color type and bit depth you want in the LodePNGColorMode, +or the parameters colortype and bitdepth of the simple decoding function. + +If, when encoding, you use another color type than the default in the raw input +image, you need to specify its color type and bit depth in the LodePNGColorMode +of the raw image, or use the parameters colortype and bitdepth of the simple +encoding function. + +If, when encoding, you don't want LodePNG to choose the output PNG color type +but control it yourself, you need to set auto_convert in the encoder settings +to false, and specify the color type you want in the LodePNGInfo of the +encoder (including palette: it can generate a palette if auto_convert is true, +otherwise not). + +If the input and output color type differ (whether user chosen or auto chosen), +LodePNG will do a color conversion, which follows the rules below, and may +sometimes result in an error. + +To avoid some confusion: +-the decoder converts from PNG to raw image +-the encoder converts from raw image to PNG +-the colortype and bitdepth in LodePNGColorMode info_raw, are those of the raw image +-the colortype and bitdepth in the color field of LodePNGInfo info_png, are those of the PNG +-when encoding, the color type in LodePNGInfo is ignored if auto_convert + is enabled, it is automatically generated instead +-when decoding, the color type in LodePNGInfo is set by the decoder to that of the original + PNG image, but it can be ignored since the raw image has the color type you requested instead +-if the color type of the LodePNGColorMode and PNG image aren't the same, a conversion + between the color types is done if the color types are supported. If it is not + supported, an error is returned. If the types are the same, no conversion is done. +-even though some conversions aren't supported, LodePNG supports loading PNGs from any + colortype and saving PNGs to any colortype, sometimes it just requires preparing + the raw image correctly before encoding. +-both encoder and decoder use the same color converter. + +Non supported color conversions: +-color to greyscale: no error is thrown, but the result will look ugly because +only the red channel is taken +-anything to palette when that palette does not have that color in it: in this +case an error is thrown + +Supported color conversions: +-anything to 8-bit RGB, 8-bit RGBA, 16-bit RGB, 16-bit RGBA +-any grey or grey+alpha, to grey or grey+alpha +-anything to a palette, as long as the palette has the requested colors in it +-removing alpha channel +-higher to smaller bitdepth, and vice versa + +If you want no color conversion to be done (e.g. for speed or control): +-In the encoder, you can make it save a PNG with any color type by giving the +raw color mode and LodePNGInfo the same color mode, and setting auto_convert to +false. +-In the decoder, you can make it store the pixel data in the same color type +as the PNG has, by setting the color_convert setting to false. Settings in +info_raw are then ignored. + +The function lodepng_convert does the color conversion. It is available in the +interface but normally isn't needed since the encoder and decoder already call +it. + +6.3. padding bits +----------------- + +In the PNG file format, if a less than 8-bit per pixel color type is used and the scanlines +have a bit amount that isn't a multiple of 8, then padding bits are used so that each +scanline starts at a fresh byte. But that is NOT true for the LodePNG raw input and output. +The raw input image you give to the encoder, and the raw output image you get from the decoder +will NOT have these padding bits, e.g. in the case of a 1-bit image with a width +of 7 pixels, the first pixel of the second scanline will the the 8th bit of the first byte, +not the first bit of a new byte. + +6.4. A note about 16-bits per channel and endianness +---------------------------------------------------- + +LodePNG uses unsigned char arrays for 16-bit per channel colors too, just like +for any other color format. The 16-bit values are stored in big endian (most +significant byte first) in these arrays. This is the opposite order of the +little endian used by x86 CPU's. + +LodePNG always uses big endian because the PNG file format does so internally. +Conversions to other formats than PNG uses internally are not supported by +LodePNG on purpose, there are myriads of formats, including endianness of 16-bit +colors, the order in which you store R, G, B and A, and so on. Supporting and +converting to/from all that is outside the scope of LodePNG. + +This may mean that, depending on your use case, you may want to convert the big +endian output of LodePNG to little endian with a for loop. This is certainly not +always needed, many applications and libraries support big endian 16-bit colors +anyway, but it means you cannot simply cast the unsigned char* buffer to an +unsigned short* buffer on x86 CPUs. + + +7. error values +--------------- + +All functions in LodePNG that return an error code, return 0 if everything went +OK, or a non-zero code if there was an error. + +The meaning of the LodePNG error values can be retrieved with the function +lodepng_error_text: given the numerical error code, it returns a description +of the error in English as a string. + +Check the implementation of lodepng_error_text to see the meaning of each code. + + +8. chunks and PNG editing +------------------------- + +If you want to add extra chunks to a PNG you encode, or use LodePNG for a PNG +editor that should follow the rules about handling of unknown chunks, or if your +program is able to read other types of chunks than the ones handled by LodePNG, +then that's possible with the chunk functions of LodePNG. + +A PNG chunk has the following layout: + +4 bytes length +4 bytes type name +length bytes data +4 bytes CRC + +8.1. iterating through chunks +----------------------------- + +If you have a buffer containing the PNG image data, then the first chunk (the +IHDR chunk) starts at byte number 8 of that buffer. The first 8 bytes are the +signature of the PNG and are not part of a chunk. But if you start at byte 8 +then you have a chunk, and can check the following things of it. + +NOTE: none of these functions check for memory buffer boundaries. To avoid +exploits, always make sure the buffer contains all the data of the chunks. +When using lodepng_chunk_next, make sure the returned value is within the +allocated memory. + +unsigned lodepng_chunk_length(const unsigned char* chunk): + +Get the length of the chunk's data. The total chunk length is this length + 12. + +void lodepng_chunk_type(char type[5], const unsigned char* chunk): +unsigned char lodepng_chunk_type_equals(const unsigned char* chunk, const char* type): + +Get the type of the chunk or compare if it's a certain type + +unsigned char lodepng_chunk_critical(const unsigned char* chunk): +unsigned char lodepng_chunk_private(const unsigned char* chunk): +unsigned char lodepng_chunk_safetocopy(const unsigned char* chunk): + +Check if the chunk is critical in the PNG standard (only IHDR, PLTE, IDAT and IEND are). +Check if the chunk is private (public chunks are part of the standard, private ones not). +Check if the chunk is safe to copy. If it's not, then, when modifying data in a critical +chunk, unsafe to copy chunks of the old image may NOT be saved in the new one if your +program doesn't handle that type of unknown chunk. + +unsigned char* lodepng_chunk_data(unsigned char* chunk): +const unsigned char* lodepng_chunk_data_const(const unsigned char* chunk): + +Get a pointer to the start of the data of the chunk. + +unsigned lodepng_chunk_check_crc(const unsigned char* chunk): +void lodepng_chunk_generate_crc(unsigned char* chunk): + +Check if the crc is correct or generate a correct one. + +unsigned char* lodepng_chunk_next(unsigned char* chunk): +const unsigned char* lodepng_chunk_next_const(const unsigned char* chunk): + +Iterate to the next chunk. This works if you have a buffer with consecutive chunks. Note that these +functions do no boundary checking of the allocated data whatsoever, so make sure there is enough +data available in the buffer to be able to go to the next chunk. + +unsigned lodepng_chunk_append(unsigned char** out, size_t* outlength, const unsigned char* chunk): +unsigned lodepng_chunk_create(unsigned char** out, size_t* outlength, unsigned length, + const char* type, const unsigned char* data): + +These functions are used to create new chunks that are appended to the data in *out that has +length *outlength. The append function appends an existing chunk to the new data. The create +function creates a new chunk with the given parameters and appends it. Type is the 4-letter +name of the chunk. + +8.2. chunks in info_png +----------------------- + +The LodePNGInfo struct contains fields with the unknown chunk in it. It has 3 +buffers (each with size) to contain 3 types of unknown chunks: +the ones that come before the PLTE chunk, the ones that come between the PLTE +and the IDAT chunks, and the ones that come after the IDAT chunks. +It's necessary to make the distionction between these 3 cases because the PNG +standard forces to keep the ordering of unknown chunks compared to the critical +chunks, but does not force any other ordering rules. + +info_png.unknown_chunks_data[0] is the chunks before PLTE +info_png.unknown_chunks_data[1] is the chunks after PLTE, before IDAT +info_png.unknown_chunks_data[2] is the chunks after IDAT + +The chunks in these 3 buffers can be iterated through and read by using the same +way described in the previous subchapter. + +When using the decoder to decode a PNG, you can make it store all unknown chunks +if you set the option settings.remember_unknown_chunks to 1. By default, this +option is off (0). + +The encoder will always encode unknown chunks that are stored in the info_png. +If you need it to add a particular chunk that isn't known by LodePNG, you can +use lodepng_chunk_append or lodepng_chunk_create to the chunk data in +info_png.unknown_chunks_data[x]. + +Chunks that are known by LodePNG should not be added in that way. E.g. to make +LodePNG add a bKGD chunk, set background_defined to true and add the correct +parameters there instead. + + +9. compiler support +------------------- + +No libraries other than the current standard C library are needed to compile +LodePNG. For the C++ version, only the standard C++ library is needed on top. +Add the files lodepng.c(pp) and lodepng.h to your project, include +lodepng.h where needed, and your program can read/write PNG files. + +It is compatible with C90 and up, and C++03 and up. + +If performance is important, use optimization when compiling! For both the +encoder and decoder, this makes a large difference. + +Make sure that LodePNG is compiled with the same compiler of the same version +and with the same settings as the rest of the program, or the interfaces with +std::vectors and std::strings in C++ can be incompatible. + +CHAR_BITS must be 8 or higher, because LodePNG uses unsigned chars for octets. + +*) gcc and g++ + +LodePNG is developed in gcc so this compiler is natively supported. It gives no +warnings with compiler options "-Wall -Wextra -pedantic -ansi", with gcc and g++ +version 4.7.1 on Linux, 32-bit and 64-bit. + +*) Clang + +Fully supported and warning-free. + +*) Mingw + +The Mingw compiler (a port of gcc for Windows) should be fully supported by +LodePNG. + +*) Visual Studio and Visual C++ Express Edition + +LodePNG should be warning-free with warning level W4. Two warnings were disabled +with pragmas though: warning 4244 about implicit conversions, and warning 4996 +where it wants to use a non-standard function fopen_s instead of the standard C +fopen. + +Visual Studio may want "stdafx.h" files to be included in each source file and +give an error "unexpected end of file while looking for precompiled header". +This is not standard C++ and will not be added to the stock LodePNG. You can +disable it for lodepng.cpp only by right clicking it, Properties, C/C++, +Precompiled Headers, and set it to Not Using Precompiled Headers there. + +NOTE: Modern versions of VS should be fully supported, but old versions, e.g. +VS6, are not guaranteed to work. + +*) Compilers on Macintosh + +LodePNG has been reported to work both with gcc and LLVM for Macintosh, both for +C and C++. + +*) Other Compilers + +If you encounter problems on any compilers, feel free to let me know and I may +try to fix it if the compiler is modern and standards complient. + + +10. examples +------------ + +This decoder example shows the most basic usage of LodePNG. More complex +examples can be found on the LodePNG website. + +10.1. decoder C++ example +------------------------- + +#include "lodepng.h" +#include + +int main(int argc, char *argv[]) +{ + const char* filename = argc > 1 ? argv[1] : "test.png"; + + //load and decode + std::vector image; + unsigned width, height; + unsigned error = lodepng::decode(image, width, height, filename); + + //if there's an error, display it + if(error) std::cout << "decoder error " << error << ": " << lodepng_error_text(error) << std::endl; + + //the pixels are now in the vector "image", 4 bytes per pixel, ordered RGBARGBA..., use it as texture, draw it, ... +} + +10.2. decoder C example +----------------------- + +#include "lodepng.h" + +int main(int argc, char *argv[]) +{ + unsigned error; + unsigned char* image; + size_t width, height; + const char* filename = argc > 1 ? argv[1] : "test.png"; + + error = lodepng_decode32_file(&image, &width, &height, filename); + + if(error) printf("decoder error %u: %s\n", error, lodepng_error_text(error)); + + / * use image here * / + + free(image); + return 0; +} + +11. state settings reference +---------------------------- + +A quick reference of some settings to set on the LodePNGState + +For decoding: + +state.decoder.zlibsettings.ignore_adler32: ignore ADLER32 checksums +state.decoder.zlibsettings.custom_...: use custom inflate function +state.decoder.ignore_crc: ignore CRC checksums +state.decoder.color_convert: convert internal PNG color to chosen one +state.decoder.read_text_chunks: whether to read in text metadata chunks +state.decoder.remember_unknown_chunks: whether to read in unknown chunks +state.info_raw.colortype: desired color type for decoded image +state.info_raw.bitdepth: desired bit depth for decoded image +state.info_raw....: more color settings, see struct LodePNGColorMode +state.info_png....: no settings for decoder but ouput, see struct LodePNGInfo + +For encoding: + +state.encoder.zlibsettings.btype: disable compression by setting it to 0 +state.encoder.zlibsettings.use_lz77: use LZ77 in compression +state.encoder.zlibsettings.windowsize: tweak LZ77 windowsize +state.encoder.zlibsettings.minmatch: tweak min LZ77 length to match +state.encoder.zlibsettings.nicematch: tweak LZ77 match where to stop searching +state.encoder.zlibsettings.lazymatching: try one more LZ77 matching +state.encoder.zlibsettings.custom_...: use custom deflate function +state.encoder.auto_convert: choose optimal PNG color type, if 0 uses info_png +state.encoder.filter_palette_zero: PNG filter strategy for palette +state.encoder.filter_strategy: PNG filter strategy to encode with +state.encoder.force_palette: add palette even if not encoding to one +state.encoder.add_id: add LodePNG identifier and version as a text chunk +state.encoder.text_compression: use compressed text chunks for metadata +state.info_raw.colortype: color type of raw input image you provide +state.info_raw.bitdepth: bit depth of raw input image you provide +state.info_raw: more color settings, see struct LodePNGColorMode +state.info_png.color.colortype: desired color type if auto_convert is false +state.info_png.color.bitdepth: desired bit depth if auto_convert is false +state.info_png.color....: more color settings, see struct LodePNGColorMode +state.info_png....: more PNG related settings, see struct LodePNGInfo + + +12. changes +----------- + +The version number of LodePNG is the date of the change given in the format +yyyymmdd. + +Some changes aren't backwards compatible. Those are indicated with a (!) +symbol. + +*) 17 sep 2017: fix memory leak for some encoder input error cases +*) 27 nov 2016: grey+alpha auto color model detection bugfix +*) 18 apr 2016: Changed qsort to custom stable sort (for platforms w/o qsort). +*) 09 apr 2016: Fixed colorkey usage detection, and better file loading (within + the limits of pure C90). +*) 08 dec 2015: Made load_file function return error if file can't be opened. +*) 24 okt 2015: Bugfix with decoding to palette output. +*) 18 apr 2015: Boundary PM instead of just package-merge for faster encoding. +*) 23 aug 2014: Reduced needless memory usage of decoder. +*) 28 jun 2014: Removed fix_png setting, always support palette OOB for + simplicity. Made ColorProfile public. +*) 09 jun 2014: Faster encoder by fixing hash bug and more zeros optimization. +*) 22 dec 2013: Power of two windowsize required for optimization. +*) 15 apr 2013: Fixed bug with LAC_ALPHA and color key. +*) 25 mar 2013: Added an optional feature to ignore some PNG errors (fix_png). +*) 11 mar 2013 (!): Bugfix with custom free. Changed from "my" to "lodepng_" + prefix for the custom allocators and made it possible with a new #define to + use custom ones in your project without needing to change lodepng's code. +*) 28 jan 2013: Bugfix with color key. +*) 27 okt 2012: Tweaks in text chunk keyword length error handling. +*) 8 okt 2012 (!): Added new filter strategy (entropy) and new auto color mode. + (no palette). Better deflate tree encoding. New compression tweak settings. + Faster color conversions while decoding. Some internal cleanups. +*) 23 sep 2012: Reduced warnings in Visual Studio a little bit. +*) 1 sep 2012 (!): Removed #define's for giving custom (de)compression functions + and made it work with function pointers instead. +*) 23 jun 2012: Added more filter strategies. Made it easier to use custom alloc + and free functions and toggle #defines from compiler flags. Small fixes. +*) 6 may 2012 (!): Made plugging in custom zlib/deflate functions more flexible. +*) 22 apr 2012 (!): Made interface more consistent, renaming a lot. Removed + redundant C++ codec classes. Reduced amount of structs. Everything changed, + but it is cleaner now imho and functionality remains the same. Also fixed + several bugs and shrunk the implementation code. Made new samples. +*) 6 nov 2011 (!): By default, the encoder now automatically chooses the best + PNG color model and bit depth, based on the amount and type of colors of the + raw image. For this, autoLeaveOutAlphaChannel replaced by auto_choose_color. +*) 9 okt 2011: simpler hash chain implementation for the encoder. +*) 8 sep 2011: lz77 encoder lazy matching instead of greedy matching. +*) 23 aug 2011: tweaked the zlib compression parameters after benchmarking. + A bug with the PNG filtertype heuristic was fixed, so that it chooses much + better ones (it's quite significant). A setting to do an experimental, slow, + brute force search for PNG filter types is added. +*) 17 aug 2011 (!): changed some C zlib related function names. +*) 16 aug 2011: made the code less wide (max 120 characters per line). +*) 17 apr 2011: code cleanup. Bugfixes. Convert low to 16-bit per sample colors. +*) 21 feb 2011: fixed compiling for C90. Fixed compiling with sections disabled. +*) 11 dec 2010: encoding is made faster, based on suggestion by Peter Eastman + to optimize long sequences of zeros. +*) 13 nov 2010: added LodePNG_InfoColor_hasPaletteAlpha and + LodePNG_InfoColor_canHaveAlpha functions for convenience. +*) 7 nov 2010: added LodePNG_error_text function to get error code description. +*) 30 okt 2010: made decoding slightly faster +*) 26 okt 2010: (!) changed some C function and struct names (more consistent). + Reorganized the documentation and the declaration order in the header. +*) 08 aug 2010: only changed some comments and external samples. +*) 05 jul 2010: fixed bug thanks to warnings in the new gcc version. +*) 14 mar 2010: fixed bug where too much memory was allocated for char buffers. +*) 02 sep 2008: fixed bug where it could create empty tree that linux apps could + read by ignoring the problem but windows apps couldn't. +*) 06 jun 2008: added more error checks for out of memory cases. +*) 26 apr 2008: added a few more checks here and there to ensure more safety. +*) 06 mar 2008: crash with encoding of strings fixed +*) 02 feb 2008: support for international text chunks added (iTXt) +*) 23 jan 2008: small cleanups, and #defines to divide code in sections +*) 20 jan 2008: support for unknown chunks allowing using LodePNG for an editor. +*) 18 jan 2008: support for tIME and pHYs chunks added to encoder and decoder. +*) 17 jan 2008: ability to encode and decode compressed zTXt chunks added + Also various fixes, such as in the deflate and the padding bits code. +*) 13 jan 2008: Added ability to encode Adam7-interlaced images. Improved + filtering code of encoder. +*) 07 jan 2008: (!) changed LodePNG to use ISO C90 instead of C++. A + C++ wrapper around this provides an interface almost identical to before. + Having LodePNG be pure ISO C90 makes it more portable. The C and C++ code + are together in these files but it works both for C and C++ compilers. +*) 29 dec 2007: (!) changed most integer types to unsigned int + other tweaks +*) 30 aug 2007: bug fixed which makes this Borland C++ compatible +*) 09 aug 2007: some VS2005 warnings removed again +*) 21 jul 2007: deflate code placed in new namespace separate from zlib code +*) 08 jun 2007: fixed bug with 2- and 4-bit color, and small interlaced images +*) 04 jun 2007: improved support for Visual Studio 2005: crash with accessing + invalid std::vector element [0] fixed, and level 3 and 4 warnings removed +*) 02 jun 2007: made the encoder add a tag with version by default +*) 27 may 2007: zlib and png code separated (but still in the same file), + simple encoder/decoder functions added for more simple usage cases +*) 19 may 2007: minor fixes, some code cleaning, new error added (error 69), + moved some examples from here to lodepng_examples.cpp +*) 12 may 2007: palette decoding bug fixed +*) 24 apr 2007: changed the license from BSD to the zlib license +*) 11 mar 2007: very simple addition: ability to encode bKGD chunks. +*) 04 mar 2007: (!) tEXt chunk related fixes, and support for encoding + palettized PNG images. Plus little interface change with palette and texts. +*) 03 mar 2007: Made it encode dynamic Huffman shorter with repeat codes. + Fixed a bug where the end code of a block had length 0 in the Huffman tree. +*) 26 feb 2007: Huffman compression with dynamic trees (BTYPE 2) now implemented + and supported by the encoder, resulting in smaller PNGs at the output. +*) 27 jan 2007: Made the Adler-32 test faster so that a timewaste is gone. +*) 24 jan 2007: gave encoder an error interface. Added color conversion from any + greyscale type to 8-bit greyscale with or without alpha. +*) 21 jan 2007: (!) Totally changed the interface. It allows more color types + to convert to and is more uniform. See the manual for how it works now. +*) 07 jan 2007: Some cleanup & fixes, and a few changes over the last days: + encode/decode custom tEXt chunks, separate classes for zlib & deflate, and + at last made the decoder give errors for incorrect Adler32 or Crc. +*) 01 jan 2007: Fixed bug with encoding PNGs with less than 8 bits per channel. +*) 29 dec 2006: Added support for encoding images without alpha channel, and + cleaned out code as well as making certain parts faster. +*) 28 dec 2006: Added "Settings" to the encoder. +*) 26 dec 2006: The encoder now does LZ77 encoding and produces much smaller files now. + Removed some code duplication in the decoder. Fixed little bug in an example. +*) 09 dec 2006: (!) Placed output parameters of public functions as first parameter. + Fixed a bug of the decoder with 16-bit per color. +*) 15 okt 2006: Changed documentation structure +*) 09 okt 2006: Encoder class added. It encodes a valid PNG image from the + given image buffer, however for now it's not compressed. +*) 08 sep 2006: (!) Changed to interface with a Decoder class +*) 30 jul 2006: (!) LodePNG_InfoPng , width and height are now retrieved in different + way. Renamed decodePNG to decodePNGGeneric. +*) 29 jul 2006: (!) Changed the interface: image info is now returned as a + struct of type LodePNG::LodePNG_Info, instead of a vector, which was a bit clumsy. +*) 28 jul 2006: Cleaned the code and added new error checks. + Corrected terminology "deflate" into "inflate". +*) 23 jun 2006: Added SDL example in the documentation in the header, this + example allows easy debugging by displaying the PNG and its transparency. +*) 22 jun 2006: (!) Changed way to obtain error value. Added + loadFile function for convenience. Made decodePNG32 faster. +*) 21 jun 2006: (!) Changed type of info vector to unsigned. + Changed position of palette in info vector. Fixed an important bug that + happened on PNGs with an uncompressed block. +*) 16 jun 2006: Internally changed unsigned into unsigned where + needed, and performed some optimizations. +*) 07 jun 2006: (!) Renamed functions to decodePNG and placed them + in LodePNG namespace. Changed the order of the parameters. Rewrote the + documentation in the header. Renamed files to lodepng.cpp and lodepng.h +*) 22 apr 2006: Optimized and improved some code +*) 07 sep 2005: (!) Changed to std::vector interface +*) 12 aug 2005: Initial release (C++, decoder only) + + +13. contact information +----------------------- + +Feel free to contact me with suggestions, problems, comments, ... concerning +LodePNG. If you encounter a PNG image that doesn't work properly with this +decoder, feel free to send it and I'll use it to find and fix the problem. + +My email address is (puzzle the account and domain together with an @ symbol): +Domain: gmail dot com. +Account: lode dot vandevenne. + + +Copyright (c) 2005-2017 Lode Vandevenne +*/ diff --git a/batch_run.py b/batch_run.py new file mode 100644 index 0000000..d7eeaeb --- /dev/null +++ b/batch_run.py @@ -0,0 +1,22 @@ +import os + +n = 31 #person id +gpu = 0 +audio = '03Fsi1831' + +## finetuning on a target person +cmd1='cd Data/; python extract_frame1.py %d.mp4' % n +os.system(cmd1) + +cmd2='cd Deep3DFaceReconstruction/; CUDA_VISIBLE_DEVICES=%d python demo_19news.py ../Data/%d' % (gpu,n) +os.system(cmd2) + +cmd3='cd Audio/code; python train_19news_1.py %d %d' % (n,gpu) +os.system(cmd3) + +cmd4='cd render-to-video; python train_19news_1.py %d %d' % (n,gpu) +os.system(cmd4) + +## test +cmd5='cd Audio/code; python test_personalized.py %s %d %d' % (audio,n,gpu) +os.system(cmd5) \ No newline at end of file diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..010d1d9 --- /dev/null +++ b/readme.md @@ -0,0 +1,68 @@ +# Audio-driven Talking Face Video Generation with Natural Head Pose + +We provide PyTorch implementations for our arxiv paper "Audio-driven Talking Face Video Generation with Natural Head Pose"(http://arxiv.org/abs/2002.10137). + +Note that this code is protected under patent. It is for research purposes only at your university (research institution) only. If you are interested in business purposes/for-profit use, please contact Prof.Liu (the corresponding author, email: liuyongjin@tsinghua.edu.cn). + +## Prerequisites +- Linux or macOS +- NVIDIA GPU +- Python 3 +- MATLAB + +## Getting Started +### Installation +- You can create a virtual env, and install all the dependencies by +```bash +pip install -r requirements.txt +``` + +### Download pre-trained models +- Including pre-trained general models and models needed for face reconstruction, identity feature extraction etc +- Download from xxx and copy to corresponding subfolders (Audio, Deep3DFaceReconstruction, render-to-video). + +### Fine-tune on a target peron's short video +- 1. Prepare a talking video of a single person that is 25 fps and longer than 12 seconds. Rename the video to [person_id].mp4 (e.g. 1.mp4) and copy to Data subfolder. You can make a video to 25 fps by +```bash +ffmpeg -i xxx.mp4 -r 25 xxx1.mp4 +``` +- 2. Extract frames and lanmarks by +```bash +cd Data/ +python extract_frame1.py [person_id].mp4 +``` +- 3. Conduct 3D face reconstruction. First should compile code in `Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/kernels` to .so, following its [readme](Deep3DFaceReconstruction/tf_mesh_renderer/README.md), and modify line 28 in [rasterize_triangles.py](Deep3DFaceReconstruction/tf_mesh_renderer/mesh_renderer/rasterize_triangles.py) to your directory. Then run +```bash +cd Deep3DFaceReconstruction/ +CUDA_VISIBLE_DEVICES=0 python demo_19news.py ../Data/[person_id] +``` +This process takes about 2 minutes on a Titan Xp. +- 4. Fine-tune the audio network. First modify line 28 in [rasterize_triangles.py](Audio/code/mesh_renderer/rasterize_triangles.py) to your directory. Then run +```bash +cd Audio/code/ +python train_19news_1.py [person_id] [gpu_id] +``` +The saved models are in `Audio/model/atcnet_pose0_con3/[person_id]`. +This process takes about 5 minutes on a Titan Xp. +- 5. Fine-tune the gan network. +Run +```bash +cd render-to-video/ +python train_19news_1.py [person_id] [gpu_id] +``` +The saved models are in `render-to-video/checkpoints/memory_seq_p2p/[person_id]`. +This process takes about 40 minutes on a Titan Xp. + + +### Test on a target peron +Place the audio file (.wav or .mp3) for test under `Audio/audio/`. +Run +```bash +cd Audio/code/ +python test_personalized.py [audio] [person_id] [gpu_id] +``` +This program will print 'saved to xxx.mov' if the videos are successfully generated. +It will output 2 movs, one is a video with face only (_full9.mov), the other is a video with background (_transbigbg.mov). + +## Acknowledgments +The face reconstruction code is from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction), the arcface code is from [insightface](https://github.com/deepinsight/insightface), the gan code is based on [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). \ No newline at end of file diff --git a/render-to-video/arcface/face_image.py b/render-to-video/arcface/face_image.py new file mode 100644 index 0000000..f69a6de --- /dev/null +++ b/render-to-video/arcface/face_image.py @@ -0,0 +1,269 @@ + +from easydict import EasyDict as edict +import os +import json +import numpy as np + + +def load_property(data_dir): + prop = edict() + for line in open(os.path.join(data_dir, 'property')): + vec = line.strip().split(',') + assert len(vec)==3 + prop.num_classes = int(vec[0]) + prop.image_size = [int(vec[1]), int(vec[2])] + return prop + + + +def get_dataset_webface(input_dir): + clean_list_file = input_dir+"_clean_list.txt" + ret = [] + for line in open(clean_list_file, 'r'): + vec = line.strip().split() + assert len(vec)==2 + fimage = edict() + fimage.id = vec[0].replace("\\", '/') + fimage.classname = vec[1] + fimage.image_path = os.path.join(input_dir, fimage.id) + ret.append(fimage) + return ret + +def get_dataset_celeb(input_dir): + clean_list_file = input_dir+"_clean_list.txt" + ret = [] + dir2label = {} + for line in open(clean_list_file, 'r'): + line = line.strip() + if not line.startswith('./m.'): + continue + line = line[2:] + vec = line.split('/') + assert len(vec)==2 + if vec[0] in dir2label: + label = dir2label[vec[0]] + else: + label = len(dir2label) + dir2label[vec[0]] = label + + fimage = edict() + fimage.id = line + fimage.classname = str(label) + fimage.image_path = os.path.join(input_dir, fimage.id) + ret.append(fimage) + return ret + +def _get_dataset_celeb(input_dir): + list_file = input_dir+"_original_list.txt" + ret = [] + for line in open(list_file, 'r'): + vec = line.strip().split() + assert len(vec)==2 + fimage = edict() + fimage.id = vec[0] + fimage.classname = vec[1] + fimage.image_path = os.path.join(input_dir, fimage.id) + ret.append(fimage) + return ret + +def get_dataset_facescrub(input_dir): + ret = [] + label = 0 + person_names = [] + for person_name in os.listdir(input_dir): + person_names.append(person_name) + person_names = sorted(person_names) + for person_name in person_names: + subdir = os.path.join(input_dir, person_name) + if not os.path.isdir(subdir): + continue + for _img in os.listdir(subdir): + fimage = edict() + fimage.id = os.path.join(person_name, _img) + fimage.classname = str(label) + fimage.image_path = os.path.join(subdir, _img) + fimage.landmark = None + fimage.bbox = None + ret.append(fimage) + label += 1 + return ret + +def get_dataset_megaface(input_dir): + ret = [] + label = 0 + for prefixdir in os.listdir(input_dir): + _prefixdir = os.path.join(input_dir, prefixdir) + for subdir in os.listdir(_prefixdir): + _subdir = os.path.join(_prefixdir, subdir) + if not os.path.isdir(_subdir): + continue + for img in os.listdir(_subdir): + if not img.endswith('.jpg.jpg') and img.endswith('.jpg'): + fimage = edict() + fimage.id = os.path.join(prefixdir, subdir, img) + fimage.classname = str(label) + fimage.image_path = os.path.join(_subdir, img) + json_file = fimage.image_path+".json" + data = None + fimage.bbox = None + fimage.landmark = None + if os.path.exists(json_file): + with open(json_file, 'r') as f: + data = f.read() + data = json.loads(data) + assert data is not None + if 'bounding_box' in data: + fimage.bbox = np.zeros( (4,), dtype=np.float32 ) + bb = data['bounding_box'] + fimage.bbox[0] = bb['x'] + fimage.bbox[1] = bb['y'] + fimage.bbox[2] = bb['x']+bb['width'] + fimage.bbox[3] = bb['y']+bb['height'] + #print('bb') + if 'landmarks' in data: + landmarks = data['landmarks'] + if '1' in landmarks and '0' in landmarks and '2' in landmarks: + fimage.landmark = np.zeros( (3,2), dtype=np.float32 ) + fimage.landmark[0][0] = landmarks['1']['x'] + fimage.landmark[0][1] = landmarks['1']['y'] + fimage.landmark[1][0] = landmarks['0']['x'] + fimage.landmark[1][1] = landmarks['0']['y'] + fimage.landmark[2][0] = landmarks['2']['x'] + fimage.landmark[2][1] = landmarks['2']['y'] + #print('lm') + + ret.append(fimage) + label+=1 + return ret + +def get_dataset_fgnet(input_dir): + ret = [] + label = 0 + for subdir in os.listdir(input_dir): + _subdir = os.path.join(input_dir, subdir) + if not os.path.isdir(_subdir): + continue + for img in os.listdir(_subdir): + if img.endswith('.JPG'): + fimage = edict() + fimage.id = os.path.join(_subdir, img) + fimage.classname = str(label) + fimage.image_path = os.path.join(_subdir, img) + json_file = fimage.image_path+".json" + data = None + fimage.bbox = None + fimage.landmark = None + if os.path.exists(json_file): + with open(json_file, 'r') as f: + data = f.read() + data = json.loads(data) + assert data is not None + if 'bounding_box' in data: + fimage.bbox = np.zeros( (4,), dtype=np.float32 ) + bb = data['bounding_box'] + fimage.bbox[0] = bb['x'] + fimage.bbox[1] = bb['y'] + fimage.bbox[2] = bb['x']+bb['width'] + fimage.bbox[3] = bb['y']+bb['height'] + #print('bb') + if 'landmarks' in data: + landmarks = data['landmarks'] + if '1' in landmarks and '0' in landmarks and '2' in landmarks: + fimage.landmark = np.zeros( (3,2), dtype=np.float32 ) + fimage.landmark[0][0] = landmarks['1']['x'] + fimage.landmark[0][1] = landmarks['1']['y'] + fimage.landmark[1][0] = landmarks['0']['x'] + fimage.landmark[1][1] = landmarks['0']['y'] + fimage.landmark[2][0] = landmarks['2']['x'] + fimage.landmark[2][1] = landmarks['2']['y'] + #print('lm') + + #fimage.landmark = None + ret.append(fimage) + label+=1 + return ret + +def get_dataset_ytf(input_dir): + ret = [] + label = 0 + person_names = [] + for person_name in os.listdir(input_dir): + person_names.append(person_name) + person_names = sorted(person_names) + for person_name in person_names: + _subdir = os.path.join(input_dir, person_name) + if not os.path.isdir(_subdir): + continue + for _subdir2 in os.listdir(_subdir): + _subdir2 = os.path.join(_subdir, _subdir2) + if not os.path.isdir(_subdir2): + continue + _ret = [] + for img in os.listdir(_subdir2): + fimage = edict() + fimage.id = os.path.join(_subdir2, img) + fimage.classname = str(label) + fimage.image_path = os.path.join(_subdir2, img) + fimage.bbox = None + fimage.landmark = None + _ret.append(fimage) + ret += _ret + label+=1 + return ret + +def get_dataset_clfw(input_dir): + ret = [] + label = 0 + for img in os.listdir(input_dir): + fimage = edict() + fimage.id = img + fimage.classname = str(0) + fimage.image_path = os.path.join(input_dir, img) + fimage.bbox = None + fimage.landmark = None + ret.append(fimage) + return ret + +def get_dataset_common(input_dir, min_images = 1): + ret = [] + label = 0 + person_names = [] + for person_name in os.listdir(input_dir): + person_names.append(person_name) + person_names = sorted(person_names) + for person_name in person_names: + _subdir = os.path.join(input_dir, person_name) + if not os.path.isdir(_subdir): + continue + _ret = [] + for img in os.listdir(_subdir): + fimage = edict() + fimage.id = os.path.join(person_name, img) + fimage.classname = str(label) + fimage.image_path = os.path.join(_subdir, img) + fimage.bbox = None + fimage.landmark = None + _ret.append(fimage) + if len(_ret)>=min_images: + ret += _ret + label+=1 + return ret + +def get_dataset(name, input_dir): + if name=='webface' or name=='lfw' or name=='vgg': + return get_dataset_common(input_dir) + if name=='celeb': + return get_dataset_celeb(input_dir) + if name=='facescrub': + return get_dataset_facescrub(input_dir) + if name=='megaface': + return get_dataset_megaface(input_dir) + if name=='fgnet': + return get_dataset_fgnet(input_dir) + if name=='ytf': + return get_dataset_ytf(input_dir) + if name=='clfw': + return get_dataset_clfw(input_dir) + return None + + diff --git a/render-to-video/arcface/face_model.py b/render-to-video/arcface/face_model.py new file mode 100644 index 0000000..593fe9f --- /dev/null +++ b/render-to-video/arcface/face_model.py @@ -0,0 +1,90 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from scipy import misc +import sys +import os +import argparse +#import tensorflow as tf +import numpy as np +import mxnet as mx +import random +import cv2 +import sklearn +from sklearn.decomposition import PCA +from time import sleep +from easydict import EasyDict as edict +from mtcnn_detector import MtcnnDetector +import face_image +import face_preprocess + + +def do_flip(data): + for idx in range(data.shape[0]): + data[idx,:,:] = np.fliplr(data[idx,:,:]) + +def get_model(ctx, image_size, model_str, layer): + _vec = model_str.split(',') + assert len(_vec)==2 + prefix = _vec[0] + epoch = int(_vec[1]) + print('loading',prefix, epoch) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + all_layers = sym.get_internals() + sym = all_layers[layer+'_output'] + model = mx.mod.Module(symbol=sym, context=ctx, label_names = None) + #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) + model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))]) + model.set_params(arg_params, aux_params) + return model + +class FaceModel: + def __init__(self, args): + self.args = args + ctx = mx.gpu(args.gpu) + _vec = args.image_size.split(',') + assert len(_vec)==2 + image_size = (int(_vec[0]), int(_vec[1])) + self.model = None + if len(args.model)>0: + self.model = get_model(ctx, image_size, args.model, 'fc1') + + self.threshold = args.threshold + self.det_minsize = 50 + self.det_threshold = [0.6,0.7,0.8] + #self.det_factor = 0.9 + self.image_size = image_size + mtcnn_path = os.path.join(os.path.dirname(__file__), 'mtcnn-model') + if args.det==0: + detector = MtcnnDetector(model_folder=mtcnn_path, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=self.det_threshold) + else: + detector = MtcnnDetector(model_folder=mtcnn_path, ctx=ctx, num_worker=1, accurate_landmark = True, threshold=[0.0,0.0,0.2]) + self.detector = detector + + + def get_input(self, face_img): + ret = self.detector.detect_face(face_img, det_type = self.args.det) + if ret is None: + return None + bbox, points = ret + if bbox.shape[0]==0: + return None + bbox = bbox[0,0:4] + points = points[0,:].reshape((2,5)).T + #print(bbox) + #print(points) + nimg = face_preprocess.preprocess(face_img, bbox, points, image_size='112,112') + nimg = cv2.cvtColor(nimg, cv2.COLOR_BGR2RGB) + aligned = np.transpose(nimg, (2,0,1)) + return aligned + + def get_feature(self, aligned): + input_blob = np.expand_dims(aligned, axis=0) + data = mx.nd.array(input_blob) + db = mx.io.DataBatch(data=(data,)) + self.model.forward(db, is_train=False) + embedding = self.model.get_outputs()[0].asnumpy() + embedding = sklearn.preprocessing.normalize(embedding).flatten() + return embedding + diff --git a/render-to-video/arcface/face_preprocess.py b/render-to-video/arcface/face_preprocess.py new file mode 100644 index 0000000..0b59828 --- /dev/null +++ b/render-to-video/arcface/face_preprocess.py @@ -0,0 +1,113 @@ + +import cv2 +import numpy as np +from skimage import transform as trans + +def parse_lst_line(line): + vec = line.strip().split("\t") + assert len(vec)>=3 + aligned = int(vec[0]) + image_path = vec[1] + label = int(vec[2]) + bbox = None + landmark = None + #print(vec) + if len(vec)>3: + bbox = np.zeros( (4,), dtype=np.int32) + for i in xrange(3,7): + bbox[i-3] = int(vec[i]) + landmark = None + if len(vec)>7: + _l = [] + for i in xrange(7,17): + _l.append(float(vec[i])) + landmark = np.array(_l).reshape( (2,5) ).T + #print(aligned) + return image_path, label, bbox, landmark, aligned + + + + +def read_image(img_path, **kwargs): + mode = kwargs.get('mode', 'rgb') + layout = kwargs.get('layout', 'HWC') + if mode=='gray': + img = cv2.imread(img_path, cv2.CV_LOAD_IMAGE_GRAYSCALE) + else: + img = cv2.imread(img_path, cv2.CV_LOAD_IMAGE_COLOR) + if mode=='rgb': + #print('to rgb') + img = img[...,::-1] + if layout=='CHW': + img = np.transpose(img, (2,0,1)) + return img + + +def preprocess(img, bbox=None, landmark=None, **kwargs): + if isinstance(img, str): + img = read_image(img, **kwargs) + M = None + image_size = [] + str_image_size = kwargs.get('image_size', '') + if len(str_image_size)>0: + image_size = [int(x) for x in str_image_size.split(',')] + if len(image_size)==1: + image_size = [image_size[0], image_size[0]] + assert len(image_size)==2 + assert image_size[0]==112 + assert image_size[0]==112 or image_size[1]==96 + if landmark is not None: + assert len(image_size)==2 + src = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041] ], dtype=np.float32 ) + if image_size[1]==112: + src[:,0] += 8.0 + dst = landmark.astype(np.float32) + + tform = trans.SimilarityTransform() + tform.estimate(dst, src) + M = tform.params[0:2,:] + #M = cv2.estimateRigidTransform( dst.reshape(1,5,2), src.reshape(1,5,2), False) + + if M is None: + if bbox is None: #use center crop + det = np.zeros(4, dtype=np.int32) + det[0] = int(img.shape[1]*0.0625) + det[1] = int(img.shape[0]*0.0625) + det[2] = img.shape[1] - det[0] + det[3] = img.shape[0] - det[1] + else: + det = bbox + margin = kwargs.get('margin', 44) + bb = np.zeros(4, dtype=np.int32) + bb[0] = np.maximum(det[0]-margin/2, 0) + bb[1] = np.maximum(det[1]-margin/2, 0) + bb[2] = np.minimum(det[2]+margin/2, img.shape[1]) + bb[3] = np.minimum(det[3]+margin/2, img.shape[0]) + ret = img[bb[1]:bb[3],bb[0]:bb[2],:] + if len(image_size)>0: + ret = cv2.resize(ret, (image_size[1], image_size[0])) + return ret + else: #do align using landmark + assert len(image_size)==2 + + #src = src[0:3,:] + #dst = dst[0:3,:] + + + #print(src.shape, dst.shape) + #print(src) + #print(dst) + #print(M) + warped = cv2.warpAffine(img,M,(image_size[1],image_size[0]), borderValue = 0.0) + + #tform3 = trans.ProjectiveTransform() + #tform3.estimate(src, dst) + #warped = trans.warp(img, tform3, output_shape=_shape) + return warped + + diff --git a/render-to-video/arcface/helper.py b/render-to-video/arcface/helper.py new file mode 100644 index 0000000..b82c4b7 --- /dev/null +++ b/render-to-video/arcface/helper.py @@ -0,0 +1,168 @@ +# coding: utf-8 +# YuanYang +import math +import cv2 +import numpy as np + + +def nms(boxes, overlap_threshold, mode='Union'): + """ + non max suppression + + Parameters: + ---------- + box: numpy array n x 5 + input bbox array + overlap_threshold: float number + threshold of overlap + mode: float number + how to compute overlap ratio, 'Union' or 'Min' + Returns: + ------- + index array of the selected bbox + """ + # if there are no boxes, return an empty list + if len(boxes) == 0: + return [] + + # if the bounding boxes integers, convert them to floats + if boxes.dtype.kind == "i": + boxes = boxes.astype("float") + + # initialize the list of picked indexes + pick = [] + + # grab the coordinates of the bounding boxes + x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + idxs = np.argsort(score) + + # keep looping while some indexes still remain in the indexes list + while len(idxs) > 0: + # grab the last index in the indexes list and add the index value to the list of picked indexes + last = len(idxs) - 1 + i = idxs[last] + pick.append(i) + + xx1 = np.maximum(x1[i], x1[idxs[:last]]) + yy1 = np.maximum(y1[i], y1[idxs[:last]]) + xx2 = np.minimum(x2[i], x2[idxs[:last]]) + yy2 = np.minimum(y2[i], y2[idxs[:last]]) + + # compute the width and height of the bounding box + w = np.maximum(0, xx2 - xx1 + 1) + h = np.maximum(0, yy2 - yy1 + 1) + + inter = w * h + if mode == 'Min': + overlap = inter / np.minimum(area[i], area[idxs[:last]]) + else: + overlap = inter / (area[i] + area[idxs[:last]] - inter) + + # delete all indexes from the index list that have + idxs = np.delete(idxs, np.concatenate(([last], + np.where(overlap > overlap_threshold)[0]))) + + return pick + +def adjust_input(in_data): + """ + adjust the input from (h, w, c) to ( 1, c, h, w) for network input + + Parameters: + ---------- + in_data: numpy array of shape (h, w, c) + input data + Returns: + ------- + out_data: numpy array of shape (1, c, h, w) + reshaped array + """ + if in_data.dtype is not np.dtype('float32'): + out_data = in_data.astype(np.float32) + else: + out_data = in_data + + out_data = out_data.transpose((2,0,1)) + out_data = np.expand_dims(out_data, 0) + out_data = (out_data - 127.5)*0.0078125 + return out_data + +def generate_bbox(map, reg, scale, threshold): + """ + generate bbox from feature map + Parameters: + ---------- + map: numpy array , n x m x 1 + detect score for each position + reg: numpy array , n x m x 4 + bbox + scale: float number + scale of this detection + threshold: float number + detect threshold + Returns: + ------- + bbox array + """ + stride = 2 + cellsize = 12 + + t_index = np.where(map>threshold) + + # find nothing + if t_index[0].size == 0: + return np.array([]) + + dx1, dy1, dx2, dy2 = [reg[0, i, t_index[0], t_index[1]] for i in range(4)] + + reg = np.array([dx1, dy1, dx2, dy2]) + score = map[t_index[0], t_index[1]] + boundingbox = np.vstack([np.round((stride*t_index[1]+1)/scale), + np.round((stride*t_index[0]+1)/scale), + np.round((stride*t_index[1]+1+cellsize)/scale), + np.round((stride*t_index[0]+1+cellsize)/scale), + score, + reg]) + + return boundingbox.T + + +def detect_first_stage(img, net, scale, threshold): + """ + run PNet for first stage + + Parameters: + ---------- + img: numpy array, bgr order + input image + scale: float number + how much should the input image scale + net: PNet + worker + Returns: + ------- + total_boxes : bboxes + """ + height, width, _ = img.shape + hs = int(math.ceil(height * scale)) + ws = int(math.ceil(width * scale)) + + im_data = cv2.resize(img, (ws,hs)) + + # adjust for the network input + input_buf = adjust_input(im_data) + output = net.predict(input_buf) + boxes = generate_bbox(output[1][0,1,:,:], output[0], scale, threshold) + + if boxes.size == 0: + return None + + # nms + pick = nms(boxes[:,0:5], 0.5, mode='Union') + boxes = boxes[pick] + return boxes + +def detect_first_stage_warpper( args ): + return detect_first_stage(*args) diff --git a/render-to-video/arcface/mtcnn-model/det1-0001.params b/render-to-video/arcface/mtcnn-model/det1-0001.params new file mode 100644 index 0000000..e4b04aa Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det1-0001.params differ diff --git a/render-to-video/arcface/mtcnn-model/det1-symbol.json b/render-to-video/arcface/mtcnn-model/det1-symbol.json new file mode 100644 index 0000000..bd9b772 --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det1-symbol.json @@ -0,0 +1,266 @@ +{ + "nodes": [ + { + "op": "null", + "param": {}, + "name": "data", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "10", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1", + "inputs": [[0, 0], [1, 0], [2, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1", + "inputs": [[3, 0], [4, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(2,2)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1", + "inputs": [[5, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "16", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2", + "inputs": [[6, 0], [7, 0], [8, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2", + "inputs": [[9, 0], [10, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "32", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3", + "inputs": [[11, 0], [12, 0], [13, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3", + "inputs": [[14, 0], [15, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(1,1)", + "no_bias": "False", + "num_filter": "4", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv4_2", + "inputs": [[16, 0], [17, 0], [18, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(1,1)", + "no_bias": "False", + "num_filter": "2", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv4_1", + "inputs": [[16, 0], [20, 0], [21, 0]], + "backward_source_id": -1 + }, + { + "op": "SoftmaxActivation", + "param": {"mode": "channel"}, + "name": "prob1", + "inputs": [[22, 0]], + "backward_source_id": -1 + } + ], + "arg_nodes": [ + 0, + 1, + 2, + 4, + 7, + 8, + 10, + 12, + 13, + 15, + 17, + 18, + 20, + 21 + ], + "heads": [[19, 0], [23, 0]] +} \ No newline at end of file diff --git a/render-to-video/arcface/mtcnn-model/det1.caffemodel b/render-to-video/arcface/mtcnn-model/det1.caffemodel new file mode 100644 index 0000000..79e93b4 Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det1.caffemodel differ diff --git a/render-to-video/arcface/mtcnn-model/det1.prototxt b/render-to-video/arcface/mtcnn-model/det1.prototxt new file mode 100644 index 0000000..c5c1657 --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det1.prototxt @@ -0,0 +1,177 @@ +name: "PNet" +input: "data" +input_dim: 1 +input_dim: 3 +input_dim: 12 +input_dim: 12 + +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + num_output: 10 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "PReLU1" + type: "PReLU" + bottom: "conv1" + top: "conv1" +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "pool1" + top: "conv2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + num_output: 16 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "PReLU2" + type: "PReLU" + bottom: "conv2" + top: "conv2" +} + +layer { + name: "conv3" + type: "Convolution" + bottom: "conv2" + top: "conv3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + num_output: 32 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "PReLU3" + type: "PReLU" + bottom: "conv3" + top: "conv3" +} + + +layer { + name: "conv4-1" + type: "Convolution" + bottom: "conv3" + top: "conv4-1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + num_output: 2 + kernel_size: 1 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} + +layer { + name: "conv4-2" + type: "Convolution" + bottom: "conv3" + top: "conv4-2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 0 + } + convolution_param { + num_output: 4 + kernel_size: 1 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prob1" + type: "Softmax" + bottom: "conv4-1" + top: "prob1" +} diff --git a/render-to-video/arcface/mtcnn-model/det2-0001.params b/render-to-video/arcface/mtcnn-model/det2-0001.params new file mode 100644 index 0000000..a14a478 Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det2-0001.params differ diff --git a/render-to-video/arcface/mtcnn-model/det2-symbol.json b/render-to-video/arcface/mtcnn-model/det2-symbol.json new file mode 100644 index 0000000..a13246a --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det2-symbol.json @@ -0,0 +1,324 @@ +{ + "nodes": [ + { + "op": "null", + "param": {}, + "name": "data", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "28", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1", + "inputs": [[0, 0], [1, 0], [2, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1", + "inputs": [[3, 0], [4, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1", + "inputs": [[5, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "48", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2", + "inputs": [[6, 0], [7, 0], [8, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2", + "inputs": [[9, 0], [10, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool2", + "inputs": [[11, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(2,2)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3", + "inputs": [[12, 0], [13, 0], [14, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3", + "inputs": [[15, 0], [16, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "128" + }, + "name": "conv4", + "inputs": [[17, 0], [18, 0], [19, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4", + "inputs": [[20, 0], [21, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv5_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv5_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "4" + }, + "name": "conv5_2", + "inputs": [[22, 0], [23, 0], [24, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv5_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv5_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "2" + }, + "name": "conv5_1", + "inputs": [[22, 0], [26, 0], [27, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prob1_label", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "SoftmaxOutput", + "param": { + "grad_scale": "1", + "ignore_label": "-1", + "multi_output": "False", + "normalization": "null", + "use_ignore": "False" + }, + "name": "prob1", + "inputs": [[28, 0], [29, 0]], + "backward_source_id": -1 + } + ], + "arg_nodes": [ + 0, + 1, + 2, + 4, + 7, + 8, + 10, + 13, + 14, + 16, + 18, + 19, + 21, + 23, + 24, + 26, + 27, + 29 + ], + "heads": [[25, 0], [30, 0]] +} \ No newline at end of file diff --git a/render-to-video/arcface/mtcnn-model/det2.caffemodel b/render-to-video/arcface/mtcnn-model/det2.caffemodel new file mode 100644 index 0000000..a5a540c Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det2.caffemodel differ diff --git a/render-to-video/arcface/mtcnn-model/det2.prototxt b/render-to-video/arcface/mtcnn-model/det2.prototxt new file mode 100644 index 0000000..51093e6 --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det2.prototxt @@ -0,0 +1,228 @@ +name: "RNet" +input: "data" +input_dim: 1 +input_dim: 3 +input_dim: 24 +input_dim: 24 + + +########################## +###################### +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 28 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prelu1" + type: "PReLU" + bottom: "conv1" + top: "conv1" + propagate_down: true +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "pool1" + top: "conv2" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 48 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prelu2" + type: "PReLU" + bottom: "conv2" + top: "conv2" + propagate_down: true +} +layer { + name: "pool2" + type: "Pooling" + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +#################################### + +################################## +layer { + name: "conv3" + type: "Convolution" + bottom: "pool2" + top: "conv3" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + convolution_param { + num_output: 64 + kernel_size: 2 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prelu3" + type: "PReLU" + bottom: "conv3" + top: "conv3" + propagate_down: true +} +############################### + +############################### + +layer { + name: "conv4" + type: "InnerProduct" + bottom: "conv3" + top: "conv4" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + inner_product_param { + num_output: 128 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prelu4" + type: "PReLU" + bottom: "conv4" + top: "conv4" +} + +layer { + name: "conv5-1" + type: "InnerProduct" + bottom: "conv4" + top: "conv5-1" + param { + lr_mult: 0 + decay_mult: 0 + } + param { + lr_mult: 0 + decay_mult: 0 + } + inner_product_param { + num_output: 2 + #kernel_size: 1 + #stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "conv5-2" + type: "InnerProduct" + bottom: "conv4" + top: "conv5-2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 4 + #kernel_size: 1 + #stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prob1" + type: "Softmax" + bottom: "conv5-1" + top: "prob1" +} \ No newline at end of file diff --git a/render-to-video/arcface/mtcnn-model/det3-0001.params b/render-to-video/arcface/mtcnn-model/det3-0001.params new file mode 100644 index 0000000..cae898b Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det3-0001.params differ diff --git a/render-to-video/arcface/mtcnn-model/det3-symbol.json b/render-to-video/arcface/mtcnn-model/det3-symbol.json new file mode 100644 index 0000000..00061ed --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det3-symbol.json @@ -0,0 +1,418 @@ +{ + "nodes": [ + { + "op": "null", + "param": {}, + "name": "data", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "32", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1", + "inputs": [[0, 0], [1, 0], [2, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1", + "inputs": [[3, 0], [4, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1", + "inputs": [[5, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2", + "inputs": [[6, 0], [7, 0], [8, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2", + "inputs": [[9, 0], [10, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool2", + "inputs": [[11, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3", + "inputs": [[12, 0], [13, 0], [14, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3", + "inputs": [[15, 0], [16, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(2,2)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool3", + "inputs": [[17, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(2,2)", + "no_bias": "False", + "num_filter": "128", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv4", + "inputs": [[18, 0], [19, 0], [20, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4", + "inputs": [[21, 0], [22, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv5_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv5_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "256" + }, + "name": "conv5", + "inputs": [[23, 0], [24, 0], [25, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu5_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu5", + "inputs": [[26, 0], [27, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv6_3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv6_3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "10" + }, + "name": "conv6_3", + "inputs": [[28, 0], [29, 0], [30, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv6_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv6_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "4" + }, + "name": "conv6_2", + "inputs": [[28, 0], [32, 0], [33, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv6_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv6_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "2" + }, + "name": "conv6_1", + "inputs": [[28, 0], [35, 0], [36, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prob1_label", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "SoftmaxOutput", + "param": { + "grad_scale": "1", + "ignore_label": "-1", + "multi_output": "False", + "normalization": "null", + "use_ignore": "False" + }, + "name": "prob1", + "inputs": [[37, 0], [38, 0]], + "backward_source_id": -1 + } + ], + "arg_nodes": [ + 0, + 1, + 2, + 4, + 7, + 8, + 10, + 13, + 14, + 16, + 19, + 20, + 22, + 24, + 25, + 27, + 29, + 30, + 32, + 33, + 35, + 36, + 38 + ], + "heads": [[31, 0], [34, 0], [39, 0]] +} \ No newline at end of file diff --git a/render-to-video/arcface/mtcnn-model/det3.caffemodel b/render-to-video/arcface/mtcnn-model/det3.caffemodel new file mode 100644 index 0000000..7b4b8a4 Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det3.caffemodel differ diff --git a/render-to-video/arcface/mtcnn-model/det3.prototxt b/render-to-video/arcface/mtcnn-model/det3.prototxt new file mode 100644 index 0000000..a192307 --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det3.prototxt @@ -0,0 +1,294 @@ +name: "ONet" +input: "data" +input_dim: 1 +input_dim: 3 +input_dim: 48 +input_dim: 48 +################################## +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 32 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prelu1" + type: "PReLU" + bottom: "conv1" + top: "conv1" +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layer { + name: "conv2" + type: "Convolution" + bottom: "pool1" + top: "conv2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 64 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} + +layer { + name: "prelu2" + type: "PReLU" + bottom: "conv2" + top: "conv2" +} +layer { + name: "pool2" + type: "Pooling" + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "conv3" + type: "Convolution" + bottom: "pool2" + top: "conv3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 64 + kernel_size: 3 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prelu3" + type: "PReLU" + bottom: "conv3" + top: "conv3" +} +layer { + name: "pool3" + type: "Pooling" + bottom: "conv3" + top: "pool3" + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layer { + name: "conv4" + type: "Convolution" + bottom: "pool3" + top: "conv4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 128 + kernel_size: 2 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prelu4" + type: "PReLU" + bottom: "conv4" + top: "conv4" +} + + +layer { + name: "conv5" + type: "InnerProduct" + bottom: "conv4" + top: "conv5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + #kernel_size: 3 + num_output: 256 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} + +layer { + name: "drop5" + type: "Dropout" + bottom: "conv5" + top: "conv5" + dropout_param { + dropout_ratio: 0.25 + } +} +layer { + name: "prelu5" + type: "PReLU" + bottom: "conv5" + top: "conv5" +} + + +layer { + name: "conv6-1" + type: "InnerProduct" + bottom: "conv5" + top: "conv6-1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + #kernel_size: 1 + num_output: 2 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "conv6-2" + type: "InnerProduct" + bottom: "conv5" + top: "conv6-2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + #kernel_size: 1 + num_output: 4 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "conv6-3" + type: "InnerProduct" + bottom: "conv5" + top: "conv6-3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + #kernel_size: 1 + num_output: 10 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layer { + name: "prob1" + type: "Softmax" + bottom: "conv6-1" + top: "prob1" +} diff --git a/render-to-video/arcface/mtcnn-model/det4-0001.params b/render-to-video/arcface/mtcnn-model/det4-0001.params new file mode 100644 index 0000000..efca9a9 Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det4-0001.params differ diff --git a/render-to-video/arcface/mtcnn-model/det4-symbol.json b/render-to-video/arcface/mtcnn-model/det4-symbol.json new file mode 100644 index 0000000..aa90e2a --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det4-symbol.json @@ -0,0 +1,1392 @@ +{ + "nodes": [ + { + "op": "null", + "param": {}, + "name": "data", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "SliceChannel", + "param": { + "axis": "1", + "num_outputs": "5", + "squeeze_axis": "False" + }, + "name": "slice", + "inputs": [[0, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "28", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1_1", + "inputs": [[1, 0], [2, 0], [3, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_1_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1_1", + "inputs": [[4, 0], [5, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1_1", + "inputs": [[6, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "48", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2_1", + "inputs": [[7, 0], [8, 0], [9, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_1_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2_1", + "inputs": [[10, 0], [11, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool2_1", + "inputs": [[12, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(2,2)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3_1", + "inputs": [[13, 0], [14, 0], [15, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_1_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3_1", + "inputs": [[16, 0], [17, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "28", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1_2", + "inputs": [[1, 1], [19, 0], [20, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_2_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1_2", + "inputs": [[21, 0], [22, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1_2", + "inputs": [[23, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "48", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2_2", + "inputs": [[24, 0], [25, 0], [26, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_2_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2_2", + "inputs": [[27, 0], [28, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool2_2", + "inputs": [[29, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(2,2)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3_2", + "inputs": [[30, 0], [31, 0], [32, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_2_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3_2", + "inputs": [[33, 0], [34, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "28", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1_3", + "inputs": [[1, 2], [36, 0], [37, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_3_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1_3", + "inputs": [[38, 0], [39, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1_3", + "inputs": [[40, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "48", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2_3", + "inputs": [[41, 0], [42, 0], [43, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_3_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2_3", + "inputs": [[44, 0], [45, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool2_3", + "inputs": [[46, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(2,2)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3_3", + "inputs": [[47, 0], [48, 0], [49, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_3_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3_3", + "inputs": [[50, 0], [51, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "28", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1_4", + "inputs": [[1, 3], [53, 0], [54, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_4_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1_4", + "inputs": [[55, 0], [56, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1_4", + "inputs": [[57, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "48", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2_4", + "inputs": [[58, 0], [59, 0], [60, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_4_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2_4", + "inputs": [[61, 0], [62, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool2_4", + "inputs": [[63, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(2,2)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3_4", + "inputs": [[64, 0], [65, 0], [66, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_4_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3_4", + "inputs": [[67, 0], [68, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_5_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv1_5_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "28", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv1_5", + "inputs": [[1, 4], [70, 0], [71, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu1_5_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu1_5", + "inputs": [[72, 0], [73, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool1_5", + "inputs": [[74, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_5_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv2_5_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(3,3)", + "no_bias": "False", + "num_filter": "48", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv2_5", + "inputs": [[75, 0], [76, 0], [77, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu2_5_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu2_5", + "inputs": [[78, 0], [79, 0]], + "backward_source_id": -1 + }, + { + "op": "Pooling", + "param": { + "global_pool": "False", + "kernel": "(3,3)", + "pad": "(0,0)", + "pool_type": "max", + "pooling_convention": "full", + "stride": "(2,2)" + }, + "name": "pool2_5", + "inputs": [[80, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_5_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "conv3_5_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "Convolution", + "param": { + "cudnn_off": "False", + "cudnn_tune": "off", + "dilate": "(1,1)", + "kernel": "(2,2)", + "no_bias": "False", + "num_filter": "64", + "num_group": "1", + "pad": "(0,0)", + "stride": "(1,1)", + "workspace": "1024" + }, + "name": "conv3_5", + "inputs": [[81, 0], [82, 0], [83, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu3_5_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu3_5", + "inputs": [[84, 0], [85, 0]], + "backward_source_id": -1 + }, + { + "op": "Concat", + "param": { + "dim": "1", + "num_args": "5" + }, + "name": "concat", + "inputs": [[18, 0], [35, 0], [52, 0], [69, 0], [86, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "256" + }, + "name": "fc4", + "inputs": [[87, 0], [88, 0], [89, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4", + "inputs": [[90, 0], [91, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "64" + }, + "name": "fc4_1", + "inputs": [[92, 0], [93, 0], [94, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_1_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4_1", + "inputs": [[95, 0], [96, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_1_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_1_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "2" + }, + "name": "fc5_1", + "inputs": [[97, 0], [98, 0], [99, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "64" + }, + "name": "fc4_2", + "inputs": [[92, 0], [101, 0], [102, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_2_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4_2", + "inputs": [[103, 0], [104, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_2_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_2_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "2" + }, + "name": "fc5_2", + "inputs": [[105, 0], [106, 0], [107, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "64" + }, + "name": "fc4_3", + "inputs": [[92, 0], [109, 0], [110, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_3_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4_3", + "inputs": [[111, 0], [112, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_3_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_3_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "2" + }, + "name": "fc5_3", + "inputs": [[113, 0], [114, 0], [115, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "64" + }, + "name": "fc4_4", + "inputs": [[92, 0], [117, 0], [118, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_4_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4_4", + "inputs": [[119, 0], [120, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_4_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_4_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "2" + }, + "name": "fc5_4", + "inputs": [[121, 0], [122, 0], [123, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_5_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc4_5_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "64" + }, + "name": "fc4_5", + "inputs": [[92, 0], [125, 0], [126, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "prelu4_5_gamma", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "LeakyReLU", + "param": { + "act_type": "prelu", + "lower_bound": "0.125", + "slope": "0.25", + "upper_bound": "0.334" + }, + "name": "prelu4_5", + "inputs": [[127, 0], [128, 0]], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_5_weight", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "null", + "param": {}, + "name": "fc5_5_bias", + "inputs": [], + "backward_source_id": -1 + }, + { + "op": "FullyConnected", + "param": { + "no_bias": "False", + "num_hidden": "2" + }, + "name": "fc5_5", + "inputs": [[129, 0], [130, 0], [131, 0]], + "backward_source_id": -1 + } + ], + "arg_nodes": [ + 0, + 2, + 3, + 5, + 8, + 9, + 11, + 14, + 15, + 17, + 19, + 20, + 22, + 25, + 26, + 28, + 31, + 32, + 34, + 36, + 37, + 39, + 42, + 43, + 45, + 48, + 49, + 51, + 53, + 54, + 56, + 59, + 60, + 62, + 65, + 66, + 68, + 70, + 71, + 73, + 76, + 77, + 79, + 82, + 83, + 85, + 88, + 89, + 91, + 93, + 94, + 96, + 98, + 99, + 101, + 102, + 104, + 106, + 107, + 109, + 110, + 112, + 114, + 115, + 117, + 118, + 120, + 122, + 123, + 125, + 126, + 128, + 130, + 131 + ], + "heads": [[100, 0], [108, 0], [116, 0], [124, 0], [132, 0]] +} \ No newline at end of file diff --git a/render-to-video/arcface/mtcnn-model/det4.caffemodel b/render-to-video/arcface/mtcnn-model/det4.caffemodel new file mode 100644 index 0000000..38353c4 Binary files /dev/null and b/render-to-video/arcface/mtcnn-model/det4.caffemodel differ diff --git a/render-to-video/arcface/mtcnn-model/det4.prototxt b/render-to-video/arcface/mtcnn-model/det4.prototxt new file mode 100644 index 0000000..4cdc329 --- /dev/null +++ b/render-to-video/arcface/mtcnn-model/det4.prototxt @@ -0,0 +1,995 @@ +name: "LNet" +input: "data" +input_dim: 1 +input_dim: 15 +input_dim: 24 +input_dim: 24 + +layer { + name: "slicer_data" + type: "Slice" + bottom: "data" + top: "data241" + top: "data242" + top: "data243" + top: "data244" + top: "data245" + slice_param { + axis: 1 + slice_point: 3 + slice_point: 6 + slice_point: 9 + slice_point: 12 + } +} +layer { + name: "conv1_1" + type: "Convolution" + bottom: "data241" + top: "conv1_1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 28 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu1_1" + type: "PReLU" + bottom: "conv1_1" + top: "conv1_1" + +} +layer { + name: "pool1_1" + type: "Pooling" + bottom: "conv1_1" + top: "pool1_1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "conv2_1" + type: "Convolution" + bottom: "pool1_1" + top: "conv2_1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 48 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu2_1" + type: "PReLU" + bottom: "conv2_1" + top: "conv2_1" +} +layer { + name: "pool2_1" + type: "Pooling" + bottom: "conv2_1" + top: "pool2_1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + +} +layer { + name: "conv3_1" + type: "Convolution" + bottom: "pool2_1" + top: "conv3_1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 64 + kernel_size: 2 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu3_1" + type: "PReLU" + bottom: "conv3_1" + top: "conv3_1" +} +########################## +layer { + name: "conv1_2" + type: "Convolution" + bottom: "data242" + top: "conv1_2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 28 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu1_2" + type: "PReLU" + bottom: "conv1_2" + top: "conv1_2" + +} +layer { + name: "pool1_2" + type: "Pooling" + bottom: "conv1_2" + top: "pool1_2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "conv2_2" + type: "Convolution" + bottom: "pool1_2" + top: "conv2_2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 48 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu2_2" + type: "PReLU" + bottom: "conv2_2" + top: "conv2_2" +} +layer { + name: "pool2_2" + type: "Pooling" + bottom: "conv2_2" + top: "pool2_2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + +} +layer { + name: "conv3_2" + type: "Convolution" + bottom: "pool2_2" + top: "conv3_2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 64 + kernel_size: 2 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu3_2" + type: "PReLU" + bottom: "conv3_2" + top: "conv3_2" +} +########################## +########################## +layer { + name: "conv1_3" + type: "Convolution" + bottom: "data243" + top: "conv1_3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 28 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu1_3" + type: "PReLU" + bottom: "conv1_3" + top: "conv1_3" + +} +layer { + name: "pool1_3" + type: "Pooling" + bottom: "conv1_3" + top: "pool1_3" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "conv2_3" + type: "Convolution" + bottom: "pool1_3" + top: "conv2_3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 48 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu2_3" + type: "PReLU" + bottom: "conv2_3" + top: "conv2_3" +} +layer { + name: "pool2_3" + type: "Pooling" + bottom: "conv2_3" + top: "pool2_3" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + +} +layer { + name: "conv3_3" + type: "Convolution" + bottom: "pool2_3" + top: "conv3_3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 64 + kernel_size: 2 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu3_3" + type: "PReLU" + bottom: "conv3_3" + top: "conv3_3" +} +########################## +########################## +layer { + name: "conv1_4" + type: "Convolution" + bottom: "data244" + top: "conv1_4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 28 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu1_4" + type: "PReLU" + bottom: "conv1_4" + top: "conv1_4" + +} +layer { + name: "pool1_4" + type: "Pooling" + bottom: "conv1_4" + top: "pool1_4" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "conv2_4" + type: "Convolution" + bottom: "pool1_4" + top: "conv2_4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 48 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu2_4" + type: "PReLU" + bottom: "conv2_4" + top: "conv2_4" +} +layer { + name: "pool2_4" + type: "Pooling" + bottom: "conv2_4" + top: "pool2_4" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + +} +layer { + name: "conv3_4" + type: "Convolution" + bottom: "pool2_4" + top: "conv3_4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 64 + kernel_size: 2 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu3_4" + type: "PReLU" + bottom: "conv3_4" + top: "conv3_4" +} +########################## +########################## +layer { + name: "conv1_5" + type: "Convolution" + bottom: "data245" + top: "conv1_5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 28 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu1_5" + type: "PReLU" + bottom: "conv1_5" + top: "conv1_5" + +} +layer { + name: "pool1_5" + type: "Pooling" + bottom: "conv1_5" + top: "pool1_5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "conv2_5" + type: "Convolution" + bottom: "pool1_5" + top: "conv2_5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 48 + kernel_size: 3 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu2_5" + type: "PReLU" + bottom: "conv2_5" + top: "conv2_5" +} +layer { + name: "pool2_5" + type: "Pooling" + bottom: "conv2_5" + top: "pool2_5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } + +} +layer { + name: "conv3_5" + type: "Convolution" + bottom: "pool2_5" + top: "conv3_5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + convolution_param { + num_output: 64 + kernel_size: 2 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu3_5" + type: "PReLU" + bottom: "conv3_5" + top: "conv3_5" +} +########################## +layer { + name: "concat" + bottom: "conv3_1" + bottom: "conv3_2" + bottom: "conv3_3" + bottom: "conv3_4" + bottom: "conv3_5" + top: "conv3" + type: "Concat" + concat_param { + axis: 1 + } +} +########################## +layer { + name: "fc4" + type: "InnerProduct" + bottom: "conv3" + top: "fc4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 256 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu4" + type: "PReLU" + bottom: "fc4" + top: "fc4" +} +############################ +layer { + name: "fc4_1" + type: "InnerProduct" + bottom: "fc4" + top: "fc4_1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 64 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu4_1" + type: "PReLU" + bottom: "fc4_1" + top: "fc4_1" +} +layer { + name: "fc5_1" + type: "InnerProduct" + bottom: "fc4_1" + top: "fc5_1" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 2 + weight_filler { + type: "xavier" + #type: "constant" + #value: 0 + } + bias_filler { + type: "constant" + value: 0 + } + } +} + + +######################### +layer { + name: "fc4_2" + type: "InnerProduct" + bottom: "fc4" + top: "fc4_2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 64 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu4_2" + type: "PReLU" + bottom: "fc4_2" + top: "fc4_2" +} +layer { + name: "fc5_2" + type: "InnerProduct" + bottom: "fc4_2" + top: "fc5_2" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 2 + weight_filler { + type: "xavier" + #type: "constant" + #value: 0 + } + bias_filler { + type: "constant" + value: 0 + } + } +} + +######################### +layer { + name: "fc4_3" + type: "InnerProduct" + bottom: "fc4" + top: "fc4_3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 64 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu4_3" + type: "PReLU" + bottom: "fc4_3" + top: "fc4_3" +} +layer { + name: "fc5_3" + type: "InnerProduct" + bottom: "fc4_3" + top: "fc5_3" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 2 + weight_filler { + type: "xavier" + #type: "constant" + #value: 0 + } + bias_filler { + type: "constant" + value: 0 + } + } +} + +######################### +layer { + name: "fc4_4" + type: "InnerProduct" + bottom: "fc4" + top: "fc4_4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 64 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu4_4" + type: "PReLU" + bottom: "fc4_4" + top: "fc4_4" +} +layer { + name: "fc5_4" + type: "InnerProduct" + bottom: "fc4_4" + top: "fc5_4" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 2 + weight_filler { + type: "xavier" + #type: "constant" + #value: 0 + } + bias_filler { + type: "constant" + value: 0 + } + } +} + +######################### +layer { + name: "fc4_5" + type: "InnerProduct" + bottom: "fc4" + top: "fc4_5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 64 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + value: 0 + } + } + +} +layer { + name: "prelu4_5" + type: "PReLU" + bottom: "fc4_5" + top: "fc4_5" +} +layer { + name: "fc5_5" + type: "InnerProduct" + bottom: "fc4_5" + top: "fc5_5" + param { + lr_mult: 1 + decay_mult: 1 + } + param { + lr_mult: 2 + decay_mult: 1 + } + inner_product_param { + num_output: 2 + weight_filler { + type: "xavier" + #type: "constant" + #value: 0 + } + bias_filler { + type: "constant" + value: 0 + } + } +} + +######################### + diff --git a/render-to-video/arcface/mtcnn_detector.py b/render-to-video/arcface/mtcnn_detector.py new file mode 100644 index 0000000..c7332a5 --- /dev/null +++ b/render-to-video/arcface/mtcnn_detector.py @@ -0,0 +1,659 @@ +# coding: utf-8 +import os +import mxnet as mx +import numpy as np +import math +import cv2 +from multiprocessing import Pool +from itertools import repeat +try: + from itertools import izip +except ImportError: + izip = zip + +from helper import nms, adjust_input, generate_bbox, detect_first_stage_warpper + +class MtcnnDetector(object): + """ + Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Neural Networks + see https://github.com/kpzhang93/MTCNN_face_detection_alignment + this is a mxnet version + """ + def __init__(self, + model_folder='.', + minsize = 20, + threshold = [0.6, 0.7, 0.8], + factor = 0.709, + num_worker = 1, + accurate_landmark = False, + ctx=mx.cpu()): + """ + Initialize the detector + + Parameters: + ---------- + model_folder : string + path for the models + minsize : float number + minimal face to detect + threshold : float number + detect threshold for 3 stages + factor: float number + scale factor for image pyramid + num_worker: int number + number of processes we use for first stage + accurate_landmark: bool + use accurate landmark localization or not + + """ + self.num_worker = num_worker + self.accurate_landmark = accurate_landmark + + # load 4 models from folder + models = ['det1', 'det2', 'det3','det4'] + models = [ os.path.join(model_folder, f) for f in models] + + self.PNets = [] + for i in range(num_worker): + workner_net = mx.model.FeedForward.load(models[0], 1, ctx=ctx) + self.PNets.append(workner_net) + + #self.Pool = Pool(num_worker) + + self.RNet = mx.model.FeedForward.load(models[1], 1, ctx=ctx) + self.ONet = mx.model.FeedForward.load(models[2], 1, ctx=ctx) + self.LNet = mx.model.FeedForward.load(models[3], 1, ctx=ctx) + + self.minsize = float(minsize) + self.factor = float(factor) + self.threshold = threshold + + + def convert_to_square(self, bbox): + """ + convert bbox to square + + Parameters: + ---------- + bbox: numpy array , shape n x 5 + input bbox + + Returns: + ------- + square bbox + """ + square_bbox = bbox.copy() + + h = bbox[:, 3] - bbox[:, 1] + 1 + w = bbox[:, 2] - bbox[:, 0] + 1 + max_side = np.maximum(h,w) + square_bbox[:, 0] = bbox[:, 0] + w*0.5 - max_side*0.5 + square_bbox[:, 1] = bbox[:, 1] + h*0.5 - max_side*0.5 + square_bbox[:, 2] = square_bbox[:, 0] + max_side - 1 + square_bbox[:, 3] = square_bbox[:, 1] + max_side - 1 + return square_bbox + + def calibrate_box(self, bbox, reg): + """ + calibrate bboxes + + Parameters: + ---------- + bbox: numpy array, shape n x 5 + input bboxes + reg: numpy array, shape n x 4 + bboxex adjustment + + Returns: + ------- + bboxes after refinement + + """ + w = bbox[:, 2] - bbox[:, 0] + 1 + w = np.expand_dims(w, 1) + h = bbox[:, 3] - bbox[:, 1] + 1 + h = np.expand_dims(h, 1) + reg_m = np.hstack([w, h, w, h]) + aug = reg_m * reg + bbox[:, 0:4] = bbox[:, 0:4] + aug + return bbox + + + def pad(self, bboxes, w, h): + """ + pad the the bboxes, alse restrict the size of it + + Parameters: + ---------- + bboxes: numpy array, n x 5 + input bboxes + w: float number + width of the input image + h: float number + height of the input image + Returns : + ------s + dy, dx : numpy array, n x 1 + start point of the bbox in target image + edy, edx : numpy array, n x 1 + end point of the bbox in target image + y, x : numpy array, n x 1 + start point of the bbox in original image + ex, ex : numpy array, n x 1 + end point of the bbox in original image + tmph, tmpw: numpy array, n x 1 + height and width of the bbox + + """ + tmpw, tmph = bboxes[:, 2] - bboxes[:, 0] + 1, bboxes[:, 3] - bboxes[:, 1] + 1 + num_box = bboxes.shape[0] + + dx , dy= np.zeros((num_box, )), np.zeros((num_box, )) + edx, edy = tmpw.copy()-1, tmph.copy()-1 + + x, y, ex, ey = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] + + tmp_index = np.where(ex > w-1) + edx[tmp_index] = tmpw[tmp_index] + w - 2 - ex[tmp_index] + ex[tmp_index] = w - 1 + + tmp_index = np.where(ey > h-1) + edy[tmp_index] = tmph[tmp_index] + h - 2 - ey[tmp_index] + ey[tmp_index] = h - 1 + + tmp_index = np.where(x < 0) + dx[tmp_index] = 0 - x[tmp_index] + x[tmp_index] = 0 + + tmp_index = np.where(y < 0) + dy[tmp_index] = 0 - y[tmp_index] + y[tmp_index] = 0 + + return_list = [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] + return_list = [item.astype(np.int32) for item in return_list] + + return return_list + + def slice_index(self, number): + """ + slice the index into (n,n,m), m < n + Parameters: + ---------- + number: int number + number + """ + def chunks(l, n): + """Yield successive n-sized chunks from l.""" + for i in range(0, len(l), n): + yield l[i:i + n] + num_list = range(number) + return list(chunks(num_list, self.num_worker)) + + def detect_face_limited(self, img, det_type=2): + height, width, _ = img.shape + if det_type>=2: + total_boxes = np.array( [ [0.0, 0.0, img.shape[1], img.shape[0], 0.9] ] ,dtype=np.float32) + num_box = total_boxes.shape[0] + + # pad the bbox + [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(total_boxes, width, height) + # (3, 24, 24) is the input shape for RNet + input_buf = np.zeros((num_box, 3, 24, 24), dtype=np.float32) + + for i in range(num_box): + tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) + tmp[dy[i]:edy[i]+1, dx[i]:edx[i]+1, :] = img[y[i]:ey[i]+1, x[i]:ex[i]+1, :] + input_buf[i, :, :, :] = adjust_input(cv2.resize(tmp, (24, 24))) + + output = self.RNet.predict(input_buf) + + # filter the total_boxes with threshold + passed = np.where(output[1][:, 1] > self.threshold[1]) + total_boxes = total_boxes[passed] + + if total_boxes.size == 0: + return None + + total_boxes[:, 4] = output[1][passed, 1].reshape((-1,)) + reg = output[0][passed] + + # nms + pick = nms(total_boxes, 0.7, 'Union') + total_boxes = total_boxes[pick] + total_boxes = self.calibrate_box(total_boxes, reg[pick]) + total_boxes = self.convert_to_square(total_boxes) + total_boxes[:, 0:4] = np.round(total_boxes[:, 0:4]) + else: + total_boxes = np.array( [ [0.0, 0.0, img.shape[1], img.shape[0], 0.9] ] ,dtype=np.float32) + num_box = total_boxes.shape[0] + [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(total_boxes, width, height) + # (3, 48, 48) is the input shape for ONet + input_buf = np.zeros((num_box, 3, 48, 48), dtype=np.float32) + + for i in range(num_box): + tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.float32) + tmp[dy[i]:edy[i]+1, dx[i]:edx[i]+1, :] = img[y[i]:ey[i]+1, x[i]:ex[i]+1, :] + input_buf[i, :, :, :] = adjust_input(cv2.resize(tmp, (48, 48))) + + output = self.ONet.predict(input_buf) + #print(output[2]) + + # filter the total_boxes with threshold + passed = np.where(output[2][:, 1] > self.threshold[2]) + total_boxes = total_boxes[passed] + + if total_boxes.size == 0: + return None + + total_boxes[:, 4] = output[2][passed, 1].reshape((-1,)) + reg = output[1][passed] + points = output[0][passed] + + # compute landmark points + bbw = total_boxes[:, 2] - total_boxes[:, 0] + 1 + bbh = total_boxes[:, 3] - total_boxes[:, 1] + 1 + points[:, 0:5] = np.expand_dims(total_boxes[:, 0], 1) + np.expand_dims(bbw, 1) * points[:, 0:5] + points[:, 5:10] = np.expand_dims(total_boxes[:, 1], 1) + np.expand_dims(bbh, 1) * points[:, 5:10] + + # nms + total_boxes = self.calibrate_box(total_boxes, reg) + pick = nms(total_boxes, 0.7, 'Min') + total_boxes = total_boxes[pick] + points = points[pick] + + if not self.accurate_landmark: + return total_boxes, points + + ############################################# + # extended stage + ############################################# + num_box = total_boxes.shape[0] + patchw = np.maximum(total_boxes[:, 2]-total_boxes[:, 0]+1, total_boxes[:, 3]-total_boxes[:, 1]+1) + patchw = np.round(patchw*0.25) + + # make it even + patchw[np.where(np.mod(patchw,2) == 1)] += 1 + + input_buf = np.zeros((num_box, 15, 24, 24), dtype=np.float32) + for i in range(5): + x, y = points[:, i], points[:, i+5] + x, y = np.round(x-0.5*patchw), np.round(y-0.5*patchw) + [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(np.vstack([x, y, x+patchw-1, y+patchw-1]).T, + width, + height) + for j in range(num_box): + tmpim = np.zeros((tmpw[j], tmpw[j], 3), dtype=np.float32) + tmpim[dy[j]:edy[j]+1, dx[j]:edx[j]+1, :] = img[y[j]:ey[j]+1, x[j]:ex[j]+1, :] + input_buf[j, i*3:i*3+3, :, :] = adjust_input(cv2.resize(tmpim, (24, 24))) + + output = self.LNet.predict(input_buf) + + pointx = np.zeros((num_box, 5)) + pointy = np.zeros((num_box, 5)) + + for k in range(5): + # do not make a large movement + tmp_index = np.where(np.abs(output[k]-0.5) > 0.35) + output[k][tmp_index[0]] = 0.5 + + pointx[:, k] = np.round(points[:, k] - 0.5*patchw) + output[k][:, 0]*patchw + pointy[:, k] = np.round(points[:, k+5] - 0.5*patchw) + output[k][:, 1]*patchw + + points = np.hstack([pointx, pointy]) + points = points.astype(np.int32) + + return total_boxes, points + + def detect_face(self, img, det_type=0): + """ + detect face over img + Parameters: + ---------- + img: numpy array, bgr order of shape (1, 3, n, m) + input image + Retures: + ------- + bboxes: numpy array, n x 5 (x1,y2,x2,y2,score) + bboxes + points: numpy array, n x 10 (x1, x2 ... x5, y1, y2 ..y5) + landmarks + """ + + # check input + height, width, _ = img.shape + if det_type==0: + MIN_DET_SIZE = 12 + + if img is None: + return None + + # only works for color image + if len(img.shape) != 3: + return None + + # detected boxes + total_boxes = [] + + minl = min( height, width) + + # get all the valid scales + scales = [] + m = MIN_DET_SIZE/self.minsize + minl *= m + factor_count = 0 + while minl > MIN_DET_SIZE: + scales.append(m*self.factor**factor_count) + minl *= self.factor + factor_count += 1 + + ############################################# + # first stage + ############################################# + #for scale in scales: + # return_boxes = self.detect_first_stage(img, scale, 0) + # if return_boxes is not None: + # total_boxes.append(return_boxes) + + sliced_index = self.slice_index(len(scales)) + total_boxes = [] + for batch in sliced_index: + #local_boxes = self.Pool.map( detect_first_stage_warpper, \ + # izip(repeat(img), self.PNets[:len(batch)], [scales[i] for i in batch], repeat(self.threshold[0])) ) + local_boxes = map( detect_first_stage_warpper, \ + izip(repeat(img), self.PNets[:len(batch)], [scales[i] for i in batch], repeat(self.threshold[0])) ) + total_boxes.extend(local_boxes) + + # remove the Nones + total_boxes = [ i for i in total_boxes if i is not None] + + if len(total_boxes) == 0: + return None + + total_boxes = np.vstack(total_boxes) + + if total_boxes.size == 0: + return None + + # merge the detection from first stage + pick = nms(total_boxes[:, 0:5], 0.7, 'Union') + total_boxes = total_boxes[pick] + + bbw = total_boxes[:, 2] - total_boxes[:, 0] + 1 + bbh = total_boxes[:, 3] - total_boxes[:, 1] + 1 + + # refine the bboxes + total_boxes = np.vstack([total_boxes[:, 0]+total_boxes[:, 5] * bbw, + total_boxes[:, 1]+total_boxes[:, 6] * bbh, + total_boxes[:, 2]+total_boxes[:, 7] * bbw, + total_boxes[:, 3]+total_boxes[:, 8] * bbh, + total_boxes[:, 4] + ]) + + total_boxes = total_boxes.T + total_boxes = self.convert_to_square(total_boxes) + total_boxes[:, 0:4] = np.round(total_boxes[:, 0:4]) + else: + total_boxes = np.array( [ [0.0, 0.0, img.shape[1], img.shape[0], 0.9] ] ,dtype=np.float32) + + ############################################# + # second stage + ############################################# + num_box = total_boxes.shape[0] + + # pad the bbox + [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(total_boxes, width, height) + # (3, 24, 24) is the input shape for RNet + input_buf = np.zeros((num_box, 3, 24, 24), dtype=np.float32) + + for i in range(num_box): + tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) + tmp[dy[i]:edy[i]+1, dx[i]:edx[i]+1, :] = img[y[i]:ey[i]+1, x[i]:ex[i]+1, :] + input_buf[i, :, :, :] = adjust_input(cv2.resize(tmp, (24, 24))) + + output = self.RNet.predict(input_buf) + + # filter the total_boxes with threshold + passed = np.where(output[1][:, 1] > self.threshold[1]) + total_boxes = total_boxes[passed] + + if total_boxes.size == 0: + return None + + total_boxes[:, 4] = output[1][passed, 1].reshape((-1,)) + reg = output[0][passed] + + # nms + pick = nms(total_boxes, 0.7, 'Union') + total_boxes = total_boxes[pick] + total_boxes = self.calibrate_box(total_boxes, reg[pick]) + total_boxes = self.convert_to_square(total_boxes) + total_boxes[:, 0:4] = np.round(total_boxes[:, 0:4]) + + ############################################# + # third stage + ############################################# + num_box = total_boxes.shape[0] + + # pad the bbox + [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(total_boxes, width, height) + # (3, 48, 48) is the input shape for ONet + input_buf = np.zeros((num_box, 3, 48, 48), dtype=np.float32) + + for i in range(num_box): + tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.float32) + tmp[dy[i]:edy[i]+1, dx[i]:edx[i]+1, :] = img[y[i]:ey[i]+1, x[i]:ex[i]+1, :] + input_buf[i, :, :, :] = adjust_input(cv2.resize(tmp, (48, 48))) + + output = self.ONet.predict(input_buf) + + # filter the total_boxes with threshold + passed = np.where(output[2][:, 1] > self.threshold[2]) + total_boxes = total_boxes[passed] + + if total_boxes.size == 0: + return None + + total_boxes[:, 4] = output[2][passed, 1].reshape((-1,)) + reg = output[1][passed] + points = output[0][passed] + + # compute landmark points + bbw = total_boxes[:, 2] - total_boxes[:, 0] + 1 + bbh = total_boxes[:, 3] - total_boxes[:, 1] + 1 + points[:, 0:5] = np.expand_dims(total_boxes[:, 0], 1) + np.expand_dims(bbw, 1) * points[:, 0:5] + points[:, 5:10] = np.expand_dims(total_boxes[:, 1], 1) + np.expand_dims(bbh, 1) * points[:, 5:10] + + # nms + total_boxes = self.calibrate_box(total_boxes, reg) + pick = nms(total_boxes, 0.7, 'Min') + total_boxes = total_boxes[pick] + points = points[pick] + + if not self.accurate_landmark: + return total_boxes, points + + ############################################# + # extended stage + ############################################# + num_box = total_boxes.shape[0] + patchw = np.maximum(total_boxes[:, 2]-total_boxes[:, 0]+1, total_boxes[:, 3]-total_boxes[:, 1]+1) + patchw = np.round(patchw*0.25) + + # make it even + patchw[np.where(np.mod(patchw,2) == 1)] += 1 + + input_buf = np.zeros((num_box, 15, 24, 24), dtype=np.float32) + for i in range(5): + x, y = points[:, i], points[:, i+5] + x, y = np.round(x-0.5*patchw), np.round(y-0.5*patchw) + [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(np.vstack([x, y, x+patchw-1, y+patchw-1]).T, + width, + height) + for j in range(num_box): + tmpim = np.zeros((tmpw[j], tmpw[j], 3), dtype=np.float32) + tmpim[dy[j]:edy[j]+1, dx[j]:edx[j]+1, :] = img[y[j]:ey[j]+1, x[j]:ex[j]+1, :] + input_buf[j, i*3:i*3+3, :, :] = adjust_input(cv2.resize(tmpim, (24, 24))) + + output = self.LNet.predict(input_buf) + + pointx = np.zeros((num_box, 5)) + pointy = np.zeros((num_box, 5)) + + for k in range(5): + # do not make a large movement + tmp_index = np.where(np.abs(output[k]-0.5) > 0.35) + output[k][tmp_index[0]] = 0.5 + + pointx[:, k] = np.round(points[:, k] - 0.5*patchw) + output[k][:, 0]*patchw + pointy[:, k] = np.round(points[:, k+5] - 0.5*patchw) + output[k][:, 1]*patchw + + points = np.hstack([pointx, pointy]) + points = points.astype(np.int32) + + return total_boxes, points + + + + def list2colmatrix(self, pts_list): + """ + convert list to column matrix + Parameters: + ---------- + pts_list: + input list + Retures: + ------- + colMat: + + """ + assert len(pts_list) > 0 + colMat = [] + for i in range(len(pts_list)): + colMat.append(pts_list[i][0]) + colMat.append(pts_list[i][1]) + colMat = np.matrix(colMat).transpose() + return colMat + + def find_tfrom_between_shapes(self, from_shape, to_shape): + """ + find transform between shapes + Parameters: + ---------- + from_shape: + to_shape: + Retures: + ------- + tran_m: + tran_b: + """ + assert from_shape.shape[0] == to_shape.shape[0] and from_shape.shape[0] % 2 == 0 + + sigma_from = 0.0 + sigma_to = 0.0 + cov = np.matrix([[0.0, 0.0], [0.0, 0.0]]) + + # compute the mean and cov + from_shape_points = from_shape.reshape(from_shape.shape[0]/2, 2) + to_shape_points = to_shape.reshape(to_shape.shape[0]/2, 2) + mean_from = from_shape_points.mean(axis=0) + mean_to = to_shape_points.mean(axis=0) + + for i in range(from_shape_points.shape[0]): + temp_dis = np.linalg.norm(from_shape_points[i] - mean_from) + sigma_from += temp_dis * temp_dis + temp_dis = np.linalg.norm(to_shape_points[i] - mean_to) + sigma_to += temp_dis * temp_dis + cov += (to_shape_points[i].transpose() - mean_to.transpose()) * (from_shape_points[i] - mean_from) + + sigma_from = sigma_from / to_shape_points.shape[0] + sigma_to = sigma_to / to_shape_points.shape[0] + cov = cov / to_shape_points.shape[0] + + # compute the affine matrix + s = np.matrix([[1.0, 0.0], [0.0, 1.0]]) + u, d, vt = np.linalg.svd(cov) + + if np.linalg.det(cov) < 0: + if d[1] < d[0]: + s[1, 1] = -1 + else: + s[0, 0] = -1 + r = u * s * vt + c = 1.0 + if sigma_from != 0: + c = 1.0 / sigma_from * np.trace(np.diag(d) * s) + + tran_b = mean_to.transpose() - c * r * mean_from.transpose() + tran_m = c * r + + return tran_m, tran_b + + def extract_image_chips(self, img, points, desired_size=256, padding=0): + """ + crop and align face + Parameters: + ---------- + img: numpy array, bgr order of shape (1, 3, n, m) + input image + points: numpy array, n x 10 (x1, x2 ... x5, y1, y2 ..y5) + desired_size: default 256 + padding: default 0 + Retures: + ------- + crop_imgs: list, n + cropped and aligned faces + """ + crop_imgs = [] + for p in points: + shape =[] + for k in range(len(p)/2): + shape.append(p[k]) + shape.append(p[k+5]) + + if padding > 0: + padding = padding + else: + padding = 0 + # average positions of face points + mean_face_shape_x = [0.224152, 0.75610125, 0.490127, 0.254149, 0.726104] + mean_face_shape_y = [0.2119465, 0.2119465, 0.628106, 0.780233, 0.780233] + + from_points = [] + to_points = [] + + for i in range(len(shape)/2): + x = (padding + mean_face_shape_x[i]) / (2 * padding + 1) * desired_size + y = (padding + mean_face_shape_y[i]) / (2 * padding + 1) * desired_size + to_points.append([x, y]) + from_points.append([shape[2*i], shape[2*i+1]]) + + # convert the points to Mat + from_mat = self.list2colmatrix(from_points) + to_mat = self.list2colmatrix(to_points) + + # compute the similar transfrom + tran_m, tran_b = self.find_tfrom_between_shapes(from_mat, to_mat) + + probe_vec = np.matrix([1.0, 0.0]).transpose() + probe_vec = tran_m * probe_vec + + scale = np.linalg.norm(probe_vec) + angle = 180.0 / math.pi * math.atan2(probe_vec[1, 0], probe_vec[0, 0]) + + from_center = [(shape[0]+shape[2])/2.0, (shape[1]+shape[3])/2.0] + to_center = [0, 0] + to_center[1] = desired_size * 0.4 + to_center[0] = desired_size * 0.5 + + ex = to_center[0] - from_center[0] + ey = to_center[1] - from_center[1] + + rot_mat = cv2.getRotationMatrix2D((from_center[0], from_center[1]), -1*angle, scale) + rot_mat[0][2] += ex + rot_mat[1][2] += ey + + chips = cv2.warpAffine(img, rot_mat, (desired_size, desired_size)) + crop_imgs.append(chips) + + return crop_imgs + diff --git a/render-to-video/arcface/test_batch.py b/render-to-video/arcface/test_batch.py new file mode 100644 index 0000000..2ac6610 --- /dev/null +++ b/render-to-video/arcface/test_batch.py @@ -0,0 +1,57 @@ +import face_model +import argparse +import cv2 +import sys +import numpy as np +import os +import glob +import pdb +import time + +for n in range(26,27): + parser = argparse.ArgumentParser(description='face model test') + # general + parser.add_argument('--image-size', default='112,112', help='') + parser.add_argument('--model', default='models/model-r100-ii/model,0', help='path to load model.') + parser.add_argument('--ga-model', default='models/gamodel-r50/model,0', help='path to load model.') + parser.add_argument('--gpu', default=1, type=int, help='gpu id') + parser.add_argument('--det', default=0, type=int, help='mtcnn option, 1 means using R+O, 0 means detect from begining') + parser.add_argument('--flip', default=0, type=int, help='whether do lr flip aug') + parser.add_argument('--threshold', default=1.24, type=float, help='ver dist threshold') + parser.add_argument('--imglist', default='trainB/'+str(n)+'_bmold_win3.txt', help='imglist name') + parser.add_argument('--listdir', default='../datasets/list/', help='dir to imglist') + parser.add_argument('--savedir', default='iden_feat', help='dir to save 512feats') + args = parser.parse_args() + model = face_model.FaceModel(args) + if not os.path.isdir(os.path.join(args.listdir,args.imglist)): + imglist = open(os.path.join(args.listdir,args.imglist),'r').read().splitlines() + else: + imglist = glob.glob(args.imglist+'/*/*.png') + imglist = [e for e in imglist if '_input2' not in e and '_render' not in e] + #dirname = os.path.basename(os.path.dirname(imglist[0][:-1])) + print('imglist len',len(imglist)) + #print('dirname',dirname) + #if not os.path.exists(os.path.join(args.savedir,dirname)): + # os.makedirs(os.path.join(args.savedir,dirname)) + t0 = time.time() + for i in range(len(imglist)): + imgname = imglist[i] + ss = imgname.split('/') + dirname = os.path.join(ss[-3],ss[-2]) + if not os.path.exists(os.path.join(args.savedir,dirname)): + os.makedirs(os.path.join(args.savedir,dirname)) + basen = os.path.basename(imgname) + savename = os.path.join(args.savedir,dirname,basen[:-4]+'.npy') + if os.path.exists(savename): + continue + img = cv2.imread(imgname) + img = model.get_input(img) + if type(img) != np.ndarray: + print(imgname,'Not detected') + continue + f1 = model.get_feature(img) + np.save(savename,f1) + if i % 1000 == 1: + print('saved',i,time.time()-t0) + t0 = time.time() + diff --git a/render-to-video/data/__init__.py b/render-to-video/data/__init__.py new file mode 100644 index 0000000..8cb6186 --- /dev/null +++ b/render-to-video/data/__init__.py @@ -0,0 +1,93 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import importlib +import torch.utils.data +from data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt) + dataset = data_loader.load_data() + return dataset + + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + print("dataset [%s] was created" % type(self.dataset).__name__) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=not opt.serial_batches, + num_workers=int(opt.num_threads)) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/render-to-video/data/aligned_feature_multi_dataset.py b/render-to-video/data/aligned_feature_multi_dataset.py new file mode 100644 index 0000000..fc1fc34 --- /dev/null +++ b/render-to-video/data/aligned_feature_multi_dataset.py @@ -0,0 +1,101 @@ +import os.path +from data.base_dataset import BaseDataset, get_params, get_transform +from data.image_folder import make_dataset +from PIL import Image +import numpy as np +import torch + + +class AlignedFeatureMultiDataset(BaseDataset): + """A dataset class for paired image dataset. + + It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. + During test time, you need to prepare a directory '/path/to/data/test'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + + imglistA = 'datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot) + imglistB = 'datasets/list/%s/%s.txt' % (opt.phase+'B', opt.dataroot) + + if not os.path.exists(imglistA) or not os.path.exists(imglistB): + self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory + self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths + else: + self.AB_paths = open(imglistA, 'r').read().splitlines() + self.AB_paths2 = open(imglistB, 'r').read().splitlines() + + assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image + self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc + self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc + self.Nw = self.opt.Nw + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + A (tensor) - - an image in the input domain + B (tensor) - - its corresponding image in the target domain + A_paths (str) - - image paths + B_paths (str) - - image paths (same as A_paths) + """ + # read a image given a random integer index + # by default A and B are from 2 png + AB_path = self.AB_paths[index] + A = Image.open(AB_path).convert('RGB') + AB_path2 = self.AB_paths2[index] + B = Image.open(AB_path2).convert('RGB') + + # apply the same transform to both A and B + transform_params = get_params(self.opt, A.size) + A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1), method=self.opt.resizemethod) + B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1),method=self.opt.resizemethod) + resnet_transform = get_transform(self.opt, transform_params, grayscale=False, resnet=True, method=self.opt.resizemethod) + + imA = A + A = A_transform(A) + B = B_transform(B) + resnet_input = resnet_transform(imA) + + As = torch.zeros((self.input_nc * self.Nw, self.opt.crop_size, self.opt.crop_size)) + As[-self.input_nc:] = A + frame = os.path.basename(AB_path).split('_')[0] + frameno = int(frame[5:]) + for i in range(1,self.Nw): + # read frameno-i frame + path1 = AB_path.replace(frame,'frame%d'%(frameno-i)) + A = Image.open(path1).convert('RGB') + # store in Nw-i's + As[-(i+1)*self.input_nc:-i*self.input_nc] = A_transform(A) + + item = {'A': As, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path2} + item['resnet_input'] = resnet_input + item['index'] = np.array(([index+0.0])).astype(np.float32)[0] + + if self.opt.isTrain or self.opt.test_use_gt: + AB_path2 = AB_path2.replace('_input2','') + ss = AB_path2.split('/') + if self.opt.dataroot != '300vw_win3' and self.opt.dataroot != 'lrwnewrender_win3' and ss[-3] != '19_news': + B_feat = np.load(os.path.join(self.opt.iden_feat_dir,ss[-2],ss[-1][:-4]+'.npy')) + elif self.opt.dataroot == '300vw_win3': + B_feat = np.load(os.path.join(self.opt.iden_feat_dir,ss[-3],ss[-2],ss[-1][:-4]+'.npy')) + elif self.opt.dataroot == 'lrwnewrender_win3': + B_feat = np.load(os.path.join(self.opt.iden_feat_dir,ss[-4],ss[-3],ss[-2],'frame14.npy')) + elif ss[-3] == '19_news': + B_feat = np.load(os.path.join(self.opt.iden_feat_dir,ss[-3],ss[-2],ss[-1][:-4]+'.npy')) + item['B_feat'] = B_feat + + return item + + def __len__(self): + """Return the total number of images in the dataset.""" + return len(self.AB_paths) diff --git a/render-to-video/data/base_dataset.py b/render-to-video/data/base_dataset.py new file mode 100644 index 0000000..7c72db7 --- /dev/null +++ b/render-to-video/data/base_dataset.py @@ -0,0 +1,161 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABCMeta, abstractmethod + + +class BaseDataset(data.Dataset): + __metaclass__ = ABCMeta + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + self.root = opt.dataroot + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.preprocess == 'resize_and_crop': + new_h = new_w = opt.load_size + elif opt.preprocess == 'scale_width_and_crop': + new_w = opt.load_size + new_h = opt.load_size * h // w + + x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) + y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) + + flip = random.random() > 0.5 + + return {'crop_pos': (x, y), 'flip': flip} + + +def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True, resnet=False): + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + if 'resize' in opt.preprocess: + osize = [opt.load_size, opt.load_size] + transform_list.append(transforms.Resize(osize, method)) + elif 'scale_width' in opt.preprocess: + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) + + if 'crop' in opt.preprocess: + if params is None: + transform_list.append(transforms.RandomCrop(opt.crop_size)) + else: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) + + if opt.preprocess == 'none': + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) + + if not opt.no_flip: + if params is None: + transform_list.append(transforms.RandomHorizontalFlip()) + elif params['flip']: + transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) + + if convert: + transform_list += [transforms.ToTensor()] + if grayscale: + transform_list += [transforms.Normalize((0.5,), (0.5,))] + else: + if resnet: + transform_list += [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] + else: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if (h == oh) and (w == ow): + return img + + __print_size_warning(ow, oh, w, h) + return img.resize((w, h), method) + + +def __scale_width(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), method) + + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + if (ow > tw or oh > th): + return img.crop((x1, y1, x1 + tw, y1 + th)) + return img + + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img + + +def __print_size_warning(ow, oh, w, h): + """Print warning information about image size(only print once)""" + if not hasattr(__print_size_warning, 'has_printed'): + print("The image size needs to be a multiple of 4. " + "The loaded image size was (%d, %d), so it was adjusted to " + "(%d, %d). This adjustment will be done to all images " + "whose sizes are not multiples of 4" % (ow, oh, w, h)) + __print_size_warning.has_printed = True diff --git a/render-to-video/data/image_folder.py b/render-to-video/data/image_folder.py new file mode 100644 index 0000000..a9cea74 --- /dev/null +++ b/render-to-video/data/image_folder.py @@ -0,0 +1,66 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/render-to-video/data/single_multi_dataset.py b/render-to-video/data/single_multi_dataset.py new file mode 100644 index 0000000..1f71adb --- /dev/null +++ b/render-to-video/data/single_multi_dataset.py @@ -0,0 +1,68 @@ +from data.base_dataset import BaseDataset, get_params, get_transform +from data.image_folder import make_dataset +from PIL import Image +import os.path +import torch + + +class SingleMultiDataset(BaseDataset): + """This dataset class can load a set of images specified by the path --dataroot /path/to/data. + + It can be used for generating CycleGAN results only for one side with the model option '-model test'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + imglistA = 'datasets/list/%s/%s.txt' % (opt.phase+'Single', opt.dataroot) + if not os.path.exists(imglistA): + self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) + else: + self.A_paths = open(imglistA, 'r').read().splitlines() + self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc + self.Nw = self.opt.Nw + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns a dictionary that contains A and A_paths + A(tensor) - - an image in one domain + A_paths(str) - - the path of the image + """ + A_path = self.A_paths[index] + A_img = Image.open(A_path).convert('RGB') + + # apply the same transform to both A and resnet_input + transform_params = get_params(self.opt, A_img.size) + A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1), method=self.opt.resizemethod) + A = A_transform(A_img) + As = torch.zeros((self.input_nc * self.Nw, self.opt.crop_size, self.opt.crop_size)) + As[-self.input_nc:] = A + frame = os.path.basename(A_path).split('_')[0] + ext = os.path.basename(A_path).split('_')[1] + frameno = int(frame) + for i in range(1,self.Nw): + # read frameno-i frame + path1 = A_path.replace(frame+'_blend','%05d_blend'%(frameno-i)) + A = Image.open(path1).convert('RGB') + # store in Nw-i's + As[-(i+1)*self.input_nc:-i*self.input_nc] = A_transform(A) + item = {'A': As, 'A_paths': A_path} + + if self.opt.use_memory: + resnet_transform = get_transform(self.opt, transform_params, grayscale=False, resnet=True, method=self.opt.resizemethod) + resnet_input = resnet_transform(A_img) + item['resnet_input'] = resnet_input + + return item + + def __len__(self): + """Return the total number of images in the dataset.""" + return len(self.A_paths) diff --git a/render-to-video/models/__init__.py b/render-to-video/models/__init__.py new file mode 100644 index 0000000..fc01113 --- /dev/null +++ b/render-to-video/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/render-to-video/models/base_model.py b/render-to-video/models/base_model.py new file mode 100644 index 0000000..928f126 --- /dev/null +++ b/render-to-video/models/base_model.py @@ -0,0 +1,248 @@ +import os +import torch +from collections import OrderedDict +from abc import ABCMeta, abstractmethod +from . import networks + + +class BaseModel(): + __metaclass__ = ABCMeta + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix) + self.print_networks(opt.verbose) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if name != 'mem': + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + else: + torch.save({'mem_model' : net.cpu().state_dict(), + 'mem_key' : net.spatial_key.cpu(), + 'mem_value' : net.color_value.cpu(), + 'mem_age' : net.age.cpu(), + 'mem_index' : net.top_index.cpu()}, save_path) + net.cuda(self.gpu_ids[0]) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + if not os.path.exists(load_path): + load_path = os.path.join(self.opt.checkpoints_dir, self.opt.name.split('/')[0],load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + if name != 'mem': + net.load_state_dict(state_dict) + else: + net.load_state_dict(state_dict['mem_model']) + net.spatial_key = state_dict['mem_key'] + net.color_value = state_dict['mem_value'] + net.age = state_dict['mem_age'] + net.top_index = state_dict['mem_index'] + #print(net.spatial_key) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/render-to-video/models/memory_network.py b/render-to-video/models/memory_network.py new file mode 100644 index 0000000..bb04383 --- /dev/null +++ b/render-to-video/models/memory_network.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import numpy as np +from torchvision import models +import pdb + +class Memory_Network(nn.Module): + + def __init__(self, mem_size, color_feat_dim = 512, spatial_feat_dim = 512, top_k = 256, alpha = 0.1, age_noise = 4.0, gpu_ids = []): + + super(Memory_Network, self).__init__() + #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device('cuda:{}'.format(gpu_ids[0])) if gpu_ids else torch.device('cpu') + self.ResNet18 = ResNet18().to(self.device) + self.ResNet18 = self.ResNet18.eval() + self.mem_size = mem_size + self.color_feat_dim = color_feat_dim + self.spatial_feat_dim = spatial_feat_dim + self.alpha = alpha + self.age_noise = age_noise + self.top_k = top_k + + ## Each color_value is probability distribution + self.color_value = F.normalize(random_uniform((self.mem_size, self.color_feat_dim), 0, 0.01), p = 1, dim=1).to(self.device) + + self.spatial_key = F.normalize(random_uniform((self.mem_size, self.spatial_feat_dim), -0.01, 0.01), dim=1).to(self.device) + self.age = torch.zeros(self.mem_size).to(self.device) + + self.top_index = torch.zeros(self.mem_size).to(self.device) + self.top_index = self.top_index - 1.0 + + self.color_value.requires_grad = False + self.spatial_key.requires_grad = False + + self.Linear = nn.Linear(512, spatial_feat_dim) + self.body = [self.ResNet18, self.Linear] + self.body = nn.Sequential(*self.body) + self.body = self.body.to(self.device) + + def forward(self, x): + q = self.body(x) + q = F.normalize(q, dim = 1) + return q + + def unsupervised_loss(self, query, color_feat, color_thres): + + bs = query.size()[0] + cosine_score = torch.matmul(query, torch.t(self.spatial_key)) + + top_k_score, top_k_index = torch.topk(cosine_score, self.top_k, 1) + + ### For unsupervised training + color_value_expand = torch.unsqueeze(torch.t(self.color_value), 0) + color_value_expand = torch.cat([color_value_expand[:,:,idx] for idx in top_k_index], dim = 0) + + color_feat_expand = torch.unsqueeze(color_feat, 2) + color_feat_expand = torch.cat([color_feat_expand for _ in range(self.top_k)], dim = 2) + + #color_similarity = self.KL_divergence(color_value_expand, color_feat_expand, 1) + color_similarity = torch.sum(torch.mul(color_value_expand, color_feat_expand),dim=1) + + #loss_mask = color_similarity < color_thres + loss_mask = color_similarity > color_thres + loss_mask = loss_mask.float() + + pos_score, pos_index = torch.topk(torch.mul(top_k_score, loss_mask), 1, dim = 1) + neg_score, neg_index = torch.topk(torch.mul(top_k_score, 1 - loss_mask), 1, dim = 1) + + loss = self._unsupervised_loss(pos_score, neg_score) + + return loss + + + def memory_update(self, query, color_feat, color_thres, top_index): + + cosine_score = torch.matmul(query, torch.t(self.spatial_key)) + top1_score, top1_index = torch.topk(cosine_score, 1, dim = 1) + top1_index = top1_index[:, 0] + top1_feature = self.spatial_key[top1_index] + top1_color_value = self.color_value[top1_index] + + #color_similarity1 = self.KL_divergence(top1_color_value, color_feat, 1) + color_similarity = torch.sum(torch.mul(top1_color_value, color_feat),dim=1) + + #memory_mask = color_similarity < color_thres + memory_mask = color_similarity > color_thres + self.age = self.age + 1.0 + + ## Case 1 update + case_index = top1_index[memory_mask] + self.spatial_key[case_index] = F.normalize(self.spatial_key[case_index] + query[memory_mask], dim = 1) + self.age[case_index] = 0.0 + #if torch.sum(memory_mask).cpu().numpy()==1: + # print(top_index,'update',self.top_index[case_index],color_similarity) + + ## Case 2 replace + memory_mask = 1.0 - memory_mask + case_index = top1_index[memory_mask] + + random_noise = random_uniform((self.mem_size, 1), -self.age_noise, self.age_noise)[:, 0] + random_noise = random_noise.to(self.device) + age_with_noise = self.age + random_noise + old_values, old_index = torch.topk(age_with_noise, len(case_index), dim=0) + + self.spatial_key[old_index] = query[memory_mask] + self.color_value[old_index] = color_feat[memory_mask] + #if torch.sum(memory_mask).cpu().numpy()==1: + # print(top_index[memory_mask],'replace',self.top_index[old_index],color_similarity) + #pdb.set_trace() + self.top_index[old_index] = top_index[memory_mask] + self.age[old_index] = 0.0 + + return torch.sum(memory_mask).cpu().numpy()==1 # for batch size 1, return number of replace + + def topk_feature(self, query, top_k = 1): + _bs = query.size()[0] + cosine_score = torch.matmul(query, torch.t(self.spatial_key)) + topk_score, topk_index = torch.topk(cosine_score, top_k, dim = 1) + + topk_feat = torch.cat([torch.unsqueeze(self.color_value[topk_index[i], :], dim = 0) for i in range(_bs)], dim = 0) + topk_idx = torch.cat([torch.unsqueeze(self.top_index[topk_index[i]], dim = 0) for i in range(_bs)], dim = 0) + + return topk_feat, topk_idx, topk_index + + def get_feature(self, k, _bs): + feat = torch.cat([torch.unsqueeze(self.color_value[k, :], dim = 0) for i in range(_bs)], dim = 0) + return feat, self.top_index[k] + + + def KL_divergence(self, a, b, dim, eps = 1e-8): + + b = b + eps + log_val = torch.log10(torch.div(a, b)) + kl_div = torch.mul(a, log_val) + kl_div = torch.sum(kl_div, dim = dim) + + return kl_div + + + def _unsupervised_loss(self, pos_score, neg_score): + + hinge = torch.clamp(neg_score - pos_score + self.alpha, min = 0.0) + loss = torch.mean(hinge) + + return loss + + +def random_uniform(shape, low, high): + x = torch.rand(*shape) + result = (high - low) * x + low + + return result + + +class ResNet18(nn.Module): + def __init__(self, pre_trained = True, require_grad = False): + super(ResNet18, self).__init__() + self.model = models.resnet18(pretrained = True) + + self.body = [layers for layers in self.model.children()] + self.body.pop(-1) + + self.body = nn.Sequential(*self.body) + + if not require_grad: + for parameter in self.parameters(): + parameter.requires_grad = False + + def forward(self, x): + x = self.body(x) + x = x.view(-1, 512) + return x \ No newline at end of file diff --git a/render-to-video/models/memory_seq_model.py b/render-to-video/models/memory_seq_model.py new file mode 100644 index 0000000..0f08993 --- /dev/null +++ b/render-to-video/models/memory_seq_model.py @@ -0,0 +1,199 @@ +import torch +from .base_model import BaseModel +from . import networks +from .memory_network import Memory_Network +import pdb + +class MemorySeqModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) + parser.set_defaults(norm='batch', netG='unetac_adain_256', dataset_mode='aligned_feature_multi', direction='AtoB',Nw=3) + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla') + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + parser.add_argument('--lambda_mask', type=float, default=0.1, help='lambda mask loss') + parser.add_argument('--lambda_mask_smooth', type=float, default=1e-5, help='lambda mask smooth loss') + else: + parser.add_argument('--test_use_gt', type=int, default=0, help='use gt feature in test') + parser.add_argument('--attention', type=int, default=1, help='whether to use attention mechanism') + parser.add_argument('--do_saturate_mask', action="store_true", default=False, help='do use mask_fake for mask_cyc') + # for memory net + parser.add_argument("--iden_feat_dim", type = int, default = 512) + parser.add_argument("--spatial_feat_dim", type = int, default = 512) + parser.add_argument("--mem_size", type = int, default = 30000)#982=819*1.2 + parser.add_argument("--alpha", type = float, default = 0.3) + parser.add_argument("--top_k", type = int, default = 256) + parser.add_argument("--iden_thres", type = float, default = 0.98)#0.5) + parser.add_argument("--iden_feat_dir", type = str, default = 'arcface/iden_feat/') + + + return parser + + def __init__(self, opt): + """Initialize the pix2pix class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + if self.isTrain: + self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake', 'mem'] + if self.opt.attention: + self.loss_names += ['G_Att', 'G_Att_smooth', 'G'] + # specify the images you want to save/display. The training/test scripts will call + self.visual_names = ['real_A', 'fake_B', 'real_B'] + if self.opt.Nw > 1: + self.visual_names = ['real_A_0', 'real_A_1', 'real_A_2', 'fake_B', 'real_B'] + if self.opt.attention: + self.visual_names += ['fake_B_img', 'fake_B_mask_vis'] + # specify the models you want to save to the disk. The training/test scripts will call and + if self.isTrain: + self.model_names = ['G', 'D', 'mem'] + else: # during test time, only load G and mem + self.model_names = ['G', 'mem'] + # define networks (both generator and discriminator) + self.netG = networks.define_G(opt.input_nc*opt.Nw, opt.output_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, feat_dim=opt.iden_feat_dim) + + self.netmem = Memory_Network(mem_size = opt.mem_size, color_feat_dim = opt.iden_feat_dim, spatial_feat_dim = opt.spatial_feat_dim, top_k = opt.top_k, alpha = opt.alpha, gpu_ids = self.gpu_ids).to(self.device) + + if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + self.netD = networks.define_D(opt.input_nc*opt.Nw + opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + # initialize optimizers; schedulers will be automatically created by function . + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_mem = torch.optim.Adam(self.netmem.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + self.optimizers.append(self.optimizer_mem) + self.replace = 0 + self.update = 0 + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + AtoB = self.opt.direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) # channel is input_nc * Nw + self.real_B = input['B' if AtoB else 'A'].to(self.device) # channel is output_nc + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + # for memory net + if self.isTrain or self.opt.test_use_gt: + self.real_B_feat = input['B_feat' if AtoB else 'A_feat'].to(self.device) + self.resnet_input = input['resnet_input'].to(self.device) + self.idx = input['index'].to(self.device) + + def forward(self): + """Run forward pass; called by both functions and .""" + if self.opt.attention: + if self.isTrain or self.opt.test_use_gt: + self.fake_B_img, self.fake_B_mask = self.netG(self.real_A, self.real_B_feat) + else: + query = self.netmem(self.resnet_input) + #pdb.set_trace() + top1_feature, top1_index, topk_index = self.netmem.topk_feature(query, 1) + top1_feature = top1_feature[:, 0, :] + self.fake_B_img, self.fake_B_mask = self.netG(self.real_A, top1_feature) + self.fake_B_mask = self._do_if_necessary_saturate_mask(self.fake_B_mask, saturate=self.opt.do_saturate_mask) + self.fake_B = self.fake_B_mask * self.real_A[:,-self.opt.input_nc:] + (1 - self.fake_B_mask) * self.fake_B_img + #print(torch.min(self.fake_B_mask), torch.max(self.fake_B_mask)) + self.fake_B_mask_vis = self.fake_B_mask * 2 - 1 + else: + if self.isTrain or self.opt.test_use_gt: + self.fake_B = self.netG(self.real_A, self.real_B_feat) + else: + query = self.netmem(self.resnet_input) + top1_feature, _, _ = self.netmem.topk_feature(query, 1) + top1_feature = top1_feature[:, 0, :] + self.fake_B = self.netG(self.real_A, top1_feature) + if self.opt.Nw > 1: + self.real_A_0 = self.real_A[:,:self.opt.input_nc] + self.real_A_1 = self.real_A[:,self.opt.input_nc:2*self.opt.input_nc] + self.real_A_2 = self.real_A[:,-self.opt.input_nc:] + + def backward_mem(self): + #print(self.image_paths) + resnet_feature = self.netmem(self.resnet_input) + self.loss_mem = self.netmem.unsupervised_loss(resnet_feature, self.real_B_feat, self.opt.iden_thres) + self.loss_mem.backward() + + def update_mem(self): + with torch.no_grad(): + resnet_feature = self.netmem(self.resnet_input) + replace = self.netmem.memory_update(resnet_feature, self.real_B_feat, self.opt.iden_thres, self.idx) + if replace: + self.replace += 1 + else: + self.update += 1 + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake_B + fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + self.loss_D.backward() + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD(fake_AB) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + # Loss for attention mask + if self.opt.attention: + # the attention mask can easily saturate to 1, which makes that generator has no effect + self.loss_G_Att = torch.mean(self.fake_B_mask) * self.opt.lambda_mask + # to enforce smooth spatial color transformation + self.loss_G_Att_smooth = self._compute_loss_smooth(self.fake_B_mask) * self.opt.lambda_mask_smooth + # combine loss and calculate gradients + self.loss_G = self.loss_G_GAN + self.loss_G_L1 + if self.opt.attention: + self.loss_G += self.loss_G_Att + self.loss_G_Att_smooth + self.loss_G.backward() + + def optimize_parameters(self): + # update mem + self.optimizer_mem.zero_grad() + self.backward_mem() + self.optimizer_mem.step() + self.update_mem() + + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + # update G + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_G.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_G.step() # udpate G's weights + + def _compute_loss_smooth(self, mat): + return torch.sum(torch.abs(mat[:, :, :, :-1] - mat[:, :, :, 1:])) + \ + torch.sum(torch.abs(mat[:, :, :-1, :] - mat[:, :, 1:, :])) + + def _do_if_necessary_saturate_mask(self, m, saturate=False): + return torch.clamp(0.55*torch.tanh(3*(m-0.5))+0.5, 0, 1) if saturate else m diff --git a/render-to-video/models/networks.py b/render-to-video/models/networks.py new file mode 100644 index 0000000..08473d0 --- /dev/null +++ b/render-to-video/models/networks.py @@ -0,0 +1,1199 @@ +# -*- coding:utf-8 -*- +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler + + +############################################################################### +# Helper Functions +############################################################################### + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + norm_layer = lambda x: Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + #print(classname, hasattr(m, 'weight'),) + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], feat_dim=512): + """Create a generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 + norm (str) -- the name of normalization layers used in the network: batch | instance | none + use_dropout (bool) -- if use dropout layers. + init_type (str) -- the name of our initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a generator + + Our current implementation provides two types of generators: + U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) + The original U-Net paper: https://arxiv.org/abs/1505.04597 + + Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) + Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). + + + The generator has been initialized by . It uses RELU for non-linearity. + """ + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netG == 'resnet_9blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) + elif netG == 'resnet_6blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) + elif netG == 'unet_128': + net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_256': + net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unetac_256': + net = UnetACGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unetac_adain_256': + net = unet_generator_ac_adain(input_nc, output_nc, ngf, feat_dim) + elif netG == 'unet_adain_256': + net = unet_generator_adain(input_nc, output_nc, ngf, feat_dim) + elif netG == 'unetr_256': + net = UnetGeneratorRefine(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % netG) + return init_net(net, init_type, init_gain, gpu_ids) + + +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): + """Create a discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the first conv layer + netD (str) -- the architecture's name: basic | n_layers | pixel + n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' + norm (str) -- the type of normalization layers used in the network. + init_type (str) -- the name of the initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a discriminator + + Our current implementation provides three types of discriminators: + [basic]: 'PatchGAN' classifier described in the original pix2pix paper. + It can classify whether 70×70 overlapping patches are real or fake. + Such a patch-level discriminator architecture has fewer parameters + than a full-image discriminator and can work on arbitrarily-sized images + in a fully convolutional fashion. + + [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator + with the parameter (default=3 as used in [basic] (PatchGAN).) + + [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. + It encourages greater color diversity but has no effect on spatial statistics. + + The discriminator has been initialized by . It uses Leakly RELU for non-linearity. + """ + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netD == 'basic': # default PatchGAN classifier + net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + elif netD == 'n_layers': # more options + net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + elif netD == 'pixel': # classify if each pixel is real or fake + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) + return init_net(net, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## +class GANLoss(nn.Module): + """Define different GAN objectives. + + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + """Create label tensors with the same size as the input. + + Parameters: + prediction (tensor) - - tpyically the prediction from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + A label tensor filled with ground truth label, and with the size of the input + """ + + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + """Calculate loss given Discriminator's output and grount truth labels. + + Parameters: + prediction (tensor) - - tpyically the prediction output from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + the calculated loss. + """ + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss + + +def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): + """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 + + Arguments: + netD (network) -- discriminator network + real_data (tensor array) -- real images + fake_data (tensor array) -- generated images from the generator + device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + type (str) -- if we mix real and fake data or not [real | fake | mixed]. + constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 + lambda_gp (float) -- weight for this loss + + Returns the gradient penalty loss + """ + if lambda_gp > 0.0: + if type == 'real': # either use real images, fake images, or a linear interpolation of two. + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1, device=device) + alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True) + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + +class ResnetGenerator(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Initialize the Resnet block + + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Construct a convolutional block. + + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + +class UnetACGenerator(nn.Module): + """Create a Unet-based generator""" + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetACGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionACBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + img, mask = self.model(input) + return img, mask + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + user_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + +class UnetSkipConnectionACBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UnetSkipConnectionACBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + + # assume outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, inner_nc, + kernel_size=4, stride=2, + padding=1) + upnorm = norm_layer(inner_nc) + down = [downconv] + up = [uprelu, upconv, upnorm, uprelu] + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + layers = [] + layers.append(nn.Conv2d(inner_nc, 3, kernel_size=7, stride=1, padding=3)) + layers.append(nn.Tanh()) + self.img_reg = nn.Sequential(*layers) + + layers = [] + layers.append(nn.Conv2d(inner_nc, 1, kernel_size=7, stride=1, padding=3)) + layers.append(nn.Sigmoid()) + self.attention_reg = nn.Sequential(*layers) + + def forward(self, x): + features = self.model(x) + return self.img_reg(features), self.attention_reg(features) + +################################################################################################################ +# Unet with AdaIN +################################################################################################################ +class unet_generator_ac_adain(nn.Module): + + def __init__(self, input_nc, output_nc, ngf, feat_dim = 512): + super(unet_generator_ac_adain, self).__init__() + + self.e1 = nn.Conv2d(input_nc, ngf, 4, 2, 1) + self.e2 = unet_encoder_block(ngf, ngf * 2) + self.e3 = unet_encoder_block(ngf * 2, ngf * 4) + self.e4 = unet_encoder_block(ngf * 4, ngf * 8) + self.e5 = unet_encoder_block(ngf * 8, ngf * 8) + self.e6 = unet_encoder_block(ngf * 8, ngf * 8) + self.e7 = unet_encoder_block(ngf * 8, ngf * 8) + self.e8 = unet_encoder_block(ngf * 8, ngf * 8, norm = None) + + self.d1 = unet_decoder_block(ngf * 8, ngf * 8) + self.d2 = unet_decoder_block(ngf * 8 * 2, ngf * 8) + self.d3 = unet_decoder_block(ngf * 8 * 2, ngf * 8) + self.d4 = unet_decoder_block(ngf * 8 * 2, ngf * 8, drop_out = None) + self.d5 = unet_decoder_block(ngf * 8 * 2, ngf * 4, drop_out = None) + self.d6 = unet_decoder_block(ngf * 4 * 2, ngf * 2, drop_out = None) + self.d7 = unet_decoder_block(ngf * 2 * 2, ngf, drop_out = None) + self.d8 = unet_decoder_ac_block(ngf * 2, output_nc, norm = None, drop_out = None) + + self.layers = [self.e1, self.e2, self.e3, self.e4, self.e5, self.e6, self.e7, self.e8, + self.d1, self.d2, self.d3, self.d4, self.d5, self.d6, self.d7, self.d8] + + self.mlp = MLP(feat_dim, self.get_num_adain_params(self.layers), self.get_num_adain_params(self.layers), 3) + + + def forward(self, x, feat): + + ### AdaIn params + adain_params = self.mlp(feat) + self.assign_adain_params(adain_params, self.layers) + + ### Encoder + e1 = self.e1(x) + e2 = self.e2(e1) + e3 = self.e3(e2) + e4 = self.e4(e3) + e5 = self.e5(e4) + e6 = self.e6(e5) + e7 = self.e7(e6) + e8 = self.e8(e7) + + ### Decoder + d1_ = self.d1(e8) + d1 = torch.cat([d1_, e7], dim = 1) + + d2_ = self.d2(d1) + d2 = torch.cat([d2_, e6], dim = 1) + + d3_ = self.d3(d2) + d3 = torch.cat([d3_, e5], dim = 1) + + d4_ = self.d4(d3) + d4 = torch.cat([d4_, e4], dim = 1) + + d5_ = self.d5(d4) + d5 = torch.cat([d5_, e3], dim = 1) + + d6_ = self.d6(d5) + d6 = torch.cat([d6_, e2], dim = 1) + + d7_ = self.d7(d6) + d7 = torch.cat([d7_, e1], dim = 1) + + color, attention = self.d8(d7) + + return color, attention + + def get_num_adain_params(self, _module): + # return the number of AdaIN parameters needed by the model + num_adain_params = 0 + for model in _module: + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + num_adain_params += 2*m.num_features + return num_adain_params + + def assign_adain_params(self, adain_params, _module): + # assign the adain_params to the AdaIN layers in model + for model in _module: + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + mean = adain_params[:, :m.num_features] + std = adain_params[:, m.num_features:2*m.num_features] + m.bias = mean.contiguous().view(-1) + m.weight = std.contiguous().view(-1) + if adain_params.size(1) > 2*m.num_features: + adain_params = adain_params[:, 2*m.num_features:] + +class unet_generator_adain(nn.Module): + + def __init__(self, input_nc, output_nc, ngf, feat_dim = 512): + super(unet_generator_adain, self).__init__() + + self.e1 = nn.Conv2d(input_nc, ngf, 4, 2, 1) + self.e2 = unet_encoder_block(ngf, ngf * 2) + self.e3 = unet_encoder_block(ngf * 2, ngf * 4) + self.e4 = unet_encoder_block(ngf * 4, ngf * 8) + self.e5 = unet_encoder_block(ngf * 8, ngf * 8) + self.e6 = unet_encoder_block(ngf * 8, ngf * 8) + self.e7 = unet_encoder_block(ngf * 8, ngf * 8) + self.e8 = unet_encoder_block(ngf * 8, ngf * 8, norm = None) + + self.d1 = unet_decoder_block(ngf * 8, ngf * 8) + self.d2 = unet_decoder_block(ngf * 8 * 2, ngf * 8) + self.d3 = unet_decoder_block(ngf * 8 * 2, ngf * 8) + self.d4 = unet_decoder_block(ngf * 8 * 2, ngf * 8, drop_out = None) + self.d5 = unet_decoder_block(ngf * 8 * 2, ngf * 4, drop_out = None) + self.d6 = unet_decoder_block(ngf * 4 * 2, ngf * 2, drop_out = None) + self.d7 = unet_decoder_block(ngf * 2 * 2, ngf, drop_out = None) + self.d8 = unet_decoder_block(ngf * 2, output_nc, norm = None, drop_out = None) + self.tanh = nn.Tanh() + + self.layers = [self.e1, self.e2, self.e3, self.e4, self.e5, self.e6, self.e7, self.e8, + self.d1, self.d2, self.d3, self.d4, self.d5, self.d6, self.d7, self.d8] + + self.mlp = MLP(feat_dim, self.get_num_adain_params(self.layers), self.get_num_adain_params(self.layers), 3) + + + def forward(self, x, feat): + + ### AdaIn params + adain_params = self.mlp(feat) + self.assign_adain_params(adain_params, self.layers) + + ### Encoder + e1 = self.e1(x) + e2 = self.e2(e1) + e3 = self.e3(e2) + e4 = self.e4(e3) + e5 = self.e5(e4) + e6 = self.e6(e5) + e7 = self.e7(e6) + e8 = self.e8(e7) + + ### Decoder + d1_ = self.d1(e8) + d1 = torch.cat([d1_, e7], dim = 1) + + d2_ = self.d2(d1) + d2 = torch.cat([d2_, e6], dim = 1) + + d3_ = self.d3(d2) + d3 = torch.cat([d3_, e5], dim = 1) + + d4_ = self.d4(d3) + d4 = torch.cat([d4_, e4], dim = 1) + + d5_ = self.d5(d4) + d5 = torch.cat([d5_, e3], dim = 1) + + d6_ = self.d6(d5) + d6 = torch.cat([d6_, e2], dim = 1) + + d7_ = self.d7(d6) + d7 = torch.cat([d7_, e1], dim = 1) + + d8 = self.d8(d7) + + output = self.tanh(d8) + + return output + + def get_num_adain_params(self, _module): + # return the number of AdaIN parameters needed by the model + num_adain_params = 0 + for model in _module: + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + num_adain_params += 2*m.num_features + return num_adain_params + + def assign_adain_params(self, adain_params, _module): + # assign the adain_params to the AdaIN layers in model + for model in _module: + for m in model.modules(): + if m.__class__.__name__ == "AdaptiveInstanceNorm2d": + mean = adain_params[:, :m.num_features] + std = adain_params[:, m.num_features:2*m.num_features] + m.bias = mean.contiguous().view(-1) + m.weight = std.contiguous().view(-1) + if adain_params.size(1) > 2*m.num_features: + adain_params = adain_params[:, 2*m.num_features:] + +class unet_encoder_block(nn.Module): + + def __init__(self, input_nc, output_nc, ks = 4, stride = 2, padding = 1, norm = 'adain', act = nn.LeakyReLU(inplace = True, negative_slope = 0.2)): + super(unet_encoder_block, self).__init__() + self.conv = nn.Conv2d(input_nc, output_nc, ks, stride, padding) + m = [act, self.conv] + + if norm == 'adain': + m.append(AdaptiveInstanceNorm2d(output_nc)) + + self.body = nn.Sequential(*m) + + def forward(self, x): + return self.body(x) + +class unet_decoder_block(nn.Module): + + def __init__(self, input_nc, output_nc, ks = 4, stride = 2, padding = 1, norm = 'adain', act = nn.ReLU(inplace = True), drop_out = 0.5): + super(unet_decoder_block, self).__init__() + self.deconv = nn.ConvTranspose2d(input_nc, output_nc, ks, stride, padding) + m = [act, self.deconv] + + if norm == 'adain': + m.append(AdaptiveInstanceNorm2d(output_nc)) + + if drop_out is not None: + m.append(nn.Dropout(drop_out)) + + self.body = nn.Sequential(*m) + + def forward(self, x): + return self.body(x) + +class unet_decoder_ac_block(nn.Module): + + def __init__(self, input_nc, output_nc, ks = 4, stride = 2, padding = 1, norm = 'adain', act = nn.ReLU(inplace = True), drop_out = 0.5): + super(unet_decoder_ac_block, self).__init__() + self.deconv = nn.ConvTranspose2d(input_nc, int(input_nc/2), ks, stride, padding) + m = [act, self.deconv, AdaptiveInstanceNorm2d(int(input_nc/2)), act] + + self.body = nn.Sequential(*m) + + layers = [] + layers.append(nn.Conv2d(int(input_nc/2), output_nc, kernel_size=7, stride=1, padding=3)) + layers.append(nn.Tanh()) + self.img_reg = nn.Sequential(*layers) + + layers = [] + layers.append(nn.Conv2d(int(input_nc/2), 1, kernel_size=7, stride=1, padding=3)) + layers.append(nn.Sigmoid()) + self.attention_reg = nn.Sequential(*layers) + + def forward(self, x): + features = self.body(x) + return self.img_reg(features), self.attention_reg(features) + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, dim, n_blk, act = nn.ReLU(inplace = True)): + + super(MLP, self).__init__() + self.model = [] + + self.model.append(nn.Linear(input_dim, dim)) + self.model.append(act) + + for i in range(n_blk - 2): + self.model.append(nn.Linear(dim, dim)) + self.model.append(act) + + self.model.append(nn.Linear(dim, output_dim)) + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + + +class AdaptiveInstanceNorm2d(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super(AdaptiveInstanceNorm2d, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + # weight and bias are dynamically assigned + self.weight = None + self.bias = None + # just dummy buffers, not used + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + def forward(self, x): + assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" + b, c = x.size(0), x.size(1) + running_mean = self.running_mean.repeat(b) + running_var = self.running_var.repeat(b) + + # Apply instance norm + x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) + + out = nn.functional.batch_norm( + x_reshaped, running_mean, running_var, self.weight, self.bias, + True, self.momentum, self.eps) + + return out.view(b, c, *x.size()[2:]) + + def __repr__(self): + return self.__class__.__name__ + '(' + str(self.num_features) + ')' +################################################################################################################ + +class UnetGeneratorRefine(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGeneratorRefine, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionRefineBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionRefineBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionRefineBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionRefineAddinputBlock(ngf * 2, ngf * 4, input_nc=None, dadd_nc=input_nc, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionRefineAddinputBlock(ngf, ngf * 2, input_nc=None, dadd_nc=input_nc, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionRefineAddinputBlock(output_nc, ngf, input_nc=input_nc, dadd_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + im = nn.Upsample(scale_factor=[0.5,0.5],mode='bilinear')(input) + #print im.shape, input.shape + return self.model(input, im) + + +class UnetSkipConnectionRefineBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + user_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionRefineBlock, self).__init__() + self.outermost = outermost + self.innermost = innermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + refine = nn.Conv2d(outer_nc, outer_nc, + kernel_size=3, stride=1, + padding=1, bias=use_bias) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm, uprelu, refine, upnorm, uprelu, refine, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + if use_dropout: + up = [uprelu, upconv, upnorm, nn.Dropout(0.5), uprelu, refine, upnorm, nn.Dropout(0.5), uprelu, refine, upnorm, nn.Dropout(0.5)] + else: + up = [uprelu, upconv, upnorm, uprelu, refine, upnorm, uprelu, refine, upnorm] + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + self.down = nn.Sequential(*down) + self.submodule = submodule + self.up = nn.Sequential(*up) + + def forward(self, x): + d = self.down(x) + if not self.innermost: + d = self.submodule(d) + u = self.up(d) + if self.outermost: + return u + else: + return torch.cat([x, u], 1) + +class UnetSkipConnectionRefineAddinputBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, dadd_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + user_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionRefineAddinputBlock, self).__init__() + self.outermost = outermost + self.innermost = innermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + refine = nn.Conv2d(outer_nc, outer_nc, + kernel_size=3, stride=1, + padding=1, bias=use_bias) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2 + dadd_nc, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc + dadd_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm, uprelu, refine, upnorm, uprelu, refine, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2 + dadd_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + if use_dropout: + up = [uprelu, upconv, upnorm, nn.Dropout(0.5), uprelu, refine, upnorm, nn.Dropout(0.5), uprelu, refine, upnorm, nn.Dropout(0.5)] + else: + up = [uprelu, upconv, upnorm, uprelu, refine, upnorm, uprelu, refine, upnorm] + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + self.down = nn.Sequential(*down) + self.submodule = submodule + self.up = nn.Sequential(*up) + + def forward(self, x, im): + d = self.down(x) + #print x.shape, d.shape, type(self) + if not self.innermost: + if isinstance(self.submodule,UnetSkipConnectionRefineBlock): + d = self.submodule(d) + else: + im2 = nn.Upsample(scale_factor=[0.5,0.5],mode='bilinear')(im) + #print im2.shape, im.shape + d = self.submodule(d, im2) + #print d.shape + u = self.up(torch.cat([im, d], 1)) + if self.outermost: + return u + else: + return torch.cat([x, u], 1) + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) diff --git a/render-to-video/models/test_model.py b/render-to-video/models/test_model.py new file mode 100644 index 0000000..f4fb09e --- /dev/null +++ b/render-to-video/models/test_model.py @@ -0,0 +1,145 @@ +from .base_model import BaseModel +from . import networks +from .memory_network import Memory_Network + + +class TestModel(BaseModel): + """ This TesteModel can be used to generate CycleGAN results for only one direction. + This model will automatically set '--dataset_mode single', which only loads the images from one collection. + + See the test instruction for more details. + """ + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + + The model can only be used during test time. It requires '--dataset_mode single'. + You need to specify the network using the option '--model_suffix'. + """ + assert not is_train, 'TestModel cannot be used during training time' + parser.set_defaults(dataset_mode='single') + parser.set_defaults(norm='batch') + parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.') + parser.add_argument('--attention', type=int, default=0, help='whether to use attention mechanism') + parser.add_argument('--do_saturate_mask', action="store_true", default=False, help='do use mask_fake for mask_cyc') + # for memory net + parser.add_argument("--use_memory", type = int, default = 0) + parser.add_argument("--iden_feat_dim", type = int, default = 512) + parser.add_argument("--spatial_feat_dim", type = int, default = 512) + parser.add_argument("--mem_size", type = int, default = 30000)#982=819*1.2 + parser.add_argument("--alpha", type = float, default = 0.3) + parser.add_argument("--top_k", type = int, default = 256) + parser.add_argument("--iden_thres", type = float, default = 0.5) + parser.add_argument("--save2", type = int, default = 0) + parser.add_argument("--fixindex", type = int, default = -1) + + return parser + + def __init__(self, opt): + """Initialize the pix2pix class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + assert(not opt.isTrain) + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + self.loss_names = [] + # specify the images you want to save/display. The training/test scripts will call + self.visual_names = ['real', 'fake'] + if self.opt.Nw > 1: + self.visual_names = ['real_A_0', 'real_A_1', 'real_A_2', 'fake'] + if self.opt.attention: + self.visual_names += ['fake_B_img', 'fake_B_mask_vis'] + if self.opt.save2: + if self.opt.Nw == 1: + self.visual_names = ['real', 'fake'] + else: + self.visual_names = ['real_A_0', 'real_A_1', 'real_A_2', 'fake'] + # specify the models you want to save to the disk. The training/test scripts will call and + self.model_names = ['G' + opt.model_suffix] # only generator is needed. + self.netG = networks.define_G(opt.input_nc*opt.Nw, opt.output_nc, opt.ngf, opt.netG, + opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.opt.use_memory: + self.netmem = Memory_Network(mem_size = opt.mem_size, color_feat_dim = opt.iden_feat_dim, spatial_feat_dim = opt.spatial_feat_dim, top_k = opt.top_k, alpha = opt.alpha, gpu_ids = self.gpu_ids) + self.model_names.append('mem') + + # assigns the model to self.netG_[suffix] so that it can be loaded + # please see + setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + + We need to use 'single_dataset' dataset mode. It only load images from one domain. + """ + self.real = input['A'].to(self.device) + self.image_paths = input['A_paths'] + # for memory net + if self.opt.use_memory: + self.resnet_input = input['resnet_input'].to(self.device) + + def forward(self): + """Run forward pass.""" + if not self.opt.attention: + self.fake = self.netG(self.real) # G(A) + else: + if not self.opt.use_memory: + self.fake_B_img, self.fake_B_mask = self.netG(self.real) + else: + if self.opt.fixindex != -1: + top1_idx = self.opt.fixindex + top1_feature, top1_index = self.netmem.get_feature(top1_idx, self.real.size()[0]) + #print(top1_index, top1_idx) + else: + query = self.netmem(self.resnet_input) + top1_feature, top1_index, top1_idx = self.netmem.topk_feature(query, 1) + top1_feature = top1_feature[:, 0, :] + #print(top1_index, top1_idx) + self.fake_B_img, self.fake_B_mask = self.netG(self.real, top1_feature) + self.fake_B_mask = self._do_if_necessary_saturate_mask(self.fake_B_mask, saturate=self.opt.do_saturate_mask) + self.fake = self.fake_B_mask * self.real[:,-self.opt.input_nc:] + (1 - self.fake_B_mask) * self.fake_B_img + self.fake_B_mask_vis = self.fake_B_mask * 2 - 1 + if self.opt.Nw > 1: + self.real_A_0 = self.real[:,:self.opt.input_nc] + self.real_A_1 = self.real[:,self.opt.input_nc:2*self.opt.input_nc] + self.real_A_2 = self.real[:,2*self.opt.input_nc:3*self.opt.input_nc] + + def forward_getfeat(self): + query = self.netmem(self.resnet_input) + top1_feature, top1_index, top1_idx = self.netmem.topk_feature(query, 1) + return top1_feature + + def forward_getfeatk(self, k): + query = self.netmem(self.resnet_input) + topk_feature, topk_index, topk_idx = self.netmem.topk_feature(query, k) + return topk_feature + + def forward_withfeat(self,feat): + self.fake_B_img, self.fake_B_mask = self.netG(self.real, feat) + self.fake_B_mask = self._do_if_necessary_saturate_mask(self.fake_B_mask, saturate=self.opt.do_saturate_mask) + self.fake = self.fake_B_mask * self.real[:,-self.opt.input_nc:] + (1 - self.fake_B_mask) * self.fake_B_img + self.fake_B_mask_vis = self.fake_B_mask * 2 - 1 + if self.opt.Nw > 1: + self.real_A_0 = self.real[:,:self.opt.input_nc] + self.real_A_1 = self.real[:,self.opt.input_nc:2*self.opt.input_nc] + self.real_A_2 = self.real[:,2*self.opt.input_nc:3*self.opt.input_nc] + + + def optimize_parameters(self): + """No optimization for test model.""" + pass + + def _do_if_necessary_saturate_mask(self, m, saturate=False): + return torch.clamp(0.55*torch.tanh(3*(m-0.5))+0.5, 0, 1) if saturate else m \ No newline at end of file diff --git a/render-to-video/options/__init__.py b/render-to-video/options/__init__.py new file mode 100644 index 0000000..e7eedeb --- /dev/null +++ b/render-to-video/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/render-to-video/options/base_options.py b/render-to-video/options/base_options.py new file mode 100644 index 0000000..46d187d --- /dev/null +++ b/render-to-video/options/base_options.py @@ -0,0 +1,152 @@ +import argparse +import os +from util import util +import torch +import models +import data +from PIL import Image + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + # model parameters + parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') + parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') + parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + # dataset parameters + parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') + parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') + parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + parser.add_argument('--Nw', type=int, default=1, help='window size of frames') + #parser.add_argument('--resizemethod', default='bicubic', type=str, help='pillow resize method: bicubic, nearest, bilinear, lanczos') + parser.add_argument('--resizemethod', default='lanczos', type=str, help='pillow resize method: bicubic, nearest, bilinear, lanczos') + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # modify dataset-related parser options + dataset_name = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + return parser.parse_args() + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + if len(opt.gpu_ids) > 0: + torch.cuda.set_device(opt.gpu_ids[0]) + + # set resizemethod + resizemethod = opt.resizemethod + if resizemethod == 'bicubic': + opt.resizemethod = Image.BICUBIC + elif resizemethod == 'nearest': + opt.resizemethod = Image.NEAREST + elif resizemethod == 'bilinear': + opt.resizemethod = Image.BILINEAR + elif resizemethod == 'lanczos': + opt.resizemethod = Image.LANCZOS + + self.opt = opt + return self.opt diff --git a/render-to-video/options/test_options.py b/render-to-video/options/test_options.py new file mode 100644 index 0000000..71aa2ff --- /dev/null +++ b/render-to-video/options/test_options.py @@ -0,0 +1,28 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') + parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images') + # rewrite devalue values + parser.set_defaults(model='test') + # To avoid cropping, the load_size should be the same as crop_size + parser.set_defaults(load_size=parser.get_default('crop_size')) + parser.add_argument('--test_batch_list', default='', type=str, help='dataroot list for test_batch') + parser.add_argument('--n', type=int, default=26, help='person id.') + parser.add_argument('--blinkframeid', type=int, default=41, help='blink frame id in 12s short video.') + self.isTrain = False + return parser diff --git a/render-to-video/options/train_options.py b/render-to-video/options/train_options.py new file mode 100644 index 0000000..8b8ebfb --- /dev/null +++ b/render-to-video/options/train_options.py @@ -0,0 +1,40 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # visdom and HTML visualization parameters + parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') + parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') + parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') + parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') + parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') + parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + # network saving and loading parameters + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + # training parameters + parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') + parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') + parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + + self.isTrain = True + return parser diff --git a/render-to-video/test.py b/render-to-video/test.py new file mode 100644 index 0000000..bbd6f02 --- /dev/null +++ b/render-to-video/test.py @@ -0,0 +1,67 @@ +"""General-purpose test script for image-to-image translation. + +Once you have trained your model with train.py, you can use this script to test the model. +It will load a saved model from --checkpoints_dir and save the results to --results_dir. + +It first creates model and dataset given the option. It will hard-code some parameters. +It then runs inference for --num_test images and save results to an HTML file. + +Example (You need to train models first or download pre-trained models from our website): + Test a CycleGAN model (both sides): + python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan + + Test a CycleGAN model (one side only): + python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout + + The option '--model test' is used for generating CycleGAN results only for one side. + This option will automatically set '--dataset_mode single', which only loads the images from one set. + On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, + which is sometimes unnecessary. The results will be saved at ./results/. + Use '--results_dir ' to specify the results directory. + + Test a pix2pix model: + python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA + +See options/base_options.py and options/test_options.py for more test options. +See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md +See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md +""" +import os +from options.test_options import TestOptions +from data import create_dataset +from models import create_model +from util.visualizer import save_images +from util import html + + +if __name__ == '__main__': + opt = TestOptions().parse() # get test options + # hard-code some parameters for test + opt.num_threads = 0 # test code only supports num_threads = 1 + opt.batch_size = 1 # test code only supports batch_size = 1 + opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + opt.no_flip = True # no flip; comment this line if results on flipped images are needed. + opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. + dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options + model = create_model(opt) # create a model given opt.model and other options + model.setup(opt) # regular setup: load and print networks; create schedulers + # create a website + web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory + #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder) + # test with eval mode. This only affects layers like batchnorm and dropout. + # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. + # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. + if opt.eval: + model.eval() + for i, data in enumerate(dataset): + if i >= opt.num_test: # only apply our model to opt.num_test images. + break + model.set_input(data) # unpack data from data loader + model.test() # run inference + visuals = model.get_current_visuals() # get image results + img_path = model.get_image_paths() # get image paths + if i % 5 == 0: # save images to an HTML file + print('processing (%04d)-th image... %s' % (i, img_path)) + save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) + webpage.save() # save the HTML diff --git a/render-to-video/test_memory.py b/render-to-video/test_memory.py new file mode 100644 index 0000000..edd5492 --- /dev/null +++ b/render-to-video/test_memory.py @@ -0,0 +1,128 @@ +"""General-purpose test script for image-to-image translation. + +Once you have trained your model with train.py, you can use this script to test the model. +It will load a saved model from --checkpoints_dir and save the results to --results_dir. + +It first creates model and dataset given the option. It will hard-code some parameters. +It then runs inference for --num_test images and save results to an HTML file. + +Example (You need to train models first or download pre-trained models from our website): + Test a CycleGAN model (both sides): + python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan + + Test a CycleGAN model (one side only): + python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout + + The option '--model test' is used for generating CycleGAN results only for one side. + This option will automatically set '--dataset_mode single', which only loads the images from one set. + On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, + which is sometimes unnecessary. The results will be saved at ./results/. + Use '--results_dir ' to specify the results directory. + + Test a pix2pix model: + python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA + +See options/base_options.py and options/test_options.py for more test options. +See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md +See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md +""" +#encoding:utf-8 +import os +from options.test_options import TestOptions +from data import create_dataset +from models import create_model +from util.visualizer import save_images +from util import html +import torch +import matplotlib.pyplot as plt +import numpy as np +from sklearn.preprocessing import PolynomialFeatures +from sklearn import linear_model +import pdb + + +if __name__ == '__main__': + opt = TestOptions().parse() # get test options + # hard-code some parameters for test + opt.num_threads = 0 # test code only supports num_threads = 1 + opt.batch_size = 1 # test code only supports batch_size = 1 + opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + opt.no_flip = True # no flip; comment this line if results on flipped images are needed. + opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. + + opt.netG = 'unetac_adain_256' + opt.model = 'test' + opt.Nw = 3 + opt.norm = 'batch' + opt.dataset_mode = 'single_multi' + + dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options + model = create_model(opt) # create a model given opt.model and other options + model.setup(opt) # regular setup: load and print networks; create schedulers + # create a website + web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory + #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) + webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder) + # test with eval mode. This only affects layers like batchnorm and dropout. + # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. + # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. + if opt.eval: + model.eval() + N = dataset.__len__() + features = torch.zeros((N,1,1,512)).cuda(opt.gpu_ids[0]) + control = 1 + for i, data in enumerate(dataset): + if i >= opt.num_test: # only apply our model to opt.num_test images. + break + model.set_input(data) # unpack data from data loader + feature = model.forward_getfeat() + feature5 = model.forward_getfeatk(5) + features[i] = feature + if control == 4: + poly = PolynomialFeatures(degree=512) + fea=features.cpu().numpy() + for m in range(0,512): + x = np.arange(1, features.shape[0]+1, 1) + y = fea[:,0,0,m] + z1 = np.polyfit(x, y, 10) + p1 = np.poly1d(z1) + yvals=p1(x) + fea[:,0,0,m]=yvals + features=torch.Tensor(fea).cuda(opt.gpu_ids[0]) + + #np.save('features.npy',features.cpu().numpy()) + for i, data in enumerate(dataset): + model.set_input(data) + if control == 0: + feature = features[0] + elif control == 1: + # interpolation + M = 25 + if i % M == 0 or i == N-1: + feature = features[i] + else: + l = i // M * M + r = min(l + M, N-1) + feature = (i-l)/float(r-l) * (features[r]-features[l]) + features[l] + elif control == 2: + # average by 3 + if i == 0: + feature = features[i] + elif i == 1: + feature = torch.mean(features[i-1:i+1],dim=0) + else: + feature = torch.mean(features[i-2:i+1],dim=0) + elif control == 3: + # average by all + feature = torch.mean(features,dim=0) + elif control == 4: + # fit + feature = features[i] + model.forward_withfeat(feature) + visuals = model.get_current_visuals() # get image results + img_path = model.get_image_paths() # get image paths + if i % 5 == 0: # save images to an HTML file + print('processing (%04d)-th image... %s' % (i, img_path)) + save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) + webpage.save() # save the HTML + print('control',control) diff --git a/render-to-video/train.py b/render-to-video/train.py new file mode 100644 index 0000000..6a4f931 --- /dev/null +++ b/render-to-video/train.py @@ -0,0 +1,87 @@ +"""General-purpose training script for image-to-image translation. + +This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and +different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). +You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). + +It first creates model, dataset, and visualizer given the option. +It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. +The script supports continue/resume training. Use '--continue_train' to resume your previous training. + +Example: + Train a CycleGAN model: + python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan + Train a pix2pix model: + python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA + +See options/base_options.py and options/train_options.py for more training options. +See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md +See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md +""" +import time +from options.train_options import TrainOptions +from data import create_dataset +from models import create_model +from util.visualizer import Visualizer +import sys +import pdb + +if __name__ == '__main__': + start = time.time() + opt = TrainOptions().parse() # get training options + dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options + dataset_size = len(dataset) # get the number of images in the dataset. + print('The number of training images = %d' % dataset_size) + model = create_model(opt) # create a model given opt.model and other options + model.setup(opt) # regular setup: load and print networks; create schedulers + visualizer = Visualizer(opt) # create a visualizer that display/save images and plots + total_iters = 0 # the total number of training iterations + + for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , + + epoch_start_time = time.time() # timer for entire epoch + iter_data_time = time.time() # timer for data loading per iteration + epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch + + for i, data in enumerate(dataset): # inner loop within one epoch + iter_start_time = time.time() # timer for computation per iteration + if total_iters % opt.print_freq == 0: + t_data = iter_start_time - iter_data_time + visualizer.reset() + total_iters += opt.batch_size + epoch_iter += opt.batch_size + model.set_input(data) # unpack data from dataset and apply preprocessing + model.optimize_parameters() # calculate loss functions, get gradients, update network weights + + if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file + save_result = total_iters % opt.update_html_freq == 0 + model.compute_visuals() + visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) + #sys.exit(-1) + + if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk + losses = model.get_current_losses() + t_comp = (time.time() - iter_start_time) / opt.batch_size + visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) + if opt.display_id > 0: + visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) + if 'memory' in opt.model: + print('replace %d, update %d' % (model.replace, model.update)) + + if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations + print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) + save_suffix = '%d_iter_%d' % (epoch,total_iters) if opt.save_by_iter else 'latest' + model.save_networks(save_suffix) + + iter_data_time = time.time() + if epoch % opt.save_epoch_freq == 0: # cache our model every epochs + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) + model.save_networks('latest') + model.save_networks(epoch) + + if 'memory' in opt.model: + print('End of epoch %d / %d \t Time Taken: %d sec, replace %d, update %d' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time, model.replace, model.update)) + else: + print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + model.update_learning_rate() # update learning rates at the end of every epoch. + + print('Total Time Taken: %d sec' % (time.time() - start)) \ No newline at end of file diff --git a/render-to-video/train_19news_1.py b/render-to-video/train_19news_1.py new file mode 100644 index 0000000..00c1dee --- /dev/null +++ b/render-to-video/train_19news_1.py @@ -0,0 +1,71 @@ +import os, sys + +def get_news(n): + trainN=300; testN=100 + video = '19_news/'+str(n);name = str(n)+'_bmold_win3';start = 0; + print(video,name) + + rootdir = os.path.join(os.getcwd(),'../Deep3DFaceReconstruction/output/render/') + srcdir = os.path.join(rootdir,video) + srcdir2 = srcdir.replace(video,video+'/bm') + + if 'bmold' not in name: + cmd = "cd "+rootdir+"/..; matlab -nojvm -nosplash -nodesktop -nodisplay -r \"alpha_blend_news('" + video + "'," + str(start) + "," + str(trainN+testN) + "); quit;\"" + else: + cmd = "cd "+rootdir+"/..; matlab -nojvm -nosplash -nodesktop -nodisplay -r \"alpha_blend_newsold('" + video + "'," + str(start) + "," + str(trainN+testN) + "); quit;\"" + os.system(cmd) + f1 = open('datasets/list/trainA/%s.txt'%name,'w') + f2 = open('datasets/list/trainB/%s.txt'%name,'w') + if 'win3' in name: + start1 = start + 2 + else: + start1 = start + for i in range(start1,start+trainN): + if 'bmold' not in name: + print(os.path.join(srcdir2,'frame%d_render_bm.png'%i),file=f1) + else: + print(os.path.join(srcdir2,'frame%d_renderold_bm.png'%i),file=f1) + print(os.path.join(srcdir,'frame%d.png'%i),file=f2) + f1.close() + f2.close() + f1 = open('datasets/list/testA/%s.txt'%name,'w') + f2 = open('datasets/list/testB/%s.txt'%name,'w') + for i in range(start+trainN,start+trainN+testN): + if 'bmold' not in name: + print(os.path.join(srcdir2,'frame%d_render_bm.png'%i),file=f1) + else: + print(os.path.join(srcdir2,'frame%d_renderold_bm.png'%i),file=f1) + print(os.path.join(srcdir,'frame%d.png'%i),file=f2) + f1.close() + f2.close() + +def save_each_60(folder): + pths = sorted(glob.glob(folder+'/*.pth')) + for pth in pths: + epoch = os.path.basename(pth).split('_')[0] + if epoch == '60': + continue + os.remove(pth) + +n = int(sys.argv[1]) +gpu_id = int(sys.argv[2]) + +# prepare training data, and write two txt as training list +get_news(n) + +# prepare arcface feature +cmd = 'cd arcface/; python test_batch.py --imglist trainB/%d_bmold_win3.txt --gpu %d' % (n,gpu_id) +os.system(cmd) +cmd = 'cd arcface/; python test_batch.py --imglist testB/%d_bmold_win3.txt --gpu %d' % (n,gpu_id) +os.system(cmd) + + +# fine tune the mapping +n = str(n) +cmd = 'python train.py --dataroot %s_bmold_win3 --name memory_seq_p2p/%s --model memory_seq --continue_train --epoch 0 --epoch_count 1 --lambda_mask 2 --lr 0.0001 --display_env memory_seq_%s --gpu_ids %d --niter 60 --niter_decay 0' % (n,n,n,gpu_id) +os.system(cmd) +save_each_60('checkpoints/memory_seq_p2p/%s'%n) + +epoch = 60 +cmd = 'python test.py --dataroot %s_bmold_win3 --name memory_seq_p2p/%s --model memory_seq --num_test 200 --epoch %d --gpu_ids %d --imagefolder images%d' % (n,n,epoch,gpu_id,epoch) +os.system(cmd) \ No newline at end of file diff --git a/render-to-video/util/__init__.py b/render-to-video/util/__init__.py new file mode 100644 index 0000000..ae36f63 --- /dev/null +++ b/render-to-video/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/render-to-video/util/get_data.py b/render-to-video/util/get_data.py new file mode 100644 index 0000000..97edc3c --- /dev/null +++ b/render-to-video/util/get_data.py @@ -0,0 +1,110 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """A Python script for downloading CycleGAN or pix2pix datasets. + + Parameters: + technique (str) -- One of: 'cyclegan' or 'pix2pix'. + verbose (bool) -- If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' + and 'scripts/download_cyclegan_model.sh'. + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Parameters: + save_path (str) -- A directory to save the data to. + dataset (str) -- (optional). A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full (str) -- the absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/render-to-video/util/html.py b/render-to-video/util/html.py new file mode 100644 index 0000000..26c18f1 --- /dev/null +++ b/render-to-video/util/html.py @@ -0,0 +1,94 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0, folder='images'): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + #img(style="width:%dpx" % width, src=os.path.join('images', im)) + img(style="width:%dpx" % width, src=os.path.join(self.folder, im)) + + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + #html_file = '%s/index.html' % self.web_dir + name = self.folder[6:] if self.folder[:6] == 'images' else self.folder + html_file = '%s/index%s.html' % (self.web_dir, name) + if len(name.split('/')) > 1: + html_file = '%s/%s/index%s.html' % (self.web_dir,os.path.dirname(name),os.path.basename(name)[6:]) + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/render-to-video/util/image_pool.py b/render-to-video/util/image_pool.py new file mode 100644 index 0000000..6d086f8 --- /dev/null +++ b/render-to-video/util/image_pool.py @@ -0,0 +1,54 @@ +import random +import torch + + +class ImagePool(): + """This class implements an image buffer that stores previously generated images. + + This buffer enables us to update discriminators using a history of generated images + rather than the ones produced by the latest generators. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_imgs = 0 + self.images = [] + + def query(self, images): + """Return an image from the pool. + + Parameters: + images: the latest generated images from the generator + + Returns images from the buffer. + + By 50/100, the buffer will return input images. + By 50/100, the buffer will return images previously stored in the buffer, + and insert the current images to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_images.append(image) + return_images = torch.cat(return_images, 0) # collect all the images and return + return return_images diff --git a/render-to-video/util/util.py b/render-to-video/util/util.py new file mode 100644 index 0000000..c368189 --- /dev/null +++ b/render-to-video/util/util.py @@ -0,0 +1,96 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/render-to-video/util/visualizer.py b/render-to-video/util/visualizer.py new file mode 100644 index 0000000..f8f4111 --- /dev/null +++ b/render-to-video/util/visualizer.py @@ -0,0 +1,229 @@ +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from scipy.misc import imresize + +if sys.version_info[0] == 2: + VisdomExceptionBase = Exception +else: + VisdomExceptionBase = ConnectionError + + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + short_path1 = ntpath.basename(ntpath.dirname(image_path[0])) + short_path = short_path1 + '-' + short_path + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + h, w, _ = im.shape + if aspect_ratio > 1.0: + im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') + if aspect_ratio < 1.0: + im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') + util.save_image(im, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: connect to a visdom server + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.port = opt.display_port + self.saved = False + if self.display_id > 0: # connect to a visdom server given and + import visdom + self.ncols = opt.display_ncols + self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) + if not self.vis.check_connection(): + self.create_visdom_connections() + + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + def create_visdom_connections(self): + """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ + cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port + print('\n\nCould not connect to Visdom server. \n Trying to start a server....') + print('Command: %s' % cmd) + Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) + + def display_current_results(self, visuals, epoch, save_result): + """Display current results on visdom; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + if self.display_id > 0: # show images in the browser using visdom + ncols = self.ncols + if ncols > 0: # show all the images in one visdom panel + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) # create a table css + # create a table of images. + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + try: + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except VisdomExceptionBase: + self.create_visdom_connections() + + else: # show each image in a separate visdom panel; + idx = 1 + try: + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + except VisdomExceptionBase: + self.create_visdom_connections() + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, epoch, counter_ratio, losses): + """display the current losses on visdom display: dictionary of error labels and values + + Parameters: + epoch (int) -- current epoch + counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + """ + if not hasattr(self, 'plot_data'): + self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) + try: + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + except VisdomExceptionBase: + self.create_visdom_connections() + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..84730af --- /dev/null +++ b/requirements.txt @@ -0,0 +1,83 @@ +absl-py==0.7.1 +astor==0.8.0 +audioread==2.1.8 +backcall==0.1.0 +certifi==2019.9.11 +cffi==1.12.3 +chardet==3.0.4 +colorthief==0.2.1 +cycler==0.10.0 +Cython==0.29.10 +decorator==4.4.0 +dlib==19.18.0 +dominate==2.4.0 +easydict==1.9 +gast==0.2.2 +google-pasta==0.1.7 +graphviz==0.8.4 +grpcio==1.21.1 +h5py==2.9.0 +idna==2.8 +imageio==2.5.0 +ipython==7.5.0 +ipython-genutils==0.2.0 +jedi==0.14.0 +joblib==0.13.2 +jsonpatch==1.24 +jsonpointer==2.0 +Keras-Applications==1.0.8 +Keras-Preprocessing==1.1.0 +kiwisolver==1.1.0 +librosa==0.7.0 +llvmlite==0.29.0 +Markdown==3.1.1 +matplotlib==3.1.0 +mxnet==1.5.1.post0 +mxnet-cu80==1.5.0 +networkx==2.3 +numba==0.44.1 +numpy==1.16.4 +olefile==0.46 +opencv-contrib-python==4.1.1.26 +opencv-python==4.1.0.25 +pandas==0.25.1 +parso==0.5.0 +pexpect==4.7.0 +pickleshare==0.7.5 +Pillow==4.3.0 +POT==0.4.0 +prompt-toolkit==2.0.9 +protobuf==3.8.0 +ptyprocess==0.6.0 +pycparser==2.19 +Pygments==2.4.2 +pyparsing==2.4.0 +pyssim==0.4 +python-dateutil==2.8.0 +python-speech-features==0.6 +pytz==2019.2 +PyWavelets==1.0.3 +pyzmq==18.1.0 +requests==2.22.0 +resampy==0.2.1 +scikit-image==0.15.0 +scikit-learn==0.21.2 +scipy==1.0.0 +six==1.12.0 +SoundFile==0.10.2 +tensorboard==1.14.0 +tensorflow-estimator==1.14.0rc1 +tensorflow-gpu==1.14.0 +termcolor==1.1.0 +torch==1.1.0 +torchfile==0.1.0 +torchvision==0.4.0 +tornado==6.0.3 +tqdm==4.19.6 +traitlets==4.3.2 +urllib3==1.25.3 +visdom==0.1.8.9 +wcwidth==0.1.7 +websocket-client==0.56.0 +Werkzeug==0.15.4 +wrapt==1.11.1