Skip to content

Commit

Permalink
Add type normalization for copy_data and add tensorstore/zarr array t…
Browse files Browse the repository at this point in the history
…ests (#33)

Co-authored-by: Juan Nunez-Iglesias <[email protected]>
  • Loading branch information
Kaltzisp and jni authored Jul 15, 2022
1 parent 692b090 commit cd52d04
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/zarpaint/_copy_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from magicgui import magic_factory
from ._normalize import normalize_dtype


@magic_factory
Expand All @@ -14,5 +15,5 @@ def copy_data(
ndim_dst = dst_data.ndim
slice_ = napari_viewer.dims.current_step
slicing = slice_[:ndim_dst - ndim_src]
dst_data[slicing] = src_data
dst_data[slicing] = src_data.astype(normalize_dtype(target_layer.dtype))
target_layer.refresh()
80 changes: 80 additions & 0 deletions src/zarpaint/_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np

_np_uints = {
8: np.uint8,
16: np.uint16,
32: np.uint32,
64: np.uint64,
}

_np_ints = {
8: np.int8,
16: np.int16,
32: np.int32,
64: np.int64,
}

_np_floats = {
16: np.float16,
32: np.float32,
64: np.float64,
}

_np_complex = {
64: np.complex64,
128: np.complex128,
}

_np_kinds = {
'uint': _np_uints,
'int': _np_ints,
'float': _np_floats,
'complex': _np_complex,
}

def _normalize_str_by_bit_depth(dtype_str, kind):
if not any(str.isdigit(c) for c in dtype_str): # Python 'int' or 'float'
return np.dtype(kind).type
bit_dict = _np_kinds[kind]
if '128' in dtype_str:
return bit_dict[128]
if '8' in dtype_str:
return bit_dict[8]
if '16' in dtype_str:
return bit_dict[16]
if '32' in dtype_str:
return bit_dict[32]
if '64' in dtype_str:
return bit_dict[64]

def normalize_dtype(dtype_spec):
"""Return a proper NumPy type given ~any duck array dtype.
Parameters
----------
dtype_spec : numpy dtype, numpy type, torch dtype, tensorstore dtype, etc
A type that can be interpreted as a NumPy numeric data type, e.g.
'uint32', np.uint8, torch.float32, etc.
Returns
-------
dtype : numpy.dtype
The corresponding dtype.
Notes
-----
half-precision floats are not supported.
"""
dtype_str = str(dtype_spec)
if 'uint' in dtype_str:
return _normalize_str_by_bit_depth(dtype_str, 'uint')
if 'int' in dtype_str:
return _normalize_str_by_bit_depth(dtype_str, 'int')
if 'float' in dtype_str:
return _normalize_str_by_bit_depth(dtype_str, 'float')
if 'complex' in dtype_str:
return _normalize_str_by_bit_depth(dtype_str, 'complex')
if 'bool' in dtype_str:
return np.bool_
# If we don't find one of the named dtypes, return the dtype_spec
# unchanged. This allows NumPy big endian types to work. See
# https://github.com/napari/napari/issues/3421
else:
return dtype_spec
36 changes: 34 additions & 2 deletions src/zarpaint/_tests/test_copy_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from napari.layers import Labels
from zarpaint import copy_data
import numpy as np

import tensorstore as ts
from zarpaint import open_tensorstore
import zarr

def test_copy_data(make_napari_viewer):
viewer = make_napari_viewer()
Expand All @@ -11,4 +13,34 @@ def test_copy_data(make_napari_viewer):
widget = copy_data()
widget(viewer, labels_layer1, labels_layer2)
np.testing.assert_array_equal(labels_layer2.data[0], 0)
np.testing.assert_array_equal(labels_layer2.data[1], labels_layer1.data)
np.testing.assert_array_equal(labels_layer2.data[1], labels_layer1.data)


def test_copy_data_tensorstore(make_napari_viewer, tmp_path):
viewer = make_napari_viewer()
labels_layer1 = viewer.add_labels(np.random.randint(0, 2**23, size=(10, 20, 30)))
array2 = open_tensorstore(tmp_path/"example.zarr", shape=(2, 10, 20, 30), chunks=(1, 1, 20, 30))
labels_layer2 = viewer.add_labels(array2)
viewer.dims.set_point(axis=0, value=1)
widget = copy_data()
widget(viewer, labels_layer1, labels_layer2)
np.testing.assert_array_equal(labels_layer2.data[0], 0)
np.testing.assert_array_equal(labels_layer2.data[1], labels_layer1.data)


def test_copy_data_zarr(make_napari_viewer, tmp_path):
viewer = make_napari_viewer()
labels_layer1 = viewer.add_labels(np.random.randint(0, 2**23, size=(10, 20, 30)))
array2 = zarr.open(
str(tmp_path/"example.zarr"),
mode='w',
shape=(2, 10, 20, 30),
dtype=np.uint32,
chunks=(1, 1, 20, 30),
)
labels_layer2 = viewer.add_labels(array2)
viewer.dims.set_point(axis=0, value=1)
widget = copy_data()
widget(viewer, labels_layer1, labels_layer2)
np.testing.assert_array_equal(labels_layer2.data[0], 0)
np.testing.assert_array_equal(labels_layer2.data[1], labels_layer1.data)

0 comments on commit cd52d04

Please sign in to comment.