Skip to content

Commit a61699e

Browse files
authored
Merge pull request #158 from TGSAI/slice_stride
Strided slices
2 parents 7190507 + 082ad62 commit a61699e

File tree

3 files changed

+241
-46
lines changed

3 files changed

+241
-46
lines changed

mdio/dataset_test.cc

+238
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,244 @@ TEST(Dataset, isel) {
411411
<< "Inline range should end at 5";
412412
}
413413

414+
TEST(Dataset, iselWithStride) {
415+
// Tests the integrity of data that is written with a strided slice.
416+
std::string iselPath = "zarrs/acceptance";
417+
{ // Scoping the dataset creation to ensure the variables are cleaned up
418+
// before the testing.
419+
auto json_vars = GetToyExample();
420+
auto dataset = mdio::Dataset::from_json(json_vars, iselPath,
421+
mdio::constants::kCreateClean);
422+
ASSERT_TRUE(dataset.status().ok()) << dataset.status();
423+
auto ds = dataset.value();
424+
425+
mdio::RangeDescriptor<mdio::Index> desc1 = {"inline", 0, 256, 2};
426+
auto sliceRes = ds.isel(desc1);
427+
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
428+
ds = sliceRes.value();
429+
430+
auto ilVarRes = ds.variables.get<mdio::dtypes::uint32_t>("inline");
431+
ASSERT_TRUE(ilVarRes.status().ok()) << ilVarRes.status();
432+
auto ilVar = ilVarRes.value();
433+
434+
auto ilDataRes = mdio::from_variable<mdio::dtypes::uint32_t>(ilVar);
435+
ASSERT_TRUE(ilDataRes.status().ok()) << ilDataRes.status();
436+
auto ilData = ilDataRes.value();
437+
438+
auto ilAccessor = ilData.get_data_accessor().data();
439+
for (uint32_t i = 0; i < 128; i++) {
440+
ilAccessor[i] = i * 2;
441+
}
442+
443+
auto ilFut = ilVar.Write(ilData);
444+
445+
ASSERT_TRUE(ilFut.status().ok()) << ilFut.status();
446+
447+
// --- Begin new QC data generation for the "image" variable (float32) ---
448+
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
449+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
450+
auto imageVar = imageVarRes.value();
451+
452+
auto imageDataRes = mdio::from_variable<mdio::dtypes::float32_t>(imageVar);
453+
ASSERT_TRUE(imageDataRes.status().ok()) << imageDataRes.status();
454+
auto imageData = imageDataRes.value();
455+
456+
auto imageAccessor = imageData.get_data_accessor().data();
457+
for (uint32_t i = 0; i < 128; i++) {
458+
for (uint32_t j = 0; j < 512; j++) {
459+
for (uint32_t k = 0; k < 384; k++) {
460+
imageAccessor[i * (512 * 384) + j * 384 + k] =
461+
static_cast<float>(i * 2) + j * 0.1f + k * 0.01f;
462+
}
463+
}
464+
}
465+
466+
auto imageWriteFut = imageVar.Write(imageData);
467+
ASSERT_TRUE(imageWriteFut.status().ok()) << imageWriteFut.status();
468+
} // end of scoping the dataset creation to ensure the variables are cleaned
469+
// up before the testing.
470+
471+
auto reopenedDsFut = mdio::Dataset::Open(iselPath, mdio::constants::kOpen);
472+
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
473+
auto reopenedDs = reopenedDsFut.value();
474+
475+
auto inlineVarRes =
476+
reopenedDs.variables.get<mdio::dtypes::uint32_t>("inline");
477+
ASSERT_TRUE(inlineVarRes.status().ok()) << inlineVarRes.status();
478+
auto inlineVar = inlineVarRes.value();
479+
480+
auto inlineDataFut = inlineVar.Read();
481+
ASSERT_TRUE(inlineDataFut.status().ok()) << inlineDataFut.status();
482+
auto inlineData = inlineDataFut.value();
483+
auto inlineAccessor = inlineData.get_data_accessor().data();
484+
for (uint32_t i = 0; i < 256; i++) {
485+
if (i % 2 == 0) {
486+
ASSERT_EQ(inlineAccessor[i], i) << "Expected inline value to be " << i
487+
<< " but got " << inlineAccessor[i];
488+
} else {
489+
ASSERT_EQ(inlineAccessor[i], 0)
490+
<< "Expected inline value to be 0 but got " << inlineAccessor[i];
491+
}
492+
}
493+
494+
auto imageVarResReopen =
495+
reopenedDs.variables.get<mdio::dtypes::float32_t>("image");
496+
ASSERT_TRUE(imageVarResReopen.status().ok()) << imageVarResReopen.status();
497+
auto imageVarReopen = imageVarResReopen.value();
498+
499+
auto imageDataFut = imageVarReopen.Read();
500+
ASSERT_TRUE(imageDataFut.status().ok()) << imageDataFut.status();
501+
auto imageDataFull = imageDataFut.value();
502+
auto imageAccessorFull = imageDataFull.get_data_accessor().data();
503+
504+
// Instead of checking all 256x512x384 elements (which can be very time
505+
// consuming), we check a few sample indices. For full "image" variable, for
506+
// every full inline index i: if (i % 2 == 0): the expected value is i +
507+
// j*0.1f + k*0.01f, otherwise NaN.
508+
std::vector<uint32_t> sample_i = {0, 1, 2,
509+
255}; // mix of even and odd indices
510+
std::vector<uint32_t> sample_j = {0, 256, 511};
511+
std::vector<uint32_t> sample_k = {0, 100, 383};
512+
513+
for (auto i : sample_i) {
514+
for (auto j : sample_j) {
515+
for (auto k : sample_k) {
516+
size_t index = i * (512 * 384) + j * 384 + k;
517+
float actual = imageAccessorFull[index];
518+
519+
if (i % 2 == 0) {
520+
// For even indices, we expect a specific value
521+
float expected = static_cast<float>(i) + j * 0.1f + k * 0.01f;
522+
ASSERT_FLOAT_EQ(actual, expected)
523+
<< "QC mismatch in image variable at (" << i << ", " << j << ", "
524+
<< k << ")";
525+
} else {
526+
// For odd indices, we expect NaN
527+
ASSERT_TRUE(std::isnan(actual))
528+
<< "Expected NaN at (" << i << ", " << j << ", " << k
529+
<< ") but got " << actual;
530+
}
531+
}
532+
}
533+
}
534+
// --- End new QC check for the "image" variable ---
535+
}
536+
537+
TEST(Dataset, iselWithStrideAndExistingData) {
538+
std::string testPath = "zarrs/slice_scale_test";
539+
float scaleFactor = 2.5f;
540+
541+
// --- Step 1: Initialize the entire image variable with QC values and Write
542+
// it ---
543+
{
544+
// Create a new dataset
545+
auto json_vars = GetToyExample();
546+
auto dataset = mdio::Dataset::from_json(json_vars, testPath,
547+
mdio::constants::kCreateClean);
548+
ASSERT_TRUE(dataset.status().ok()) << dataset.status();
549+
auto ds = dataset.value();
550+
551+
// Get the "image" variable (expected to be float32_t type)
552+
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
553+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
554+
auto imageVar = imageVarRes.value();
555+
556+
auto imageDataRes = mdio::from_variable<mdio::dtypes::float32_t>(imageVar);
557+
ASSERT_TRUE(imageDataRes.status().ok()) << imageDataRes.status();
558+
auto imageData = imageDataRes.value();
559+
auto imageAccessor = imageData.get_data_accessor().data();
560+
561+
// Initialize the entire "image" variable with QC values.
562+
// For this test, we assume dimensions 256 x 512 x 384.
563+
for (uint32_t i = 0; i < 256; i++) {
564+
for (uint32_t j = 0; j < 512; j++) {
565+
for (uint32_t k = 0; k < 384; k++) {
566+
imageAccessor[i * (512 * 384) + j * 384 + k] =
567+
static_cast<float>(i) + j * 0.1f + k * 0.01f;
568+
}
569+
}
570+
}
571+
572+
auto writeFut = imageVar.Write(imageData);
573+
ASSERT_TRUE(writeFut.status().ok()) << writeFut.status();
574+
} // End of Step 1
575+
576+
// --- Step 2: Slice with stride of 2 and scale the values of the "image"
577+
// variable ---
578+
{
579+
// Re-open the dataset for modifications.
580+
auto reopenedDsFut = mdio::Dataset::Open(testPath, mdio::constants::kOpen);
581+
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
582+
auto ds = reopenedDsFut.value();
583+
584+
// Slice the dataset along the "inline" dimension using a stride of 2.
585+
mdio::RangeDescriptor<mdio::Index> desc = {"inline", 0, 256, 2};
586+
auto sliceRes = ds.isel(desc);
587+
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
588+
auto ds_slice = sliceRes.value();
589+
590+
// Get the "image" variable from the sliced dataset.
591+
auto imageVarRes = ds_slice.variables.get<mdio::dtypes::float32_t>("image");
592+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
593+
auto imageVar = imageVarRes.value();
594+
595+
auto imageDataFut = imageVar.Read();
596+
ASSERT_TRUE(imageDataFut.status().ok()) << imageDataFut.status();
597+
auto imageData = imageDataFut.value();
598+
auto imageAccessor = imageData.get_data_accessor().data();
599+
600+
// The sliced "image" now has dimensions 128 x 512 x 384 because we selected
601+
// every 2nd index. Scale each element in the slice by 'scaleFactor'
602+
for (uint32_t ii = 0; ii < 128;
603+
ii++) { // 'ii' corresponds to original index i = ii * 2.
604+
for (uint32_t j = 0; j < 512; j++) {
605+
for (uint32_t k = 0; k < 384; k++) {
606+
size_t index = ii * (512 * 384) + j * 384 + k;
607+
imageAccessor[index] *= scaleFactor;
608+
}
609+
}
610+
}
611+
// Write the updated (scaled) data back to the dataset.
612+
auto writeFut = imageVar.Write(imageData);
613+
ASSERT_TRUE(writeFut.status().ok()) << writeFut.status();
614+
} // End of Step 2
615+
616+
// --- Step 3: Read the entire image variable and validate QC values ---
617+
{
618+
// Re-open the dataset for the final validation.
619+
auto reopenedDsFut = mdio::Dataset::Open(testPath, mdio::constants::kOpen);
620+
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
621+
auto ds = reopenedDsFut.value();
622+
623+
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
624+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
625+
auto imageVar = imageVarRes.value();
626+
627+
auto imageReadFut = imageVar.Read();
628+
ASSERT_TRUE(imageReadFut.status().ok()) << imageReadFut.status();
629+
auto imageData = imageReadFut.value();
630+
auto imageAccessor = imageData.get_data_accessor().data();
631+
632+
// Validate the values over the entire "image" variable.
633+
// For even inline indices (i % 2 == 0) we expect the initial QC value
634+
// scaled by 'scaleFactor'. For odd inline indices, the original QC values
635+
// should remain.
636+
for (uint32_t i = 0; i < 256; i++) {
637+
for (uint32_t j = 0; j < 512; j++) {
638+
for (uint32_t k = 0; k < 384; k++) {
639+
size_t index = i * (512 * 384) + j * 384 + k;
640+
float baseValue = static_cast<float>(i) + j * 0.1f + k * 0.01f;
641+
float expected = (i % 2 == 0) ? baseValue * scaleFactor : baseValue;
642+
auto val = imageAccessor[index];
643+
ASSERT_FLOAT_EQ(val, expected)
644+
<< "Mismatch at (" << i << ", " << j << ", " << k
645+
<< "): expected " << expected << ", but got " << val;
646+
}
647+
}
648+
}
649+
} // End of Step 3
650+
}
651+
414652
TEST(Dataset, selValue) {
415653
std::string path = "zarrs/selTester.mdio";
416654
auto dsRes = makePopulated(path);

mdio/variable.h

+2-21
Original file line numberDiff line numberDiff line change
@@ -1024,8 +1024,7 @@ class Variable {
10241024
* @brief Slices the Variable along the specified dimensions and returns the
10251025
* resulting sub-Variable. This slice is performed as a half open interval.
10261026
* Dimensions that are not described will remain fully intact.
1027-
* @pre The step of the descriptor object must be 1.
1028-
* @pre The start of the descriptor object must be less than the stop.
1027+
* @pre The start of the slice descriptor must be less than the stop.
10291028
* @post The resulting Variable will be sliced along the specified dimensions
10301029
* within it's domain. If the slice lay outside of the domain of the Variable,
10311030
* the slice will be clamped to the domain.
@@ -1054,7 +1053,6 @@ class Variable {
10541053
stop.reserve(numDescriptors);
10551054
step.reserve(numDescriptors);
10561055
// -1 Everything is ok
1057-
// -2 Error: Step is not 1
10581056
// >=0 Error: Start is greater than or equal to stop
10591057
int8_t preconditionStatus = -1;
10601058

@@ -1063,10 +1061,6 @@ class Variable {
10631061
size_t idx = 0;
10641062
((
10651063
[&] {
1066-
if (desc.step != 1) {
1067-
preconditionStatus = -2;
1068-
return -2;
1069-
}
10701064
auto clampedDesc = sliceInRange(desc);
10711065
if (clampedDesc.start > clampedDesc.stop) {
10721066
preconditionStatus = idx;
@@ -1086,10 +1080,7 @@ class Variable {
10861080
},
10871081
tuple_descs);
10881082

1089-
if (preconditionStatus == -2) {
1090-
return absl::InvalidArgumentError(
1091-
"Only step 1 is supported for slicing.");
1092-
} else if (preconditionStatus >= 0) {
1083+
if (preconditionStatus >= 0) {
10931084
mdio::RangeDescriptor<Index> err;
10941085
std::apply(
10951086
[&](const auto&... desc) {
@@ -1641,8 +1632,6 @@ struct LabeledArray {
16411632

16421633
tensorstore::DimensionIndexBuffer buffer;
16431634

1644-
bool preconditionStatus = true;
1645-
16461635
absl::Status overall_status = absl::OkStatus();
16471636
std::apply(
16481637
[&](const auto&... desc) {
@@ -1658,21 +1647,13 @@ struct LabeledArray {
16581647
overall_status = result; // Capture the error status
16591648
return; // Exit lambda on error
16601649
}
1661-
if (desc.step != 1) {
1662-
preconditionStatus = false;
1663-
}
16641650
dims[idx] = buffer[0];
16651651
}(),
16661652
idx++),
16671653
...);
16681654
},
16691655
tuple_descs);
16701656

1671-
if (!preconditionStatus) {
1672-
return absl::InvalidArgumentError(
1673-
"Only step 1 is supported for slicing.");
1674-
}
1675-
16761657
/// could be we can't slice a dimension
16771658
if (!overall_status.ok()) {
16781659
return overall_status;

mdio/variable_test.cc

+1-25
Original file line numberDiff line numberDiff line change
@@ -756,30 +756,11 @@ TEST(Variable, outOfBoundsSlice) {
756756
EXPECT_THAT(badDomain.dimensions().shape(), ::testing::ElementsAre(250, 500))
757757
<< badDomain.dimensions();
758758

759-
mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
760-
// var = mdio::Variable<>::Open(json_good,
761-
// mdio::constants::kCreateClean).result(); auto illegal =
762-
// var.value().slice(illegal_step); EXPECT_FALSE(illegal.status().ok()) <<
763-
// "Step precondition was violated but still sliced";
764-
765-
// mdio::RangeDescriptor<mdio::Index> illegal_start_stop = {"x", 500, 0, 1};
766-
// illegal = var.value().slice(illegal_start_stop);
767-
// EXPECT_FALSE(illegal.status().ok()) << "Start stop precondition was
768-
// violated but still sliced";
769-
770-
// mdio::RangeDescriptor<mdio::Index> same_idx = {"x", 1, 1, 1};
771-
// auto legal = var.value().slice(same_idx);
772-
// EXPECT_TRUE(legal.status().ok()) <<
773-
// legal.status();mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0,
774-
// 500, 2};
775759
auto var1 =
776760
mdio::Variable<>::Open(json_good, mdio::constants::kCreateClean).result();
777-
auto illegal = var1.value().slice(illegal_step);
778-
EXPECT_FALSE(illegal.status().ok())
779-
<< "Step precondition was violated but still sliced";
780761

781762
mdio::RangeDescriptor<mdio::Index> illegal_start_stop = {"x", 500, 0, 1};
782-
illegal = var1.value().slice(illegal_start_stop);
763+
auto illegal = var1.value().slice(illegal_start_stop);
783764
EXPECT_FALSE(illegal.status().ok())
784765
<< "Start stop precondition was violated but still sliced";
785766

@@ -895,11 +876,6 @@ TEST(VariableData, outOfBoundsSlice) {
895876
auto outbounds = varData.slice(x_outbounds, y_inbounds);
896877
EXPECT_FALSE(outbounds.status().ok())
897878
<< "Slicing out of bounds should fail but did not";
898-
899-
mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
900-
auto illegal = varData.slice(illegal_step);
901-
EXPECT_FALSE(illegal.status().ok())
902-
<< "Step precondition was violated but still sliced";
903879
}
904880

905881
TEST(VariableSpec, open) {

0 commit comments

Comments
 (0)