Skip to content

Commit

Permalink
Merge pull request #3031 from snbianco/download-products-verbosity
Browse files Browse the repository at this point in the history
Modulate verbosity on download_products() function
  • Loading branch information
bsipocz committed Jun 14, 2024
2 parents 604f284 + a5189c6 commit 2f2a643
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 37 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ alma
^^^^

- Added method to return quantities instead of values and regions footprint in alma [#2855]

- Added support for frequency_resolution in KHz [#3035]

mpc
Expand Down Expand Up @@ -88,8 +89,12 @@ mast
^^^^

- Fix bug in which the ``local_path`` parameter for the ``mast.observations.download_file`` method does not accept a directory. [#3016]

- Optimize remote test suite to improve performance and reduce execution time. [#3036]

- Add ``verbose`` parameter to modulate output in ``mast.observations.download_products`` method. [#3031]



0.4.7 (2024-03-08)
==================
Expand Down
39 changes: 22 additions & 17 deletions astroquery/mast/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def get_cloud_uri_list(self, data_products, include_bucket=True, full_url=False)

return [self.get_cloud_uri(product, include_bucket, full_url) for product in data_products]

def download_file(self, data_product, local_path, cache=True):
def download_file(self, data_product, local_path, cache=True, verbose=True):
"""
Takes a data product in the form of an `~astropy.table.Row` and downloads it from the cloud into
the given directory.
Expand All @@ -174,6 +174,8 @@ def download_file(self, data_product, local_path, cache=True):
The local filename to which toe downloaded file will be saved.
cache : bool
Default is True. If file is found on disc it will not be downloaded again.
verbose : bool, optional
Default is True. Whether to show download progress in the console.
"""

s3 = self.boto3.resource('s3', config=self.config)
Expand Down Expand Up @@ -203,24 +205,27 @@ def download_file(self, data_product, local_path, cache=True):
.format(local_path, statinfo.st_size))
return

with ProgressBarOrSpinner(length, ('Downloading URL s3://{0}/{1} to {2} ...'.format(
self.pubdata_bucket, bucket_path, local_path))) as pb:
if verbose:
with ProgressBarOrSpinner(length, ('Downloading URL s3://{0}/{1} to {2} ...'.format(
self.pubdata_bucket, bucket_path, local_path))) as pb:

# Bytes read tracks how much data has been received so far
# This variable will be updated in multiple threads below
global bytes_read
bytes_read = 0
# Bytes read tracks how much data has been received so far
# This variable will be updated in multiple threads below
global bytes_read
bytes_read = 0

progress_lock = threading.Lock()
progress_lock = threading.Lock()

def progress_callback(numbytes):
# Boto3 calls this from multiple threads pulling the data from S3
global bytes_read
def progress_callback(numbytes):
# Boto3 calls this from multiple threads pulling the data from S3
global bytes_read

# This callback can be called in multiple threads
# Access to updating the console needs to be locked
with progress_lock:
bytes_read += numbytes
pb.update(bytes_read)
# This callback can be called in multiple threads
# Access to updating the console needs to be locked
with progress_lock:
bytes_read += numbytes
pb.update(bytes_read)

bkt.download_file(bucket_path, local_path, Callback=progress_callback)
bkt.download_file(bucket_path, local_path, Callback=progress_callback)
else:
bkt.download_file(bucket_path, local_path)
32 changes: 22 additions & 10 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def filter_products(self, products, *, mrp_only=False, extension=None, **filters

return products[np.where(filter_mask)]

def download_file(self, uri, *, local_path=None, base_url=None, cache=True, cloud_only=False):
def download_file(self, uri, *, local_path=None, base_url=None, cache=True, cloud_only=False, verbose=True):
"""
Downloads a single file based on the data URI
Expand All @@ -518,6 +518,8 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
Default False. If set to True and cloud data access is enabled (see `enable_cloud_dataset`)
files that are not found in the cloud will be skipped rather than downloaded from MAST
as is the default behavior. If cloud access is not enables this argument as no affect.
verbose : bool, optional
Default True. Whether to show download progress in the console.
Returns
-------
Expand Down Expand Up @@ -554,7 +556,7 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
try:
if self._cloud_connection is not None and self._cloud_connection.is_supported(data_product):
try:
self._cloud_connection.download_file(data_product, local_path, cache)
self._cloud_connection.download_file(data_product, local_path, cache, verbose)
except Exception as ex:
log.exception("Error pulling from S3 bucket: {}".format(ex))
if cloud_only:
Expand All @@ -564,10 +566,12 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
else:
log.warning("Falling back to mast download...")
self._download_file(data_url, local_path,
cache=cache, head_safe=True, continuation=False)
cache=cache, head_safe=True, continuation=False,
verbose=verbose)
else:
self._download_file(data_url, local_path,
cache=cache, head_safe=True, continuation=False)
cache=cache, head_safe=True, continuation=False,
verbose=verbose)

# check if file exists also this is where would perform md5,
# and also check the filesize if the database reliably reported file sizes
Expand All @@ -583,7 +587,7 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou

return status, msg, url

def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_only=False,):
def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_only=False, verbose=True):
"""
Takes an `~astropy.table.Table` of data products and downloads them into the directory given by base_dir.
Expand All @@ -602,6 +606,8 @@ def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_o
Default False. If set to True and cloud data access is enabled (see `enable_cloud_dataset`)
files that are not found in the cloud will be skipped rather than downloaded from MAST
as is the default behavior. If cloud access is not enables this argument as no affect.
verbose : bool, optional
Default True. Whether to show download progress in the console.
Returns
-------
Expand All @@ -622,15 +628,15 @@ def _download_files(self, products, base_dir, *, flat=False, cache=True, cloud_o

# download the files
status, msg, url = self.download_file(data_product["dataURI"], local_path=local_path,
cache=cache, cloud_only=cloud_only)
cache=cache, cloud_only=cloud_only, verbose=verbose)

manifest_array.append([local_path, status, msg, url])

manifest = Table(rows=manifest_array, names=('Local Path', 'Status', 'Message', "URL"))

return manifest

def _download_curl_script(self, products, out_dir):
def _download_curl_script(self, products, out_dir, verbose=True):
"""
Takes an `~astropy.table.Table` of data products and downloads a curl script to pull the datafiles.
Expand All @@ -640,6 +646,8 @@ def _download_curl_script(self, products, out_dir):
Table containing products to be included in the curl script.
out_dir : str
Directory in which the curl script will be saved.
verbose : bool, optional
Default True. Whether to show download progress in the console.
Returns
-------
Expand All @@ -651,7 +659,7 @@ def _download_curl_script(self, products, out_dir):
local_path = os.path.join(out_dir, download_file)

self._download_file(self._portal_api_connection.MAST_BUNDLE_URL + ".sh",
local_path, data=url_list, method="POST")
local_path, data=url_list, method="POST", verbose=verbose)

status = "COMPLETE"
msg = None
Expand All @@ -666,7 +674,8 @@ def _download_curl_script(self, products, out_dir):
return manifest

def download_products(self, products, *, download_dir=None, flat=False,
cache=True, curl_flag=False, mrp_only=False, cloud_only=False, **filters):
cache=True, curl_flag=False, mrp_only=False, cloud_only=False, verbose=True,
**filters):
"""
Download data products.
If cloud access is enabled, files will be downloaded from the cloud if possible.
Expand Down Expand Up @@ -698,6 +707,8 @@ def download_products(self, products, *, download_dir=None, flat=False,
Default False. If set to True and cloud data access is enabled (see `enable_cloud_dataset`)
files that are not found in the cloud will be skipped rather than downloaded from MAST
as is the default behavior. If cloud access is not enables this argument as no affect.
verbose : bool, optional
Default True. Whether to show download progress in the console.
**filters :
Filters to be applied. Valid filters are all products fields returned by
``get_metadata("products")`` and 'extension' which is the desired file extension.
Expand Down Expand Up @@ -758,7 +769,8 @@ def download_products(self, products, *, download_dir=None, flat=False,
manifest = self._download_files(products,
base_dir=base_dir, flat=flat,
cache=cache,
cloud_only=cloud_only)
cloud_only=cloud_only,
verbose=verbose)

return manifest

Expand Down
7 changes: 7 additions & 0 deletions astroquery/mast/tests/test_mast.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,13 @@ def test_observations_download_products(patch_post, tmpdir):
mrp_only=False)
assert isinstance(result, Table)

# without console output
result = mast.Observations.download_products('2003738726',
download_dir=str(tmpdir),
productType=["SCIENCE"],
verbose=False)
assert isinstance(result, Table)

# passing row product
products = mast.Observations.get_product_list('2003738726')
result1 = mast.Observations.download_products(products[0],
Expand Down
27 changes: 17 additions & 10 deletions astroquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _request(self, method, url,

def _download_file(self, url, local_filepath, timeout=None, auth=None,
continuation=True, cache=False, method="GET",
head_safe=False, **kwargs):
head_safe=False, verbose=True, **kwargs):
"""
Download a file. Resembles `astropy.utils.data.download_file` but uses
the local ``_session``
Expand All @@ -405,6 +405,8 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
Cache downloaded file. Defaults to False.
method : "GET" or "POST"
head_safe : bool
verbose : bool
Whether to show download progress. Defaults to True.
"""

if head_safe:
Expand Down Expand Up @@ -492,16 +494,21 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
else:
progress_stream = io.StringIO()

with ProgressBarOrSpinner(length, f'Downloading URL {url} to {local_filepath} ...',
file=progress_stream) as pb:
if verbose:
with ProgressBarOrSpinner(length, f'Downloading URL {url} to {local_filepath} ...',
file=progress_stream) as pb:
with open(local_filepath, open_mode) as f:
for block in response.iter_content(blocksize):
f.write(block)
bytes_read += len(block)
if length is not None:
pb.update(bytes_read if bytes_read <= length else length)
else:
pb.update(bytes_read)
else:
with open(local_filepath, open_mode) as f:
for block in response.iter_content(blocksize):
f.write(block)
bytes_read += len(block)
if length is not None:
pb.update(bytes_read if bytes_read <= length else length)
else:
pb.update(bytes_read)
f.write(response.content)

response.close()
return response

Expand Down

0 comments on commit 2f2a643

Please sign in to comment.