Skip to content

Commit

Permalink
Check parts kwargs in url source (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
sandorkertesz authored Jan 15, 2024
1 parent ac24d98 commit 625df82
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 36 deletions.
60 changes: 42 additions & 18 deletions earthkit/data/sources/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,18 @@ def __init__(
# TODO: re-enable this feature
extension = None

self.prepare()

if not self.stream:
self.update_if_out_of_date = update_if_out_of_date

# ensure no parts kwargs is used when the parts are defined together with the urls
_parts_kwargs = {}
if not isinstance(url, (list, tuple)):
url = [url]
if isinstance(url[0], (list, tuple)):
if self.parts is not None:
raise ValueError("cannot specify parts both as arg and kwarg")
else:
_parts_kwargs = {"parts": self.parts}

LOG.debug(f"http_headers={self.http_headers}")
LOG.debug(
(
f"url={self.url} parts={self.parts} auth={self.auth}) "
f"http_headers={self.http_headers} parts_kwargs={self.parts_kwargs}"
f" _kwargs={self._kwargs}"
)
)

self.downloader = Downloader(
self.url,
Expand All @@ -224,7 +222,7 @@ def __init__(
resume_transfers=True,
override_target_file=False,
download_file_extension=".download",
**_parts_kwargs,
**self.parts_kwargs,
)

if extension and extension[0] != ".":
Expand Down Expand Up @@ -257,9 +255,9 @@ def mutate(self):
from multiurl.downloader import _canonicalize

s = []
_kwargs = {}
if self.parts is not None:
_kwargs = {"parts": self.parts}
_kwargs = dict(**self.parts_kwargs)
# if self.parts is not None:
# _kwargs = {"parts": self.parts}
urls, _ = _canonicalize(self.url, **_kwargs)

for url, parts in urls:
Expand All @@ -281,6 +279,32 @@ def mutate(self):
else:
return super().mutate()

def prepare(self):
# ensure no parts kwargs is used when the parts are defined together with the urls
self.parts_kwargs = {}
urls = self.url

if not isinstance(urls, (list, tuple)):
urls = [urls]

# a single url as [url, parts] is not allowed by multiurl
if (
len(urls) == 2
and isinstance(urls[0], str)
and (urls[1] is None or isinstance(urls[1], (list, tuple)))
):
if self.parts is not None:
raise ValueError("Cannot specify parts both as arg and kwarg")
self.url, self.parts = urls
self.parts_kwargs = {"parts": self.parts}
# each url is a [url, parts]
elif isinstance(urls[0], (list, tuple)):
if self.parts is not None:
raise ValueError("Cannot specify parts both as arg and kwarg")
# each url is a str
else:
self.parts_kwargs = {"parts": self.parts}

def out_of_date(self, url, path, cache_data):
if SETTINGS.get("check-out-of-date-urls") is False:
return False
Expand All @@ -302,7 +326,7 @@ def out_of_date(self, url, path, cache_data):

class RequestIterStreamer:
"""Expose fixed chunk-based stream reader used in mutiurl as a
stream supporting a generic read method
stream supporting a generic read method.
"""

def __init__(self, iter_content):
Expand Down Expand Up @@ -407,13 +431,13 @@ def __init__(
super().__init__(url, **kwargs)

if isinstance(self.url, (list, tuple)):
raise TypeError("only a single url is supported")
raise TypeError("Only a single url is supported")

from urllib.parse import urlparse

o = urlparse(self.url)
if o.scheme not in ("http", "https"):
raise NotImplementedError(f"streams are not supported for {o.scheme} URLs")
raise NotImplementedError(f"Streams are not supported for {o.scheme} urls")

def mutate(self):
from .stream import _from_source
Expand Down
42 changes: 42 additions & 0 deletions tests/grib/test_grib_url_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,48 @@ def test_grib_single_url_stream_parts(path, parts, expected_meta):
assert sum([1 for _ in ds]) == 0


@pytest.mark.parametrize(
"parts,expected_meta",
[
([(0, 150)], [("t", 1000)]),
(
None,
[("t", 1000), ("u", 1000), ("v", 1000), ("t", 850), ("u", 850), ("v", 850)],
),
],
)
def test_grib_single_url_stream_parts_as_arg(parts, expected_meta):
ds = from_source(
"url",
[earthkit_remote_test_data_file("examples/test6.grib"), parts],
stream=True,
)

# no fieldlist methods are available
with pytest.raises(TypeError):
len(ds)

cnt = 0
for i, f in enumerate(ds):
assert f.metadata(("param", "level")) == expected_meta[i], i
cnt += 1

assert cnt == len(expected_meta)

# stream consumed, no data is available
assert sum([1 for _ in ds]) == 0


def test_grib_single_url_stream_parts_as_arg_invalid():
with pytest.raises(ValueError):
from_source(
"url",
[earthkit_remote_test_data_file("examples/test6.grib"), [(0, 150)]],
parts=[(0, 160)],
stream=True,
)


@pytest.mark.parametrize(
"parts1,parts2,expected_meta",
[
Expand Down
77 changes: 59 additions & 18 deletions tests/sources/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_url_source_tar():
assert len(ds) == 6


def test_part_url():
def test_parts_url():
ds = from_source(
"url",
"https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/temp.bufr",
Expand Down Expand Up @@ -111,6 +111,62 @@ def test_part_url():
assert f.read()[:4] == b"BUFR"


def test_parts_as_arg_url_1():
ds = from_source(
"url",
[
"https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/temp.bufr",
[(0, 4)],
],
)

assert os.path.getsize(ds.path) == 4

with open(ds.path, "rb") as f:
assert f.read() == b"BUFR"


def test_parts_as_arg_url_2():
ds = from_source(
"url",
[
"https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/temp.bufr",
None,
],
)

assert os.path.getsize(ds.path) > 4

with open(ds.path, "rb") as f:
assert f.read(4) == b"BUFR"


def test_multi_url_parts_as_arg_invalid_1():
with pytest.raises(ValueError):
from_source(
"url",
[
"https://get.ecmwf.int/repository/test-data/earthkit-data/test-data/temp.bufr",
[(0, 4)],
],
parts=[(0, 5)],
)


def test_multi_url_parts_invalid():
parts1 = [(240, 150)]
parts2 = [(0, 526)]
with pytest.raises(ValueError):
from_source(
"url",
[
[earthkit_remote_test_data_file("examples/test6.grib"), parts1],
[earthkit_remote_test_data_file("examples/test.grib"), parts2],
],
parts=[(0, 240)],
)


@pytest.mark.skipif( # TODO: fix
sys.platform == "win32",
reason="file:// not working on Windows yet",
Expand Down Expand Up @@ -145,22 +201,7 @@ def test_url_netcdf_source_save():
assert os.path.exists(tmp.path)


def test_multi_url_parts_invalid():
parts1 = [(240, 150)]
parts2 = [(0, 526)]
with pytest.raises(ValueError):
from_source(
"url",
[
[earthkit_remote_test_data_file("examples/test6.grib"), parts1],
[earthkit_remote_test_data_file("examples/test.grib"), parts2],
],
parts=[(0, 240)],
)


if __name__ == "__main__":
test_part_url()
# from earthkit.data.testing import main
from earthkit.data.testing import main

# main(__file__)
main(__file__)

0 comments on commit 625df82

Please sign in to comment.