Skip to content

Commit

Permalink
Feat: Allow for template sources in the HDF5 file (#192)
Browse files Browse the repository at this point in the history
* Implement Mapping to allow templates in the schema
* Implement access to multiple different dataset with key of mapping
  • Loading branch information
martin-schlipf authored Feb 17, 2025
1 parent 8c06840 commit 5377898
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 20 deletions.
37 changes: 29 additions & 8 deletions src/py4vasp/_raw/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from py4vasp import exception, raw
from py4vasp._raw.definition import DEFAULT_FILE, DEFAULT_SOURCE, schema
from py4vasp._raw.schema import Length, Link, error_message
from py4vasp._raw.schema import Length, Link, Mapping, error_message
from py4vasp._util import convert


@contextlib.contextmanager
Expand Down Expand Up @@ -120,23 +121,43 @@ def _check_version(self, h5f, required, quantity):
raise exception.OutdatedVaspVersion(message)

def _get_datasets(self, h5f, data):
return {
field.name: self._get_dataset(h5f, getattr(data, field.name))
valid_indices = self._get_valid_indices(h5f, data)
result = {
field.name: self._get_dataset(h5f, getattr(data, field.name), valid_indices)
for field in dataclasses.fields(data)
if field.name != "valid_indices"
}
if valid_indices is not None:
result["valid_indices"] = valid_indices
return result

def _get_valid_indices(self, h5f, data):
if not isinstance(data, Mapping):
return None
valid_indices = self._get_dataset(h5f, data.valid_indices)
if valid_indices.ndim == 0:
return range(valid_indices)
else:
return tuple(convert.text_to_string(index) for index in valid_indices)

def _get_dataset(self, h5f, key):
def _get_dataset(self, h5f, key, valid_indices=None):
if key is None:
return raw.VaspData(None)
if isinstance(key, Link):
return self.access(key.quantity, source=key.source)
if isinstance(key, Length):
dataset = h5f.get(key.dataset)
return len(dataset) if dataset else None
return self._parse_dataset(h5f.get(key))

def _parse_dataset(self, dataset):
result = raw.VaspData(dataset)
if key.format(0) == key or valid_indices is None:
return self._parse_dataset(h5f, key)
return [self._parse_dataset(h5f, key, index) for index in valid_indices]

def _parse_dataset(self, h5f, key, index=None):
if index is not None:
if isinstance(index, int):
index = index + 1 # convert to Fortran index
key = key.format(index)
result = raw.VaspData(h5f.get(key))
if _is_scalar(result):
result = result[()]
return result
Expand Down
29 changes: 29 additions & 0 deletions src/py4vasp/_raw/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dataclasses
import textwrap
from collections import abc

import numpy as np

Expand Down Expand Up @@ -159,6 +160,34 @@ class Length:
__str__ = lambda self: f"length({self.dataset})"


@dataclasses.dataclass
class Mapping(abc.Mapping):
valid_indices: Sequence

def __len__(self):
return len(self.valid_indices)

def __iter__(self):
return iter(self.valid_indices)

def __getitem__(self, key):
index = self.valid_indices.index(key)
elements = {
key: value[index] if isinstance(value, list) else value
for key, value in self._as_dict().items()
if key != "valid_indices"
}
return dataclasses.replace(self, valid_indices=[key], **elements)

def _as_dict(self):
# shallow copy of dataclass to dictionary
return {
field.name: getattr(self, field.name)
for field in dataclasses.fields(self)
if getattr(self, field.name) is not None
}


def _parse_version(version):
return f"""version:
major: {version.major}
Expand Down
46 changes: 38 additions & 8 deletions tests/raw/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import dataclasses

import pytest
from util import VERSION, Complex, OptionalArgument, Simple, WithLength, WithLink
from util import (
VERSION,
Complex,
Mapping,
OptionalArgument,
Simple,
WithLength,
WithLink,
)

from py4vasp import raw
from py4vasp._raw.schema import Length, Link, Schema, Source
Expand All @@ -20,24 +30,34 @@ def make_data(path):
pointer = WithLink("baz_dataset", Link("simple", "default"))
version = raw.Version(1, 2, 3)
length = WithLength(Length("dataset"))
mapping = Mapping(
valid_indices="foo_mapping", common="common_data", variable="variable_data{}"
)
list_ = Mapping(
valid_indices="list_mapping", common="common", variable="variable_data_{}"
)
first = Complex(
Link("optional_argument", "default"),
Link("with_link", "default"),
Link("mapping", "default"),
Link("with_length", "default"),
)
second = Complex(
Link("optional_argument", name),
Link("with_link", "default"),
Link("mapping", "my_list"),
)
schema = Schema(VERSION)
schema.add(Simple, file=filename, foo=simple.foo, bar=simple.bar)
schema.add(OptionalArgument, name=name, mandatory=only_mandatory.mandatory)
schema.add(OptionalArgument, mandatory=both.mandatory, optional=both.optional)
schema.add(WithLink, required=version, baz=pointer.baz, simple=pointer.simple)
schema.add(WithLength, alias="alias_name", num_data=length.num_data)
schema.add(Complex, opt=first.opt, link=first.link, length=first.length)
schema.add(Complex, name=name, opt=second.opt, link=second.link)
schema.add(Simple, file=filename, **as_dict(simple))
schema.add(Simple, name="factory", file=filename, data_factory=make_data)
schema.add(OptionalArgument, name=name, **as_dict(only_mandatory))
schema.add(OptionalArgument, **as_dict(both))
schema.add(WithLink, required=version, **as_dict(pointer))
schema.add(WithLength, alias="alias_name", **as_dict(length))
schema.add(Mapping, **as_dict(mapping))
schema.add(Mapping, name="my_list", **as_dict(list_))
schema.add(Complex, **as_dict(first))
schema.add(Complex, name=name, **as_dict(second))
other_file_source = Source(simple, file=filename)
data_factory_source = Source(None, file=filename, data_factory=make_data)
alias_source = Source(length, alias_for="default")
Expand All @@ -47,6 +67,16 @@ def make_data(path):
"optional_argument": {"default": Source(both), name: Source(only_mandatory)},
"with_link": {"default": Source(pointer, required=version)},
"with_length": {"default": Source(length), "alias_name": alias_source},
"mapping": {"default": Source(mapping), "my_list": Source(list_)},
"complex": {"default": Source(first), name: Source(second)},
}
return schema, reference


def as_dict(dataclass):
# shallow copy of dataclass to dictionary
return {
field.name: getattr(dataclass, field.name)
for field in dataclasses.fields(dataclass)
if getattr(dataclass, field.name) is not None
}
42 changes: 39 additions & 3 deletions tests/raw/test_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import py4vasp.raw as raw
from py4vasp import exception
from py4vasp._raw.definition import DEFAULT_FILE
from py4vasp._raw.schema import Mapping


@pytest.fixture
Expand All @@ -32,8 +33,9 @@ def mock_schema(complex_schema):


_mock_results = {}
EXAMPLE_ARRAY = np.zeros(3)
EXAMPLE_SCALAR = np.array(1)
EXAMPLE_ARRAY = np.zeros(4)
EXAMPLE_SCALAR = np.array(3)
EXAMPLE_INDICES = np.array((b"one", b"two", b"three"))


def mock_read_result(key):
Expand All @@ -42,6 +44,8 @@ def mock_read_result(key):
if "foo" in key:
mock.ndim = 0
mock.__array__ = MagicMock(return_value=EXAMPLE_SCALAR)
elif "list" in key:
mock = EXAMPLE_INDICES
else:
mock.__array__ = MagicMock(return_value=EXAMPLE_ARRAY)
_mock_results[key] = mock
Expand Down Expand Up @@ -102,6 +106,29 @@ def test_access_with_link(mock_access):
assert with_link.simple.bar[:] == reference.bar[:]


@pytest.mark.parametrize("selection", (None, "my_list"))
def test_access_mapping(mock_access, selection):
if selection is None:
expected_indices = range(EXAMPLE_SCALAR)
else:
expected_indices = tuple(index.decode() for index in EXAMPLE_INDICES)
quantity = "mapping"
mock_file, sources = mock_access
source = sources[quantity][selection or "default"]
with raw.access(quantity, selection=selection) as mapping:
assert len(mapping.valid_indices) == len(mapping)
assert all(np.atleast_1d(mapping.valid_indices == expected_indices))
check_single_file_access(mock_file, DEFAULT_FILE, source)
for index, element in mapping.items():
assert len(element) == 1
assert element.valid_indices == [index]
check_data(element.common, source.data.common)
if selection is None:
index = str(index + 1) # convert Python to Fortran index
variable = source.data.variable.format(index)
check_data(element.variable, variable)


def linked_quantity_reference(mock_access, file=None):
quantity = "simple"
mock_file, _ = mock_access
Expand Down Expand Up @@ -250,7 +277,16 @@ def expected_calls(source):

def expected_call(data, field):
key = getattr(data, field.name)
if isinstance(key, str):
if not isinstance(key, str):
return
if not isinstance(data, Mapping):
distinct_keys = {key}
elif data.valid_indices == "list_mapping":
distinct_keys = {key.format(index.decode()) for index in EXAMPLE_INDICES}
else:
# convert to Fortran index
distinct_keys = {key.format(index + 1) for index in range(EXAMPLE_SCALAR)}
for key in distinct_keys:
yield call(key)


Expand Down
27 changes: 26 additions & 1 deletion tests/raw/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import pytest
from util import VERSION, OptionalArgument, Simple, WithLength, WithLink
from util import VERSION, Mapping, OptionalArgument, Simple, WithLength, WithLink

from py4vasp import exception, raw
from py4vasp._raw.schema import Length, Link, Schema, Source
Expand Down Expand Up @@ -107,6 +107,19 @@ def make_data(source):
assert remove_version(schema.sources) == reference


def test_mapping():
mapping = Mapping("valid_indices", "common_data", "variable_data{}")
schema = Schema(VERSION)
schema.add(
Mapping,
valid_indices=mapping.valid_indices,
common=mapping.common,
variable=mapping.variable,
)
reference = {"mapping": {"default": Source(mapping)}}
assert remove_version(schema.sources) == reference


def remove_version(sources):
version = sources.pop("version")
assert version == {"default": Source(VERSION)}
Expand Down Expand Up @@ -159,14 +172,26 @@ def test_complex_str(complex_schema):
num_data: length(dataset)
alias_name: *with_length-default
mapping:
default: &mapping-default
valid_indices: foo_mapping
common: common_data
variable: variable_data{}
my_list: &mapping-my_list
valid_indices: list_mapping
common: common
variable: variable_data_{}
complex:
default: &complex-default
opt: *optional_argument-default
link: *with_link-default
mapping: *mapping-default
length: *with_length-default
mandatory: &complex-mandatory
opt: *optional_argument-mandatory
link: *with_link-default
mapping: *mapping-my_list
"""
assert str(schema) == reference

Expand Down
8 changes: 8 additions & 0 deletions tests/raw/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses

from py4vasp import raw
from py4vasp._raw import schema

VERSION = raw.Version("major_dataset", "minor_dataset", "patch_dataset")

Expand Down Expand Up @@ -30,8 +31,15 @@ class WithLength:
num_data: int


@dataclasses.dataclass
class Mapping(schema.Mapping):
common: str
variable: str


@dataclasses.dataclass
class Complex:
opt: OptionalArgument
link: WithLink
mapping: Mapping
length: WithLength = None

0 comments on commit 5377898

Please sign in to comment.