Skip to content

Commit

Permalink
Apply changes done in h5::array_interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Thoemi09 authored and Wentzell committed Apr 24, 2024
1 parent 77b6ce2 commit d441083
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions c++/nda/h5.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ namespace nda {

static constexpr bool is_complex = is_complex_v<typename A::value_type>;

auto [L_tot, strides_h5] = h5::array_interface::get_L_tot_and_strides_h5(a.indexmap().strides().data(), A::rank, a.size());
auto [parent_shape, strides_h5] = h5::array_interface::get_parent_shape_and_h5_strides(a.indexmap().strides().data(), A::rank, a.size());

h5::array_interface::h5_array_view v{h5::hdf5_type<get_value_t<A>>(), (void *)a.data(), A::rank, is_complex};
h5::array_interface::array_view v{h5::hdf5_type<get_value_t<A>>(), (void *)a.data(), A::rank, is_complex};
for (int u = 0; u < A::rank; ++u) { // size of lhs may be size of the rhs vector + 1 if complex. Can not simply use =
v.slab.count[u] = a.shape()[u];
v.slab.stride[u] = strides_h5[u];
v.L_tot[u] = L_tot[u];
v.slab.count[u] = a.shape()[u];
v.slab.stride[u] = strides_h5[u];
v.parent_shape[u] = parent_shape[u];
}
h5::array_interface::write(g, name, v, compress);

Expand Down Expand Up @@ -230,20 +230,20 @@ namespace nda {
static constexpr bool is_complex = is_complex_v<typename A::value_type>;

// Dataset must already exist for the sliced h5_write
auto lt = h5::array_interface::get_h5_lengths_type(g, name);
if (is_complex != lt.has_complex_attribute)
auto ds_info = h5::array_interface::get_dataset_info(g, name);
if (is_complex != ds_info.has_complex_attribute)
NDA_RUNTIME_ERROR << "Error in sliced h5_write. Existing dataset and array must both be either complex or real";
auto const [sl, sh] = hyperslab_and_shape_from_slice<A::rank>(slice, lt.lengths, is_complex);
auto const [sl, sh] = hyperslab_and_shape_from_slice<A::rank>(slice, ds_info.lengths, is_complex);
if (sh != a.shape())
NDA_RUNTIME_ERROR << "Error in sliced h5_write. Shape of slice and Array shape incompatible" << "\n shape of slice : " << sh
<< "\n array : " << a.shape();

auto rank_in_file = lt.rank() - is_complex;
h5::array_interface::h5_array_view v{h5::hdf5_type<get_value_t<A>>(), (void *)(a.data()), rank_in_file, is_complex};
v.slab.count = sl.count;
v.L_tot = sl.count;
auto rank_in_file = ds_info.rank() - is_complex;
h5::array_interface::array_view v{h5::hdf5_type<get_value_t<A>>(), (void *)(a.data()), rank_in_file, is_complex};
v.slab.count = sl.count;
v.parent_shape = sl.count;

h5::array_interface::write_slice(g, name, v, lt, sl);
h5::array_interface::write_slice(g, name, v, sl);
}

template <MemoryArray A, typename... IRs>
Expand Down Expand Up @@ -276,11 +276,11 @@ namespace nda {

static constexpr bool is_complex = is_complex_v<typename A::value_type>;

auto lt = h5::array_interface::get_h5_lengths_type(g, name);
auto ds_info = h5::array_interface::get_dataset_info(g, name);

// Allow to read non-complex data into array<complex>
if constexpr (is_complex) {
if (!lt.has_complex_attribute) {
if (!ds_info.has_complex_attribute) {
array<double, A::rank> tmp;
h5_read(g, name, tmp);
a = tmp;
Expand All @@ -292,11 +292,11 @@ namespace nda {
auto slice_slab = h5::array_interface::hyperslab{};

if constexpr (slicing) {
auto const [sl, sh] = hyperslab_and_shape_from_slice<A::rank>(slice, lt.lengths, is_complex);
auto const [sl, sh] = hyperslab_and_shape_from_slice<A::rank>(slice, ds_info.lengths, is_complex);
slice_slab = sl;
shape = sh;
} else {
for (int u = 0; u < A::rank; ++u) shape[u] = lt.lengths[u]; // NB : correct for complex
for (int u = 0; u < A::rank; ++u) shape[u] = ds_info.lengths[u]; // NB : correct for complex
}

if constexpr (is_regular_v<A>) {
Expand All @@ -307,20 +307,20 @@ namespace nda {
<< "\n in view : " << a.shape();
}

auto rank_in_file = lt.rank() - is_complex;
auto rank_in_file = ds_info.rank() - is_complex;
if (!slicing && rank_in_file != A::rank)
NDA_RUNTIME_ERROR << "Error in h5_read: Rank mismatch. Array has rank " << A::rank << " while Dataset has rank " << rank_in_file;
h5::array_interface::h5_array_view v{h5::hdf5_type<get_value_t<A>>(), (void *)(a.data()), rank_in_file, is_complex};
h5::array_interface::array_view v{h5::hdf5_type<get_value_t<A>>(), (void *)(a.data()), rank_in_file, is_complex};
if constexpr (slicing) {
v.slab.count = slice_slab.count;
v.L_tot = slice_slab.count;
v.slab.count = slice_slab.count;
v.parent_shape = slice_slab.count;
} else {
for (int u = 0; u < A::rank; ++u) {
v.slab.count[u] = shape[u];
v.L_tot[u] = shape[u];
v.slab.count[u] = shape[u];
v.parent_shape[u] = shape[u];
}
}
h5::array_interface::read(g, name, v, lt, slice_slab);
h5::array_interface::read(g, name, v, slice_slab);

} else { // generic unknown type to hdf5

Expand Down

0 comments on commit d441083

Please sign in to comment.