From 6174324744595771c008c8398e41b91676581673 Mon Sep 17 00:00:00 2001 From: zy5015 Date: Fri, 12 May 2017 15:26:29 +0100 Subject: [PATCH] update download.py according to https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py --- download.py | 58 ++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/download.py b/download.py index 00a00aa..078afec 100755 --- a/download.py +++ b/download.py @@ -49,6 +49,33 @@ def download(url, dirpath): f.close() return filepath +def download_file_from_google_drive(id, destination): + URL = "https://docs.google.com/uc?export=download" + session = requests.Session() + + response = session.get(URL, params={ 'id': id }, stream=True) + token = get_confirm_token(response) + + if token: + params = { 'id' : id, 'confirm' : token } + response = session.get(URL, params=params, stream=True) + + save_response_content(response, destination) + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + +def save_response_content(response, destination, chunk_size=32*1024): + total_size = int(response.headers.get('content-length', 0)) + with open(destination, "wb") as f: + for chunk in tqdm(response.iter_content(chunk_size), total=total_size, + unit='B', unit_scale=True, desc=destination): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + def unzip(filepath): print("Extracting: " + filepath) dirpath = os.path.dirname(filepath) @@ -57,18 +84,25 @@ def unzip(filepath): os.remove(filepath) def download_celeb_a(dirpath): - data_dir = 'celebA' - if os.path.exists(os.path.join(dirpath, data_dir)): - print('Found Celeb-A - skip') - return - url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1' - filepath = download(url, dirpath) - zip_dir = '' - with zipfile.ZipFile(filepath) as zf: - zip_dir = zf.namelist()[0] - zf.extractall(dirpath) - os.remove(filepath) - os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) + data_dir = 'celebA' + if os.path.exists(os.path.join(dirpath, data_dir)): + print('Found Celeb-A - skip') + return + + filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" + save_path = os.path.join(dirpath, filename) + + if os.path.exists(save_path): + print('[*] {} already exists'.format(save_path)) + else: + download_file_from_google_drive(drive_id, save_path) + + zip_dir = '' + with zipfile.ZipFile(save_path) as zf: + zip_dir = zf.namelist()[0] + zf.extractall(dirpath) + os.remove(save_path) + os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) def _list_categories(tag): url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag