Skip to content

Strided slices #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 238 additions & 0 deletions mdio/dataset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,244 @@ TEST(Dataset, isel) {
<< "Inline range should end at 5";
}

TEST(Dataset, iselWithStride) {
// Tests the integrity of data that is written with a strided slice.
std::string iselPath = "zarrs/acceptance";
{ // Scoping the dataset creation to ensure the variables are cleaned up
// before the testing.
auto json_vars = GetToyExample();
auto dataset = mdio::Dataset::from_json(json_vars, iselPath,
mdio::constants::kCreateClean);
ASSERT_TRUE(dataset.status().ok()) << dataset.status();
auto ds = dataset.value();

mdio::RangeDescriptor<mdio::Index> desc1 = {"inline", 0, 256, 2};
auto sliceRes = ds.isel(desc1);
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
ds = sliceRes.value();

auto ilVarRes = ds.variables.get<mdio::dtypes::uint32_t>("inline");
ASSERT_TRUE(ilVarRes.status().ok()) << ilVarRes.status();
auto ilVar = ilVarRes.value();

auto ilDataRes = mdio::from_variable<mdio::dtypes::uint32_t>(ilVar);
ASSERT_TRUE(ilDataRes.status().ok()) << ilDataRes.status();
auto ilData = ilDataRes.value();

auto ilAccessor = ilData.get_data_accessor().data();
for (uint32_t i = 0; i < 128; i++) {
ilAccessor[i] = i * 2;
}

auto ilFut = ilVar.Write(ilData);

ASSERT_TRUE(ilFut.status().ok()) << ilFut.status();

// --- Begin new QC data generation for the "image" variable (float32) ---
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
auto imageVar = imageVarRes.value();

auto imageDataRes = mdio::from_variable<mdio::dtypes::float32_t>(imageVar);
ASSERT_TRUE(imageDataRes.status().ok()) << imageDataRes.status();
auto imageData = imageDataRes.value();

auto imageAccessor = imageData.get_data_accessor().data();
for (uint32_t i = 0; i < 128; i++) {
for (uint32_t j = 0; j < 512; j++) {
for (uint32_t k = 0; k < 384; k++) {
imageAccessor[i * (512 * 384) + j * 384 + k] =
static_cast<float>(i * 2) + j * 0.1f + k * 0.01f;
}
}
}

auto imageWriteFut = imageVar.Write(imageData);
ASSERT_TRUE(imageWriteFut.status().ok()) << imageWriteFut.status();
} // end of scoping the dataset creation to ensure the variables are cleaned
// up before the testing.

auto reopenedDsFut = mdio::Dataset::Open(iselPath, mdio::constants::kOpen);
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
auto reopenedDs = reopenedDsFut.value();

auto inlineVarRes =
reopenedDs.variables.get<mdio::dtypes::uint32_t>("inline");
ASSERT_TRUE(inlineVarRes.status().ok()) << inlineVarRes.status();
auto inlineVar = inlineVarRes.value();

auto inlineDataFut = inlineVar.Read();
ASSERT_TRUE(inlineDataFut.status().ok()) << inlineDataFut.status();
auto inlineData = inlineDataFut.value();
auto inlineAccessor = inlineData.get_data_accessor().data();
for (uint32_t i = 0; i < 256; i++) {
if (i % 2 == 0) {
ASSERT_EQ(inlineAccessor[i], i) << "Expected inline value to be " << i
<< " but got " << inlineAccessor[i];
} else {
ASSERT_EQ(inlineAccessor[i], 0)
<< "Expected inline value to be 0 but got " << inlineAccessor[i];
}
}

auto imageVarResReopen =
reopenedDs.variables.get<mdio::dtypes::float32_t>("image");
ASSERT_TRUE(imageVarResReopen.status().ok()) << imageVarResReopen.status();
auto imageVarReopen = imageVarResReopen.value();

auto imageDataFut = imageVarReopen.Read();
ASSERT_TRUE(imageDataFut.status().ok()) << imageDataFut.status();
auto imageDataFull = imageDataFut.value();
auto imageAccessorFull = imageDataFull.get_data_accessor().data();

// Instead of checking all 256x512x384 elements (which can be very time
// consuming), we check a few sample indices. For full "image" variable, for
// every full inline index i: if (i % 2 == 0): the expected value is i +
// j*0.1f + k*0.01f, otherwise NaN.
std::vector<uint32_t> sample_i = {0, 1, 2,
255}; // mix of even and odd indices
std::vector<uint32_t> sample_j = {0, 256, 511};
std::vector<uint32_t> sample_k = {0, 100, 383};

for (auto i : sample_i) {
for (auto j : sample_j) {
for (auto k : sample_k) {
size_t index = i * (512 * 384) + j * 384 + k;
float actual = imageAccessorFull[index];

if (i % 2 == 0) {
// For even indices, we expect a specific value
float expected = static_cast<float>(i) + j * 0.1f + k * 0.01f;
ASSERT_FLOAT_EQ(actual, expected)
<< "QC mismatch in image variable at (" << i << ", " << j << ", "
<< k << ")";
} else {
// For odd indices, we expect NaN
ASSERT_TRUE(std::isnan(actual))
<< "Expected NaN at (" << i << ", " << j << ", " << k
<< ") but got " << actual;
}
}
}
}
// --- End new QC check for the "image" variable ---
}

TEST(Dataset, iselWithStrideAndExistingData) {
std::string testPath = "zarrs/slice_scale_test";
float scaleFactor = 2.5f;

// --- Step 1: Initialize the entire image variable with QC values and Write
// it ---
{
// Create a new dataset
auto json_vars = GetToyExample();
auto dataset = mdio::Dataset::from_json(json_vars, testPath,
mdio::constants::kCreateClean);
ASSERT_TRUE(dataset.status().ok()) << dataset.status();
auto ds = dataset.value();

// Get the "image" variable (expected to be float32_t type)
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
auto imageVar = imageVarRes.value();

auto imageDataRes = mdio::from_variable<mdio::dtypes::float32_t>(imageVar);
ASSERT_TRUE(imageDataRes.status().ok()) << imageDataRes.status();
auto imageData = imageDataRes.value();
auto imageAccessor = imageData.get_data_accessor().data();

// Initialize the entire "image" variable with QC values.
// For this test, we assume dimensions 256 x 512 x 384.
for (uint32_t i = 0; i < 256; i++) {
for (uint32_t j = 0; j < 512; j++) {
for (uint32_t k = 0; k < 384; k++) {
imageAccessor[i * (512 * 384) + j * 384 + k] =
static_cast<float>(i) + j * 0.1f + k * 0.01f;
}
}
}

auto writeFut = imageVar.Write(imageData);
ASSERT_TRUE(writeFut.status().ok()) << writeFut.status();
} // End of Step 1

// --- Step 2: Slice with stride of 2 and scale the values of the "image"
// variable ---
{
// Re-open the dataset for modifications.
auto reopenedDsFut = mdio::Dataset::Open(testPath, mdio::constants::kOpen);
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
auto ds = reopenedDsFut.value();

// Slice the dataset along the "inline" dimension using a stride of 2.
mdio::RangeDescriptor<mdio::Index> desc = {"inline", 0, 256, 2};
auto sliceRes = ds.isel(desc);
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
auto ds_slice = sliceRes.value();

// Get the "image" variable from the sliced dataset.
auto imageVarRes = ds_slice.variables.get<mdio::dtypes::float32_t>("image");
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
auto imageVar = imageVarRes.value();

auto imageDataFut = imageVar.Read();
ASSERT_TRUE(imageDataFut.status().ok()) << imageDataFut.status();
auto imageData = imageDataFut.value();
auto imageAccessor = imageData.get_data_accessor().data();

// The sliced "image" now has dimensions 128 x 512 x 384 because we selected
// every 2nd index. Scale each element in the slice by 'scaleFactor'
for (uint32_t ii = 0; ii < 128;
ii++) { // 'ii' corresponds to original index i = ii * 2.
for (uint32_t j = 0; j < 512; j++) {
for (uint32_t k = 0; k < 384; k++) {
size_t index = ii * (512 * 384) + j * 384 + k;
imageAccessor[index] *= scaleFactor;
}
}
}
// Write the updated (scaled) data back to the dataset.
auto writeFut = imageVar.Write(imageData);
ASSERT_TRUE(writeFut.status().ok()) << writeFut.status();
} // End of Step 2

// --- Step 3: Read the entire image variable and validate QC values ---
{
// Re-open the dataset for the final validation.
auto reopenedDsFut = mdio::Dataset::Open(testPath, mdio::constants::kOpen);
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
auto ds = reopenedDsFut.value();

auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
auto imageVar = imageVarRes.value();

auto imageReadFut = imageVar.Read();
ASSERT_TRUE(imageReadFut.status().ok()) << imageReadFut.status();
auto imageData = imageReadFut.value();
auto imageAccessor = imageData.get_data_accessor().data();

// Validate the values over the entire "image" variable.
// For even inline indices (i % 2 == 0) we expect the initial QC value
// scaled by 'scaleFactor'. For odd inline indices, the original QC values
// should remain.
for (uint32_t i = 0; i < 256; i++) {
for (uint32_t j = 0; j < 512; j++) {
for (uint32_t k = 0; k < 384; k++) {
size_t index = i * (512 * 384) + j * 384 + k;
float baseValue = static_cast<float>(i) + j * 0.1f + k * 0.01f;
float expected = (i % 2 == 0) ? baseValue * scaleFactor : baseValue;
auto val = imageAccessor[index];
ASSERT_FLOAT_EQ(val, expected)
<< "Mismatch at (" << i << ", " << j << ", " << k
<< "): expected " << expected << ", but got " << val;
}
}
}
} // End of Step 3
}

TEST(Dataset, selValue) {
std::string path = "zarrs/selTester.mdio";
auto dsRes = makePopulated(path);
Expand Down
23 changes: 2 additions & 21 deletions mdio/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,7 @@ class Variable {
* @brief Slices the Variable along the specified dimensions and returns the
* resulting sub-Variable. This slice is performed as a half open interval.
* Dimensions that are not described will remain fully intact.
* @pre The step of the descriptor object must be 1.
* @pre The start of the descriptor object must be less than the stop.
* @pre The start of the slice descriptor must be less than the stop.
* @post The resulting Variable will be sliced along the specified dimensions
* within it's domain. If the slice lay outside of the domain of the Variable,
* the slice will be clamped to the domain.
Expand Down Expand Up @@ -1054,7 +1053,6 @@ class Variable {
stop.reserve(numDescriptors);
step.reserve(numDescriptors);
// -1 Everything is ok
// -2 Error: Step is not 1
// >=0 Error: Start is greater than or equal to stop
int8_t preconditionStatus = -1;

Expand All @@ -1063,10 +1061,6 @@ class Variable {
size_t idx = 0;
((
[&] {
if (desc.step != 1) {
preconditionStatus = -2;
return -2;
}
auto clampedDesc = sliceInRange(desc);
if (clampedDesc.start > clampedDesc.stop) {
preconditionStatus = idx;
Expand All @@ -1086,10 +1080,7 @@ class Variable {
},
tuple_descs);

if (preconditionStatus == -2) {
return absl::InvalidArgumentError(
"Only step 1 is supported for slicing.");
} else if (preconditionStatus >= 0) {
if (preconditionStatus >= 0) {
mdio::RangeDescriptor<Index> err;
std::apply(
[&](const auto&... desc) {
Expand Down Expand Up @@ -1641,8 +1632,6 @@ struct LabeledArray {

tensorstore::DimensionIndexBuffer buffer;

bool preconditionStatus = true;

absl::Status overall_status = absl::OkStatus();
std::apply(
[&](const auto&... desc) {
Expand All @@ -1658,21 +1647,13 @@ struct LabeledArray {
overall_status = result; // Capture the error status
return; // Exit lambda on error
}
if (desc.step != 1) {
preconditionStatus = false;
}
dims[idx] = buffer[0];
}(),
idx++),
...);
},
tuple_descs);

if (!preconditionStatus) {
return absl::InvalidArgumentError(
"Only step 1 is supported for slicing.");
}

/// could be we can't slice a dimension
if (!overall_status.ok()) {
return overall_status;
Expand Down
26 changes: 1 addition & 25 deletions mdio/variable_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -756,30 +756,11 @@ TEST(Variable, outOfBoundsSlice) {
EXPECT_THAT(badDomain.dimensions().shape(), ::testing::ElementsAre(250, 500))
<< badDomain.dimensions();

mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
// var = mdio::Variable<>::Open(json_good,
// mdio::constants::kCreateClean).result(); auto illegal =
// var.value().slice(illegal_step); EXPECT_FALSE(illegal.status().ok()) <<
// "Step precondition was violated but still sliced";

// mdio::RangeDescriptor<mdio::Index> illegal_start_stop = {"x", 500, 0, 1};
// illegal = var.value().slice(illegal_start_stop);
// EXPECT_FALSE(illegal.status().ok()) << "Start stop precondition was
// violated but still sliced";

// mdio::RangeDescriptor<mdio::Index> same_idx = {"x", 1, 1, 1};
// auto legal = var.value().slice(same_idx);
// EXPECT_TRUE(legal.status().ok()) <<
// legal.status();mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0,
// 500, 2};
auto var1 =
mdio::Variable<>::Open(json_good, mdio::constants::kCreateClean).result();
auto illegal = var1.value().slice(illegal_step);
EXPECT_FALSE(illegal.status().ok())
<< "Step precondition was violated but still sliced";

mdio::RangeDescriptor<mdio::Index> illegal_start_stop = {"x", 500, 0, 1};
illegal = var1.value().slice(illegal_start_stop);
auto illegal = var1.value().slice(illegal_start_stop);
EXPECT_FALSE(illegal.status().ok())
<< "Start stop precondition was violated but still sliced";

Expand Down Expand Up @@ -895,11 +876,6 @@ TEST(VariableData, outOfBoundsSlice) {
auto outbounds = varData.slice(x_outbounds, y_inbounds);
EXPECT_FALSE(outbounds.status().ok())
<< "Slicing out of bounds should fail but did not";

mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
auto illegal = varData.slice(illegal_step);
EXPECT_FALSE(illegal.status().ok())
<< "Step precondition was violated but still sliced";
}

TEST(VariableSpec, open) {
Expand Down