Skip to content

Commit

Permalink
Fix Attribute Error when accessing identifiers on partial dataset (#4413
Browse files Browse the repository at this point in the history
)

* no exception if identifiers are not downloaded

* always download identifiers

* Update tests/data/test_dataset.py

Co-authored-by: Matthew Silverman <[email protected]>

* add error message

---------

Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
brownj85 and timmysilv authored Aug 1, 2023
1 parent f91811d commit 0905217
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 5 deletions.
9 changes: 8 additions & 1 deletion pennylane/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def identifiers(self) -> typing.Mapping[str, str]: # pylint: disable=function-r
return {
attr_name: getattr(self, attr_name)
for attr_name in self.info.get("identifiers", self.info.get("params", []))
if attr_name in self.bind
}

@property
Expand Down Expand Up @@ -323,7 +324,7 @@ def write(
values are "w-" (create, fail if file exists), "w" (create, overwrite existing),
and "a" (append existing, create if doesn't exist). Default is "w-".
attributes: Optional list of attributes to copy. If None, all attributes
will be copied.
will be copied. Note that identifiers will always be copied.
overwrite: Whether to overwrite attributes that already exist in this
dataset.
"""
Expand All @@ -337,6 +338,12 @@ def write(

hdf5.copy_all(self.bind, dest.bind, *attributes, on_conflict=on_conflict)

missing_identifiers = [
identifier for identifier in self.identifiers if not hasattr(dest, identifier)
]
if missing_identifiers:
hdf5.copy_all(self.bind, dest.bind, *missing_identifiers)

def _init_bind(
self, data_name: Optional[str] = None, identifiers: Optional[Tuple[str, ...]] = None
):
Expand Down
30 changes: 27 additions & 3 deletions pennylane/data/data_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ def _download_dataset(
f.write(resp.content)


def _validate_attributes(data_struct: dict, data_name: str, attributes: typing.Iterable[str]):
"""Checks that ``attributes`` contains only valid attributes for the given
``data_name``. If any attributes do not exist, raise a ValueError."""
invalid_attributes = [
attr for attr in attributes if attr not in data_struct[data_name]["attributes"]
]
if not invalid_attributes:
return

Check warning on line 107 in pennylane/data/data_manager/__init__.py

View check run for this annotation

Codecov / codecov/patch

pennylane/data/data_manager/__init__.py#L107

Added line #L107 was not covered by tests

if len(invalid_attributes) == 1:
values_err = f"'{invalid_attributes[0]}' is an invalid attribute for '{data_name}'"
else:
values_err = f"{invalid_attributes} are invalid attributes for '{data_name}'"

raise ValueError(f"{values_err}. Valid attributes are: {data_struct[data_name]['attributes']}")


def load( # pylint: disable=too-many-arguments
data_name: str,
attributes: Optional[typing.Iterable[str]] = None,
Expand Down Expand Up @@ -186,14 +203,18 @@ def load( # pylint: disable=too-many-arguments
>>> print(circuit())
-1.0791430411076344
"""
foldermap = _get_foldermap()
data_struct = _get_data_struct()

params = format_params(**params)

if attributes:
_validate_attributes(data_struct, data_name, attributes)

Check warning on line 212 in pennylane/data/data_manager/__init__.py

View check run for this annotation

Codecov / codecov/patch

pennylane/data/data_manager/__init__.py#L212

Added line #L212 was not covered by tests

folder_path = Path(folder_path)
if cache_dir and not Path(cache_dir).is_absolute():
cache_dir = folder_path / cache_dir

foldermap = _get_foldermap()

data_paths = [data_path for _, data_path in foldermap.find(data_name, **params)]

dest_paths = [folder_path / data_path for data_path in data_paths]
Expand Down Expand Up @@ -374,7 +395,9 @@ def load_interactive():
value = _interactive_request_single(node, param)
description[param] = value

attributes = _interactive_request_attributes(data_struct[data_name]["attributes"])
attributes = _interactive_request_attributes(
[attribute for attribute in data_struct[data_name]["attributes"] if attribute not in params]
)
force = input("Force download files? (Default is no) [y/N]: ") in ["y", "Y"]
dest_folder = Path(
input("Folder to download to? (Default is pwd, will download to /datasets subdirectory): ")
Expand All @@ -390,6 +413,7 @@ def load_interactive():
if approve not in ["Y", "", "y"]:
print("Aborting and not downloading!")
return None

return load(
data_name, attributes=attributes, folder_path=dest_folder, force=force, **description
)[0]
Expand Down
25 changes: 24 additions & 1 deletion tests/data/data_manager/test_dataset_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pennylane as qml
import pennylane.data.data_manager
from pennylane.data import Dataset
from pennylane.data.data_manager import DataPath, S3_URL
from pennylane.data.data_manager import DataPath, S3_URL, _validate_attributes

# pylint:disable=protected-access,redefined-outer-name

Expand Down Expand Up @@ -376,3 +376,26 @@ def test_download_dataset_escapes_url_partial(mock_download_partial, datapath, e
mock_download_partial.assert_called_once_with(
f"{S3_URL}/{escaped}", dest, attributes, overwrite=force
)


@pytest.mark.parametrize(
"attributes,msg",
[
(
["x", "y", "z", "foo"],
r"'foo' is an invalid attribute for 'my_dataset'. Valid attributes are: \['x', 'y', 'z'\]",
),
(
["x", "y", "z", "foo", "bar"],
r"\['foo', 'bar'\] are invalid attributes for 'my_dataset'. Valid attributes are: \['x', 'y', 'z'\]",
),
],
)
def test_validate_attributes_except(attributes, msg):
"""Test that ``_validate_attributes()`` raises a ValueError when passed
invalid attributes."""

data_struct = {"my_dataset": {"attributes": ["x", "y", "z"]}}

with pytest.raises(ValueError, match=msg):
_validate_attributes(data_struct, "my_dataset", attributes)
30 changes: 30 additions & 0 deletions tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,25 @@ def test_identifiers_base(self, identifiers, expect):

assert ds.identifiers == expect

def test_identifiers_base_missing(self):
"""Test that identifiers whose attribute is missing on the
dataset will not be in the returned dict."""
ds = Dataset(x="1", identifiers=("x", "y"))

assert ds.identifiers == {"x": "1"}

def test_subclass_identifiers(self):
"""Test that dataset subclasses' identifiers can be set."""
ds = MyDataset(x="1", y="2", description="abc")

assert ds.identifiers == {"x": "1", "y": "2"}

def test_subclass_identifiers_missing(self):
"""Test that dataset subclasses' identifiers can be set."""
ds = MyDataset(x="1", description="abc")

assert ds.identifiers == {"x": "1"}

def test_attribute_info(self):
"""Test that attribute info can be set and accessed
on a dataset attribute."""
Expand Down Expand Up @@ -357,6 +370,23 @@ def test_write(self, tmp_path, mode):
assert ds_2.bind is not ds.bind
assert ds.attrs == ds_2.attrs

@pytest.mark.parametrize(
"attributes_arg,attributes_expect",
[
(["x"], ["x", "y"]),
(["x", "y", "data"], ["x", "y", "data"]),
(["data"], ["x", "y", "data"]),
],
)
def test_write_partial_always_copies_identifiers(self, attributes_arg, attributes_expect):
"""Test that ``write`` will always copy attributes that are identifiers."""
ds = Dataset(x="a", y="b", data="Some data", identifiers=("x", "y"))
ds_2 = Dataset()

ds.write(ds_2, attributes=attributes_arg)
assert set(ds_2.list_attributes()) == set(attributes_expect)
assert ds_2.identifiers == ds.identifiers

def test_init_subclass(self):
"""Test that __init_subclass__() does the following:
Expand Down

0 comments on commit 0905217

Please sign in to comment.