Skip to content

Commit

Permalink
Merge pull request asdf-format#1753 from braingram/all_your_base
Browse files Browse the repository at this point in the history
add option to control base array saving
  • Loading branch information
braingram authored May 10, 2024
2 parents 097dffa + 4d369dd commit 3693386
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 12 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
resolve to public classes (even if the class is implemented
in a private module). [#1654]

- Add options to control saving the base array when saving array views
controlled via ``AsdfConfig.default_array_save_base``,
``AsdfFile.set_array_save_base`` and
``SerializationContext.set_array_save_base`` [#1753]


3.2.0 (2024-04-05)
------------------
Expand Down
35 changes: 35 additions & 0 deletions asdf/_asdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,41 @@ def get_array_compression_kwargs(self, arr):
""" """
return self._blocks._get_array_compression_kwargs(arr)

def set_array_save_base(self, arr, save_base):
"""
Set the ``save_base`` option for ``arr``. When ``arr`` is
written to a file, if ``save_base`` is ``True`` the base array
for ``arr`` will be saved.
Note that similar to other array options this setting is linked
to the base array if ``arr`` is a view.
Parameters
----------
arr : numpy.ndarray
save_base : bool or None
if ``None`` the ``default_array_save_base`` value from asdf
config will be used
"""
self._blocks._set_array_save_base(arr, save_base)

def get_array_save_base(self, arr):
"""
Returns the ``save_base`` option for ``arr``. When ``arr`` is
written to a file, if ``save_base`` is ``True`` the base array
for ``arr`` will be saved.
Parameters
----------
arr : numpy.ndarray
Returns
-------
save_base : bool or None
"""
return self._blocks._get_array_save_base(arr)

@classmethod
def _parse_header_line(cls, line):
"""
Expand Down
8 changes: 8 additions & 0 deletions asdf/_block/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,14 @@ def _get_array_compression_kwargs(self, arr):
def get_output_compressions(self):
return self.options.get_output_compressions()

def _set_array_save_base(self, data, save_base):
options = self.options.get_options(data)
options.save_base = save_base
self.options.set_options(data, options)

def _get_array_save_base(self, data):
return self.options.get_options(data).save_base

@contextlib.contextmanager
def options_context(self):
"""
Expand Down
17 changes: 16 additions & 1 deletion asdf/_block/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ class Options:
Storage and compression options useful when reading or writing ASDF blocks.
"""

def __init__(self, storage_type=None, compression_type=None, compression_kwargs=None):
def __init__(self, storage_type=None, compression_type=None, compression_kwargs=None, save_base=None):
if storage_type is None:
storage_type = get_config().all_array_storage or "internal"
if save_base is None:
save_base = get_config().default_array_save_base

self._storage_type = None
self._compression = None
self._compression_kwargs = None
Expand All @@ -18,6 +21,7 @@ def __init__(self, storage_type=None, compression_type=None, compression_kwargs=
self.compression_kwargs = compression_kwargs
self.compression = compression_type
self.storage_type = storage_type
self.save_base = save_base

@property
def storage_type(self):
Expand Down Expand Up @@ -61,5 +65,16 @@ def compression_kwargs(self, kwargs):
kwargs = {}
self._compression_kwargs = kwargs

@property
def save_base(self):
return self._save_base

@save_base.setter
def save_base(self, save_base):
if not (isinstance(save_base, bool) or save_base is None):
msg = "save_base must be a bool or None"
raise ValueError(msg)
self._save_base = save_base

def __copy__(self):
return type(self)(self._storage_type, self._compression, self._compression_kwargs)
28 changes: 17 additions & 11 deletions asdf/_core/_converters/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,39 @@ def to_yaml_tree(self, obj, tag, ctx):
result["strides"] = data._strides
return result

# sort out block writing options
if isinstance(obj, NDArrayType) and isinstance(obj._source, str):
# this is an external block, if we have no other settings, keep it as external
options = ctx._blocks.options.lookup_by_object(data)
if options is None:
options = Options("external")
else:
options = ctx._blocks.options.get_options(data)

# The ndarray-1.0.0 schema does not permit 0 valued strides.
# Perhaps we'll want to allow this someday, to efficiently
# represent an array of all the same value.
if any(stride == 0 for stride in data.strides):
data = np.ascontiguousarray(data)

# Use the base array if that option is set or if the option
# is unset and the AsdfConfig default is set
cfg = config.get_config()
if options.save_base or (options.save_base is None and cfg.default_array_save_base):
base = util.get_array_base(data)
else:
base = data

# The view computations that follow assume that the base array
# is contiguous. If not, we need to make a copy to avoid
# writing a nonsense view.
base = util.get_array_base(data)
if not base.flags.forc:
data = np.ascontiguousarray(data)
base = util.get_array_base(data)

shape = data.shape

# sort out block writing options
if isinstance(obj, NDArrayType) and isinstance(obj._source, str):
# this is an external block, if we have no other settings, keep it as external
options = ctx._blocks.options.lookup_by_object(data)
if options is None:
options = Options("external")
else:
options = ctx._blocks.options.get_options(data)

# possibly override options based on config settings
cfg = config.get_config()
if cfg.all_array_storage is not None:
options.storage_type = cfg.all_array_storage
if cfg.all_array_compression != "input":
Expand Down
21 changes: 21 additions & 0 deletions asdf/_tests/test_array_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,24 @@ def test_open_memmap_from_closed_file(tmp_path):

with pytest.raises(OSError, match=msg):
base2[:]


@pytest.mark.parametrize("default_array_save_base", [True, False])
@pytest.mark.parametrize("save_base", [True, False, None])
def test_views_save_base(tmp_path, default_array_save_base, save_base):
fn = tmp_path / "test.asdf"
arr = np.zeros(100, dtype="uint8")
tree = {"v": arr[:10]}
with asdf.config_context() as cfg:
cfg.default_array_save_base = default_array_save_base
af = asdf.AsdfFile(tree)
if save_base is not None:
af.set_array_save_base(af["v"], save_base)
af.write_to(fn)

with asdf.open(fn, copy_arrays=True) as af:
base = af["v"].base
if save_base or (save_base is None and default_array_save_base):
assert len(base) == 100
else:
assert len(base) == 10
14 changes: 14 additions & 0 deletions asdf/_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,17 @@ def test_config_repr():
assert "io_block_size: 9999" in repr(config)
assert "legacy_fill_schema_defaults: False" in repr(config)
assert "array_inline_threshold: 14" in repr(config)


@pytest.mark.parametrize("value", [True, False])
def test_get_set_default_array_save_base(value):
with asdf.config_context() as config:
config.default_array_save_base = value
assert config.default_array_save_base == value


@pytest.mark.parametrize("value", [1, None])
def test_invalid_set_default_array_save_base(value):
with asdf.config_context() as config:
with pytest.raises(ValueError, match="default_array_save_base must be a bool"):
config.default_array_save_base = value
35 changes: 35 additions & 0 deletions asdf/_tests/test_serialization_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,38 @@ def test_get_set_array_compression(block_access):
assert af.get_array_compression_kwargs(arr) == kwargs
assert context.get_array_compression(arr) == compression
assert context.get_array_compression_kwargs(arr) == kwargs


def test_get_set_array_save_base():
af = asdf.AsdfFile()
context = af._create_serialization_context()
arr = np.zeros(3)
cfg = asdf.get_config()
save_base = cfg.default_array_save_base
assert af.get_array_save_base(arr) == save_base
assert context.get_array_save_base(arr) == save_base

save_base = not save_base
context.set_array_save_base(arr, save_base)
assert af.get_array_save_base(arr) == save_base
assert context.get_array_save_base(arr) == save_base

save_base = not save_base
af.set_array_save_base(arr, save_base)
assert af.get_array_save_base(arr) == save_base
assert context.get_array_save_base(arr) == save_base

af.set_array_save_base(arr, None)
assert af.get_array_save_base(arr) is None
assert context.get_array_save_base(arr) is None


@pytest.mark.parametrize("value", [1, "true"])
def test_invalid_set_array_save_base(value):
af = asdf.AsdfFile()
context = af._create_serialization_context()
arr = np.zeros(3)
with pytest.raises(ValueError, match="save_base must be a bool or None"):
af.set_array_save_base(arr, value)
with pytest.raises(ValueError, match="save_base must be a bool or None"):
context.set_array_save_base(arr, value)
19 changes: 19 additions & 0 deletions asdf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DEFAULT_ALL_ARRAY_STORAGE = None
DEFAULT_ALL_ARRAY_COMPRESSION = "input"
DEFAULT_ALL_ARRAY_COMPRESSION_KWARGS = None
DEFAULT_DEFAULT_ARRAY_SAVE_BASE = True
DEFAULT_CONVERT_UNKNOWN_NDARRAY_SUBCLASSES = True


Expand All @@ -46,6 +47,7 @@ def __init__(self):
self._all_array_storage = DEFAULT_ALL_ARRAY_STORAGE
self._all_array_compression = DEFAULT_ALL_ARRAY_COMPRESSION
self._all_array_compression_kwargs = DEFAULT_ALL_ARRAY_COMPRESSION_KWARGS
self._default_array_save_base = DEFAULT_DEFAULT_ARRAY_SAVE_BASE
self._convert_unknown_ndarray_subclasses = DEFAULT_CONVERT_UNKNOWN_NDARRAY_SUBCLASSES

self._lock = threading.RLock()
Expand Down Expand Up @@ -391,6 +393,22 @@ def all_array_compression_kwargs(self, value):
raise ValueError(msg)
self._all_array_compression_kwargs = value

@property
def default_array_save_base(self):
"""
Option to control if when saving arrays the base array should be
saved (so views of the same array will refer to offsets/strides of the
same block).
"""
return self._default_array_save_base

@default_array_save_base.setter
def default_array_save_base(self, value):
if not isinstance(value, bool):
msg = "default_array_save_base must be a bool"
raise ValueError(msg)
self._default_array_save_base = value

@property
def validate_on_read(self):
"""
Expand Down Expand Up @@ -447,6 +465,7 @@ def __repr__(self):
f" all_array_storage: {self.all_array_storage}\n"
f" all_array_compression: {self.all_array_compression}\n"
f" all_array_compression_kwargs: {self.all_array_compression_kwargs}\n"
f" default_array_save_base: {self.default_array_save_base}\n"
f" convert_unknown_ndarray_subclasses: {self.convert_unknown_ndarray_subclasses}\n"
f" default_version: {self.default_version}\n"
f" io_block_size: {self.io_block_size}\n"
Expand Down
35 changes: 35 additions & 0 deletions asdf/extension/_serialization_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,41 @@ def get_array_compression_kwargs(self, arr):
""" """
return self._blocks._get_array_compression_kwargs(arr)

def set_array_save_base(self, arr, save_base):
"""
Set the ``save_base`` option for ``arr``. When ``arr`` is
written to a file, if ``save_base`` is ``True`` the base array
for ``arr`` will be saved.
Note that similar to other array options this setting is linked
to the base array if ``arr`` is a view.
Parameters
----------
arr : numpy.ndarray
save_base : bool or None
if ``None`` the ``default_array_save_base`` value from asdf
config will be used
"""
self._blocks._set_array_save_base(arr, save_base)

def get_array_save_base(self, arr):
"""
Returns the ``save_base`` option for ``arr``. When ``arr`` is
written to a file, if ``save_base`` is ``True`` the base array
for ``arr`` will be saved.
Parameters
----------
arr : numpy.ndarray
Returns
-------
save_base : bool
"""
return self._blocks._get_array_save_base(arr)


class ReadBlocksContext(SerializationContext):
"""
Expand Down
6 changes: 6 additions & 0 deletions docs/asdf/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ data being saved.
ff = AsdfFile(tree)
ff.write_to("test.asdf")

For circumstances where this is undesirable (such as saving
a small view of a large array) this can be disabled by setting
`asdf.config.AsdfConfig.default_array_save_base` (to set the default behavior)
or `asdf.AsdfFile.set_array_save_base` to control the behavior for
a specific array.

.. asdf:: test.asdf

Saving inline arrays
Expand Down
13 changes: 13 additions & 0 deletions docs/asdf/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ the currently active config:
all_array_storage: None
all_array_compression: input
all_array_compression_kwargs: None
default_array_save_base: True
convert_unknown_ndarray_subclasses: True
default_version: 1.5.0
io_block_size: -1
Expand All @@ -63,6 +64,7 @@ This allows for short-lived configuration changes that do not impact other code:
all_array_storage: None
all_array_compression: input
all_array_compression_kwargs: None
default_array_save_base: True
convert_unknown_ndarray_subclasses: True
default_version: 1.5.0
io_block_size: -1
Expand All @@ -75,6 +77,7 @@ This allows for short-lived configuration changes that do not impact other code:
all_array_storage: None
all_array_compression: input
all_array_compression_kwargs: None
default_array_save_base: True
convert_unknown_ndarray_subclasses: True
default_version: 1.5.0
io_block_size: -1
Expand Down Expand Up @@ -137,6 +140,16 @@ can be set for each array. See ``AsdfFile.set_array_compression`` for more detai

Defaults to ``None``.

.. _default_array_save_base:

default_array_save_base
-----------------------

Controls the default behavior asdf will follow when saving an array view.
If ``True`` (the default) the base array for the view will be saved in an ASDF
binary block. If ``False`` the data corresponding to the view will be saved in
an ASDF binary block.

.. _convert_unknown_ndarray_subclasses:

convert_unknown_ndarray_subclasses
Expand Down

0 comments on commit 3693386

Please sign in to comment.