From 537789813c024cc02e79609bc129125738ebc12c Mon Sep 17 00:00:00 2001 From: Martin Schlipf Date: Mon, 17 Feb 2025 14:12:07 +0100 Subject: [PATCH] Feat: Allow for template sources in the HDF5 file (#192) * Implement Mapping to allow templates in the schema * Implement access to multiple different dataset with key of mapping --- src/py4vasp/_raw/access.py | 37 +++++++++++++++++++++++------- src/py4vasp/_raw/schema.py | 29 ++++++++++++++++++++++++ tests/raw/conftest.py | 46 +++++++++++++++++++++++++++++++------- tests/raw/test_access.py | 42 +++++++++++++++++++++++++++++++--- tests/raw/test_schema.py | 27 +++++++++++++++++++++- tests/raw/util.py | 8 +++++++ 6 files changed, 169 insertions(+), 20 deletions(-) diff --git a/src/py4vasp/_raw/access.py b/src/py4vasp/_raw/access.py index e74862a9..c461abcd 100644 --- a/src/py4vasp/_raw/access.py +++ b/src/py4vasp/_raw/access.py @@ -9,7 +9,8 @@ from py4vasp import exception, raw from py4vasp._raw.definition import DEFAULT_FILE, DEFAULT_SOURCE, schema -from py4vasp._raw.schema import Length, Link, error_message +from py4vasp._raw.schema import Length, Link, Mapping, error_message +from py4vasp._util import convert @contextlib.contextmanager @@ -120,12 +121,26 @@ def _check_version(self, h5f, required, quantity): raise exception.OutdatedVaspVersion(message) def _get_datasets(self, h5f, data): - return { - field.name: self._get_dataset(h5f, getattr(data, field.name)) + valid_indices = self._get_valid_indices(h5f, data) + result = { + field.name: self._get_dataset(h5f, getattr(data, field.name), valid_indices) for field in dataclasses.fields(data) + if field.name != "valid_indices" } + if valid_indices is not None: + result["valid_indices"] = valid_indices + return result + + def _get_valid_indices(self, h5f, data): + if not isinstance(data, Mapping): + return None + valid_indices = self._get_dataset(h5f, data.valid_indices) + if valid_indices.ndim == 0: + return range(valid_indices) + else: + return tuple(convert.text_to_string(index) for index in valid_indices) - def _get_dataset(self, h5f, key): + def _get_dataset(self, h5f, key, valid_indices=None): if key is None: return raw.VaspData(None) if isinstance(key, Link): @@ -133,10 +148,16 @@ def _get_dataset(self, h5f, key): if isinstance(key, Length): dataset = h5f.get(key.dataset) return len(dataset) if dataset else None - return self._parse_dataset(h5f.get(key)) - - def _parse_dataset(self, dataset): - result = raw.VaspData(dataset) + if key.format(0) == key or valid_indices is None: + return self._parse_dataset(h5f, key) + return [self._parse_dataset(h5f, key, index) for index in valid_indices] + + def _parse_dataset(self, h5f, key, index=None): + if index is not None: + if isinstance(index, int): + index = index + 1 # convert to Fortran index + key = key.format(index) + result = raw.VaspData(h5f.get(key)) if _is_scalar(result): result = result[()] return result diff --git a/src/py4vasp/_raw/schema.py b/src/py4vasp/_raw/schema.py index fa4deb80..4a5cd86a 100644 --- a/src/py4vasp/_raw/schema.py +++ b/src/py4vasp/_raw/schema.py @@ -4,6 +4,7 @@ import dataclasses import textwrap +from collections import abc import numpy as np @@ -159,6 +160,34 @@ class Length: __str__ = lambda self: f"length({self.dataset})" +@dataclasses.dataclass +class Mapping(abc.Mapping): + valid_indices: Sequence + + def __len__(self): + return len(self.valid_indices) + + def __iter__(self): + return iter(self.valid_indices) + + def __getitem__(self, key): + index = self.valid_indices.index(key) + elements = { + key: value[index] if isinstance(value, list) else value + for key, value in self._as_dict().items() + if key != "valid_indices" + } + return dataclasses.replace(self, valid_indices=[key], **elements) + + def _as_dict(self): + # shallow copy of dataclass to dictionary + return { + field.name: getattr(self, field.name) + for field in dataclasses.fields(self) + if getattr(self, field.name) is not None + } + + def _parse_version(version): return f"""version: major: {version.major} diff --git a/tests/raw/conftest.py b/tests/raw/conftest.py index c9a1ad1a..95810cd3 100644 --- a/tests/raw/conftest.py +++ b/tests/raw/conftest.py @@ -1,7 +1,17 @@ # Copyright © VASP Software GmbH, # Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import dataclasses + import pytest -from util import VERSION, Complex, OptionalArgument, Simple, WithLength, WithLink +from util import ( + VERSION, + Complex, + Mapping, + OptionalArgument, + Simple, + WithLength, + WithLink, +) from py4vasp import raw from py4vasp._raw.schema import Length, Link, Schema, Source @@ -20,24 +30,34 @@ def make_data(path): pointer = WithLink("baz_dataset", Link("simple", "default")) version = raw.Version(1, 2, 3) length = WithLength(Length("dataset")) + mapping = Mapping( + valid_indices="foo_mapping", common="common_data", variable="variable_data{}" + ) + list_ = Mapping( + valid_indices="list_mapping", common="common", variable="variable_data_{}" + ) first = Complex( Link("optional_argument", "default"), Link("with_link", "default"), + Link("mapping", "default"), Link("with_length", "default"), ) second = Complex( Link("optional_argument", name), Link("with_link", "default"), + Link("mapping", "my_list"), ) schema = Schema(VERSION) - schema.add(Simple, file=filename, foo=simple.foo, bar=simple.bar) - schema.add(OptionalArgument, name=name, mandatory=only_mandatory.mandatory) - schema.add(OptionalArgument, mandatory=both.mandatory, optional=both.optional) - schema.add(WithLink, required=version, baz=pointer.baz, simple=pointer.simple) - schema.add(WithLength, alias="alias_name", num_data=length.num_data) - schema.add(Complex, opt=first.opt, link=first.link, length=first.length) - schema.add(Complex, name=name, opt=second.opt, link=second.link) + schema.add(Simple, file=filename, **as_dict(simple)) schema.add(Simple, name="factory", file=filename, data_factory=make_data) + schema.add(OptionalArgument, name=name, **as_dict(only_mandatory)) + schema.add(OptionalArgument, **as_dict(both)) + schema.add(WithLink, required=version, **as_dict(pointer)) + schema.add(WithLength, alias="alias_name", **as_dict(length)) + schema.add(Mapping, **as_dict(mapping)) + schema.add(Mapping, name="my_list", **as_dict(list_)) + schema.add(Complex, **as_dict(first)) + schema.add(Complex, name=name, **as_dict(second)) other_file_source = Source(simple, file=filename) data_factory_source = Source(None, file=filename, data_factory=make_data) alias_source = Source(length, alias_for="default") @@ -47,6 +67,16 @@ def make_data(path): "optional_argument": {"default": Source(both), name: Source(only_mandatory)}, "with_link": {"default": Source(pointer, required=version)}, "with_length": {"default": Source(length), "alias_name": alias_source}, + "mapping": {"default": Source(mapping), "my_list": Source(list_)}, "complex": {"default": Source(first), name: Source(second)}, } return schema, reference + + +def as_dict(dataclass): + # shallow copy of dataclass to dictionary + return { + field.name: getattr(dataclass, field.name) + for field in dataclasses.fields(dataclass) + if getattr(dataclass, field.name) is not None + } diff --git a/tests/raw/test_access.py b/tests/raw/test_access.py index 10b48af8..0837e1b9 100644 --- a/tests/raw/test_access.py +++ b/tests/raw/test_access.py @@ -11,6 +11,7 @@ import py4vasp.raw as raw from py4vasp import exception from py4vasp._raw.definition import DEFAULT_FILE +from py4vasp._raw.schema import Mapping @pytest.fixture @@ -32,8 +33,9 @@ def mock_schema(complex_schema): _mock_results = {} -EXAMPLE_ARRAY = np.zeros(3) -EXAMPLE_SCALAR = np.array(1) +EXAMPLE_ARRAY = np.zeros(4) +EXAMPLE_SCALAR = np.array(3) +EXAMPLE_INDICES = np.array((b"one", b"two", b"three")) def mock_read_result(key): @@ -42,6 +44,8 @@ def mock_read_result(key): if "foo" in key: mock.ndim = 0 mock.__array__ = MagicMock(return_value=EXAMPLE_SCALAR) + elif "list" in key: + mock = EXAMPLE_INDICES else: mock.__array__ = MagicMock(return_value=EXAMPLE_ARRAY) _mock_results[key] = mock @@ -102,6 +106,29 @@ def test_access_with_link(mock_access): assert with_link.simple.bar[:] == reference.bar[:] +@pytest.mark.parametrize("selection", (None, "my_list")) +def test_access_mapping(mock_access, selection): + if selection is None: + expected_indices = range(EXAMPLE_SCALAR) + else: + expected_indices = tuple(index.decode() for index in EXAMPLE_INDICES) + quantity = "mapping" + mock_file, sources = mock_access + source = sources[quantity][selection or "default"] + with raw.access(quantity, selection=selection) as mapping: + assert len(mapping.valid_indices) == len(mapping) + assert all(np.atleast_1d(mapping.valid_indices == expected_indices)) + check_single_file_access(mock_file, DEFAULT_FILE, source) + for index, element in mapping.items(): + assert len(element) == 1 + assert element.valid_indices == [index] + check_data(element.common, source.data.common) + if selection is None: + index = str(index + 1) # convert Python to Fortran index + variable = source.data.variable.format(index) + check_data(element.variable, variable) + + def linked_quantity_reference(mock_access, file=None): quantity = "simple" mock_file, _ = mock_access @@ -250,7 +277,16 @@ def expected_calls(source): def expected_call(data, field): key = getattr(data, field.name) - if isinstance(key, str): + if not isinstance(key, str): + return + if not isinstance(data, Mapping): + distinct_keys = {key} + elif data.valid_indices == "list_mapping": + distinct_keys = {key.format(index.decode()) for index in EXAMPLE_INDICES} + else: + # convert to Fortran index + distinct_keys = {key.format(index + 1) for index in range(EXAMPLE_SCALAR)} + for key in distinct_keys: yield call(key) diff --git a/tests/raw/test_schema.py b/tests/raw/test_schema.py index 0449f1c1..f225113b 100644 --- a/tests/raw/test_schema.py +++ b/tests/raw/test_schema.py @@ -1,7 +1,7 @@ # Copyright © VASP Software GmbH, # Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0) import pytest -from util import VERSION, OptionalArgument, Simple, WithLength, WithLink +from util import VERSION, Mapping, OptionalArgument, Simple, WithLength, WithLink from py4vasp import exception, raw from py4vasp._raw.schema import Length, Link, Schema, Source @@ -107,6 +107,19 @@ def make_data(source): assert remove_version(schema.sources) == reference +def test_mapping(): + mapping = Mapping("valid_indices", "common_data", "variable_data{}") + schema = Schema(VERSION) + schema.add( + Mapping, + valid_indices=mapping.valid_indices, + common=mapping.common, + variable=mapping.variable, + ) + reference = {"mapping": {"default": Source(mapping)}} + assert remove_version(schema.sources) == reference + + def remove_version(sources): version = sources.pop("version") assert version == {"default": Source(VERSION)} @@ -159,14 +172,26 @@ def test_complex_str(complex_schema): num_data: length(dataset) alias_name: *with_length-default +mapping: + default: &mapping-default + valid_indices: foo_mapping + common: common_data + variable: variable_data{} + my_list: &mapping-my_list + valid_indices: list_mapping + common: common + variable: variable_data_{} + complex: default: &complex-default opt: *optional_argument-default link: *with_link-default + mapping: *mapping-default length: *with_length-default mandatory: &complex-mandatory opt: *optional_argument-mandatory link: *with_link-default + mapping: *mapping-my_list """ assert str(schema) == reference diff --git a/tests/raw/util.py b/tests/raw/util.py index 83391b85..b0bedf2c 100644 --- a/tests/raw/util.py +++ b/tests/raw/util.py @@ -3,6 +3,7 @@ import dataclasses from py4vasp import raw +from py4vasp._raw import schema VERSION = raw.Version("major_dataset", "minor_dataset", "patch_dataset") @@ -30,8 +31,15 @@ class WithLength: num_data: int +@dataclasses.dataclass +class Mapping(schema.Mapping): + common: str + variable: str + + @dataclasses.dataclass class Complex: opt: OptionalArgument link: WithLink + mapping: Mapping length: WithLength = None