Skip to content

Commit

Permalink
Merge pull request #3016 from snbianco/download-file-directory
Browse files Browse the repository at this point in the history
Download file with a directory specified by local_path parameter
  • Loading branch information
bsipocz authored Jun 11, 2024
2 parents d96ce9e + 1a69e07 commit ab7cf03
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 12 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ jplhorizons

- Add missing column definitions, especially for ``refraction=True`` and ``extra_precision=True``. [#2986]

mast
^^^^

- Fix bug in which the ``local_path`` parameter for the ``mast.observations.download_file`` method does not accept a directory. [#3016]


0.4.7 (2024-03-08)
==================
Expand Down
17 changes: 12 additions & 5 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
This module contains various methods for querying MAST observations.
"""

from pathlib import Path
import warnings
import time
import os
Expand Down Expand Up @@ -508,7 +509,7 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
uri : str
The product dataURI, e.g. mast:JWST/product/jw00736-o039_t001_miri_ch1-long_x1d.fits
local_path : str
Directory in which the files will be downloaded. Defaults to current working directory.
Directory or filename to which the file will be downloaded. Defaults to current working directory.
base_url: str
A base url to use when downloading. Default is the MAST Portal API
cache : bool
Expand All @@ -532,10 +533,16 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
base_url = base_url if base_url else self._portal_api_connection.MAST_DOWNLOAD_URL
data_url = base_url + "?uri=" + uri

# create a local file path if none is input. Use current directory as default.
if not local_path:
filename = os.path.basename(uri)
local_path = os.path.join(os.path.abspath('.'), filename)
# parse a local file path from local_path parameter. Use current directory as default.
filename = os.path.basename(uri)
if not local_path: # local file path is not defined
local_path = filename
else:
path = Path(local_path)
if not path.suffix: # local_path is a directory
local_path = path / filename # append filename
if not path.exists(): # create directory if it doesn't exist
path.mkdir(parents=True, exist_ok=True)

# recreate the data_product key for cloud connection check
data_product = {'dataURI': uri}
Expand Down
3 changes: 2 additions & 1 deletion astroquery/mast/tests/test_mast.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def test_observations_download_products(patch_post, tmpdir):

# passing row product
products = mast.Observations.get_product_list('2003738726')
result1 = mast.Observations.download_products(products[0])
result1 = mast.Observations.download_products(products[0],
download_dir=str(tmpdir))
assert isinstance(result1, Table)


Expand Down
21 changes: 18 additions & 3 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,26 @@ def test_observations_download_file(self, tmp_path):

# pull the URI of a single product
uri = products['dataURI'][0]
local_path = Path(tmp_path, Path(uri).name)
filename = Path(uri).name

# download it
result = mast.Observations.download_file(uri, local_path=local_path)
# download with unspecified local_path parameter
# should download to current working directory
result = mast.Observations.download_file(uri)
assert result == ('COMPLETE', None, None)
assert os.path.exists(Path(os.getcwd(), filename))
Path.unlink(filename) # clean up file

# download with directory as local_path parameter
local_path = Path(tmp_path, filename)
result = mast.Observations.download_file(uri, local_path=tmp_path)
assert result == ('COMPLETE', None, None)
assert os.path.exists(local_path)

# download with filename as local_path parameter
local_path_file = Path(tmp_path, "test.fits")
result = mast.Observations.download_file(uri, local_path=local_path_file)
assert result == ('COMPLETE', None, None)
assert os.path.exists(local_path_file)

@pytest.mark.parametrize("test_data_uri, expected_cloud_uri", [
("mast:HST/product/u24r0102t_c1f.fits",
Expand Down
6 changes: 3 additions & 3 deletions docs/mast/mast_obsquery.rst
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ curl script that can be used to download the files at a later time.
Downloading a Single File
-------------------------

You can download a single data product file using the `~astroquery.mast.ObservationsClass.download_file`
method, and passing in a MAST Data URI. The default is to download the file the current working directory,
which can be changed with the ``local_path`` keyword argument.
You can download a single data product file by using the `~astroquery.mast.ObservationsClass.download_file`
method and passing in a MAST Data URI. The default is to download the file to the current working directory, but
you can specify the download directory or filepath with the ``local_path`` keyword argument.

.. doctest-remote-data::

Expand Down

0 comments on commit ab7cf03

Please sign in to comment.