Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance download_and_extract #8216

Merged
merged 12 commits into from
Dec 21, 2024
37 changes: 34 additions & 3 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import re
import shutil
import sys
import tarfile
Expand All @@ -27,6 +28,8 @@
from urllib.parse import urlparse
from urllib.request import urlopen, urlretrieve

import requests
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved

from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import

Expand Down Expand Up @@ -298,6 +301,20 @@ def extractall(
)


def get_filename_from_url(data_url: str):
try:
response = requests.head(data_url, allow_redirects=True)
content_disposition = response.headers.get("Content-Disposition")
if content_disposition:
filename = re.findall("filename=(.+)", content_disposition)
return filename[0].strip('"').strip("'")
else:
filename = _basename(data_url)
return filename
except Exception as e:
raise Exception(f"Error processing URL: {e}")


def download_and_extract(
url: str,
filepath: PathLike = "",
Expand Down Expand Up @@ -328,6 +345,20 @@ def download_and_extract(
progress: whether to display progress bar.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
if not filepath:
filename = get_filename_from_url(url)
full_path = Path(tmp_dir, filename)
elif os.path.isdir(filepath) or not os.path.splitext(filepath)[1]:
filename = get_filename_from_url(url)
full_path = Path(os.path.join(filepath, filename))
logger.warning(f"No compress file extension provided, downloading as: '{full_path}'")
else:
url_filename_ext = "".join(Path(".", _basename(url)).resolve().suffixes)
filepath_ext = "".join(Path(".", _basename(filepath)).resolve().suffixes)
if filepath_ext != url_filename_ext:
raise ValueError(
f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}"
)
full_path = Path(filepath)
download_url(url=url, filepath=full_path, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=full_path, output_dir=output_dir, file_type=file_type, has_base=has_base)
Loading