Skip to content

Commit

Permalink
Merge pull request #1 from melissalinkert/fix-sharding
Browse files Browse the repository at this point in the history
Fix preset sharding options and add tests
  • Loading branch information
sbesson authored Jul 26, 2024
2 parents e8e2b5b + 1c817c1 commit a0ecef2
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 16 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ dependencies {

test {
useJUnit()

maxHeapSize = "2g"
}

jar {
Expand Down
101 changes: 85 additions & 16 deletions src/main/java/com/glencoesoftware/zarr/Convert.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.Callable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ch.qos.logback.classic.Level;
import picocli.CommandLine;
import picocli.CommandLine.Option;
import picocli.CommandLine.Parameters;
Expand All @@ -61,6 +63,7 @@ public class Convert implements Callable<Integer> {

private String inputLocation;
private String outputLocation;
private String logLevel = "INFO";
private boolean writeV2;

private ShardConfiguration shardConfig;
Expand Down Expand Up @@ -90,6 +93,26 @@ public void setOutput(String output) {
outputLocation = output;
}

/**
* Set the slf4j logging level. Defaults to "INFO".
*
* @param level logging level
*/
@Option(
names = {"--log-level", "--debug"},
arity = "0..1",
description = "Change logging level; valid values are " +
"OFF, ERROR, WARN, INFO, DEBUG, TRACE and ALL. " +
"(default: ${DEFAULT-VALUE})",
defaultValue = "INFO",
fallbackValue = "DEBUG"
)
public void setLogLevel(String level) {
if (level != null) {
logLevel = level;
}
}

@Option(
names = "--write-v2",
description = "Read v3, write v2",
Expand Down Expand Up @@ -134,6 +157,10 @@ public void setCompression(String[] compression) {

@Override
public Integer call() throws Exception {
ch.qos.logback.classic.Logger root = (ch.qos.logback.classic.Logger)
LoggerFactory.getLogger(Logger.ROOT_LOGGER_NAME);
root.setLevel(Level.toLevel(logLevel));

if (writeV2) {
convertToV2();
}
Expand All @@ -151,6 +178,7 @@ public void convertToV3() throws Exception {
Path inputPath = Paths.get(inputLocation);

// get the root-level attributes
LOGGER.debug("opening v2 root group: {}", inputPath);
ZarrGroup reader = ZarrGroup.open(inputPath);
Map<String, Object> attributes = reader.getAttributes();

Expand All @@ -163,6 +191,7 @@ public void convertToV3() throws Exception {
// but this doesn't seem to actually create the group
// separating the group creation and attribute writing into
// two calls seems to work correctly
LOGGER.debug("opening v3 root group: {}", outputLocation);
FilesystemStore outputStore = new FilesystemStore(outputLocation);
Group outputRootGroup = Group.create(outputStore.resolve());
outputRootGroup.setAttributes(attributes);
Expand All @@ -175,9 +204,11 @@ public void convertToV3() throws Exception {

for (String seriesGroupKey : groupKeys) {
if (seriesGroupKey.indexOf("/") > 0) {
LOGGER.debug("skipping v2 group key: {}", seriesGroupKey);
continue;
}
Path seriesPath = inputPath.resolve(seriesGroupKey);
LOGGER.debug("opening v2 group: {}", seriesPath);
ZarrGroup seriesGroup = ZarrGroup.open(seriesPath);
LOGGER.info("opened {}", seriesPath);

Expand All @@ -190,13 +221,16 @@ public void convertToV3() throws Exception {
Set<String> columnKeys = seriesGroup.getGroupKeys();
// "pass through" if this is not HCS
if (columnKeys.size() == 0) {
LOGGER.debug("no column group keys (likely not HCS)");
columnKeys.add("");
}
for (String columnKey : columnKeys) {
if (columnKey.indexOf("/") > 0) {
LOGGER.debug("skipping v2 column group key: {}", columnKey);
continue;
}
Path columnPath = columnKey.isEmpty() ? seriesPath : seriesPath.resolve(columnKey);
LOGGER.debug("opening v2 group: {}", columnPath);
ZarrGroup column = ZarrGroup.open(columnPath);

if (!columnKey.isEmpty()) {
Expand All @@ -208,14 +242,15 @@ public void convertToV3() throws Exception {
Set<String> fieldKeys = column.getGroupKeys();
// "pass through" if this is not HCS
if (fieldKeys.size() == 0) {
LOGGER.debug("no field group keys");
fieldKeys.add("");
}

for (String fieldKey : fieldKeys) {
Path fieldPath = fieldKey.isEmpty() ? columnPath : columnPath.resolve(fieldKey);
LOGGER.debug("opening v2 field group: {}", fieldPath);
ZarrGroup field = ZarrGroup.open(fieldPath);


Map<String, Object> fieldAttributes = field.getAttributes();
if (!fieldKey.isEmpty()) {
Group outputFieldGroup = Group.create(outputStore.resolve(seriesGroupKey, columnKey, fieldKey));
Expand All @@ -239,12 +274,16 @@ public void convertToV3() throws Exception {

for (int res=0; res<totalResolutions; res++) {
String resolutionPath = fieldPath + "/" + res;
LOGGER.debug("opening v2 array: {}", resolutionPath);

ZarrArray tile = field.openArray("/" + res);
LOGGER.info("opened array {}", resolutionPath);
int[] chunkSizes = tile.getChunks();
int[] originalChunkSizes = tile.getChunks();
int[] shape = tile.getShape();

int[] chunkSizes = new int[originalChunkSizes.length];
System.arraycopy(originalChunkSizes, 0, chunkSizes, 0, chunkSizes.length);

int[] gridPosition = new int[] {0, 0, 0, 0, 0};
int tileX = chunkSizes[chunkSizes.length - 2];
int tileY = chunkSizes[chunkSizes.length - 1];
Expand All @@ -257,22 +296,31 @@ public void convertToV3() throws Exception {
if (shardConfig != null) {
switch (shardConfig) {
case SINGLE:
codecBuilder = codecBuilder.withSharding(shape);
// single shard covering the whole image
// internal chunk sizes remain the same as in input data
chunkSizes = shape;
break;
case CHUNK:
codecBuilder = codecBuilder.withSharding(chunkSizes);
// exactly one shard per chunk
// no changes needed
break;
case SUPERCHUNK:
int[] shardSize = new int[chunkSizes.length];
System.arraycopy(chunkSizes, 0, shardSize, 0, shardSize.length);
shardSize[4] *= 2;
shardSize[3] *= 2;
codecBuilder = codecBuilder.withSharding(shardSize);
// each shard covers 2x2 chunks
chunkSizes[4] *= 2;
chunkSizes[3] *= 2;
break;
case CUSTOM:
// TODO
break;
}

if (chunkAndShardCompatible(originalChunkSizes, chunkSizes, shape)) {
codecBuilder = codecBuilder.withSharding(originalChunkSizes);
}
else {
LOGGER.warn("Skipping sharding due to incompatible sizes");
chunkSizes = originalChunkSizes;
}
}
if (codecs != null) {
for (String codecName : codecs) {
Expand All @@ -292,19 +340,21 @@ else if (codecName.equals("blosc")) {
}
final CodecBuilder builder = codecBuilder;

Array outputArray = Array.create(outputStore.resolve(seriesGroupKey, columnKey, fieldKey, String.valueOf(res)),
StoreHandle v3ArrayHandle = outputStore.resolve(seriesGroupKey, columnKey, fieldKey, String.valueOf(res));
LOGGER.debug("opening v3 array: {}", v3ArrayHandle);
Array outputArray = Array.create(v3ArrayHandle,
Array.metadataBuilder()
.withShape(Utils.toLongArray(shape))
.withDataType(getV3Type(type))
.withChunkShape(chunkSizes)
.withChunkShape(chunkSizes) // if sharding is used, this will be the shard size
.withFillValue(255)
.withCodecs(c -> builder)
.build()
);

for (int t=0; t<shape[0]; t+=chunkSizes[0]) {
for (int c=0; c<shape[1]; c+=chunkSizes[1]) {
for (int z=0; z<shape[2]; z+=chunkSizes[2]) {
for (int t=0; t<shape[0]; t+=originalChunkSizes[0]) {
for (int c=0; c<shape[1]; c+=originalChunkSizes[1]) {
for (int z=0; z<shape[2]; z+=originalChunkSizes[2]) {
// copy each chunk, keeping the original chunk sizes
for (int y=0; y<shape[4]; y+=tileY) {
for (int x=0; x<shape[3]; x+=tileX) {
Expand All @@ -313,8 +363,10 @@ else if (codecName.equals("blosc")) {
gridPosition[2] = z;
gridPosition[1] = c;
gridPosition[0] = t;
Object bytes = tile.read(chunkSizes, gridPosition);
outputArray.write(Utils.toLongArray(gridPosition), NetCDF_Util.createArrayWithGivenStorage(bytes, chunkSizes));
LOGGER.debug("copying chunk of size {} at position {}",
Arrays.toString(originalChunkSizes), Arrays.toString(gridPosition));
Object bytes = tile.read(originalChunkSizes, gridPosition);
outputArray.write(Utils.toLongArray(gridPosition), NetCDF_Util.createArrayWithGivenStorage(bytes, originalChunkSizes));
}
}
}
Expand Down Expand Up @@ -530,6 +582,23 @@ private DataType getV2Type(dev.zarr.zarrjava.v3.DataType v3) {
throw new IllegalArgumentException(v3.toString());
}

/**
* Check that the desired chunk, shard, and shape are compatible with each other.
* In each dimension, the chunk size must evenly divide into the shard size,
* which must evenly divide into the shape.
*/
private boolean chunkAndShardCompatible(int[] chunkSize, int[] shardSize, int[] shape) {
for (int d=0; d<shape.length; d++) {
if (shape[d] % shardSize[d] != 0) {
return false;
}
if (shardSize[d] % chunkSize[d] != 0) {
return false;
}
}
return true;
}

public static void main(String[] args) {
CommandLine.call(new Convert(), args);
}
Expand Down
59 changes: 59 additions & 0 deletions src/test/java/com/glencoesoftware/zarr/test/ConversionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,65 @@ public void testCodecs() throws Exception {
}
}

/**
* Test different sharding options
*/
@Test
public void testSharding() throws Exception {
input = fake("sizeX", "10240", "sizeY", "10240");
assertBioFormats2Raw();

String[] shardOptions = new String[] {
"SINGLE", "CHUNK", "SUPERCHUNK"
};
int[][] shardSizes = new int[][] {
{1, 1, 1, 10240, 10240},
{1, 1, 1, 1024, 1024},
{1, 1, 1, 2048, 2048}
};

for (int opt=0; opt<shardOptions.length; opt++) {
// first convert v2 produced by bioformats2raw to v3
Path v3Output = tmp.newFolder().toPath().resolve("v3-test");
Convert v3Converter = new Convert();
v3Converter.setInput(output.toString());
v3Converter.setOutput(v3Output.toString());

v3Converter.setSharding(shardOptions[opt]);
v3Converter.convertToV3();

// check list of codecs in the v3 arrays

Store store = new FilesystemStore(v3Output);
Array resolution = Array.open(store.resolve("0", "0"));

int[] shardSize = shardSizes[opt];
Assert.assertArrayEquals(resolution.metadata.chunkShape(), shardSize);

// now convert v3 back to v2
Path roundtripOutput = tmp.newFolder().toPath().resolve("v2-roundtrip-test");
Convert v2Converter = new Convert();
v2Converter.setInput(v3Output.toString());
v2Converter.setOutput(roundtripOutput.toString());
v2Converter.setWriteV2(true);
v2Converter.convertToV2();

Path originalOMEXML = output.resolve("OME").resolve("METADATA.ome.xml");
Path roundtripOMEXML = roundtripOutput.resolve("OME").resolve("METADATA.ome.xml");

// make sure the OME-XML is present and not changed
Assert.assertEquals(Files.readAllLines(originalOMEXML), Files.readAllLines(roundtripOMEXML));

// since the image is small, make sure all pixels are identical in both resolutions
for (int r=0; r<7; r++) {
ZarrArray original = ZarrGroup.open(output.resolve("0")).openArray(String.valueOf(r));
ZarrArray roundtrip = ZarrGroup.open(roundtripOutput.resolve("0")).openArray(String.valueOf(r));

compareZarrArrays(original, roundtrip);
}
}
}

private void compareZarrArrays(ZarrArray original, ZarrArray roundtrip) throws Exception {
Assert.assertArrayEquals(original.getShape(), roundtrip.getShape());

Expand Down

0 comments on commit a0ecef2

Please sign in to comment.