Skip to content

Commit

Permalink
Merge pull request #3080 from snbianco/ASB-21598-escape-url
Browse files Browse the repository at this point in the history
Escape MAST Download URIs
  • Loading branch information
bsipocz authored Aug 7, 2024
2 parents 1ac230e + 487a66e commit 515928d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 40 deletions.
6 changes: 4 additions & 2 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
import time
import os
from urllib.parse import quote

import numpy as np

Expand Down Expand Up @@ -534,6 +535,7 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
# create the full data URL
base_url = base_url if base_url else self._portal_api_connection.MAST_DOWNLOAD_URL
data_url = base_url + "?uri=" + uri
escaped_url = base_url + "?uri=" + quote(uri, safe=":/")

# parse a local file path from local_path parameter. Use current directory as default.
filename = os.path.basename(uri)
Expand Down Expand Up @@ -565,11 +567,11 @@ def download_file(self, uri, *, local_path=None, base_url=None, cache=True, clou
status = "SKIPPED"
else:
log.warning("Falling back to mast download...")
self._download_file(data_url, local_path,
self._download_file(escaped_url, local_path,
cache=cache, head_safe=True, continuation=False,
verbose=verbose)
else:
self._download_file(data_url, local_path,
self._download_file(escaped_url, local_path,
cache=cache, head_safe=True, continuation=False,
verbose=verbose)

Expand Down
51 changes: 13 additions & 38 deletions astroquery/mast/tests/test_mast_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ def test_mast_service_request_async(self):
assert isinstance(responses, list)

def test_mast_service_request(self):

# clear columns config
Mast._column_configs = dict()

service = 'Mast.Caom.Cone'
params = {'ra': 184.3,
'dec': 54.5,
Expand All @@ -174,9 +170,6 @@ def test_mast_service_request(self):
assert len(result[np.where(result["obs_id"] == "6374399093149532160")]) == 2

def test_mast_query(self):
# clear columns config
Mast._column_configs = dict()

result = Mast.mast_query('Mast.Caom.Cone', ra=184.3, dec=54.5, radius=0.2)

# Is result in the right format
Expand Down Expand Up @@ -225,9 +218,6 @@ def test_observations_query_region_async(self):
assert isinstance(responses, list)

def test_observations_query_region(self):
# clear columns config
Observations._column_configs = dict()

result = Observations.query_region("322.49324 12.16683", radius="0.005 deg")
assert isinstance(result, Table)
assert len(result) > 500
Expand All @@ -243,9 +233,6 @@ def test_observations_query_object_async(self):
assert isinstance(responses, list)

def test_observations_query_object(self):
# clear columns config
Observations._column_configs = dict()

result = Observations.query_object("M8", radius=".04 deg")
assert isinstance(result, Table)
assert len(result) > 150
Expand All @@ -264,10 +251,6 @@ def test_observations_query_criteria_async(self):
assert isinstance(responses, list)

def test_observations_query_criteria(self):

# clear columns config
Observations._column_configs = dict()

# without position
result = Observations.query_criteria(instrument_name="*WFPC2*",
proposal_id=8169,
Expand Down Expand Up @@ -333,10 +316,6 @@ def test_observations_get_product_list_async(self):
assert isinstance(responses, list)

def test_observations_get_product_list(self):

# clear columns config
Observations._column_configs = dict()

observations = Observations.query_object("M8", radius=".04 deg")
test_obs_id = str(observations[0]['obsid'])
mult_obs_ids = str(observations[0]['obsid']) + ',' + str(observations[1]['obsid'])
Expand Down Expand Up @@ -519,6 +498,19 @@ def test_observations_download_file_cloud(self, tmp_path, in_uri):
assert result == ('COMPLETE', None, None)
assert Path(tmp_path, filename).exists()

def test_observations_download_file_escaped(self, tmp_path):
# test that `download_file` correctly escapes a URI
in_uri = 'mast:HLA/url/cgi-bin/fitscut.cgi?' \
'red=hst_04819_65_wfpc2_f814w_pc&blue=hst_04819_65_wfpc2_f555w_pc&size=ALL&format=fits'
filename = Path(in_uri).name
result = Observations.download_file(uri=in_uri, local_path=tmp_path)
assert result == ('COMPLETE', None, None)
assert Path(tmp_path, filename).exists()

# check that downloaded file is a valid FITS file
f = fits.open(Path(tmp_path, filename))
f.close()

@pytest.mark.parametrize("test_data_uri, expected_cloud_uri", [
("mast:HST/product/u24r0102t_c1f.fits",
"s3://stpubdata/hst/public/u24r/u24r0102t/u24r0102t_c1f.fits"),
Expand Down Expand Up @@ -618,10 +610,7 @@ def check_result(result, row, exp_values):
for k, v in exp_values.items():
assert result[row][k] == v

# clear columns config
Catalogs._column_configs = dict()
in_radius = 0.1 * u.deg

result = Catalogs.query_region("158.47924 -7.30962",
radius=in_radius,
catalog="Gaia")
Expand Down Expand Up @@ -717,9 +706,6 @@ def check_result(result, exp_values):
for k, v in exp_values.items():
assert v in result[k]

# clear columns config
Catalogs._column_configs = dict()

result = Catalogs.query_object("M10",
radius=.001,
catalog="TIC")
Expand Down Expand Up @@ -819,9 +805,6 @@ def check_result(result, exp_vals):
for k, v in exp_vals.items():
assert v in result[k]

# clear columns config
Catalogs._column_configs = dict()

# without position
result = Catalogs.query_criteria(catalog="Tic",
Bmag=[30, 50],
Expand Down Expand Up @@ -897,10 +880,6 @@ def test_catalogs_query_hsc_matchid_async(self):
assert isinstance(responses, list)

def test_catalogs_query_hsc_matchid(self):

# clear columns config
Catalogs._column_configs = dict()

catalogData = Catalogs.query_object("M10",
radius=.001,
catalog="HSC",
Expand All @@ -921,10 +900,6 @@ def test_catalogs_get_hsc_spectra_async(self):
assert isinstance(responses, list)

def test_catalogs_get_hsc_spectra(self):

# clear columns config
Catalogs._column_configs = dict()

result = Catalogs.get_hsc_spectra()
assert isinstance(result, Table)
assert result[np.where(result['MatchID'] == '19657846')]
Expand Down

0 comments on commit 515928d

Please sign in to comment.