Skip to content

Commit

Permalink
Modified dataset processing script for retrain EG3D
Browse files Browse the repository at this point in the history
  • Loading branch information
YanniZhangYZ committed Jan 16, 2024
1 parent f46ae40 commit 0377756
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 117 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ archive
*.ply
eval
out
dataset_preprocessing/ffhq/in-the-wild-images/
dataset_preprocessing/ffhq/final_crops/
dataset_preprocessing/ffhq/realign1500/
dataset_preprocessing/ffhq/FFHQ_512.zip
dataset_preprocessing/ffhq/LICENSE.txt
dataset_preprocessing/ffhq/ffhq-dataset-v2.json
client_secrets.json

# evaluation:
temp/
Expand Down
4 changes: 2 additions & 2 deletions dataset_preprocessing/ffhq/align_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def process_image(kwargs):#item_idx, item, dst_dir="realign1500", output_size=15
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
img = img.resize(rsize, PIL.Image.LANCZOS)
quad /= shrink
qsize /= shrink
# print("shrink--- %s seconds ---" % (time.time() - start_time))
Expand Down Expand Up @@ -147,7 +147,7 @@ def process_image(kwargs):#item_idx, item, dst_dir="realign1500", output_size=15
start_time = time.time()
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
img = img.resize((output_size, output_size), PIL.Image.LANCZOS)
# print("transform--- %s seconds ---" % (time.time() - start_time))

# Save aligned image.
Expand Down
264 changes: 150 additions & 114 deletions dataset_preprocessing/ffhq/download_ffhq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Download Flickr-Faces-HQ (FFHQ) dataset to current working directory."""

import os
import re
import sys
import requests
import html
Expand All @@ -27,7 +28,9 @@
import itertools
import shutil
from collections import OrderedDict, defaultdict
import cv2
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error

#----------------------------------------------------------------------------
Expand Down Expand Up @@ -130,6 +133,50 @@ def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10):
except:
pass

def pydrive_create_drive_manager(cmd_auth):
gAuth = GoogleAuth()

if cmd_auth:
gAuth.CommandLineAuth()
else:
gAuth.LocalWebserverAuth()

gAuth.Authorize()
print("authorized access to google drive API!")

drive: GoogleDrive = GoogleDrive(gAuth)
return drive


def pydrive_extract_files_id(drive, link):
try:
fileID = re.search(r"(?<=/d/|id=|rs/).+?(?=/|$)", link)[0] # extract the fileID
return fileID
except Exception as error:
print("error : " + str(error))
print("Link is probably invalid")
print(link)


def pydrive_download_file(drive, spec, stats, chunk_size=128, num_attempts=10):
link = spec['file_url']
save_path = spec['file_path']
id = pydrive_extract_files_id(drive, link)
file_dir = os.path.dirname(save_path)
if file_dir:
os.makedirs(file_dir, exist_ok=True)

pydrive_file = drive.CreateFile({'id': id})
for attempts_left in reversed(range(num_attempts)):
try:
pydrive_file.GetContentFile(save_path)
break
except:
if not attempts_left:
raise
stats['files_done'] += 1
stats['bytes_done'] += os.stat(save_path).st_size

#----------------------------------------------------------------------------

def choose_bytes_unit(num_bytes):
Expand All @@ -152,7 +199,7 @@ def format_time(seconds):

#----------------------------------------------------------------------------

def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):
def download_files(file_specs, drive=None, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):

# Determine which files to download.
done_specs = {spec['file_path']: spec for spec in file_specs if os.path.isfile(spec['file_path'])}
Expand All @@ -169,7 +216,7 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
exception_queue = queue.Queue()
for spec in missing_specs:
spec_queue.put(spec)
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, download_kwargs=download_kwargs)
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, drive=drive, download_kwargs=download_kwargs)
for _thread_idx in range(min(num_threads, len(missing_specs))):
threading.Thread(target=_download_thread, kwargs=thread_kwargs, daemon=True).start()

Expand Down Expand Up @@ -206,12 +253,15 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
except queue.Empty:
pass

def _download_thread(spec_queue, exception_queue, stats, download_kwargs):
def _download_thread(spec_queue, exception_queue, stats, drive, download_kwargs):
with requests.Session() as session:
while not spec_queue.empty():
spec = spec_queue.get()
try:
download_file(session, spec, stats, **download_kwargs)
if drive is not None:
pydrive_download_file(drive, spec, stats, **download_kwargs)
else:
download_file(session, spec, stats, **download_kwargs)
except:
exception_queue.put(sys.exc_info())

Expand Down Expand Up @@ -254,126 +304,111 @@ def print_statistics(json_data):
for row in rows:
print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))


#----------------------------------------------------------------------------

def recreate_aligned_images_fast(json_data, dst_dir='realign1024x1024', output_size=1024, transform_size=4096, enable_padding=True, start_index=0):
def recreate_aligned_images(json_data, dst_dir='realign1024x1024', output_size=1024, transform_size=4096, enable_padding=True):
print('Recreating aligned images...')
if dst_dir:
os.makedirs(dst_dir, exist_ok=True)
shutil.copyfile('LICENSE.txt', os.path.join(dst_dir, 'LICENSE.txt'))
print(len(json_data))

for item_idx, item in enumerate(json_data.values()):
if item_idx >= start_index and item_idx <= start_index+5000:
print('\r%d / %d ... ' % (item_idx, len(json_data)), end='', flush=True)

# Parse landmarks.
# pylint: disable=unused-variable
lm = np.array(item['in_the_wild']['face_landmarks'])
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise

# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg

# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
q_scale = 1.8
x = q_scale * x
y = q_scale * y
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2

# Load in-the-wild image.
src_file = item['in_the_wild']['file_path']
if not os.path.isfile(src_file):
print('\nCannot find source image. Please run "--wilds" before "--align".')
return
img = PIL.Image.open(src_file)

import time

# Shrink.
start_time = time.time()
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
print("shrink--- %s seconds ---" % (time.time() - start_time))

# Crop.
start_time = time.time()
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]
print("crop--- %s seconds ---" % (time.time() - start_time))

# Pad.
start_time = time.time()
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
low_res = cv2.resize(img, (0,0), fx=0.1, fy=0.1, interpolation = cv2.INTER_AREA)
blur = qsize * 0.02*0.1
low_res = scipy.ndimage.gaussian_filter(low_res, [blur, blur, 0])
low_res = cv2.resize(low_res, (img.shape[1], img.shape[0]), interpolation = cv2.INTER_LANCZOS4)
img += (low_res - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
median = cv2.resize(img, (0,0), fx=0.1, fy=0.1, interpolation = cv2.INTER_AREA)
median = np.median(median, axis=(0,1))
img += (median - img) * np.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]
print("pad--- %s seconds ---" % (time.time() - start_time))

# Transform.
start_time = time.time()
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
print("transform--- %s seconds ---" % (time.time() - start_time))

# Save aligned image.
dst_subdir = os.path.join(dst_dir, '%05d' % (item_idx - item_idx % 1000))
os.makedirs(dst_subdir, exist_ok=True)
img.save(os.path.join(dst_subdir, '%05d.png' % item_idx))
print('\r%d / %d ... ' % (item_idx, len(json_data)), end='', flush=True)

# Parse landmarks.
# pylint: disable=unused-variable
lm = np.array(item['in_the_wild']['face_landmarks'])
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise

# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg

# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2

# Load in-the-wild image.
src_file = item['in_the_wild']['file_path']
if not os.path.isfile(src_file):
print('\nCannot find source image. Please run "--wilds" before "--align".')
return
img = PIL.Image.open(src_file)

# Shrink.
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink

# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]

# Pad.
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]

# Transform.
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

# Save aligned image.
dst_subdir = os.path.join(dst_dir, '%05d' % (item_idx - item_idx % 1000))
os.makedirs(dst_subdir, exist_ok=True)
img.save(os.path.join(dst_subdir, '%05d.png' % item_idx))

# All done.
print('\r%d / %d ... done' % (len(json_data), len(json_data)))

#----------------------------------------------------------------------------

def run(tasks, start_index, **download_kwargs):
def run(tasks, pydrive, cmd_auth, **download_kwargs):
if pydrive:
drive = pydrive_create_drive_manager(cmd_auth)
else:
drive = None

if not os.path.isfile(json_spec['file_path']) or not os.path.isfile('LICENSE.txt'):
print('Downloading JSON metadata...')
download_files([json_spec, license_specs['json']], **download_kwargs)
download_files([json_spec, license_specs['json']], drive=drive, **download_kwargs)

print('Parsing JSON metadata...')
with open(json_spec['file_path'], 'rb') as f:
Expand All @@ -395,10 +430,10 @@ def run(tasks, start_index, **download_kwargs):
if len(specs):
print('Downloading %d files...' % len(specs))
np.random.shuffle(specs) # to make the workload more homogeneous
download_files(specs, **download_kwargs)
download_files(specs, drive=drive, **download_kwargs)

if 'align' in tasks:
recreate_aligned_images_fast(json_data, dst_dir="realign1500", output_size=1500, start_index=start_index)
recreate_aligned_images(json_data)

#----------------------------------------------------------------------------

Expand All @@ -410,13 +445,14 @@ def run_cmdline(argv):
parser.add_argument('-t', '--thumbs', help='download 128x128 thumbnails as PNG (1.95 GB)', dest='tasks', action='append_const', const='thumbs')
parser.add_argument('-w', '--wilds', help='download in-the-wild images as PNG (955 GB)', dest='tasks', action='append_const', const='wilds')
parser.add_argument('-r', '--tfrecords', help='download multi-resolution TFRecords (273 GB)', dest='tasks', action='append_const', const='tfrecords')
parser.add_argument('--pydrive', help='use pydrive interface to download files. it overrides google drive quota limitation this requires google credentials (default: False)', action='store_true')
parser.add_argument('--cmd_auth', help='use command line google authentication when using pydrive interface (default: False)', action='store_true')
parser.add_argument('-a', '--align', help='recreate 1024x1024 images from in-the-wild images', dest='tasks', action='append_const', const='align')
parser.add_argument('--num_threads', help='number of concurrent download threads (default: 32)', type=int, default=32, metavar='NUM')
parser.add_argument('--status_delay', help='time between download status prints (default: 0.2)', type=float, default=0.2, metavar='SEC')
parser.add_argument('--timing_window', help='samples for estimating download eta (default: 50)', type=int, default=50, metavar='LEN')
parser.add_argument('--chunk_size', help='chunk size for each download thread (default: 128)', type=int, default=128, metavar='KB')
parser.add_argument('--num_attempts', help='number of download attempts per file (default: 10)', type=int, default=10, metavar='NUM')
parser.add_argument('--start_index', help='start index for alignment (default: 0)', type=int, default=0, metavar='NUM')

args = parser.parse_args()
if not args.tasks:
Expand All @@ -429,4 +465,4 @@ def run_cmdline(argv):
if __name__ == "__main__":
run_cmdline(sys.argv)

#----------------------------------------------------------------------------
#----------------------------------------------------------------------------
Loading

0 comments on commit 0377756

Please sign in to comment.