From 4b3c5398823ae49bdacd1521549a0782de0ca59f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 26 Sep 2024 14:06:49 +0200 Subject: [PATCH 01/32] Preparations for new transformations module --- src/scippnexus/nxtransformations.py | 39 ++++---- src/scippnexus/transformations.py | 138 ++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 23 deletions(-) create mode 100644 src/scippnexus/transformations.py diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index c8382cbf..f3b575ab 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -12,10 +12,7 @@ from .base import Group, NexusStructureError, NXobject, ScippIndex from .field import Field, depends_on_to_relative_path - - -class TransformationError(NexusStructureError): - pass +from .transformations import Transform, TransformationError class NXtransformations(NXobject): @@ -26,18 +23,6 @@ class Transformation: def __init__(self, obj: Field | Group): # could be an NXlog self._obj = obj - @property - def sizes(self) -> dict: - return self._obj.sizes - - @property - def dims(self) -> tuple[str, ...]: - return self._obj.dims - - @property - def shape(self) -> tuple[int, ...]: - return self._obj.shape - @property def attrs(self): return self._obj.attrs @@ -219,6 +204,7 @@ def maybe_transformation( """ if (transformation_type := obj.attrs.get('transformation_type')) is None: return value + # return Transform(obj, value) transform = Transformation(obj).make_transformation( value, transformation_type=transformation_type, select=sel ) @@ -278,8 +264,6 @@ class follows the paths and resolves the chain of transformations. class ChainError(KeyError): """Raised when a transformation chain cannot be resolved.""" - pass - @dataclass class Entry: name: str @@ -337,7 +321,9 @@ def __getitem__(self, path: str) -> TransformationChainResolver: ) return node if len(remainder) == 0 else node[remainder[0]] - def resolve_depends_on(self) -> sc.DataArray | sc.Variable | None: + def resolve_depends_on( + self, depends_on: str | None = None + ) -> sc.DataArray | sc.Variable | None: """ Resolve the depends_on attribute of a transformation chain. @@ -348,7 +334,7 @@ def resolve_depends_on(self) -> sc.DataArray | sc.Variable | None: """ if 'resolved_depends_on' in self.value: depends_on = self.value['resolved_depends_on'] - else: + elif depends_on is None: depends_on = self.value.get('depends_on') if depends_on is None: return None @@ -389,6 +375,7 @@ def compute_positions( *, store_position: str = 'position', store_transform: str | None = None, + transformations: sc.DataGroup | None = None, ) -> sc.DataGroup: """ Recursively compute positions from depends_on attributes as well as the @@ -422,6 +409,9 @@ def compute_positions( Name used to store result of resolving each depends_on chain. store_transform: If not None, store the resolved transformation chain in this field. + transformations: + Optional data group containing transformation chains. If not provided, the + transformations are looked up in the input data group. Returns ------- @@ -430,7 +420,9 @@ def compute_positions( """ # Create resolver at root level, since any depends_on chain may lead to a parent, # i.e., we cannot use a resolver at the level of each chain's entry point. - resolver = TransformationChainResolver.from_root(dg) + # TODO need to be able to set root, would be better to construct resolver outside, + # see we can navigate to correct path? + resolver = TransformationChainResolver.from_root(transformations or dg) return _with_positions( dg, store_position=store_position, @@ -480,7 +472,7 @@ def _with_positions( transform = None if 'depends_on' in dg: try: - transform = resolver.resolve_depends_on() + transform = resolver.resolve_depends_on(dg['depends_on']) except TransformationChainResolver.ChainError as e: warnings.warn( UserWarning(f'depends_on chain references missing node:\n{e}'), @@ -491,7 +483,8 @@ def _with_positions( if store_transform is not None: out[store_transform] = transform for name, value in dg.items(): - if isinstance(value, sc.DataGroup): + # Do not descend into groups that are not in the resolver. + if isinstance(value, sc.DataGroup) and name in resolver.value: value = _with_positions( value, store_position=store_position, diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py new file mode 100644 index 00000000..934d1c24 --- /dev/null +++ b/src/scippnexus/transformations.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# @author Simon Heybrock +from __future__ import annotations + +import warnings +from dataclasses import dataclass + +import numpy as np +import scipp as sc +from scipp.scipy import interpolate + +from .base import Group, NexusStructureError, NXobject, ScippIndex +from .field import Field, depends_on_to_relative_path +from .file import File +import h5py +from typing import Any + + +class TransformationError(NexusStructureError): + pass + + +# Plan: +# - helper to find an load al nxtransformations +# - do not auto-conv to scipp transform, add special classes for trans and rot +# - consider skipping trans load as part of group load in favor of separate load? +# or just remove code that follows chains? +# - avoid storing depends_on as weird coord! use dedicated data structure +# - have raw Transform (with vector and offset) +# - translate into scipp transform when building transform + +# TODO +# - remove resolved_depends_on mechanism +# - remove Transformation + + +def _parse_offset(obj: Field | Group) -> sc.Variable | None: + if (offset := obj.attrs.get('offset')) is None: + return None + if (offset_units := obj.attrs.get('offset_units')) is None: + raise TransformationError( + f"Found {offset=} but no corresponding 'offset_units' " + f"attribute at {obj.name}" + ) + return sc.spatial.translation(value=offset, unit=offset_units) + + +def _parse_value( + obj: Field | Group, + value: sc.Variable | sc.DataArray | sc.DataGroup, +) -> sc.Variable | sc.DataArray: + if isinstance(value, sc.DataGroup) and ( + isinstance(value.get('value'), sc.DataArray) + ): + # Some NXlog groups are split into value, alarm, and connection_status + # sublogs. We only care about the value. + value = value['value'] + if not isinstance(value, sc.Variable | sc.DataArray): + raise TransformationError(f"Failed to load transformation value at {obj.name}") + return value + + +class Transform: + def __init__( + self, obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup + ): + self.offset = _parse_offset(obj) + self.vector = sc.vector(value=obj.attrs.get('vector')) + # TODO This is annoying... what if we keep it, load transform independently, + # and index before returning? + # TODO Change NXobject.__getitem__ to never descend into NXtransformations? + self.depends_on = depends_on_to_relative_path( + obj.attrs.get('depends_on'), obj.parent.name + ) + self.transformation_type = obj.attrs.get('transformation_type') + if self.transformation_type not in ['translation', 'rotation']: + raise TransformationError( + f"{self.transformation_type=} attribute at {obj.name}," + " expected 'translation' or 'rotation'." + ) + self.value = _parse_value(obj, value) + + # TODO can cache this + def build(self) -> sc.Variable | sc.DataArray: + try: + t = self.value * self.vector + v = t if isinstance(t, sc.Variable) else t.data + if self.transformation_type == 'translation': + v = sc.spatial.translations(dims=v.dims, values=v.values, unit=v.unit) + elif self.transformation_type == 'rotation': + v = sc.spatial.rotations_from_rotvecs(v) + if isinstance(t, sc.Variable): + t = v + else: + t.data = v + if self.offset is None: + return t + if self.transformation_type == 'translation': + return t * self.offset.to(unit=t.unit, copy=False) + return t * self.offset + except (sc.DimensionError, sc.UnitError, TransformationError): + # TODO We should probably try to return some other data structure and + # also insert offset and other attributes. + return self.value + + +def find_transformation_groups(filename: str) -> list[str]: + transforms: list[str] = [] + + def _collect_transforms(name: str, group: h5py.Group) -> None: + if group.attrs.get('NX_class') == 'NXtransformations': + transforms.append(name) + + with h5py.File(filename, 'r') as f: + # TODO This is slow! No need to visit everything, just recurse groups? + f.visititems(_collect_transforms) + return transforms + + +def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: + if '/' not in path: + dg[path] = value + else: + first, remainder = path.split('/', maxsplit=1) + if first not in dg: + dg[first] = sc.DataGroup() + _set_recursive(dg[first], remainder, value) + + +def load_transformations(filename: str) -> sc.DataGroup: + groups = find_transformation_groups(filename) + with File(filename, mode='r') as f: + transforms = sc.DataGroup({group: f[group][()] for group in groups}) + dg = sc.DataGroup() + for path, value in transforms.items(): + _set_recursive(dg, path, value) + return dg From 433e17eb5b9c1e99ba0dc6c537f5e9636fd025eb Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 08:39:28 +0200 Subject: [PATCH 02/32] Make maybe_transformation injectable --- src/scippnexus/base.py | 29 ++++++++++++++++++++++------- src/scippnexus/field.py | 7 ++++++- src/scippnexus/file.py | 9 +++++++-- src/scippnexus/nxtransformations.py | 5 ++--- src/scippnexus/typing.py | 12 ++++++++++++ 5 files changed, 49 insertions(+), 13 deletions(-) diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index 717dfe1f..8a2c47f0 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -19,7 +19,7 @@ from ._common import to_child_select from .attrs import Attrs from .field import Field -from .typing import H5Dataset, H5Group, ScippIndex +from .typing import H5Dataset, H5Group, MaybeTransformation, ScippIndex def asvariable(obj: Any | sc.Variable) -> sc.Variable: @@ -29,8 +29,6 @@ def asvariable(obj: Any | sc.Variable) -> sc.Variable: class NexusStructureError(Exception): """Invalid or unsupported class and field structure in Nexus.""" - pass - def is_dataset(obj: H5Group | H5Dataset) -> bool: """Return true if the object is an h5py.Dataset or equivalent. @@ -202,11 +200,17 @@ class Group(Mapping): # interpretation of the file, but need to cache information. An earlier version # of ScippNexus used such a mechanism without caching, which was very slow. - def __init__(self, group: H5Group, definitions: dict[str, type] | None = None): + def __init__( + self, + group: H5Group, + definitions: dict[str, type] | None = None, + maybe_transformation: MaybeTransformation = None, + ): self._group = group self._definitions = {} if definitions is None else definitions self._lazy_children = None self._lazy_nexus = None + self._maybe_transformation = maybe_transformation @property def nx_class(self) -> type | None: @@ -258,9 +262,15 @@ def _children(self) -> dict[str, Field | Group]: def _read_children(self) -> dict[str, Field | Group]: def _make_child(obj: H5Dataset | H5Group) -> Field | Group: if is_dataset(obj): - return Field(obj, parent=self) + return Field( + obj, parent=self, _maybe_transformation=self._maybe_transformation + ) else: - return Group(obj, definitions=self._definitions) + return Group( + obj, + definitions=self._definitions, + maybe_transformation=self._maybe_transformation, + ) items = {name: _make_child(obj) for name, obj in self._group.items()} # In the case of NXevent_data, the `cue_` fields are unusable, since @@ -409,7 +419,12 @@ def isclass(x): # For a time-dependent transformation in NXtransformations, an NXlog may # take the place of the `value` field. In this case, we need to read the # properties of the NXlog group to make the actual transformation. - from .nxtransformations import maybe_resolve, maybe_transformation + from .nxtransformations import maybe_resolve + + if self._maybe_transformation is not None: + maybe_transformation = self._maybe_transformation + else: + from .nxtransformations import maybe_transformation if ( isinstance(dg, sc.DataGroup) diff --git a/src/scippnexus/field.py b/src/scippnexus/field.py index 601ff475..d349ffa4 100644 --- a/src/scippnexus/field.py +++ b/src/scippnexus/field.py @@ -19,6 +19,7 @@ from ._cache import cached_property from .attrs import Attrs +from .typing import MaybeTransformation if TYPE_CHECKING: from .base import Group @@ -84,6 +85,7 @@ class Field: sizes: dict[str, int] | None = None dtype: sc.DType | None = None errors: H5Dataset | None = None + _maybe_transformation: MaybeTransformation = None @cached_property def attrs(self) -> dict[str, Any]: @@ -135,7 +137,10 @@ def __getitem__(self, select: ScippIndex) -> Any | sc.Variable: : Loaded data. """ - from .nxtransformations import maybe_transformation + if self._maybe_transformation is not None: + maybe_transformation = self._maybe_transformation + else: + from .nxtransformations import maybe_transformation index = to_plain_index(self.dims, select) if isinstance(index, int | slice): diff --git a/src/scippnexus/file.py b/src/scippnexus/file.py index 939cec50..fa5d2751 100644 --- a/src/scippnexus/file.py +++ b/src/scippnexus/file.py @@ -13,7 +13,7 @@ Group, base_definitions, ) -from .typing import Definitions +from .typing import Definitions, MaybeTransformation class File(AbstractContextManager, Group): @@ -22,6 +22,7 @@ def __init__( name: str | os.PathLike[str] | io.BytesIO | h5py.Group, *args, definitions: Definitions | DefaultDefinitionsType = DefaultDefinitions, + maybe_transformation: MaybeTransformation = None, **kwargs, ): """Context manager for NeXus files, similar to h5py.File. @@ -50,7 +51,11 @@ def __init__( else: self._file = h5py.File(name, *args, **kwargs) self._manage_file_context = True - super().__init__(self._file, definitions=definitions) + super().__init__( + self._file, + definitions=definitions, + maybe_transformation=maybe_transformation, + ) def __enter__(self): if self._manage_file_context: diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index f3b575ab..ae0706a5 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -10,9 +10,9 @@ import scipp as sc from scipp.scipy import interpolate -from .base import Group, NexusStructureError, NXobject, ScippIndex +from .base import Group, NXobject, ScippIndex from .field import Field, depends_on_to_relative_path -from .transformations import Transform, TransformationError +from .transformations import TransformationError class NXtransformations(NXobject): @@ -204,7 +204,6 @@ def maybe_transformation( """ if (transformation_type := obj.attrs.get('transformation_type')) is None: return value - # return Transform(obj, value) transform = Transformation(obj).make_transformation( value, transformation_type=transformation_type, select=sel ) diff --git a/src/scippnexus/typing.py b/src/scippnexus/typing.py index 339071c7..c51f3661 100644 --- a/src/scippnexus/typing.py +++ b/src/scippnexus/typing.py @@ -6,6 +6,8 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Protocol +import scipp as sc + class H5Base(Protocol): @property @@ -66,6 +68,8 @@ def visititems(self, func: Callable) -> None: class ellipsis(Enum): Ellipsis = "..." + from .base import Field, Group + else: ellipsis = type(Ellipsis) @@ -76,3 +80,11 @@ class ellipsis(Enum): ) Definitions = Mapping[str, type] + +MaybeTransformation = ( + Callable[ + ['Field | Group', sc.Variable | sc.DataArray | sc.DataGroup, ScippIndex], + sc.Variable | sc.DataArray | sc.DataGroup, + ] + | None +) From 135c28b78da68b923212ff1883eba953fb6e1ce2 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 10:43:04 +0200 Subject: [PATCH 03/32] Fix new and old functionality --- src/scippnexus/nxtransformations.py | 27 +++++++++++++++++---------- src/scippnexus/transformations.py | 27 +++++++++++++++++---------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index ae0706a5..a628077e 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -12,7 +12,7 @@ from .base import Group, NXobject, ScippIndex from .field import Field, depends_on_to_relative_path -from .transformations import TransformationError +from .transformations import Transform, TransformationError class NXtransformations(NXobject): @@ -320,9 +320,7 @@ def __getitem__(self, path: str) -> TransformationChainResolver: ) return node if len(remainder) == 0 else node[remainder[0]] - def resolve_depends_on( - self, depends_on: str | None = None - ) -> sc.DataArray | sc.Variable | None: + def resolve_depends_on(self) -> sc.DataArray | sc.Variable | None: """ Resolve the depends_on attribute of a transformation chain. @@ -333,7 +331,7 @@ def resolve_depends_on( """ if 'resolved_depends_on' in self.value: depends_on = self.value['resolved_depends_on'] - elif depends_on is None: + else: depends_on = self.value.get('depends_on') if depends_on is None: return None @@ -347,18 +345,25 @@ def get_chain( ) -> list[sc.DataArray | sc.Variable]: if depends_on == '.': return [] + new_style_transform = False if isinstance(depends_on, str): node = self[depends_on] - transform = node.value.copy(deep=False) + if isinstance(node.value, Transform): + transform = node.value.build().copy(deep=False) + depends_on = node.value.depends_on + new_style_transform = True + else: + transform = node.value.copy(deep=False) + depends_on = '.' node = node.parent else: # Fake node, resolved_depends_on is recursive so this is actually ignored. node = self transform = depends_on - depends_on = '.' + depends_on = '.' if transform.dtype in (sc.DType.translation3, sc.DType.affine_transform3): transform = transform.to(unit='m', copy=False) - if isinstance(transform, sc.DataArray): + if not new_style_transform and isinstance(transform, sc.DataArray): if (attr := transform.coords.pop('resolved_depends_on', None)) is not None: depends_on = attr.value elif (attr := transform.coords.pop('depends_on', None)) is not None: @@ -471,7 +476,7 @@ def _with_positions( transform = None if 'depends_on' in dg: try: - transform = resolver.resolve_depends_on(dg['depends_on']) + transform = resolver.resolve_depends_on() except TransformationChainResolver.ChainError as e: warnings.warn( UserWarning(f'depends_on chain references missing node:\n{e}'), @@ -482,7 +487,9 @@ def _with_positions( if store_transform is not None: out[store_transform] = transform for name, value in dg.items(): - # Do not descend into groups that are not in the resolver. + # If the resolver was constructed from an external tree of transformations it + # will not contain groups that do not contain any transformations or depends_on + # field. Do not descend into such groups. if isinstance(value, sc.DataGroup) and name in resolver.value: value = _with_positions( value, diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 934d1c24..4752154e 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -3,18 +3,15 @@ # @author Simon Heybrock from __future__ import annotations -import warnings -from dataclasses import dataclass +from typing import Any -import numpy as np +import h5py import scipp as sc -from scipp.scipy import interpolate -from .base import Group, NexusStructureError, NXobject, ScippIndex +from .base import Group, NexusStructureError, ScippIndex from .field import Field, depends_on_to_relative_path from .file import File -import h5py -from typing import Any +from .typing import H5Base class TransformationError(NexusStructureError): @@ -108,8 +105,8 @@ def build(self) -> sc.Variable | sc.DataArray: def find_transformation_groups(filename: str) -> list[str]: transforms: list[str] = [] - def _collect_transforms(name: str, group: h5py.Group) -> None: - if group.attrs.get('NX_class') == 'NXtransformations': + def _collect_transforms(name: str, obj: H5Base) -> None: + if name.endswith('/depends_on') or 'transformation_type' in obj.attrs: transforms.append(name) with h5py.File(filename, 'r') as f: @@ -128,9 +125,19 @@ def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: _set_recursive(dg[first], remainder, value) +def _maybe_transformation( + obj: Field | Group, + value: sc.Variable | sc.DataArray | sc.DataGroup, + sel: ScippIndex, +) -> sc.Variable | sc.DataArray | sc.DataGroup: + if obj.attrs.get('transformation_type') is None: + return value + return Transform(obj, value) + + def load_transformations(filename: str) -> sc.DataGroup: groups = find_transformation_groups(filename) - with File(filename, mode='r') as f: + with File(filename, mode='r', maybe_transformation=_maybe_transformation) as f: transforms = sc.DataGroup({group: f[group][()] for group in groups}) dg = sc.DataGroup() for path, value in transforms.items(): From 336c685d742665aa5c9f726ad8824589d91d908f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 10:44:32 +0200 Subject: [PATCH 04/32] Add new tests --- src/scippnexus/transformations.py | 4 -- tests/transformations_test.py | 77 +++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 tests/transformations_test.py diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 4752154e..9603ce50 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -27,10 +27,6 @@ class TransformationError(NexusStructureError): # - have raw Transform (with vector and offset) # - translate into scipp transform when building transform -# TODO -# - remove resolved_depends_on mechanism -# - remove Transformation - def _parse_offset(obj: Field | Group) -> sc.Variable | None: if (offset := obj.attrs.get('offset')) is None: diff --git a/tests/transformations_test.py b/tests/transformations_test.py new file mode 100644 index 00000000..9419d041 --- /dev/null +++ b/tests/transformations_test.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +import pytest +from ess.reduce import data +from scipp.testing import assert_identical + +import scippnexus as snx +from scippnexus import transformations + + +def test_find_transformation_groups_finds_expected_groups() -> None: + filename = data.loki_tutorial_sample_run_60250() + paths = transformations.find_transformation_groups(filename) + assert paths == [ + 'entry/instrument/larmor_detector/depends_on', + 'entry/instrument/larmor_detector/transformations/trans_1', + 'entry/instrument/monitor_1/depends_on', + 'entry/instrument/monitor_1/transformations/trans_3', + 'entry/instrument/monitor_2/depends_on', + 'entry/instrument/monitor_2/transformations/trans_4', + 'entry/instrument/source/depends_on', + 'entry/instrument/source/transformations/trans_2', + ] + + +def test_load_transformations_loads_as_flat_datagroup() -> None: + filename = data.loki_tutorial_sample_run_60250() + dg = transformations.load_transformations(filename) + assert list(dg) == ['entry'] + entry = dg['entry'] + assert list(entry) == ['instrument'] + instrument = entry['instrument'] + assert list(instrument) == ['larmor_detector', 'monitor_1', 'monitor_2', 'source'] + for group in instrument.values(): + assert list(group) == ['depends_on', 'transformations'] + + +def test_find_transformations_bifrost() -> None: + filename = '/home/simon/instruments/bifrost/BIFROST_20240905T122604.h5' + transformations.find_transformation_groups(filename) + + +def test_load_transformations_bifrost() -> None: + filename = '/home/simon/instruments/bifrost/BIFROST_20240905T122604.h5' + transformations.load_transformations(filename) + + +def test_scippnexus_can_parse_transformation_chain() -> None: + filename = data.loki_tutorial_sample_run_60250() + transforms = transformations.load_transformations(filename) + dg = snx.load(filename) + result = snx.compute_positions( + dg, + store_position='position', + store_transform='transform', + transformations=transforms, + ) + detector = result['entry']['instrument']['larmor_detector'] + assert 'position' in detector['larmor_detector_events'].coords + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_positions_consistent_with_separate_load() -> None: + filename = '/home/simon/instruments/bifrost/BIFROST_20240905T122604.h5' + filename = '/home/simon/instruments/bifrost/268227_00021671.hdf' + transforms = transformations.load_transformations(filename) + dg = snx.load(filename) + expected = snx.compute_positions( + dg, store_position='position', store_transform='transform' + ) + result = snx.compute_positions( + dg, + store_position='position', + store_transform='transform', + transformations=transforms, + ) + assert_identical(result, expected) From 2ff9723ee7075b1666b8774c5177217299146462 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 11:12:23 +0200 Subject: [PATCH 05/32] Remove try-except --- src/scippnexus/transformations.py | 35 +++++++++++++------------------ 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 9603ce50..75a211d5 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -76,26 +76,21 @@ def __init__( # TODO can cache this def build(self) -> sc.Variable | sc.DataArray: - try: - t = self.value * self.vector - v = t if isinstance(t, sc.Variable) else t.data - if self.transformation_type == 'translation': - v = sc.spatial.translations(dims=v.dims, values=v.values, unit=v.unit) - elif self.transformation_type == 'rotation': - v = sc.spatial.rotations_from_rotvecs(v) - if isinstance(t, sc.Variable): - t = v - else: - t.data = v - if self.offset is None: - return t - if self.transformation_type == 'translation': - return t * self.offset.to(unit=t.unit, copy=False) - return t * self.offset - except (sc.DimensionError, sc.UnitError, TransformationError): - # TODO We should probably try to return some other data structure and - # also insert offset and other attributes. - return self.value + t = self.value * self.vector + v = t if isinstance(t, sc.Variable) else t.data + if self.transformation_type == 'translation': + v = sc.spatial.translations(dims=v.dims, values=v.values, unit=v.unit) + elif self.transformation_type == 'rotation': + v = sc.spatial.rotations_from_rotvecs(v) + if isinstance(t, sc.Variable): + t = v + else: + t.data = v + if self.offset is None: + return t + if self.transformation_type == 'translation': + return t * self.offset.to(unit=t.unit, copy=False) + return t * self.offset def find_transformation_groups(filename: str) -> list[str]: From e3274d0c4c1f8994806eb7c464534948fac10bd7 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 11:15:14 +0200 Subject: [PATCH 06/32] Rename --- src/scippnexus/transformations.py | 5 ++--- tests/transformations_test.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 75a211d5..4bc5b593 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -93,7 +93,7 @@ def build(self) -> sc.Variable | sc.DataArray: return t * self.offset -def find_transformation_groups(filename: str) -> list[str]: +def find_transformations(filename: str) -> list[str]: transforms: list[str] = [] def _collect_transforms(name: str, obj: H5Base) -> None: @@ -101,7 +101,6 @@ def _collect_transforms(name: str, obj: H5Base) -> None: transforms.append(name) with h5py.File(filename, 'r') as f: - # TODO This is slow! No need to visit everything, just recurse groups? f.visititems(_collect_transforms) return transforms @@ -127,7 +126,7 @@ def _maybe_transformation( def load_transformations(filename: str) -> sc.DataGroup: - groups = find_transformation_groups(filename) + groups = find_transformations(filename) with File(filename, mode='r', maybe_transformation=_maybe_transformation) as f: transforms = sc.DataGroup({group: f[group][()] for group in groups}) dg = sc.DataGroup() diff --git a/tests/transformations_test.py b/tests/transformations_test.py index 9419d041..266efd79 100644 --- a/tests/transformations_test.py +++ b/tests/transformations_test.py @@ -10,7 +10,7 @@ def test_find_transformation_groups_finds_expected_groups() -> None: filename = data.loki_tutorial_sample_run_60250() - paths = transformations.find_transformation_groups(filename) + paths = transformations.find_transformations(filename) assert paths == [ 'entry/instrument/larmor_detector/depends_on', 'entry/instrument/larmor_detector/transformations/trans_1', @@ -37,7 +37,7 @@ def test_load_transformations_loads_as_flat_datagroup() -> None: def test_find_transformations_bifrost() -> None: filename = '/home/simon/instruments/bifrost/BIFROST_20240905T122604.h5' - transformations.find_transformation_groups(filename) + transformations.find_transformations(filename) def test_load_transformations_bifrost() -> None: From 38988501304efc7d47be329f27d21ae0151d1fe2 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 11:35:33 +0200 Subject: [PATCH 07/32] Make tests runnable without extra data --- src/scippnexus/transformations.py | 12 +++++++-- tests/transformations_test.py | 42 ++++++++----------------------- 2 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 4bc5b593..71521e3f 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -3,6 +3,7 @@ # @author Simon Heybrock from __future__ import annotations +import warnings from typing import Any import h5py @@ -64,7 +65,7 @@ def __init__( # and index before returning? # TODO Change NXobject.__getitem__ to never descend into NXtransformations? self.depends_on = depends_on_to_relative_path( - obj.attrs.get('depends_on'), obj.parent.name + obj.attrs['depends_on'], obj.parent.name ) self.transformation_type = obj.attrs.get('transformation_type') if self.transformation_type not in ['translation', 'rotation']: @@ -122,7 +123,14 @@ def _maybe_transformation( ) -> sc.Variable | sc.DataArray | sc.DataGroup: if obj.attrs.get('transformation_type') is None: return value - return Transform(obj, value) + try: + return Transform(obj, value) + except KeyError as e: + warnings.warn( + UserWarning(f'Invalid transformation, missing attribute {e}'), + stacklevel=2, + ) + return value def load_transformations(filename: str) -> sc.DataGroup: diff --git a/tests/transformations_test.py b/tests/transformations_test.py index 266efd79..67e9fbb3 100644 --- a/tests/transformations_test.py +++ b/tests/transformations_test.py @@ -1,15 +1,17 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 Scipp contributors (https://github.com/scipp) import pytest -from ess.reduce import data -from scipp.testing import assert_identical +import scipp as sc import scippnexus as snx from scippnexus import transformations +externalfile = pytest.importorskip('externalfile') + +@pytest.mark.externalfile() def test_find_transformation_groups_finds_expected_groups() -> None: - filename = data.loki_tutorial_sample_run_60250() + filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') paths = transformations.find_transformations(filename) assert paths == [ 'entry/instrument/larmor_detector/depends_on', @@ -23,8 +25,9 @@ def test_find_transformation_groups_finds_expected_groups() -> None: ] +@pytest.mark.externalfile() def test_load_transformations_loads_as_flat_datagroup() -> None: - filename = data.loki_tutorial_sample_run_60250() + filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') dg = transformations.load_transformations(filename) assert list(dg) == ['entry'] entry = dg['entry'] @@ -35,34 +38,11 @@ def test_load_transformations_loads_as_flat_datagroup() -> None: assert list(group) == ['depends_on', 'transformations'] -def test_find_transformations_bifrost() -> None: - filename = '/home/simon/instruments/bifrost/BIFROST_20240905T122604.h5' - transformations.find_transformations(filename) - - -def test_load_transformations_bifrost() -> None: - filename = '/home/simon/instruments/bifrost/BIFROST_20240905T122604.h5' - transformations.load_transformations(filename) - - -def test_scippnexus_can_parse_transformation_chain() -> None: - filename = data.loki_tutorial_sample_run_60250() - transforms = transformations.load_transformations(filename) - dg = snx.load(filename) - result = snx.compute_positions( - dg, - store_position='position', - store_transform='transform', - transformations=transforms, - ) - detector = result['entry']['instrument']['larmor_detector'] - assert 'position' in detector['larmor_detector_events'].coords - - @pytest.mark.filterwarnings("ignore::UserWarning") +@pytest.mark.externalfile() def test_positions_consistent_with_separate_load() -> None: - filename = '/home/simon/instruments/bifrost/BIFROST_20240905T122604.h5' - filename = '/home/simon/instruments/bifrost/268227_00021671.hdf' + # The Bifrost instrument has complex transformation chains so this is a good test. + filename = externalfile.get_path('2023/BIFROST_873855_00000015.hdf') transforms = transformations.load_transformations(filename) dg = snx.load(filename) expected = snx.compute_positions( @@ -74,4 +54,4 @@ def test_positions_consistent_with_separate_load() -> None: store_transform='transform', transformations=transforms, ) - assert_identical(result, expected) + assert sc.identical(result, expected) From cafcc0a35079aca64e5c45060abb84493f89a423 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 11:36:59 +0200 Subject: [PATCH 08/32] whitespace --- src/scippnexus/transformations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 71521e3f..2728a82d 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -127,8 +127,7 @@ def _maybe_transformation( return Transform(obj, value) except KeyError as e: warnings.warn( - UserWarning(f'Invalid transformation, missing attribute {e}'), - stacklevel=2, + UserWarning(f'Invalid transformation, missing attribute {e}'), stacklevel=2 ) return value From 05da5b494ed71854248e281665285a705335f573 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 11:41:11 +0200 Subject: [PATCH 09/32] Do not auto-conv to nested layout --- src/scippnexus/transformations.py | 35 +++++++++++++++++-------------- tests/transformations_test.py | 2 ++ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 2728a82d..ea149cf9 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -106,14 +106,17 @@ def _collect_transforms(name: str, obj: H5Base) -> None: return transforms -def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: - if '/' not in path: - dg[path] = value - else: - first, remainder = path.split('/', maxsplit=1) - if first not in dg: - dg[first] = sc.DataGroup() - _set_recursive(dg[first], remainder, value) +def load_transformations(filename: str) -> sc.DataGroup: + groups = find_transformations(filename) + with File(filename, mode='r', maybe_transformation=_maybe_transformation) as f: + return sc.DataGroup({group: f[group][()] for group in groups}) + + +def as_nested(dg: sc.DataGroup) -> sc.DataGroup: + out = sc.DataGroup() + for path, value in dg.items(): + _set_recursive(out, path, value) + return out def _maybe_transformation( @@ -132,11 +135,11 @@ def _maybe_transformation( return value -def load_transformations(filename: str) -> sc.DataGroup: - groups = find_transformations(filename) - with File(filename, mode='r', maybe_transformation=_maybe_transformation) as f: - transforms = sc.DataGroup({group: f[group][()] for group in groups}) - dg = sc.DataGroup() - for path, value in transforms.items(): - _set_recursive(dg, path, value) - return dg +def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: + if '/' not in path: + dg[path] = value + else: + first, remainder = path.split('/', maxsplit=1) + if first not in dg: + dg[first] = sc.DataGroup() + _set_recursive(dg[first], remainder, value) diff --git a/tests/transformations_test.py b/tests/transformations_test.py index 67e9fbb3..3d09d669 100644 --- a/tests/transformations_test.py +++ b/tests/transformations_test.py @@ -29,6 +29,7 @@ def test_find_transformation_groups_finds_expected_groups() -> None: def test_load_transformations_loads_as_flat_datagroup() -> None: filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') dg = transformations.load_transformations(filename) + dg = transformations.as_nested(dg) assert list(dg) == ['entry'] entry = dg['entry'] assert list(entry) == ['instrument'] @@ -44,6 +45,7 @@ def test_positions_consistent_with_separate_load() -> None: # The Bifrost instrument has complex transformation chains so this is a good test. filename = externalfile.get_path('2023/BIFROST_873855_00000015.hdf') transforms = transformations.load_transformations(filename) + transforms = transformations.as_nested(transforms) dg = snx.load(filename) expected = snx.compute_positions( dg, store_position='position', store_transform='transform' From da613cb9e300d9549eefc876f0bbd115b508151f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 11:49:37 +0200 Subject: [PATCH 10/32] Some cleanup and docs --- src/scippnexus/transformations.py | 95 ++++++++++++++++++------------- 1 file changed, 55 insertions(+), 40 deletions(-) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index ea149cf9..8ca6598b 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -19,51 +19,12 @@ class TransformationError(NexusStructureError): pass -# Plan: -# - helper to find an load al nxtransformations -# - do not auto-conv to scipp transform, add special classes for trans and rot -# - consider skipping trans load as part of group load in favor of separate load? -# or just remove code that follows chains? -# - avoid storing depends_on as weird coord! use dedicated data structure -# - have raw Transform (with vector and offset) -# - translate into scipp transform when building transform - - -def _parse_offset(obj: Field | Group) -> sc.Variable | None: - if (offset := obj.attrs.get('offset')) is None: - return None - if (offset_units := obj.attrs.get('offset_units')) is None: - raise TransformationError( - f"Found {offset=} but no corresponding 'offset_units' " - f"attribute at {obj.name}" - ) - return sc.spatial.translation(value=offset, unit=offset_units) - - -def _parse_value( - obj: Field | Group, - value: sc.Variable | sc.DataArray | sc.DataGroup, -) -> sc.Variable | sc.DataArray: - if isinstance(value, sc.DataGroup) and ( - isinstance(value.get('value'), sc.DataArray) - ): - # Some NXlog groups are split into value, alarm, and connection_status - # sublogs. We only care about the value. - value = value['value'] - if not isinstance(value, sc.Variable | sc.DataArray): - raise TransformationError(f"Failed to load transformation value at {obj.name}") - return value - - class Transform: def __init__( self, obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup ): self.offset = _parse_offset(obj) self.vector = sc.vector(value=obj.attrs.get('vector')) - # TODO This is annoying... what if we keep it, load transform independently, - # and index before returning? - # TODO Change NXobject.__getitem__ to never descend into NXtransformations? self.depends_on = depends_on_to_relative_path( obj.attrs['depends_on'], obj.parent.name ) @@ -75,7 +36,6 @@ def __init__( ) self.value = _parse_value(obj, value) - # TODO can cache this def build(self) -> sc.Variable | sc.DataArray: t = self.value * self.vector v = t if isinstance(t, sc.Variable) else t.data @@ -107,12 +67,41 @@ def _collect_transforms(name: str, obj: H5Base) -> None: def load_transformations(filename: str) -> sc.DataGroup: + """ + Load transformations and depends_on fields from a NeXus file. + + Parameters + ---------- + filename: + The path to the NeXus file. + + Returns + ------- + : + A flat DataGroup with the transformations and depends_on fields. + """ groups = find_transformations(filename) with File(filename, mode='r', maybe_transformation=_maybe_transformation) as f: return sc.DataGroup({group: f[group][()] for group in groups}) def as_nested(dg: sc.DataGroup) -> sc.DataGroup: + """ + Convert a flat DataGroup with paths as keys to a nested DataGroup. + + This is useful when loading transformations from a NeXus file, where the + paths are used as keys to represent the structure of the NeXus file. + + Parameters + ---------- + dg: + The flat DataGroup to convert. + + Returns + ------- + : + The nested DataGroup. + """ out = sc.DataGroup() for path, value in dg.items(): _set_recursive(out, path, value) @@ -143,3 +132,29 @@ def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: if first not in dg: dg[first] = sc.DataGroup() _set_recursive(dg[first], remainder, value) + + +def _parse_offset(obj: Field | Group) -> sc.Variable | None: + if (offset := obj.attrs.get('offset')) is None: + return None + if (offset_units := obj.attrs.get('offset_units')) is None: + raise TransformationError( + f"Found {offset=} but no corresponding 'offset_units' " + f"attribute at {obj.name}" + ) + return sc.spatial.translation(value=offset, unit=offset_units) + + +def _parse_value( + obj: Field | Group, + value: sc.Variable | sc.DataArray | sc.DataGroup, +) -> sc.Variable | sc.DataArray: + if isinstance(value, sc.DataGroup) and ( + isinstance(value.get('value'), sc.DataArray) + ): + # Some NXlog groups are split into value, alarm, and connection_status + # sublogs. We only care about the value. + value = value['value'] + if not isinstance(value, sc.Variable | sc.DataArray): + raise TransformationError(f"Failed to load transformation value at {obj.name}") + return value From 4bbc7509e694661791f8d97fced0d22d54a5b196 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 13:08:05 +0200 Subject: [PATCH 11/32] Add apply_to_transformations --- src/scippnexus/base.py | 12 +++++- src/scippnexus/transformations.py | 61 +++++++++++++++++++++++-------- tests/transformations_test.py | 42 +++++++++++++++++---- 3 files changed, 90 insertions(+), 25 deletions(-) diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index 8a2c47f0..c590fd8c 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -246,11 +246,19 @@ def unit(self) -> sc.Unit | None: @property def parent(self) -> Group: - return Group(self._group.parent, definitions=self._definitions) + return Group( + self._group.parent, + definitions=self._definitions, + maybe_transformation=self._maybe_transformation, + ) @cached_property def file(self) -> Group: - return Group(self._group.file, definitions=self._definitions) + return Group( + self._group.file, + definitions=self._definitions, + maybe_transformation=self._maybe_transformation, + ) @property def _children(self) -> dict[str, Field | Group]: diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 8ca6598b..fc94d9a8 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) # @author Simon Heybrock from __future__ import annotations import warnings -from typing import Any +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Literal import h5py import scipp as sc @@ -19,22 +21,37 @@ class TransformationError(NexusStructureError): pass +@dataclass class Transform: - def __init__( - self, obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup - ): - self.offset = _parse_offset(obj) - self.vector = sc.vector(value=obj.attrs.get('vector')) - self.depends_on = depends_on_to_relative_path( - obj.attrs['depends_on'], obj.parent.name - ) - self.transformation_type = obj.attrs.get('transformation_type') + name: str + transformation_type: Literal['translation', 'rotation'] + value: sc.DataArray | sc.Variable + vector: sc.Variable + depends_on: str + offset: sc.Variable | None + + def __post_init__(self): if self.transformation_type not in ['translation', 'rotation']: raise TransformationError( - f"{self.transformation_type=} attribute at {obj.name}," + f"{self.transformation_type=} attribute at {self.name}," " expected 'translation' or 'rotation'." ) - self.value = _parse_value(obj, value) + + @staticmethod + def from_object( + obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup + ) -> Transform: + depends_on = depends_on_to_relative_path( + obj.attrs['depends_on'], obj.parent.name + ) + return Transform( + name=obj.name, + transformation_type=obj.attrs.get('transformation_type'), + value=_parse_value(obj, value), + vector=sc.vector(value=obj.attrs.get('vector')), + depends_on=depends_on, + offset=_parse_offset(obj), + ) def build(self) -> sc.Variable | sc.DataArray: t = self.value * self.vector @@ -59,7 +76,7 @@ def find_transformations(filename: str) -> list[str]: def _collect_transforms(name: str, obj: H5Base) -> None: if name.endswith('/depends_on') or 'transformation_type' in obj.attrs: - transforms.append(name) + transforms.append(f'/{name}') with h5py.File(filename, 'r') as f: f.visititems(_collect_transforms) @@ -85,6 +102,19 @@ def load_transformations(filename: str) -> sc.DataGroup: return sc.DataGroup({group: f[group][()] for group in groups}) +def apply_to_transformations( + dg: sc.DataGroup, func: Callable[[Transform], Transform] +) -> sc.DataGroup: + def apply_nested(node: Any) -> Any: + if isinstance(node, sc.DataGroup): + return node.apply(apply_nested) + if isinstance(node, Transform): + return func(node) + return node + + return dg.apply(apply_nested) + + def as_nested(dg: sc.DataGroup) -> sc.DataGroup: """ Convert a flat DataGroup with paths as keys to a nested DataGroup. @@ -116,7 +146,7 @@ def _maybe_transformation( if obj.attrs.get('transformation_type') is None: return value try: - return Transform(obj, value) + return Transform.from_object(obj, value) except KeyError as e: warnings.warn( UserWarning(f'Invalid transformation, missing attribute {e}'), stacklevel=2 @@ -128,6 +158,7 @@ def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: if '/' not in path: dg[path] = value else: + path = path.lstrip('/') first, remainder = path.split('/', maxsplit=1) if first not in dg: dg[first] = sc.DataGroup() diff --git a/tests/transformations_test.py b/tests/transformations_test.py index 3d09d669..7f443d0e 100644 --- a/tests/transformations_test.py +++ b/tests/transformations_test.py @@ -14,14 +14,14 @@ def test_find_transformation_groups_finds_expected_groups() -> None: filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') paths = transformations.find_transformations(filename) assert paths == [ - 'entry/instrument/larmor_detector/depends_on', - 'entry/instrument/larmor_detector/transformations/trans_1', - 'entry/instrument/monitor_1/depends_on', - 'entry/instrument/monitor_1/transformations/trans_3', - 'entry/instrument/monitor_2/depends_on', - 'entry/instrument/monitor_2/transformations/trans_4', - 'entry/instrument/source/depends_on', - 'entry/instrument/source/transformations/trans_2', + '/entry/instrument/larmor_detector/depends_on', + '/entry/instrument/larmor_detector/transformations/trans_1', + '/entry/instrument/monitor_1/depends_on', + '/entry/instrument/monitor_1/transformations/trans_3', + '/entry/instrument/monitor_2/depends_on', + '/entry/instrument/monitor_2/transformations/trans_4', + '/entry/instrument/source/depends_on', + '/entry/instrument/source/transformations/trans_2', ] @@ -39,6 +39,32 @@ def test_load_transformations_loads_as_flat_datagroup() -> None: assert list(group) == ['depends_on', 'transformations'] +@pytest.mark.externalfile() +def test_apply_to_transformations() -> None: + filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') + dg = transformations.load_transformations(filename) + + def gather_names(t: transformations.Transform) -> transformations.Transform: + applied_to.append(t.name) + return t + + paths = [ + '/entry/instrument/larmor_detector/transformations/trans_1', + '/entry/instrument/monitor_1/transformations/trans_3', + '/entry/instrument/monitor_2/transformations/trans_4', + '/entry/instrument/source/transformations/trans_2', + ] + + applied_to = [] + transformations.apply_to_transformations(dg, gather_names) + assert applied_to == paths + + dg = transformations.as_nested(dg) + applied_to = [] + transformations.apply_to_transformations(dg, gather_names) + assert applied_to == paths + + @pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.externalfile() def test_positions_consistent_with_separate_load() -> None: From 6309d0d457fe2ce36c19bbaacbbc1029d196aa26 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Fri, 27 Sep 2024 13:21:43 +0200 Subject: [PATCH 12/32] Document module --- docs/api-reference/index.md | 12 ++++++++++++ src/scippnexus/transformations.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index 7baa789d..9b065c2e 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -115,3 +115,15 @@ create_class load ``` + + +## Submodules + +```{eval-rst} +.. autosummary:: + :toctree: ../generated/modules + :template: module-template.rst + :recursive: + + transformations +``` diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index fc94d9a8..15d3e4a1 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -1,6 +1,36 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 Scipp contributors (https://github.com/scipp) # @author Simon Heybrock +""" +Utilities for loading and working with NeXus transformations. + +Transformation chains in NeXus files can be non-local and can thus be challenging to +work with. Additionally, values of transformations can be time-dependent, with each +chain link potentially having a different time-dependent value. In practice the user is +interested in the position and orientation of a component at a specific time or time +range. This may involve evaluating the transformation chain at a specific time, or +applying some heuristic to determine if the changes in the transformation value are +significant or just noise. In combination, the above means that we need to remain +flexible in how we handle transformations, preserving all necessary information from +the source files. This module is therefore structured as follows: + +1. :py:class:`Transform` is a dataclass representing a transformation. The raw `value` + dataset is preserved (instead of directly converting to, e.g., a rotation matrix) to + facilitate further processing such as computing the mean or variance. +2. :py:func:`load_transformations` loads transformations from a NeXus file into a flat + :py:class:`scipp.DataGroup`. It can optionally be followed by + :py:func:`as_nested` to convert the flat structure to a nested one. +3. :py:func:`apply_to_transformations` applies a function to each transformation in a + :py:class:`scipp.DataGroup`. We imagine that this can be used to + - Evaluate the transformation at a specific time. + - Apply filters to remove noise, to avoid having to deal with very small time + intervals when processing data. + +By keeping the loaded transformations in a simple and modifiable format, we can +furthermore manually update the transformations with information from other sources, +such as streamed NXlog values received from a data acquisition system. +""" + from __future__ import annotations import warnings From 25832c0e1943d40ffd27762602f981d04b896911 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 10:47:59 +0200 Subject: [PATCH 13/32] Revert "Make maybe_transformation injectable" This reverts commit 433e17eb5b9c1e99ba0dc6c537f5e9636fd025eb. --- src/scippnexus/base.py | 29 +++++++---------------------- src/scippnexus/field.py | 7 +------ src/scippnexus/file.py | 9 ++------- src/scippnexus/nxtransformations.py | 3 ++- src/scippnexus/typing.py | 12 ------------ 5 files changed, 12 insertions(+), 48 deletions(-) diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index c590fd8c..83fba57c 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -19,7 +19,7 @@ from ._common import to_child_select from .attrs import Attrs from .field import Field -from .typing import H5Dataset, H5Group, MaybeTransformation, ScippIndex +from .typing import H5Dataset, H5Group, ScippIndex def asvariable(obj: Any | sc.Variable) -> sc.Variable: @@ -29,6 +29,8 @@ def asvariable(obj: Any | sc.Variable) -> sc.Variable: class NexusStructureError(Exception): """Invalid or unsupported class and field structure in Nexus.""" + pass + def is_dataset(obj: H5Group | H5Dataset) -> bool: """Return true if the object is an h5py.Dataset or equivalent. @@ -200,17 +202,11 @@ class Group(Mapping): # interpretation of the file, but need to cache information. An earlier version # of ScippNexus used such a mechanism without caching, which was very slow. - def __init__( - self, - group: H5Group, - definitions: dict[str, type] | None = None, - maybe_transformation: MaybeTransformation = None, - ): + def __init__(self, group: H5Group, definitions: dict[str, type] | None = None): self._group = group self._definitions = {} if definitions is None else definitions self._lazy_children = None self._lazy_nexus = None - self._maybe_transformation = maybe_transformation @property def nx_class(self) -> type | None: @@ -270,15 +266,9 @@ def _children(self) -> dict[str, Field | Group]: def _read_children(self) -> dict[str, Field | Group]: def _make_child(obj: H5Dataset | H5Group) -> Field | Group: if is_dataset(obj): - return Field( - obj, parent=self, _maybe_transformation=self._maybe_transformation - ) + return Field(obj, parent=self) else: - return Group( - obj, - definitions=self._definitions, - maybe_transformation=self._maybe_transformation, - ) + return Group(obj, definitions=self._definitions) items = {name: _make_child(obj) for name, obj in self._group.items()} # In the case of NXevent_data, the `cue_` fields are unusable, since @@ -427,12 +417,7 @@ def isclass(x): # For a time-dependent transformation in NXtransformations, an NXlog may # take the place of the `value` field. In this case, we need to read the # properties of the NXlog group to make the actual transformation. - from .nxtransformations import maybe_resolve - - if self._maybe_transformation is not None: - maybe_transformation = self._maybe_transformation - else: - from .nxtransformations import maybe_transformation + from .nxtransformations import maybe_resolve, maybe_transformation if ( isinstance(dg, sc.DataGroup) diff --git a/src/scippnexus/field.py b/src/scippnexus/field.py index d349ffa4..601ff475 100644 --- a/src/scippnexus/field.py +++ b/src/scippnexus/field.py @@ -19,7 +19,6 @@ from ._cache import cached_property from .attrs import Attrs -from .typing import MaybeTransformation if TYPE_CHECKING: from .base import Group @@ -85,7 +84,6 @@ class Field: sizes: dict[str, int] | None = None dtype: sc.DType | None = None errors: H5Dataset | None = None - _maybe_transformation: MaybeTransformation = None @cached_property def attrs(self) -> dict[str, Any]: @@ -137,10 +135,7 @@ def __getitem__(self, select: ScippIndex) -> Any | sc.Variable: : Loaded data. """ - if self._maybe_transformation is not None: - maybe_transformation = self._maybe_transformation - else: - from .nxtransformations import maybe_transformation + from .nxtransformations import maybe_transformation index = to_plain_index(self.dims, select) if isinstance(index, int | slice): diff --git a/src/scippnexus/file.py b/src/scippnexus/file.py index fa5d2751..939cec50 100644 --- a/src/scippnexus/file.py +++ b/src/scippnexus/file.py @@ -13,7 +13,7 @@ Group, base_definitions, ) -from .typing import Definitions, MaybeTransformation +from .typing import Definitions class File(AbstractContextManager, Group): @@ -22,7 +22,6 @@ def __init__( name: str | os.PathLike[str] | io.BytesIO | h5py.Group, *args, definitions: Definitions | DefaultDefinitionsType = DefaultDefinitions, - maybe_transformation: MaybeTransformation = None, **kwargs, ): """Context manager for NeXus files, similar to h5py.File. @@ -51,11 +50,7 @@ def __init__( else: self._file = h5py.File(name, *args, **kwargs) self._manage_file_context = True - super().__init__( - self._file, - definitions=definitions, - maybe_transformation=maybe_transformation, - ) + super().__init__(self._file, definitions=definitions) def __enter__(self): if self._manage_file_context: diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index a628077e..52ab6310 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -10,7 +10,7 @@ import scipp as sc from scipp.scipy import interpolate -from .base import Group, NXobject, ScippIndex +from .base import Group, NexusStructureError, NXobject, ScippIndex from .field import Field, depends_on_to_relative_path from .transformations import Transform, TransformationError @@ -204,6 +204,7 @@ def maybe_transformation( """ if (transformation_type := obj.attrs.get('transformation_type')) is None: return value + # return Transform(obj, value) transform = Transformation(obj).make_transformation( value, transformation_type=transformation_type, select=sel ) diff --git a/src/scippnexus/typing.py b/src/scippnexus/typing.py index c51f3661..339071c7 100644 --- a/src/scippnexus/typing.py +++ b/src/scippnexus/typing.py @@ -6,8 +6,6 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Protocol -import scipp as sc - class H5Base(Protocol): @property @@ -68,8 +66,6 @@ def visititems(self, func: Callable) -> None: class ellipsis(Enum): Ellipsis = "..." - from .base import Field, Group - else: ellipsis = type(Ellipsis) @@ -80,11 +76,3 @@ class ellipsis(Enum): ) Definitions = Mapping[str, type] - -MaybeTransformation = ( - Callable[ - ['Field | Group', sc.Variable | sc.DataArray | sc.DataGroup, ScippIndex], - sc.Variable | sc.DataArray | sc.DataGroup, - ] - | None -) From 86fe7eccd5d3674b2bade0270da1c60f59c6846f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 11:19:03 +0200 Subject: [PATCH 14/32] Resolve chains and load as transformations.Transform --- src/scippnexus/base.py | 14 +-- src/scippnexus/nxtransformations.py | 179 ++++++---------------------- 2 files changed, 37 insertions(+), 156 deletions(-) diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index 83fba57c..a9e66a74 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -242,19 +242,11 @@ def unit(self) -> sc.Unit | None: @property def parent(self) -> Group: - return Group( - self._group.parent, - definitions=self._definitions, - maybe_transformation=self._maybe_transformation, - ) + return Group(self._group.parent, definitions=self._definitions) @cached_property def file(self) -> Group: - return Group( - self._group.file, - definitions=self._definitions, - maybe_transformation=self._maybe_transformation, - ) + return Group(self._group.file, definitions=self._definitions) @property def _children(self) -> dict[str, Field | Group]: @@ -424,7 +416,7 @@ def isclass(x): and (depends_on := dg.get('depends_on')) is not None ): if (resolved := maybe_resolve(self['depends_on'], depends_on)) is not None: - dg['resolved_depends_on'] = resolved + dg['resolved_transformations'] = resolved return maybe_transformation(self, value=dg, sel=sel) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 52ab6310..6a6683ea 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -10,111 +10,18 @@ import scipp as sc from scipp.scipy import interpolate -from .base import Group, NexusStructureError, NXobject, ScippIndex -from .field import Field, depends_on_to_relative_path -from .transformations import Transform, TransformationError +from .base import Group, NXobject, ScippIndex +from .field import Field +from .transformations import Transform + +# TODO skip loading?! +# TODO convert depends_on to absolute path! class NXtransformations(NXobject): """Group of transformations.""" -class Transformation: - def __init__(self, obj: Field | Group): # could be an NXlog - self._obj = obj - - @property - def attrs(self): - return self._obj.attrs - - @property - def name(self): - return self._obj.name - - @property - def offset(self): - if (offset := self.attrs.get('offset')) is None: - return None - if (offset_units := self.attrs.get('offset_units')) is None: - raise TransformationError( - f"Found {offset=} but no corresponding 'offset_units' " - f"attribute at {self.name}" - ) - return sc.spatial.translation(value=offset, unit=offset_units) - - @property - def vector(self) -> sc.Variable: - if self.attrs.get('vector') is None: - raise TransformationError('A transformation needs a vector attribute.') - return sc.vector(value=self.attrs.get('vector')) - - def __getitem__(self, select: ScippIndex): - transformation_type = self.attrs.get('transformation_type') - # According to private communication with Tobias Richter, NeXus allows 0-D or - # shape=[1] for single values. It is unclear how and if this could be - # distinguished from a scan of length 1. - value = self._obj[select] - return self.make_transformation( - value, transformation_type=transformation_type, select=select - ) - - def make_transformation( - self, - value: sc.Variable | sc.DataArray, - *, - transformation_type: str, - select: ScippIndex, - ): - try: - if isinstance(value, sc.DataGroup) and ( - isinstance(value.get('value'), sc.DataArray) - ): - # Some NXlog groups are split into value, alarm, and connection_status - # sublogs. We only care about the value. - value = value['value'] - if isinstance(value, sc.DataGroup): - return value - t = value * self.vector - v = t if isinstance(t, sc.Variable) else t.data - if transformation_type == 'translation': - v = sc.spatial.translations(dims=v.dims, values=v.values, unit=v.unit) - elif transformation_type == 'rotation': - v = sc.spatial.rotations_from_rotvecs(v) - else: - raise TransformationError( - f"{transformation_type=} attribute at {self.name}," - " expected 'translation' or 'rotation'." - ) - if isinstance(t, sc.Variable): - t = v - else: - t.data = v - if (offset := self.offset) is None: - transform = t - else: - offset = sc.vector(value=offset.values, unit=offset.unit) - offset = sc.spatial.translation(value=offset.value, unit=offset.unit) - if transformation_type == 'translation': - offset = offset.to(unit=t.unit, copy=False) - transform = t * offset - if (depends_on := self.attrs.get('depends_on')) is not None: - if not isinstance(transform, sc.DataArray): - transform = sc.DataArray(transform) - transform.coords['depends_on'] = sc.scalar( - depends_on_to_relative_path(depends_on, self._obj.parent.name) - ) - return transform - except (sc.DimensionError, sc.UnitError, TransformationError) as e: - msg = ( - f"Failed to convert {self.name} into a transformation: {e} " - "Falling back to returning underlying value." - ) - warnings.warn(msg, stacklevel=2) - # TODO We should probably try to return some other data structure and - # also insert offset and other attributes. - return value - - def _interpolate_transform(transform, xnew): # scipy can't interpolate with a single value if transform.sizes["time"] == 1: @@ -156,6 +63,8 @@ def combine_transformations( ) total_transform = None for transform in chain: + if transform.dtype in (sc.DType.translation3, sc.DType.affine_transform3): + transform = transform.to(unit='m', copy=False) if total_transform is None: total_transform = transform elif isinstance(total_transform, sc.DataArray) and isinstance( @@ -202,52 +111,33 @@ def maybe_transformation( Instead we use the presence of the attribute 'transformation_type' to identify transformation fields. """ - if (transformation_type := obj.attrs.get('transformation_type')) is None: + if obj.attrs.get('transformation_type') is None: + return value + try: + return Transform.from_object(obj, value) + except KeyError as e: + warnings.warn( + UserWarning(f'Invalid transformation, missing attribute {e}'), stacklevel=2 + ) return value - # return Transform(obj, value) - transform = Transformation(obj).make_transformation( - value, transformation_type=transformation_type, select=sel - ) - # When loading a subgroup of a file there can be transformation chains - # that lead outside the loaded group. In this case we cannot resolve the - # chain after loading, so we try to resolve it directly. - return assign_resolved(obj, transform) - - -def assign_resolved( - obj: Field | Group, - transform: sc.DataArray | sc.Variable, - force_resolve: bool = False, -) -> sc.DataArray | sc.Variable: - """Add resolved_depends_on coord to a transformation if resolve is performed.""" - if ( - isinstance(transform, sc.DataArray) - and (depends_on := transform.coords.get('depends_on')) is not None - ): - if ( - resolved := maybe_resolve( - obj, depends_on.value, force_resolve=force_resolve - ) - ) is not None: - transform.coords["resolved_depends_on"] = sc.scalar(resolved) - return transform def maybe_resolve( - obj: Field | Group, depends_on: str, force_resolve: bool = False + obj: Field | Group, depends_on: str ) -> sc.DataArray | sc.Variable | None: """Conditionally resolve a depend_on attribute.""" - relative = depends_on_to_relative_path(depends_on, obj.parent.name) - if (force_resolve or relative.startswith('..')) and depends_on != '.': - try: - target = obj.parent[depends_on] - resolved = target[()] - except Exception: # noqa: S110 - # Catchall since resolving not strictly necessary, we should not - # fail the rest of the loading process. - pass - else: - return assign_resolved(target, resolved, force_resolve=True) + transforms = sc.DataGroup() + parent = obj.parent + try: + while depends_on != '.': + transform = parent[depends_on] + parent = transform.parent + depends_on = transform.attrs['depends_on'] + transforms[transform.name] = transform[()] + except KeyError as e: + warnings.warn(UserWarning(f'{obj.name=} missing {e}'), stacklevel=2) + return None + return transforms class TransformationChainResolver: @@ -330,10 +220,7 @@ def resolve_depends_on(self) -> sc.DataArray | sc.Variable | None: : The resolved position in meter, or None if no depends_on was found. """ - if 'resolved_depends_on' in self.value: - depends_on = self.value['resolved_depends_on'] - else: - depends_on = self.value.get('depends_on') + depends_on = self.value.get('depends_on') if depends_on is None: return None # Note that transformations have to be applied in "reverse" order, i.e., @@ -477,8 +364,10 @@ def _with_positions( transform = None if 'depends_on' in dg: try: - transform = resolver.resolve_depends_on() - except TransformationChainResolver.ChainError as e: + chain = list(dg['resolved_transformations'].values()) + # TODO chain should be correct as is, but could add consistency check + transform = combine_transformations([t.build() for t in chain]) + except KeyError as e: warnings.warn( UserWarning(f'depends_on chain references missing node:\n{e}'), stacklevel=2, From cc0c0b97f4fd93f9182b9df30ad2ff19acbbbc11 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 11:47:09 +0200 Subject: [PATCH 15/32] Add class DependsOn --- src/scippnexus/field.py | 28 ++++++++++++++--------- src/scippnexus/nxtransformations.py | 35 +++++++++++++---------------- src/scippnexus/transformations.py | 8 +++---- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/src/scippnexus/field.py b/src/scippnexus/field.py index 601ff475..f0d5c820 100644 --- a/src/scippnexus/field.py +++ b/src/scippnexus/field.py @@ -24,14 +24,22 @@ from .base import Group -def depends_on_to_relative_path(depends_on: str, parent_path: str) -> str: - """Replace depends_on paths with relative paths. +@dataclass +class DependsOn: + """ + Represents a depends_on field in a NeXus file. + + The parent (the full path within the NeXus file) is stored as the value may be + relative or absolute, so having the path available after loading is essential. + """ + + parent: str + value: str - After loading we will generally not have the same root so absolute paths - cannot be resolved after loading.""" - if depends_on.startswith('/'): - return posixpath.relpath(depends_on, parent_path) - return depends_on + def absolute_path(self) -> str | None: + if self.value == '.': + return None + return posixpath.normpath(posixpath.join(self.parent, self.value)) def _is_time(obj): @@ -170,10 +178,8 @@ def __getitem__(self, select: ScippIndex) -> Any | sc.Variable: strings = self.dataset.asstr(encoding='latin-1')[index] _warn_latin1_decode(self.dataset, strings, str(e)) variable.values = np.asarray(strings).flatten() - if self.dataset.name.endswith('depends_on') and variable.ndim == 0: - variable.value = depends_on_to_relative_path( - variable.value, self.dataset.parent.name - ) + if self.dataset.name.endswith('/depends_on') and variable.ndim == 0: + return DependsOn(parent=self.dataset.parent.name, value=variable.value) elif variable.values.flags["C_CONTIGUOUS"]: # On versions of h5py prior to 3.2, a TypeError occurs in some cases # where h5py cannot broadcast data with e.g. shape (20, 1) to a buffer diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 6a6683ea..7f2439b4 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -11,7 +11,7 @@ from scipp.scipy import interpolate from .base import Group, NXobject, ScippIndex -from .field import Field +from .field import DependsOn, Field from .transformations import Transform # TODO skip loading?! @@ -123,11 +123,12 @@ def maybe_transformation( def maybe_resolve( - obj: Field | Group, depends_on: str + obj: Field | Group, depends_on: DependsOn ) -> sc.DataArray | sc.Variable | None: """Conditionally resolve a depend_on attribute.""" transforms = sc.DataGroup() parent = obj.parent + depends_on = depends_on.value try: while depends_on != '.': transform = parent[depends_on] @@ -310,16 +311,11 @@ def compute_positions( : New data group with added positions. """ - # Create resolver at root level, since any depends_on chain may lead to a parent, - # i.e., we cannot use a resolver at the level of each chain's entry point. - # TODO need to be able to set root, would be better to construct resolver outside, - # see we can navigate to correct path? - resolver = TransformationChainResolver.from_root(transformations or dg) return _with_positions( dg, store_position=store_position, store_transform=store_transform, - resolver=resolver, + transformations=transformations, ) @@ -358,14 +354,19 @@ def _with_positions( *, store_position: str, store_transform: str | None = None, - resolver: TransformationChainResolver, + transformations: sc.DataGroup | None = None, ) -> sc.DataGroup: out = sc.DataGroup() transform = None - if 'depends_on' in dg: + transformations = transformations or dg.get('resolved_transformations', {}) + if (depends_on := dg.get('depends_on')) is not None: + path = depends_on.absolute_path() try: - chain = list(dg['resolved_transformations'].values()) - # TODO chain should be correct as is, but could add consistency check + chain = [] + while path is not None: + transform = transformations[path] + chain.append(transform) + path = transform.depends_on.absolute_path() transform = combine_transformations([t.build() for t in chain]) except KeyError as e: warnings.warn( @@ -377,15 +378,9 @@ def _with_positions( if store_transform is not None: out[store_transform] = transform for name, value in dg.items(): - # If the resolver was constructed from an external tree of transformations it - # will not contain groups that do not contain any transformations or depends_on - # field. Do not descend into such groups. - if isinstance(value, sc.DataGroup) and name in resolver.value: + if isinstance(value, sc.DataGroup): value = _with_positions( - value, - store_position=store_position, - store_transform=store_transform, - resolver=resolver[name], + value, store_position=store_position, store_transform=store_transform ) elif ( isinstance(value, sc.DataArray) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 15d3e4a1..494768d7 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -42,7 +42,7 @@ import scipp as sc from .base import Group, NexusStructureError, ScippIndex -from .field import Field, depends_on_to_relative_path +from .field import DependsOn, Field from .file import File from .typing import H5Base @@ -57,7 +57,7 @@ class Transform: transformation_type: Literal['translation', 'rotation'] value: sc.DataArray | sc.Variable vector: sc.Variable - depends_on: str + depends_on: DependsOn offset: sc.Variable | None def __post_init__(self): @@ -71,9 +71,7 @@ def __post_init__(self): def from_object( obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup ) -> Transform: - depends_on = depends_on_to_relative_path( - obj.attrs['depends_on'], obj.parent.name - ) + depends_on = DependsOn(parent=obj.parent.name, value=obj.attrs['depends_on']) return Transform( name=obj.name, transformation_type=obj.attrs.get('transformation_type'), From 189583034e7f627f2dbbe8437b38efcb64f2f24a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 11:48:44 +0200 Subject: [PATCH 16/32] Remove unused TransformationChainResolver --- src/scippnexus/nxtransformations.py | 126 +--------------------------- tests/nxtransformations_test.py | 99 +--------------------- 2 files changed, 2 insertions(+), 223 deletions(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 7f2439b4..01cd1d9f 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) # @author Simon Heybrock from __future__ import annotations import warnings -from dataclasses import dataclass import numpy as np import scipp as sc @@ -15,7 +14,6 @@ from .transformations import Transform # TODO skip loading?! -# TODO convert depends_on to absolute path! class NXtransformations(NXobject): @@ -141,128 +139,6 @@ def maybe_resolve( return transforms -class TransformationChainResolver: - """ - Resolve a chain of transformations, given depends_on attributes with absolute or - relative paths. - - A `depends_on` field serves as an entry point into a chain of transformations. - It points to another entry, based on an absolute or relative path. The target - entry may have a `depends_on` attribute pointing to the next transform. This - class follows the paths and resolves the chain of transformations. - """ - - class ChainError(KeyError): - """Raised when a transformation chain cannot be resolved.""" - - @dataclass - class Entry: - name: str - value: sc.DataGroup - - def __init__(self, stack: list[TransformationChainResolver.Entry]): - self._stack = stack - - @staticmethod - def from_root(dg: sc.DataGroup) -> TransformationChainResolver: - return TransformationChainResolver( - [TransformationChainResolver.Entry(name='', value=dg)] - ) - - @property - def name(self) -> str: - return '/'.join([e.name for e in self._stack]) - - @property - def root(self) -> TransformationChainResolver: - return TransformationChainResolver(self._stack[0:1]) - - @property - def parent(self) -> TransformationChainResolver: - if len(self._stack) == 1: - raise TransformationChainResolver.ChainError( - "Transformation depends on node beyond root" - ) - return TransformationChainResolver(self._stack[:-1]) - - @property - def value(self) -> sc.DataGroup: - return self._stack[-1].value - - def __getitem__(self, path: str) -> TransformationChainResolver: - base, *remainder = path.split('/', maxsplit=1) - if base == '': - node = self.root - elif base == '.': - node = self - elif base == '..': - node = self.parent - else: - try: - child = self._stack[-1].value[base] - except KeyError: - raise TransformationChainResolver.ChainError( - f"{base} not found in {self.name}" - ) from None - node = TransformationChainResolver( - [ - *self._stack, - TransformationChainResolver.Entry(name=base, value=child), - ] - ) - return node if len(remainder) == 0 else node[remainder[0]] - - def resolve_depends_on(self) -> sc.DataArray | sc.Variable | None: - """ - Resolve the depends_on attribute of a transformation chain. - - Returns - ------- - : - The resolved position in meter, or None if no depends_on was found. - """ - depends_on = self.value.get('depends_on') - if depends_on is None: - return None - # Note that transformations have to be applied in "reverse" order, i.e., - # simply taking math.prod(chain) would be wrong, even if we could - # ignore potential time-dependence. - return combine_transformations(self.get_chain(depends_on)) - - def get_chain( - self, depends_on: str | sc.DataArray | sc.Variable - ) -> list[sc.DataArray | sc.Variable]: - if depends_on == '.': - return [] - new_style_transform = False - if isinstance(depends_on, str): - node = self[depends_on] - if isinstance(node.value, Transform): - transform = node.value.build().copy(deep=False) - depends_on = node.value.depends_on - new_style_transform = True - else: - transform = node.value.copy(deep=False) - depends_on = '.' - node = node.parent - else: - # Fake node, resolved_depends_on is recursive so this is actually ignored. - node = self - transform = depends_on - depends_on = '.' - if transform.dtype in (sc.DType.translation3, sc.DType.affine_transform3): - transform = transform.to(unit='m', copy=False) - if not new_style_transform and isinstance(transform, sc.DataArray): - if (attr := transform.coords.pop('resolved_depends_on', None)) is not None: - depends_on = attr.value - elif (attr := transform.coords.pop('depends_on', None)) is not None: - depends_on = attr.value - # If transform is time-dependent then we keep it is a DataArray, otherwise - # we convert it to a Variable. - transform = transform if 'time' in transform.coords else transform.data - return [transform, *node.get_chain(depends_on)] - - def compute_positions( dg: sc.DataGroup, *, diff --git a/tests/nxtransformations_test.py b/tests/nxtransformations_test.py index 5a36d2dd..ea987aaa 100644 --- a/tests/nxtransformations_test.py +++ b/tests/nxtransformations_test.py @@ -5,7 +5,7 @@ from scipp.testing import assert_identical import scippnexus as snx -from scippnexus.nxtransformations import NXtransformations, TransformationChainResolver +from scippnexus.nxtransformations import NXtransformations def make_group(group: h5py.Group) -> snx.Group: @@ -513,108 +513,11 @@ def test_label_slice_transformations(h5root): ) -def test_TransformationChainResolver_path_handling(): - tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) - assert tree['a']['b']['c'].value == 1 - assert tree['a/b/c'].value == 1 - assert tree['/a/b/c'].value == 1 - assert tree['a']['../a/b/c'].value == 1 - assert tree['a/b']['../../a/b/c'].value == 1 - assert tree['a/b']['./c'].value == 1 - - -def test_TransformationChainResolver_name(): - tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) - assert tree['a']['b']['c'].name == '/a/b/c' - assert tree['a/b/c'].name == '/a/b/c' - assert tree['/a/b/c'].name == '/a/b/c' - assert tree['a']['../a/b/c'].name == '/a/b/c' - assert tree['a/b']['../../a/b/c'].name == '/a/b/c' - assert tree['a/b']['./c'].name == '/a/b/c' - - -def test_TransformationChainResolver_raises_ChainError_if_child_does_not_exists(): - tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) - with pytest.raises(TransformationChainResolver.ChainError): - tree['a']['b']['d'] - - -def test_TransformationChainResolver_raises_ChainError_if_path_leads_beyond_root(): - tree = TransformationChainResolver.from_root({'a': {'b': {'c': 1}}}) - with pytest.raises(TransformationChainResolver.ChainError): - tree['..'] - with pytest.raises(TransformationChainResolver.ChainError): - tree['a']['../..'] - with pytest.raises(TransformationChainResolver.ChainError): - tree['../a'] - - origin = sc.vector([0, 0, 0], unit='m') shiftX = sc.spatial.translation(value=[1, 0, 0], unit='m') rotZ = sc.spatial.rotations_from_rotvecs(sc.vector([0, 0, 90], unit='deg')) -def test_resolve_depends_on_dot(): - tree = TransformationChainResolver.from_root({'depends_on': '.'}) - assert sc.identical(tree.resolve_depends_on() * origin, origin) - - -def test_resolve_depends_on_child(): - transform = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver.from_root( - {'depends_on': 'child', 'child': transform} - ) - expected = sc.vector([1, 0, 0], unit='m') - assert sc.identical(tree.resolve_depends_on() * origin, expected) - - -def test_resolve_depends_on_grandchild(): - transform = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver.from_root( - {'depends_on': 'child/grandchild', 'child': {'grandchild': transform}} - ) - expected = sc.vector([1, 0, 0], unit='m') - assert sc.identical(tree.resolve_depends_on() * origin, expected) - - -def test_resolve_depends_on_child1_depends_on_child2(): - transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('child2')}) - transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver.from_root( - {'depends_on': 'child1', 'child1': transform1, 'child2': transform2} - ) - # Note order - expected = transform2.data * transform1.data - assert sc.identical(tree.resolve_depends_on(), expected) - - -def test_resolve_depends_on_grandchild1_depends_on_grandchild2(): - transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('grandchild2')}) - transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver.from_root( - { - 'depends_on': 'child/grandchild1', - 'child': {'grandchild1': transform1, 'grandchild2': transform2}, - } - ) - expected = transform2.data * transform1.data - assert sc.identical(tree.resolve_depends_on(), expected) - - -def test_resolve_depends_on_grandchild1_depends_on_child2(): - transform1 = sc.DataArray(shiftX, coords={'depends_on': sc.scalar('../child2')}) - transform2 = sc.DataArray(rotZ, coords={'depends_on': sc.scalar('.')}) - tree = TransformationChainResolver.from_root( - { - 'depends_on': 'child1/grandchild1', - 'child1': {'grandchild1': transform1}, - 'child2': transform2, - } - ) - expected = transform2.data * transform1.data - assert sc.identical(tree.resolve_depends_on(), expected) - - def test_compute_positions(h5root): instrument = snx.create_class(h5root, 'instrument', snx.NXinstrument) detector = create_detector(instrument) From 9f794c650e311aec84e6774b84fe1ffa20c6aeb5 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 11:52:20 +0200 Subject: [PATCH 17/32] Remove unused arg --- src/scippnexus/base.py | 2 +- src/scippnexus/field.py | 4 ++-- src/scippnexus/nxtransformations.py | 6 ++---- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index a9e66a74..88499428 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -418,7 +418,7 @@ def isclass(x): if (resolved := maybe_resolve(self['depends_on'], depends_on)) is not None: dg['resolved_transformations'] = resolved - return maybe_transformation(self, value=dg, sel=sel) + return maybe_transformation(self, value=dg) def _warn_fallback(self, e: Exception) -> None: msg = ( diff --git a/src/scippnexus/field.py b/src/scippnexus/field.py index f0d5c820..de1c0da0 100644 --- a/src/scippnexus/field.py +++ b/src/scippnexus/field.py @@ -169,7 +169,7 @@ def __getitem__(self, select: ScippIndex) -> Any | sc.Variable: # If the variable is empty, return early if np.prod(shape) == 0: variable = self._maybe_datetime(variable) - return maybe_transformation(self, value=variable, sel=select) + return maybe_transformation(self, value=variable) if self.dtype == sc.DType.string: try: @@ -205,7 +205,7 @@ def __getitem__(self, select: ScippIndex) -> Any | sc.Variable: else: return variable.value variable = self._maybe_datetime(variable) - return maybe_transformation(self, value=variable, sel=select) + return maybe_transformation(self, value=variable) def _maybe_datetime(self, variable: sc.Variable) -> sc.Variable: if _is_time(variable): diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 01cd1d9f..19454c3a 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -9,7 +9,7 @@ import scipp as sc from scipp.scipy import interpolate -from .base import Group, NXobject, ScippIndex +from .base import Group, NXobject from .field import DependsOn, Field from .transformations import Transform @@ -95,9 +95,7 @@ def combine_transformations( def maybe_transformation( - obj: Field | Group, - value: sc.Variable | sc.DataArray | sc.DataGroup, - sel: ScippIndex, + obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup ) -> sc.Variable | sc.DataArray | sc.DataGroup: """ Return a loaded field, possibly modified if it is a transformation. From 916f959a95b8b6c193b013211762ad92980af9c1 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 11:56:37 +0200 Subject: [PATCH 18/32] Simplify --- src/scippnexus/nxtransformations.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 19454c3a..80c3a779 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -231,16 +231,13 @@ def _with_positions( transformations: sc.DataGroup | None = None, ) -> sc.DataGroup: out = sc.DataGroup() - transform = None transformations = transformations or dg.get('resolved_transformations', {}) if (depends_on := dg.get('depends_on')) is not None: - path = depends_on.absolute_path() try: chain = [] - while path is not None: - transform = transformations[path] - chain.append(transform) - path = transform.depends_on.absolute_path() + while (path := depends_on.absolute_path()) is not None: + chain.append(transformations[path]) + depends_on = chain[-1].depends_on transform = combine_transformations([t.build() for t in chain]) except KeyError as e: warnings.warn( From e89b5960f4553dd44dfd80ca0653432cdd9072bc Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 12:30:29 +0200 Subject: [PATCH 19/32] Update tests --- src/scippnexus/nxtransformations.py | 7 +- src/scippnexus/transformations.py | 30 ++---- tests/nxtransformations_test.py | 138 +++++++++++++--------------- 3 files changed, 75 insertions(+), 100 deletions(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 80c3a779..4ddfe3cc 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -132,7 +132,9 @@ def maybe_resolve( depends_on = transform.attrs['depends_on'] transforms[transform.name] = transform[()] except KeyError as e: - warnings.warn(UserWarning(f'{obj.name=} missing {e}'), stacklevel=2) + warnings.warn( + UserWarning(f'depends_on chain references missing node {e}'), stacklevel=2 + ) return None return transforms @@ -178,7 +180,8 @@ def compute_positions( If not None, store the resolved transformation chain in this field. transformations: Optional data group containing transformation chains. If not provided, the - transformations are looked up in the input data group. + transformations are looked up in the 'resolved_transformations' subgroups of the + input data group. Returns ------- diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 494768d7..1e8cae67 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -33,7 +33,6 @@ from __future__ import annotations -import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Literal @@ -41,7 +40,7 @@ import h5py import scipp as sc -from .base import Group, NexusStructureError, ScippIndex +from .base import Group, NexusStructureError from .field import DependsOn, Field from .file import File from .typing import H5Base @@ -55,7 +54,7 @@ class TransformationError(NexusStructureError): class Transform: name: str transformation_type: Literal['translation', 'rotation'] - value: sc.DataArray | sc.Variable + value: sc.Variable | sc.DataArray | sc.DataGroup vector: sc.Variable depends_on: DependsOn offset: sc.Variable | None @@ -75,8 +74,8 @@ def from_object( return Transform( name=obj.name, transformation_type=obj.attrs.get('transformation_type'), - value=_parse_value(obj, value), - vector=sc.vector(value=obj.attrs.get('vector')), + value=_parse_value(value), + vector=sc.vector(value=obj.attrs['vector']), depends_on=depends_on, offset=_parse_offset(obj), ) @@ -126,7 +125,7 @@ def load_transformations(filename: str) -> sc.DataGroup: A flat DataGroup with the transformations and depends_on fields. """ groups = find_transformations(filename) - with File(filename, mode='r', maybe_transformation=_maybe_transformation) as f: + with File(filename, mode='r') as f: return sc.DataGroup({group: f[group][()] for group in groups}) @@ -166,20 +165,6 @@ def as_nested(dg: sc.DataGroup) -> sc.DataGroup: return out -def _maybe_transformation( - obj: Field | Group, - value: sc.Variable | sc.DataArray | sc.DataGroup, - sel: ScippIndex, -) -> sc.Variable | sc.DataArray | sc.DataGroup: - if obj.attrs.get('transformation_type') is None: - return value - try: - return Transform.from_object(obj, value) - except KeyError as e: - warnings.warn( - UserWarning(f'Invalid transformation, missing attribute {e}'), stacklevel=2 - ) - return value def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: @@ -205,15 +190,12 @@ def _parse_offset(obj: Field | Group) -> sc.Variable | None: def _parse_value( - obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup, -) -> sc.Variable | sc.DataArray: +) -> sc.Variable | sc.DataArray | sc.DataGroup: if isinstance(value, sc.DataGroup) and ( isinstance(value.get('value'), sc.DataArray) ): # Some NXlog groups are split into value, alarm, and connection_status # sublogs. We only care about the value. value = value['value'] - if not isinstance(value, sc.Variable | sc.DataArray): - raise TransformationError(f"Failed to load transformation value at {obj.name}") return value diff --git a/tests/nxtransformations_test.py b/tests/nxtransformations_test.py index ea987aaa..db2d6b35 100644 --- a/tests/nxtransformations_test.py +++ b/tests/nxtransformations_test.py @@ -32,9 +32,7 @@ def create_detector(group): def test_Transformation_with_single_value(h5root): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) value = sc.scalar(6.5, unit='mm') offset = sc.spatial.translation(value=[1, 2, 3], unit='mm') @@ -49,19 +47,16 @@ def test_Transformation_with_single_value(h5root): value.attrs['offset_units'] = str(offset.unit) value.attrs['vector'] = vector.value - expected = sc.DataArray(data=expected, coords={'depends_on': sc.scalar('.')}) detector = make_group(detector) - depends_on = detector['depends_on'][()] + depends_on = detector['depends_on'][()].value assert depends_on == 'transformations/t1' - t = detector[depends_on][()] + t = detector[depends_on][()].build() assert_identical(t, expected) def test_time_independent_Transformation_with_length_0(h5root): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) value = sc.array(dims=['dim_0'], values=[], unit='mm') offset = sc.spatial.translation(value=[1, 2, 3], unit='mm') @@ -76,20 +71,25 @@ def test_time_independent_Transformation_with_length_0(h5root): value.attrs['offset_units'] = str(offset.unit) value.attrs['vector'] = vector.value - expected = sc.DataArray(data=expected, coords={'depends_on': sc.scalar('.')}) detector = make_group(detector) - depends_on = detector['depends_on'][()] + depends_on = detector['depends_on'][()].value assert depends_on == 'transformations/t1' - t = detector[depends_on][()] + t = detector[depends_on][()].build() assert_identical(t, expected) -def test_depends_on_absolute_path_to_sibling_group_resolved_to_relative_path(h5root): +def test_depends_on_absolute_path_to_sibling_group_resolved_correctly(h5root): det1 = snx.create_class(h5root, 'det1', NXtransformations) snx.create_field(det1, 'depends_on', sc.scalar('/det2/transformations/t1')) + depends_on = make_group(det1)['depends_on'][()] + assert depends_on.absolute_path() == '/det2/transformations/t1' + +def test_depends_on_relative_path_to_sibling_group_resolved_correctly(h5root): + det1 = snx.create_class(h5root, 'det1', NXtransformations) + snx.create_field(det1, 'depends_on', sc.scalar('../det2/transformations/t1')) depends_on = make_group(det1)['depends_on'][()] - assert depends_on == '../det2/transformations/t1' + assert depends_on.absolute_path() == '/det2/transformations/t1' def test_depends_on_relative_path_unchanged(h5root): @@ -97,12 +97,10 @@ def test_depends_on_relative_path_unchanged(h5root): snx.create_field(det1, 'depends_on', sc.scalar('transformations/t1')) depends_on = make_group(det1)['depends_on'][()] - assert depends_on == 'transformations/t1' + assert depends_on.value == 'transformations/t1' -def test_depends_on_attr_absolute_path_to_sibling_group_resolved_to_relative_path( - h5root, -): +def test_depends_on_attr_absolute_path_to_sibling_group_preserved(h5root): det1 = snx.create_class(h5root, 'det1', NXtransformations) transformations = snx.create_class(det1, 'transformations', NXtransformations) t1 = snx.create_field(transformations, 't1', sc.scalar(0.1, unit='cm')) @@ -111,7 +109,7 @@ def test_depends_on_attr_absolute_path_to_sibling_group_resolved_to_relative_pat t1.attrs['vector'] = [0, 0, 1] loaded = make_group(det1)['transformations/t1'][()] - assert loaded.coords['depends_on'].value == '../../det2/transformations/t2' + assert loaded.depends_on.value == '/det2/transformations/t2' def test_depends_on_attr_relative_path_unchanged(h5root): @@ -123,17 +121,15 @@ def test_depends_on_attr_relative_path_unchanged(h5root): t1.attrs['vector'] = [0, 0, 1] loaded = make_group(det)['transformations/t1'][()] - assert loaded.coords['depends_on'].value == '.' + assert loaded.depends_on.value == '.' t1.attrs['depends_on'] = 't2' loaded = make_group(det)['transformations/t1'][()] - assert loaded.coords['depends_on'].value == 't2' + assert loaded.depends_on.value == 't2' def test_chain_with_single_values_and_different_unit(h5root): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) value = sc.scalar(6.5, unit='mm') offset = sc.spatial.translation(value=[1, 2, 3], unit='mm') @@ -157,19 +153,17 @@ def test_chain_with_single_values_and_different_unit(h5root): detector = make_group(h5root['detector_0']) loaded = detector[()] depends_on = loaded['depends_on'] - assert depends_on == 'transformations/t1' + assert depends_on.value == 'transformations/t1' transforms = loaded['transformations'] - assert_identical(transforms['t1'].data, t1) - assert transforms['t1'].coords['depends_on'].value == 't2' - assert_identical(transforms['t2'].data, t2) - assert transforms['t2'].coords['depends_on'].value == '.' + assert_identical(transforms['t1'].build(), t1) + assert transforms['t1'].depends_on.value == 't2' + assert_identical(transforms['t2'].build(), t2) + assert transforms['t2'].depends_on.value == '.' def test_Transformation_with_multiple_values(h5root): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) log = sc.DataArray( sc.array(dims=['time'], values=[1.1, 2.2], unit='m'), @@ -190,18 +184,15 @@ def test_Transformation_with_multiple_values(h5root): value.attrs['vector'] = vector.value expected = t * offset - expected.coords['depends_on'] = sc.scalar('.') detector = make_group(detector) depends_on = detector['depends_on'][()] - assert depends_on == 'transformations/t1' - assert_identical(detector[depends_on][()], expected) + assert depends_on.value == 'transformations/t1' + assert_identical(detector[depends_on.absolute_path()][()].build(), expected) def test_time_dependent_transform_uses_value_sublog(h5root): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) log = sc.DataArray( sc.array(dims=['time'], values=[1.1, 2.2], unit='m'), @@ -229,18 +220,15 @@ def test_time_dependent_transform_uses_value_sublog(h5root): value.attrs['vector'] = vector.value expected = t * offset - expected.coords['depends_on'] = sc.scalar('.') detector = make_group(detector) depends_on = detector['depends_on'][()] - assert depends_on == 'transformations/t1' - assert_identical(detector[depends_on][()], expected) + assert depends_on.value == 'transformations/t1' + assert_identical(detector[depends_on.absolute_path()][()].build(), expected) def test_chain_with_multiple_values(h5root): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) log = sc.DataArray( sc.array(dims=['time'], values=[1.1, 2.2], unit='m'), @@ -267,21 +255,17 @@ def test_chain_with_multiple_values(h5root): value2.attrs['vector'] = vector.value expected1 = t * offset - expected1.coords['depends_on'] = sc.scalar('t2') expected2 = t - expected2.coords['depends_on'] = sc.scalar('.') detector = make_group(detector)[()] depends_on = detector['depends_on'] - assert depends_on == 'transformations/t1' - assert_identical(detector['transformations']['t1'], expected1) - assert_identical(detector['transformations']['t2'], expected2) + assert depends_on.value == 'transformations/t1' + assert_identical(detector['transformations']['t1'].build(), expected1) + assert_identical(detector['transformations']['t2'].build(), expected2) def test_chain_with_multiple_values_and_different_time_unit(h5root): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) # Making sure to not use nanoseconds since that is used internally and may thus # mask bugs. @@ -312,19 +296,17 @@ def test_chain_with_multiple_values_and_different_time_unit(h5root): value2.attrs['vector'] = vector.value expected1 = t * offset - expected1.coords['depends_on'] = sc.scalar('t2') t2 = t.copy() t2.coords['time'] = t2.coords['time'].to(unit='ms') expected2 = t2 - expected2.coords['depends_on'] = sc.scalar('.') detector = make_group(detector) loaded = detector[...] depends_on = loaded['depends_on'] - assert depends_on == 'transformations/t1' - assert_identical(loaded['transformations']['t1'], expected1) - assert_identical(loaded['transformations']['t2'], expected2) + assert depends_on.value == 'transformations/t1' + assert_identical(loaded['transformations']['t1'].build(), expected1) + assert_identical(loaded['transformations']['t2'].build(), expected2) @pytest.mark.filterwarnings( @@ -334,9 +316,7 @@ def test_broken_time_dependent_transformation_returns_datagroup_but_sets_up_depe h5root, ): detector = create_detector(h5root) - snx.create_field( - detector, 'depends_on', sc.scalar('/detector_0/transformations/t1') - ) + snx.create_field(detector, 'depends_on', sc.scalar('transformations/t1')) transformations = snx.create_class(detector, 'transformations', NXtransformations) log = sc.DataArray( sc.array(dims=['time'], values=[1.1, 2.2], unit='m'), @@ -364,11 +344,11 @@ def test_broken_time_dependent_transformation_returns_datagroup_but_sets_up_depe # Due to the way NXtransformations works, vital information is stored in the # attributes. DataGroup does currently not support attributes, so this information # is mostly useless until that is addressed. - t1 = t['t1'] + t1 = t['t1'].value assert isinstance(t1, sc.DataGroup) assert t1.keys() == {'time', 'value'} - assert loaded['depends_on'] == 'transformations/t1' - assert_identical(loaded['transformations']['t1'], t1) + assert loaded['depends_on'].value == 'transformations/t1' + assert_identical(loaded['transformations']['t1'].value, t1) def write_translation( @@ -392,10 +372,11 @@ def test_nxtransformations_group_single_item(h5root): transformations = snx.create_class(h5root, 'transformations', NXtransformations) write_translation(transformations, 't1', value, offset, vector) + transformations['t1'].attrs['depends_on'] = '.' loaded = make_group(h5root)['transformations'][()] assert set(loaded.keys()) == {'t1'} - assert sc.identical(loaded['t1'], expected) + assert sc.identical(loaded['t1'].build(), expected) def test_nxtransformations_group_two_independent_items(h5root): @@ -406,6 +387,7 @@ def test_nxtransformations_group_two_independent_items(h5root): vector = sc.vector(value=[0, 1, 1]) t = value * vector write_translation(transformations, 't1', value, offset, vector) + transformations['t1'].attrs['depends_on'] = '.' expected1 = ( sc.spatial.translations(dims=t.dims, values=t.values, unit=t.unit) * offset ) @@ -413,14 +395,15 @@ def test_nxtransformations_group_two_independent_items(h5root): value = value * 0.1 t = value * vector write_translation(transformations, 't2', value, offset, vector) + transformations['t2'].attrs['depends_on'] = '.' expected2 = ( sc.spatial.translations(dims=t.dims, values=t.values, unit=t.unit) * offset ) loaded = make_group(h5root)['transformations'][()] assert set(loaded.keys()) == {'t1', 't2'} - assert sc.identical(loaded['t1'], expected1) - assert sc.identical(loaded['t2'], expected2) + assert sc.identical(loaded['t1'].build(), expected1) + assert sc.identical(loaded['t2'].build(), expected2) def test_nxtransformations_group_single_chain(h5root): @@ -431,6 +414,7 @@ def test_nxtransformations_group_single_chain(h5root): vector = sc.vector(value=[0, 1, 1]) t = value * vector write_translation(transformations, 't1', value, offset, vector) + transformations['t1'].attrs['depends_on'] = '.' expected1 = ( sc.spatial.translations(dims=t.dims, values=t.values, unit=t.unit) * offset ) @@ -445,9 +429,9 @@ def test_nxtransformations_group_single_chain(h5root): loaded = make_group(h5root)['transformations'][()] assert set(loaded.keys()) == {'t1', 't2'} - assert_identical(loaded['t1'], expected1) - assert_identical(loaded['t2'].data, expected2) - assert loaded['t2'].coords['depends_on'].value == 't1' + assert_identical(loaded['t1'].build(), expected1) + assert_identical(loaded['t2'].build(), expected2) + assert loaded['t2'].depends_on.value == 't1' def test_slice_transformations(h5root): @@ -468,11 +452,13 @@ def test_slice_transformations(h5root): value1.attrs['offset'] = offset.values value1.attrs['offset_units'] = str(offset.unit) value1.attrs['vector'] = vector.value + value1.attrs['depends_on'] = '.' expected = t * offset assert sc.identical( - make_group(h5root)['transformations']['time', 1:3]['t1'], expected['time', 1:3] + make_group(h5root)['transformations']['time', 1:3]['t1'].build(), + expected['time', 1:3], ) @@ -494,6 +480,7 @@ def test_label_slice_transformations(h5root): value1.attrs['offset'] = offset.values value1.attrs['offset_units'] = str(offset.unit) value1.attrs['vector'] = vector.value + value1.attrs['depends_on'] = '.' expected = t * offset @@ -503,7 +490,7 @@ def test_label_slice_transformations(h5root): sc.scalar(22, unit='s').to(unit='ns') : sc.scalar(44, unit='s').to( unit='ns' ), - ]['t1'], + ]['t1'].build(), expected[ 'time', sc.datetime('1970-01-01T00:00:22', unit='ns') : sc.datetime( @@ -742,7 +729,8 @@ def test_compute_positions_warns_if_depends_on_is_dead_link(h5root): detector = create_detector(instrument) snx.create_field(detector, 'depends_on', sc.scalar('transform')) root = make_group(h5root) - loaded = root[()] + with pytest.warns(UserWarning, match='depends_on chain references missing node'): + loaded = root[()] with pytest.warns(UserWarning, match='depends_on chain references missing node'): snx.compute_positions(loaded) @@ -830,11 +818,13 @@ def test_compute_transformation_warns_if_transformation_missing_vector_attr( value1 = snx.create_class(transformations, 't1', snx.NXlog) snx.create_field(value1, 'time', _log.coords['time'] - sc.epoch(unit='s')) snx.create_field(value1, 'value', _log.data) - value1.attrs['depends_on'] = 't2' + value1.attrs['depends_on'] = '.' value1.attrs['transformation_type'] = 'rotation' value1.attrs['offset'] = offset1.values value1.attrs['offset_units'] = str(offset1.unit) root = make_group(h5root) - with pytest.warns(UserWarning, match='transformation needs a vector attribute'): + with pytest.warns( + UserWarning, match="Invalid transformation, missing attribute 'vector'" + ): root[()] From 30e81f81957612285da4de49747a57310ea35d54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:31:09 +0000 Subject: [PATCH 20/32] Apply automatic formatting --- src/scippnexus/transformations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 1e8cae67..0eb5bafa 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -165,8 +165,6 @@ def as_nested(dg: sc.DataGroup) -> sc.DataGroup: return out - - def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: if '/' not in path: dg[path] = value From 2a29c8c990152912771b0401862f655266fffaaa Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 14:02:14 +0200 Subject: [PATCH 21/32] grammar --- src/scippnexus/field.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scippnexus/field.py b/src/scippnexus/field.py index de1c0da0..395d2280 100644 --- a/src/scippnexus/field.py +++ b/src/scippnexus/field.py @@ -27,9 +27,9 @@ @dataclass class DependsOn: """ - Represents a depends_on field in a NeXus file. + Represents a depends_on reference in a NeXus file. - The parent (the full path within the NeXus file) is stored as the value may be + The parent (the full path within the NeXus file) is stored, as the value may be relative or absolute, so having the path available after loading is essential. """ From 08fceef7aaa57532bdb797fac1094d1d89f6341d Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 14:42:08 +0200 Subject: [PATCH 22/32] Cleanup --- src/scippnexus/nxtransformations.py | 14 ++++-- src/scippnexus/transformations.py | 68 ----------------------------- 2 files changed, 10 insertions(+), 72 deletions(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 4ddfe3cc..1bb1ab4f 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -9,15 +9,18 @@ import scipp as sc from scipp.scipy import interpolate -from .base import Group, NXobject +from .base import Group, NXobject, base_definitions_dict from .field import DependsOn, Field from .transformations import Transform -# TODO skip loading?! - class NXtransformations(NXobject): - """Group of transformations.""" + """ + Group of transformations. + + Currently all transformations in the group are loaded. This may lead to redundant + loads as transformations are also loaded by following depends_on chains. + """ def _interpolate_transform(transform, xnew): @@ -268,3 +271,6 @@ def _with_positions( value = value.assign_coords({store_position: transform * offset}) out[name] = value return out + + +base_definitions_dict['NXtransformations'] = NXtransformations diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 0eb5bafa..735a08c5 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -37,13 +37,10 @@ from dataclasses import dataclass from typing import Any, Literal -import h5py import scipp as sc from .base import Group, NexusStructureError from .field import DependsOn, Field -from .file import File -from .typing import H5Base class TransformationError(NexusStructureError): @@ -98,37 +95,6 @@ def build(self) -> sc.Variable | sc.DataArray: return t * self.offset -def find_transformations(filename: str) -> list[str]: - transforms: list[str] = [] - - def _collect_transforms(name: str, obj: H5Base) -> None: - if name.endswith('/depends_on') or 'transformation_type' in obj.attrs: - transforms.append(f'/{name}') - - with h5py.File(filename, 'r') as f: - f.visititems(_collect_transforms) - return transforms - - -def load_transformations(filename: str) -> sc.DataGroup: - """ - Load transformations and depends_on fields from a NeXus file. - - Parameters - ---------- - filename: - The path to the NeXus file. - - Returns - ------- - : - A flat DataGroup with the transformations and depends_on fields. - """ - groups = find_transformations(filename) - with File(filename, mode='r') as f: - return sc.DataGroup({group: f[group][()] for group in groups}) - - def apply_to_transformations( dg: sc.DataGroup, func: Callable[[Transform], Transform] ) -> sc.DataGroup: @@ -142,40 +108,6 @@ def apply_nested(node: Any) -> Any: return dg.apply(apply_nested) -def as_nested(dg: sc.DataGroup) -> sc.DataGroup: - """ - Convert a flat DataGroup with paths as keys to a nested DataGroup. - - This is useful when loading transformations from a NeXus file, where the - paths are used as keys to represent the structure of the NeXus file. - - Parameters - ---------- - dg: - The flat DataGroup to convert. - - Returns - ------- - : - The nested DataGroup. - """ - out = sc.DataGroup() - for path, value in dg.items(): - _set_recursive(out, path, value) - return out - - -def _set_recursive(dg: sc.DataGroup, path: str, value: Any) -> None: - if '/' not in path: - dg[path] = value - else: - path = path.lstrip('/') - first, remainder = path.split('/', maxsplit=1) - if first not in dg: - dg[first] = sc.DataGroup() - _set_recursive(dg[first], remainder, value) - - def _parse_offset(obj: Field | Group) -> sc.Variable | None: if (offset := obj.attrs.get('offset')) is None: return None From 03b70ac698c9454dda8a4bc8e6c7db5083ea3c7a Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 14:48:25 +0200 Subject: [PATCH 23/32] Remove unused --- src/scippnexus/transformations.py | 16 +----- tests/transformations_test.py | 85 ------------------------------- 2 files changed, 1 insertion(+), 100 deletions(-) delete mode 100644 tests/transformations_test.py diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py index 735a08c5..1597b68f 100644 --- a/src/scippnexus/transformations.py +++ b/src/scippnexus/transformations.py @@ -33,9 +33,8 @@ from __future__ import annotations -from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Literal +from typing import Literal import scipp as sc @@ -95,19 +94,6 @@ def build(self) -> sc.Variable | sc.DataArray: return t * self.offset -def apply_to_transformations( - dg: sc.DataGroup, func: Callable[[Transform], Transform] -) -> sc.DataGroup: - def apply_nested(node: Any) -> Any: - if isinstance(node, sc.DataGroup): - return node.apply(apply_nested) - if isinstance(node, Transform): - return func(node) - return node - - return dg.apply(apply_nested) - - def _parse_offset(obj: Field | Group) -> sc.Variable | None: if (offset := obj.attrs.get('offset')) is None: return None diff --git a/tests/transformations_test.py b/tests/transformations_test.py deleted file mode 100644 index 7f443d0e..00000000 --- a/tests/transformations_test.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) -import pytest -import scipp as sc - -import scippnexus as snx -from scippnexus import transformations - -externalfile = pytest.importorskip('externalfile') - - -@pytest.mark.externalfile() -def test_find_transformation_groups_finds_expected_groups() -> None: - filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') - paths = transformations.find_transformations(filename) - assert paths == [ - '/entry/instrument/larmor_detector/depends_on', - '/entry/instrument/larmor_detector/transformations/trans_1', - '/entry/instrument/monitor_1/depends_on', - '/entry/instrument/monitor_1/transformations/trans_3', - '/entry/instrument/monitor_2/depends_on', - '/entry/instrument/monitor_2/transformations/trans_4', - '/entry/instrument/source/depends_on', - '/entry/instrument/source/transformations/trans_2', - ] - - -@pytest.mark.externalfile() -def test_load_transformations_loads_as_flat_datagroup() -> None: - filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') - dg = transformations.load_transformations(filename) - dg = transformations.as_nested(dg) - assert list(dg) == ['entry'] - entry = dg['entry'] - assert list(entry) == ['instrument'] - instrument = entry['instrument'] - assert list(instrument) == ['larmor_detector', 'monitor_1', 'monitor_2', 'source'] - for group in instrument.values(): - assert list(group) == ['depends_on', 'transformations'] - - -@pytest.mark.externalfile() -def test_apply_to_transformations() -> None: - filename = externalfile.get_path('2023/LOKI_60322-2022-03-02_2205_fixed.nxs') - dg = transformations.load_transformations(filename) - - def gather_names(t: transformations.Transform) -> transformations.Transform: - applied_to.append(t.name) - return t - - paths = [ - '/entry/instrument/larmor_detector/transformations/trans_1', - '/entry/instrument/monitor_1/transformations/trans_3', - '/entry/instrument/monitor_2/transformations/trans_4', - '/entry/instrument/source/transformations/trans_2', - ] - - applied_to = [] - transformations.apply_to_transformations(dg, gather_names) - assert applied_to == paths - - dg = transformations.as_nested(dg) - applied_to = [] - transformations.apply_to_transformations(dg, gather_names) - assert applied_to == paths - - -@pytest.mark.filterwarnings("ignore::UserWarning") -@pytest.mark.externalfile() -def test_positions_consistent_with_separate_load() -> None: - # The Bifrost instrument has complex transformation chains so this is a good test. - filename = externalfile.get_path('2023/BIFROST_873855_00000015.hdf') - transforms = transformations.load_transformations(filename) - transforms = transformations.as_nested(transforms) - dg = snx.load(filename) - expected = snx.compute_positions( - dg, store_position='position', store_transform='transform' - ) - result = snx.compute_positions( - dg, - store_position='position', - store_transform='transform', - transformations=transforms, - ) - assert sc.identical(result, expected) From 95d6c86aafb0fcccb3ab4240b802085c6b2bbd4c Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 14:53:50 +0200 Subject: [PATCH 24/32] Move to common file --- docs/api-reference/index.md | 14 +--- src/scippnexus/nxtransformations.py | 112 +++++++++++++++++++++++++- src/scippnexus/transformations.py | 117 ---------------------------- 3 files changed, 111 insertions(+), 132 deletions(-) delete mode 100644 src/scippnexus/transformations.py diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index 9b065c2e..ebc86ef7 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -114,16 +114,4 @@ create_field create_class load -``` - - -## Submodules - -```{eval-rst} -.. autosummary:: - :toctree: ../generated/modules - :template: module-template.rst - :recursive: - - transformations -``` +``` \ No newline at end of file diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 1bb1ab4f..552f51cd 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -1,17 +1,50 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 Scipp contributors (https://github.com/scipp) # @author Simon Heybrock +""" +Utilities for loading and working with NeXus transformations. + +Transformation chains in NeXus files can be non-local and can thus be challenging to +work with. Additionally, values of transformations can be time-dependent, with each +chain link potentially having a different time-dependent value. In practice the user is +interested in the position and orientation of a component at a specific time or time +range. This may involve evaluating the transformation chain at a specific time, or +applying some heuristic to determine if the changes in the transformation value are +significant or just noise. In combination, the above means that we need to remain +flexible in how we handle transformations, preserving all necessary information from +the source files. Therefore: + +1. :py:class:`Transform` is a dataclass representing a transformation. The raw `value` + dataset is preserved (instead of directly converting to, e.g., a rotation matrix) to + facilitate further processing such as computing the mean or variance. +2. Loading a :py:class:`Group` will follow depends_on chains and place them in a + subgroup 'resolved_transformations'. This is done by :py:func:`maybe_resolve`. +3. :py:func:`compute_positions` computes component positions (and transformations). By + making this an explicit separate step, transformations can be applied to the + 'resolved_transformations' subgroup before doing so. We imagine that this can be used + to + + - Evaluate the transformation at a specific time. + - Apply filters to remove noise, to avoid having to deal with very small time + intervals when processing data. + +By keeping the loaded transformations in a simple and modifiable format, we can +furthermore manually update the transformations with information from other sources, +such as streamed NXlog values received from a data acquisition system. +""" + from __future__ import annotations import warnings +from dataclasses import dataclass +from typing import Literal import numpy as np import scipp as sc from scipp.scipy import interpolate -from .base import Group, NXobject, base_definitions_dict +from .base import Group, NexusStructureError, NXobject, base_definitions_dict from .field import DependsOn, Field -from .transformations import Transform class NXtransformations(NXobject): @@ -23,6 +56,81 @@ class NXtransformations(NXobject): """ +class TransformationError(NexusStructureError): + pass + + +@dataclass +class Transform: + name: str + transformation_type: Literal['translation', 'rotation'] + value: sc.Variable | sc.DataArray | sc.DataGroup + vector: sc.Variable + depends_on: DependsOn + offset: sc.Variable | None + + def __post_init__(self): + if self.transformation_type not in ['translation', 'rotation']: + raise TransformationError( + f"{self.transformation_type=} attribute at {self.name}," + " expected 'translation' or 'rotation'." + ) + + @staticmethod + def from_object( + obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup + ) -> Transform: + depends_on = DependsOn(parent=obj.parent.name, value=obj.attrs['depends_on']) + return Transform( + name=obj.name, + transformation_type=obj.attrs.get('transformation_type'), + value=_parse_value(value), + vector=sc.vector(value=obj.attrs['vector']), + depends_on=depends_on, + offset=_parse_offset(obj), + ) + + def build(self) -> sc.Variable | sc.DataArray: + t = self.value * self.vector + v = t if isinstance(t, sc.Variable) else t.data + if self.transformation_type == 'translation': + v = sc.spatial.translations(dims=v.dims, values=v.values, unit=v.unit) + elif self.transformation_type == 'rotation': + v = sc.spatial.rotations_from_rotvecs(v) + if isinstance(t, sc.Variable): + t = v + else: + t.data = v + if self.offset is None: + return t + if self.transformation_type == 'translation': + return t * self.offset.to(unit=t.unit, copy=False) + return t * self.offset + + +def _parse_offset(obj: Field | Group) -> sc.Variable | None: + if (offset := obj.attrs.get('offset')) is None: + return None + if (offset_units := obj.attrs.get('offset_units')) is None: + raise TransformationError( + f"Found {offset=} but no corresponding 'offset_units' " + f"attribute at {obj.name}" + ) + return sc.spatial.translation(value=offset, unit=offset_units) + + +def _parse_value( + value: sc.Variable | sc.DataArray | sc.DataGroup, +) -> sc.Variable | sc.DataArray | sc.DataGroup: + if isinstance(value, sc.DataGroup) and ( + isinstance(value.get('value'), sc.DataArray) + ): + # Some NXlog groups are split into value, alarm, and connection_status + # sublogs. We only care about the value. + value = value['value'] + return value + + def _interpolate_transform(transform, xnew): # scipy can't interpolate with a single value if transform.sizes["time"] == 1: diff --git a/src/scippnexus/transformations.py b/src/scippnexus/transformations.py deleted file mode 100644 index 1597b68f..00000000 --- a/src/scippnexus/transformations.py +++ /dev/null @@ -1,117 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) -# @author Simon Heybrock -""" -Utilities for loading and working with NeXus transformations. - -Transformation chains in NeXus files can be non-local and can thus be challenging to -work with. Additionally, values of transformations can be time-dependent, with each -chain link potentially having a different time-dependent value. In practice the user is -interested in the position and orientation of a component at a specific time or time -range. This may involve evaluating the transformation chain at a specific time, or -applying some heuristic to determine if the changes in the transformation value are -significant or just noise. In combination, the above means that we need to remain -flexible in how we handle transformations, preserving all necessary information from -the source files. This module is therefore structured as follows: - -1. :py:class:`Transform` is a dataclass representing a transformation. The raw `value` - dataset is preserved (instead of directly converting to, e.g., a rotation matrix) to - facilitate further processing such as computing the mean or variance. -2. :py:func:`load_transformations` loads transformations from a NeXus file into a flat - :py:class:`scipp.DataGroup`. It can optionally be followed by - :py:func:`as_nested` to convert the flat structure to a nested one. -3. :py:func:`apply_to_transformations` applies a function to each transformation in a - :py:class:`scipp.DataGroup`. We imagine that this can be used to - - Evaluate the transformation at a specific time. - - Apply filters to remove noise, to avoid having to deal with very small time - intervals when processing data. - -By keeping the loaded transformations in a simple and modifiable format, we can -furthermore manually update the transformations with information from other sources, -such as streamed NXlog values received from a data acquisition system. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Literal - -import scipp as sc - -from .base import Group, NexusStructureError -from .field import DependsOn, Field - - -class TransformationError(NexusStructureError): - pass - - -@dataclass -class Transform: - name: str - transformation_type: Literal['translation', 'rotation'] - value: sc.Variable | sc.DataArray | sc.DataGroup - vector: sc.Variable - depends_on: DependsOn - offset: sc.Variable | None - - def __post_init__(self): - if self.transformation_type not in ['translation', 'rotation']: - raise TransformationError( - f"{self.transformation_type=} attribute at {self.name}," - " expected 'translation' or 'rotation'." - ) - - @staticmethod - def from_object( - obj: Field | Group, value: sc.Variable | sc.DataArray | sc.DataGroup - ) -> Transform: - depends_on = DependsOn(parent=obj.parent.name, value=obj.attrs['depends_on']) - return Transform( - name=obj.name, - transformation_type=obj.attrs.get('transformation_type'), - value=_parse_value(value), - vector=sc.vector(value=obj.attrs['vector']), - depends_on=depends_on, - offset=_parse_offset(obj), - ) - - def build(self) -> sc.Variable | sc.DataArray: - t = self.value * self.vector - v = t if isinstance(t, sc.Variable) else t.data - if self.transformation_type == 'translation': - v = sc.spatial.translations(dims=v.dims, values=v.values, unit=v.unit) - elif self.transformation_type == 'rotation': - v = sc.spatial.rotations_from_rotvecs(v) - if isinstance(t, sc.Variable): - t = v - else: - t.data = v - if self.offset is None: - return t - if self.transformation_type == 'translation': - return t * self.offset.to(unit=t.unit, copy=False) - return t * self.offset - - -def _parse_offset(obj: Field | Group) -> sc.Variable | None: - if (offset := obj.attrs.get('offset')) is None: - return None - if (offset_units := obj.attrs.get('offset_units')) is None: - raise TransformationError( - f"Found {offset=} but no corresponding 'offset_units' " - f"attribute at {obj.name}" - ) - return sc.spatial.translation(value=offset, unit=offset_units) - - -def _parse_value( - value: sc.Variable | sc.DataArray | sc.DataGroup, -) -> sc.Variable | sc.DataArray | sc.DataGroup: - if isinstance(value, sc.DataGroup) and ( - isinstance(value.get('value'), sc.DataArray) - ): - # Some NXlog groups are split into value, alarm, and connection_status - # sublogs. We only care about the value. - value = value['value'] - return value From 048e3ecb000f18b1a4fde3d3e342193e120ad82f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 14:58:26 +0200 Subject: [PATCH 25/32] Docs and cleanup --- docs/api-reference/index.md | 2 +- src/scippnexus/base.py | 6 ++++-- src/scippnexus/nxtransformations.py | 12 ++++++++---- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index ebc86ef7..7baa789d 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -114,4 +114,4 @@ create_field create_class load -``` \ No newline at end of file +``` diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index 88499428..0a5ae56f 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -409,13 +409,15 @@ def isclass(x): # For a time-dependent transformation in NXtransformations, an NXlog may # take the place of the `value` field. In this case, we need to read the # properties of the NXlog group to make the actual transformation. - from .nxtransformations import maybe_resolve, maybe_transformation + from .nxtransformations import maybe_transformation, parse_depends_on_chain if ( isinstance(dg, sc.DataGroup) and (depends_on := dg.get('depends_on')) is not None ): - if (resolved := maybe_resolve(self['depends_on'], depends_on)) is not None: + if ( + resolved := parse_depends_on_chain(self['depends_on'], depends_on) + ) is not None: dg['resolved_transformations'] = resolved return maybe_transformation(self, value=dg) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 552f51cd..622d00ae 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -18,7 +18,8 @@ dataset is preserved (instead of directly converting to, e.g., a rotation matrix) to facilitate further processing such as computing the mean or variance. 2. Loading a :py:class:`Group` will follow depends_on chains and place them in a - subgroup 'resolved_transformations'. This is done by :py:func:`maybe_resolve`. + subgroup 'resolved_transformations'. This is done by + :py:func:`parse_depends_on_chain`. 3. :py:func:`compute_positions` computes component positions (and transformations). By making this an explicit separate step, transformations can be applied to the 'resolved_transformations' subgroup before doing so. We imagine that this can be used @@ -62,6 +63,8 @@ class TransformationError(NexusStructureError): @dataclass class Transform: + """In-memory component translation or rotation as described by NXtransformations.""" + name: str transformation_type: Literal['translation', 'rotation'] value: sc.Variable | sc.DataArray | sc.DataGroup @@ -91,6 +94,7 @@ def from_object( ) def build(self) -> sc.Variable | sc.DataArray: + """Convert the raw transform into a rotation or translation matrix.""" t = self.value * self.vector v = t if isinstance(t, sc.Variable) else t.data if self.transformation_type == 'translation': @@ -229,10 +233,10 @@ def maybe_transformation( return value -def maybe_resolve( +def parse_depends_on_chain( obj: Field | Group, depends_on: DependsOn -) -> sc.DataArray | sc.Variable | None: - """Conditionally resolve a depend_on attribute.""" +) -> sc.DataGroup | None: + """Follow a depends_on chain and return the transformations.""" transforms = sc.DataGroup() parent = obj.parent depends_on = depends_on.value From 00ee3ccb1bc0c3fb5b596f1f7f18f9d1a649c08b Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Mon, 30 Sep 2024 15:06:00 +0200 Subject: [PATCH 26/32] Forward the transformations --- src/scippnexus/nxtransformations.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 622d00ae..3c0994b9 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -349,12 +349,12 @@ def _with_positions( transformations: sc.DataGroup | None = None, ) -> sc.DataGroup: out = sc.DataGroup() - transformations = transformations or dg.get('resolved_transformations', {}) if (depends_on := dg.get('depends_on')) is not None: + registry = transformations or dg.get('resolved_transformations', {}) try: chain = [] while (path := depends_on.absolute_path()) is not None: - chain.append(transformations[path]) + chain.append(registry[path]) depends_on = chain[-1].depends_on transform = combine_transformations([t.build() for t in chain]) except KeyError as e: @@ -369,7 +369,10 @@ def _with_positions( for name, value in dg.items(): if isinstance(value, sc.DataGroup): value = _with_positions( - value, store_position=store_position, store_transform=store_transform + value, + store_position=store_position, + store_transform=store_transform, + transformations=transformations, ) elif ( isinstance(value, sc.DataArray) From 26c716d6fe8cd9d4169b7ad59421d85563d2915f Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 1 Oct 2024 08:46:02 +0200 Subject: [PATCH 27/32] Store chain in DependsOn subclass instead of adding new subgroup --- src/scippnexus/__init__.py | 19 ++++++++++++- src/scippnexus/base.py | 11 ++------ src/scippnexus/nxtransformations.py | 44 +++++++++++++++++++---------- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/src/scippnexus/__init__.py b/src/scippnexus/__init__.py index 9a0bd33e..5038d604 100644 --- a/src/scippnexus/__init__.py +++ b/src/scippnexus/__init__.py @@ -18,8 +18,25 @@ create_class, create_field, ) -from .field import Attrs, Field +from .field import Attrs, DependsOn, Field from .file import File from ._load import load from .nexus_classes import * from .nxtransformations import compute_positions, zip_pixel_offsets + +__all__ = [ + 'Attrs', + 'DependsOn', + 'Field', + 'File', + 'Group', + 'NexusStructureError', + 'NXobject', + 'base_definitions', + 'compute_positions', + 'create_class', + 'create_field', + 'zip_pixel_offsets', + 'load', + 'typing', +] diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index 0a5ae56f..b3dd218a 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -411,14 +411,9 @@ def isclass(x): # properties of the NXlog group to make the actual transformation. from .nxtransformations import maybe_transformation, parse_depends_on_chain - if ( - isinstance(dg, sc.DataGroup) - and (depends_on := dg.get('depends_on')) is not None - ): - if ( - resolved := parse_depends_on_chain(self['depends_on'], depends_on) - ) is not None: - dg['resolved_transformations'] = resolved + if isinstance(dg, sc.DataGroup) and 'depends_on' in dg: + if (chain := parse_depends_on_chain(self, dg['depends_on'])) is not None: + dg['depends_on'] = chain return maybe_transformation(self, value=dg) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 3c0994b9..76d1f2dc 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -17,13 +17,12 @@ 1. :py:class:`Transform` is a dataclass representing a transformation. The raw `value` dataset is preserved (instead of directly converting to, e.g., a rotation matrix) to facilitate further processing such as computing the mean or variance. -2. Loading a :py:class:`Group` will follow depends_on chains and place them in a - subgroup 'resolved_transformations'. This is done by - :py:func:`parse_depends_on_chain`. +2. Loading a :py:class:`Group` will follow depends_on chains and store them as an + attribute of thr depends_on field. This is done by :py:func:`parse_depends_on_chain`. 3. :py:func:`compute_positions` computes component positions (and transformations). By making this an explicit separate step, transformations can be applied to the - 'resolved_transformations' subgroup before doing so. We imagine that this can be used - to + transformations stored by thr depends_on field before doing so. We imagine that this + can be used to - Evaluate the transformation at a specific time. - Apply filters to remove noise, to avoid having to deal with very small time @@ -37,7 +36,7 @@ from __future__ import annotations import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal import numpy as np @@ -233,25 +232,38 @@ def maybe_transformation( return value +@dataclass +class TransformationChain(DependsOn): + transformations: sc.DataGroup = field(default_factory=sc.DataGroup) + + def compute(self) -> sc.Variable | sc.DataArray: + dg = compute_positions( + sc.DataGroup(depends_on=self), + store_position='position', + store_transform='transform', + transformations=self.transformations, + ) + return dg['transform'] + + def parse_depends_on_chain( - obj: Field | Group, depends_on: DependsOn -) -> sc.DataGroup | None: + parent: Field | Group, depends_on: DependsOn +) -> TransformationChain | None: """Follow a depends_on chain and return the transformations.""" - transforms = sc.DataGroup() - parent = obj.parent + chain = TransformationChain(depends_on.parent, depends_on.value) depends_on = depends_on.value try: while depends_on != '.': transform = parent[depends_on] parent = transform.parent depends_on = transform.attrs['depends_on'] - transforms[transform.name] = transform[()] + chain.transformations[transform.name] = transform[()] except KeyError as e: warnings.warn( UserWarning(f'depends_on chain references missing node {e}'), stacklevel=2 ) return None - return transforms + return chain def compute_positions( @@ -295,8 +307,7 @@ def compute_positions( If not None, store the resolved transformation chain in this field. transformations: Optional data group containing transformation chains. If not provided, the - transformations are looked up in the 'resolved_transformations' subgroups of the - input data group. + transformations are looked up in the chains stored within the depends_on field. Returns ------- @@ -350,7 +361,10 @@ def _with_positions( ) -> sc.DataGroup: out = sc.DataGroup() if (depends_on := dg.get('depends_on')) is not None: - registry = transformations or dg.get('resolved_transformations', {}) + if isinstance(depends_on, TransformationChain): + registry = transformations or depends_on.transformations + else: + registry = transformations or sc.DataGroup() try: chain = [] while (path := depends_on.absolute_path()) is not None: From 020435ded9bce8d3f9cd6c6c3a47c32dcafa569e Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 1 Oct 2024 08:47:46 +0200 Subject: [PATCH 28/32] Expose TransformationChain --- src/scippnexus/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/scippnexus/__init__.py b/src/scippnexus/__init__.py index 5038d604..f0ea5fa1 100644 --- a/src/scippnexus/__init__.py +++ b/src/scippnexus/__init__.py @@ -22,7 +22,7 @@ from .file import File from ._load import load from .nexus_classes import * -from .nxtransformations import compute_positions, zip_pixel_offsets +from .nxtransformations import compute_positions, zip_pixel_offsets, TransformationChain __all__ = [ 'Attrs', @@ -30,13 +30,14 @@ 'Field', 'File', 'Group', - 'NexusStructureError', 'NXobject', + 'NexusStructureError', + 'TransformationChain', 'base_definitions', 'compute_positions', 'create_class', 'create_field', - 'zip_pixel_offsets', 'load', 'typing', + 'zip_pixel_offsets', ] From 6e94e00050ed1f30c081c0dd2ce3b01fa330ddac Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 1 Oct 2024 08:57:51 +0200 Subject: [PATCH 29/32] Refactor --- src/scippnexus/nxtransformations.py | 46 +++++++++++++---------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 76d1f2dc..46861ae8 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -36,7 +36,7 @@ from __future__ import annotations import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import Literal import numpy as np @@ -237,13 +237,20 @@ class TransformationChain(DependsOn): transformations: sc.DataGroup = field(default_factory=sc.DataGroup) def compute(self) -> sc.Variable | sc.DataArray: - dg = compute_positions( - sc.DataGroup(depends_on=self), - store_position='position', - store_transform='transform', - transformations=self.transformations, - ) - return dg['transform'] + depends_on = self + try: + chain = [] + while (path := depends_on.absolute_path()) is not None: + chain.append(self.transformations[path]) + depends_on = chain[-1].depends_on + transform = combine_transformations([t.build() for t in chain]) + except KeyError as e: + warnings.warn( + UserWarning(f'depends_on chain references missing node:\n{e}'), + stacklevel=2, + ) + else: + return transform def parse_depends_on_chain( @@ -360,23 +367,12 @@ def _with_positions( transformations: sc.DataGroup | None = None, ) -> sc.DataGroup: out = sc.DataGroup() - if (depends_on := dg.get('depends_on')) is not None: - if isinstance(depends_on, TransformationChain): - registry = transformations or depends_on.transformations - else: - registry = transformations or sc.DataGroup() - try: - chain = [] - while (path := depends_on.absolute_path()) is not None: - chain.append(registry[path]) - depends_on = chain[-1].depends_on - transform = combine_transformations([t.build() for t in chain]) - except KeyError as e: - warnings.warn( - UserWarning(f'depends_on chain references missing node:\n{e}'), - stacklevel=2, - ) - else: + if (chain := dg.get('depends_on')) is not None: + if not isinstance(chain, TransformationChain): + chain = TransformationChain(chain.parent, chain.value) + if transformations is not None: + chain = replace(chain, transformations=transformations) + if (transform := chain.compute()) is not None: out[store_position] = transform * sc.vector([0, 0, 0], unit='m') if store_transform is not None: out[store_transform] = transform From 5d492b77e7b26fb6d6ce60474345885801509d75 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Tue, 1 Oct 2024 09:09:25 +0200 Subject: [PATCH 30/32] Docstring --- src/scippnexus/nxtransformations.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 46861ae8..5a5b53d6 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -234,6 +234,14 @@ def maybe_transformation( @dataclass class TransformationChain(DependsOn): + """ + Represents a chain of transformations references by a depends_on field. + + Loading a group with a depends_on field will try to follow the chain and store the + transformations as an additional attribute of the in-memory representation of the + depends_on field. + """ + transformations: sc.DataGroup = field(default_factory=sc.DataGroup) def compute(self) -> sc.Variable | sc.DataArray: From b990c9a65c9b5785c1c7dad29c0ae6de4e33a288 Mon Sep 17 00:00:00 2001 From: Simon Heybrock <12912489+SimonHeybrock@users.noreply.github.com> Date: Mon, 7 Oct 2024 06:11:39 +0200 Subject: [PATCH 31/32] Update src/scippnexus/nxtransformations.py Co-authored-by: Neil Vaytet <39047984+nvaytet@users.noreply.github.com> --- src/scippnexus/nxtransformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 5a5b53d6..52d0ac78 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -21,7 +21,7 @@ attribute of thr depends_on field. This is done by :py:func:`parse_depends_on_chain`. 3. :py:func:`compute_positions` computes component positions (and transformations). By making this an explicit separate step, transformations can be applied to the - transformations stored by thr depends_on field before doing so. We imagine that this + transformations stored by the depends_on field before doing so. We imagine that this can be used to - Evaluate the transformation at a specific time. From 54b2b688e1e8e102a1d2b8d43b14859c353ca847 Mon Sep 17 00:00:00 2001 From: Simon Heybrock <12912489+SimonHeybrock@users.noreply.github.com> Date: Mon, 7 Oct 2024 06:15:30 +0200 Subject: [PATCH 32/32] Update src/scippnexus/nxtransformations.py Co-authored-by: Neil Vaytet <39047984+nvaytet@users.noreply.github.com> --- src/scippnexus/nxtransformations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 52d0ac78..5a4c9e23 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -18,7 +18,7 @@ dataset is preserved (instead of directly converting to, e.g., a rotation matrix) to facilitate further processing such as computing the mean or variance. 2. Loading a :py:class:`Group` will follow depends_on chains and store them as an - attribute of thr depends_on field. This is done by :py:func:`parse_depends_on_chain`. + attribute of the depends_on field. This is done by :py:func:`parse_depends_on_chain`. 3. :py:func:`compute_positions` computes component positions (and transformations). By making this an explicit separate step, transformations can be applied to the transformations stored by the depends_on field before doing so. We imagine that this