Skip to content

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: NVlabs/eg3d
Failed to load repositories. Confirm that selected base ref is valid, then try again.
base: main
Choose a base ref
head repository: Logitech/eg3d
Failed to load repositories. Confirm that selected head ref is valid, then try again.
compare: main
Choose a head ref
Able to merge. These branches can be automatically merged.
  • 9 commits
  • 8 files changed
  • 1 contributor

Commits on Sep 28, 2023

  1. Make the triplane available when call synthesis(). Useful for buildin…

    …g loss function of live3D portrait.
    YanniZhangYZ committed Sep 28, 2023
    Copy the full SHA
    a33fe63 View commit details

Commits on Oct 2, 2023

  1. Copy the full SHA
    bbbab7d View commit details

Commits on Oct 3, 2023

  1. Copy the full SHA
    fd33438 View commit details

Commits on Oct 11, 2023

  1. Copy the full SHA
    55040d2 View commit details
  2. Add camera parameters sampling needed for live3d: fov, principal poin…

    …t, camera radius, camera roll.
    YanniZhangYZ committed Oct 11, 2023
    Copy the full SHA
    c811e35 View commit details
  3. Change the return value of synthesis(): now it additionally returns t…

    …he feature image and triplane.
    YanniZhangYZ committed Oct 11, 2023
    Copy the full SHA
    b406383 View commit details

Commits on Oct 12, 2023

  1. Copy the full SHA
    50dd9c6 View commit details

Commits on Oct 18, 2023

  1. add more comments

    YanniZhangYZ committed Oct 18, 2023
    Copy the full SHA
    f46ae40 View commit details

Commits on Jan 16, 2024

  1. Copy the full SHA
    0377756 View commit details
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -14,6 +14,13 @@ archive

# evaluation:
4 changes: 2 additions & 2 deletions dataset_preprocessing/ffhq/
Original file line number Diff line number Diff line change
@@ -106,7 +106,7 @@ def process_image(kwargs):#item_idx, item, dst_dir="realign1500", output_size=15
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
img = img.resize(rsize, PIL.Image.LANCZOS)
quad /= shrink
qsize /= shrink
# print("shrink--- %s seconds ---" % (time.time() - start_time))
@@ -147,7 +147,7 @@ def process_image(kwargs):#item_idx, item, dst_dir="realign1500", output_size=15
start_time = time.time()
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
img = img.resize((output_size, output_size), PIL.Image.LANCZOS)
# print("transform--- %s seconds ---" % (time.time() - start_time))

# Save aligned image.
264 changes: 150 additions & 114 deletions dataset_preprocessing/ffhq/
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
"""Download Flickr-Faces-HQ (FFHQ) dataset to current working directory."""

import os
import re
import sys
import requests
import html
@@ -27,7 +28,9 @@
import itertools
import shutil
from collections import OrderedDict, defaultdict
import cv2
from pydrive2.auth import GoogleAuth
from import GoogleDrive

PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error

@@ -130,6 +133,50 @@ def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10):

def pydrive_create_drive_manager(cmd_auth):
gAuth = GoogleAuth()

if cmd_auth:

print("authorized access to google drive API!")

drive: GoogleDrive = GoogleDrive(gAuth)
return drive

def pydrive_extract_files_id(drive, link):
fileID ="(?<=/d/|id=|rs/).+?(?=/|$)", link)[0] # extract the fileID
return fileID
except Exception as error:
print("error : " + str(error))
print("Link is probably invalid")

def pydrive_download_file(drive, spec, stats, chunk_size=128, num_attempts=10):
link = spec['file_url']
save_path = spec['file_path']
id = pydrive_extract_files_id(drive, link)
file_dir = os.path.dirname(save_path)
if file_dir:
os.makedirs(file_dir, exist_ok=True)

pydrive_file = drive.CreateFile({'id': id})
for attempts_left in reversed(range(num_attempts)):
if not attempts_left:
stats['files_done'] += 1
stats['bytes_done'] += os.stat(save_path).st_size


def choose_bytes_unit(num_bytes):
@@ -152,7 +199,7 @@ def format_time(seconds):


def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):
def download_files(file_specs, drive=None, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):

# Determine which files to download.
done_specs = {spec['file_path']: spec for spec in file_specs if os.path.isfile(spec['file_path'])}
@@ -169,7 +216,7 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
exception_queue = queue.Queue()
for spec in missing_specs:
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, download_kwargs=download_kwargs)
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, drive=drive, download_kwargs=download_kwargs)
for _thread_idx in range(min(num_threads, len(missing_specs))):
threading.Thread(target=_download_thread, kwargs=thread_kwargs, daemon=True).start()

@@ -206,12 +253,15 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
except queue.Empty:

def _download_thread(spec_queue, exception_queue, stats, download_kwargs):
def _download_thread(spec_queue, exception_queue, stats, drive, download_kwargs):
with requests.Session() as session:
while not spec_queue.empty():
spec = spec_queue.get()
download_file(session, spec, stats, **download_kwargs)
if drive is not None:
pydrive_download_file(drive, spec, stats, **download_kwargs)
download_file(session, spec, stats, **download_kwargs)

@@ -254,126 +304,111 @@ def print_statistics(json_data):
for row in rows:
print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))


def recreate_aligned_images_fast(json_data, dst_dir='realign1024x1024', output_size=1024, transform_size=4096, enable_padding=True, start_index=0):
def recreate_aligned_images(json_data, dst_dir='realign1024x1024', output_size=1024, transform_size=4096, enable_padding=True):
print('Recreating aligned images...')
if dst_dir:
os.makedirs(dst_dir, exist_ok=True)
shutil.copyfile('LICENSE.txt', os.path.join(dst_dir, 'LICENSE.txt'))

for item_idx, item in enumerate(json_data.values()):
if item_idx >= start_index and item_idx <= start_index+5000:
print('\r%d / %d ... ' % (item_idx, len(json_data)), end='', flush=True)

# Parse landmarks.
# pylint: disable=unused-variable
lm = np.array(item['in_the_wild']['face_landmarks'])
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise

# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg

# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
q_scale = 1.8
x = q_scale * x
y = q_scale * y
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2

# Load in-the-wild image.
src_file = item['in_the_wild']['file_path']
if not os.path.isfile(src_file):
print('\nCannot find source image. Please run "--wilds" before "--align".')
img =

import time

# Shrink.
start_time = time.time()
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
print("shrink--- %s seconds ---" % (time.time() - start_time))

# Crop.
start_time = time.time()
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]
print("crop--- %s seconds ---" % (time.time() - start_time))

# Pad.
start_time = time.time()
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
low_res = cv2.resize(img, (0,0), fx=0.1, fy=0.1, interpolation = cv2.INTER_AREA)
blur = qsize * 0.02*0.1
low_res = scipy.ndimage.gaussian_filter(low_res, [blur, blur, 0])
low_res = cv2.resize(low_res, (img.shape[1], img.shape[0]), interpolation = cv2.INTER_LANCZOS4)
img += (low_res - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
median = cv2.resize(img, (0,0), fx=0.1, fy=0.1, interpolation = cv2.INTER_AREA)
median = np.median(median, axis=(0,1))
img += (median - img) * np.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]
print("pad--- %s seconds ---" % (time.time() - start_time))

# Transform.
start_time = time.time()
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
print("transform--- %s seconds ---" % (time.time() - start_time))

# Save aligned image.
dst_subdir = os.path.join(dst_dir, '%05d' % (item_idx - item_idx % 1000))
os.makedirs(dst_subdir, exist_ok=True), '%05d.png' % item_idx))
print('\r%d / %d ... ' % (item_idx, len(json_data)), end='', flush=True)

# Parse landmarks.
# pylint: disable=unused-variable
lm = np.array(item['in_the_wild']['face_landmarks'])
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise

# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg

# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2

# Load in-the-wild image.
src_file = item['in_the_wild']['file_path']
if not os.path.isfile(src_file):
print('\nCannot find source image. Please run "--wilds" before "--align".')
img =

# Shrink.
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink

# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]

# Pad.
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]

# Transform.
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

# Save aligned image.
dst_subdir = os.path.join(dst_dir, '%05d' % (item_idx - item_idx % 1000))
os.makedirs(dst_subdir, exist_ok=True), '%05d.png' % item_idx))

# All done.
print('\r%d / %d ... done' % (len(json_data), len(json_data)))


def run(tasks, start_index, **download_kwargs):
def run(tasks, pydrive, cmd_auth, **download_kwargs):
if pydrive:
drive = pydrive_create_drive_manager(cmd_auth)
drive = None

if not os.path.isfile(json_spec['file_path']) or not os.path.isfile('LICENSE.txt'):
print('Downloading JSON metadata...')
download_files([json_spec, license_specs['json']], **download_kwargs)
download_files([json_spec, license_specs['json']], drive=drive, **download_kwargs)

print('Parsing JSON metadata...')
with open(json_spec['file_path'], 'rb') as f:
@@ -395,10 +430,10 @@ def run(tasks, start_index, **download_kwargs):
if len(specs):
print('Downloading %d files...' % len(specs))
np.random.shuffle(specs) # to make the workload more homogeneous
download_files(specs, **download_kwargs)
download_files(specs, drive=drive, **download_kwargs)

if 'align' in tasks:
recreate_aligned_images_fast(json_data, dst_dir="realign1500", output_size=1500, start_index=start_index)


@@ -410,13 +445,14 @@ def run_cmdline(argv):
parser.add_argument('-t', '--thumbs', help='download 128x128 thumbnails as PNG (1.95 GB)', dest='tasks', action='append_const', const='thumbs')
parser.add_argument('-w', '--wilds', help='download in-the-wild images as PNG (955 GB)', dest='tasks', action='append_const', const='wilds')
parser.add_argument('-r', '--tfrecords', help='download multi-resolution TFRecords (273 GB)', dest='tasks', action='append_const', const='tfrecords')
parser.add_argument('--pydrive', help='use pydrive interface to download files. it overrides google drive quota limitation this requires google credentials (default: False)', action='store_true')
parser.add_argument('--cmd_auth', help='use command line google authentication when using pydrive interface (default: False)', action='store_true')
parser.add_argument('-a', '--align', help='recreate 1024x1024 images from in-the-wild images', dest='tasks', action='append_const', const='align')
parser.add_argument('--num_threads', help='number of concurrent download threads (default: 32)', type=int, default=32, metavar='NUM')
parser.add_argument('--status_delay', help='time between download status prints (default: 0.2)', type=float, default=0.2, metavar='SEC')
parser.add_argument('--timing_window', help='samples for estimating download eta (default: 50)', type=int, default=50, metavar='LEN')
parser.add_argument('--chunk_size', help='chunk size for each download thread (default: 128)', type=int, default=128, metavar='KB')
parser.add_argument('--num_attempts', help='number of download attempts per file (default: 10)', type=int, default=10, metavar='NUM')
parser.add_argument('--start_index', help='start index for alignment (default: 0)', type=int, default=0, metavar='NUM')

args = parser.parse_args()
if not args.tasks:
@@ -429,4 +465,4 @@ def run_cmdline(argv):
if __name__ == "__main__":
