Skip to content

Commit

Permalink
Merge pull request #6 from yzwxx/master
Browse files Browse the repository at this point in the history
fix download.py for celebA
  • Loading branch information
zsdonghao authored May 12, 2017
2 parents d2f662b + 6174324 commit c910e83
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ def download(url, dirpath):
f.close()
return filepath

def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()

response = session.get(URL, params={ 'id': id }, stream=True)
token = get_confirm_token(response)

if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params=params, stream=True)

save_response_content(response, destination)

def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None

def save_response_content(response, destination, chunk_size=32*1024):
total_size = int(response.headers.get('content-length', 0))
with open(destination, "wb") as f:
for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
unit='B', unit_scale=True, desc=destination):
if chunk: # filter out keep-alive new chunks
f.write(chunk)

def unzip(filepath):
print("Extracting: " + filepath)
dirpath = os.path.dirname(filepath)
Expand All @@ -57,18 +84,25 @@ def unzip(filepath):
os.remove(filepath)

def download_celeb_a(dirpath):
data_dir = 'celebA'
if os.path.exists(os.path.join(dirpath, data_dir)):
print('Found Celeb-A - skip')
return
url = 'https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1&pv=1'
filepath = download(url, dirpath)
zip_dir = ''
with zipfile.ZipFile(filepath) as zf:
zip_dir = zf.namelist()[0]
zf.extractall(dirpath)
os.remove(filepath)
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
data_dir = 'celebA'
if os.path.exists(os.path.join(dirpath, data_dir)):
print('Found Celeb-A - skip')
return

filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
save_path = os.path.join(dirpath, filename)

if os.path.exists(save_path):
print('[*] {} already exists'.format(save_path))
else:
download_file_from_google_drive(drive_id, save_path)

zip_dir = ''
with zipfile.ZipFile(save_path) as zf:
zip_dir = zf.namelist()[0]
zf.extractall(dirpath)
os.remove(save_path)
os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))

def _list_categories(tag):
url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag
Expand Down

0 comments on commit c910e83

Please sign in to comment.