diff --git a/kerchunk/netCDF3.py b/kerchunk/netCDF3.py index d43b6b97..c5986b3e 100644 --- a/kerchunk/netCDF3.py +++ b/kerchunk/netCDF3.py @@ -5,7 +5,7 @@ from fsspec.implementations.reference import LazyReferenceMapper import fsspec -from kerchunk.utils import _encode_for_JSON, inline_array +from kerchunk.utils import _encode_for_JSON, inline_array, zarr_init_group_and_store try: from scipy.io._netcdf import ZERO, NC_VARIABLE, netcdf_file, netcdf_variable @@ -31,6 +31,7 @@ def __init__( inline_threshold=100, max_chunk_size=0, out=None, + zarr_version=None, **kwargs, ): """ @@ -52,6 +53,9 @@ def __init__( This allows you to supply an fsspec.implementations.reference.LazyReferenceMapper to write out parquet as the references get filled, or some other dictionary-like class to customise how references get stored + zarr_version: int + The desired zarr spec version to target (currently 2 or 3). The default + of None will use the default zarr version. args, kwargs: passed to scipy superclass ``scipy.io.netcdf.netcdf_file`` """ assert kwargs.pop("mmap", False) is False @@ -63,6 +67,7 @@ def __init__( self.chunks = {} self.threshold = inline_threshold self.max_chunk_size = max_chunk_size + self.zarr_version = zarr_version self.out = out or {} self.storage_options = storage_options self.fp = fsspec.open(filename, **(storage_options or {})).open() @@ -164,10 +169,9 @@ def translate(self): Parameters ---------- """ - import zarr out = self.out - z = zarr.open(out, mode="w") + zroot, out = zarr_init_group_and_store(out, self.zarr_version) for dim, var in self.variables.items(): if dim in self.chunks: shape = self.chunks[dim][-1] @@ -191,17 +195,25 @@ def translate(self): fill = float(fill) if fill is not None and var.data.dtype.kind == "i": fill = int(fill) - arr = z.create_dataset( + arr = zroot.create_dataset( name=dim, shape=shape, dtype=var.data.dtype, fill_value=fill, chunks=shape, compression=None, + overwrite=True, ) - part = ".".join(["0"] * len(shape)) or "0" - k = f"{dim}/{part}" - out[k] = [ + + if self.zarr_version == 3: + part = "/".join(["0"] * len(shape)) or "0" + key = f"data/root/{dim}/c{part}" + else: + part = ".".join(["0"] * len(shape)) or "0" + + key = f"{dim}/{part}" + + self.out[key] = [self.filename] + [ self.filename, int(self.chunks[dim][0]), int(self.chunks[dim][1]), @@ -245,13 +257,14 @@ def translate(self): fill = float(fill) if fill is not None and base.kind == "i": fill = int(fill) - arr = z.create_dataset( + arr = zroot.create_dataset( name=name, shape=shape, dtype=base, fill_value=fill, chunks=(1,) + dtype.shape, compression=None, + overwrite=True, ) arr.attrs.update( { @@ -266,18 +279,33 @@ def translate(self): arr.attrs["_ARRAY_DIMENSIONS"] = list(var.dimensions) - suffix = ( - ("." + ".".join("0" for _ in dtype.shape)) if dtype.shape else "" - ) + if self.zarr_version == 3: + suffix = ( + ("/" + "/".join("0" for _ in dtype.shape)) + if dtype.shape + else "" + ) + else: + suffix = ( + ("." + ".".join("0" for _ in dtype.shape)) + if dtype.shape + else "" + ) + for i in range(outer_shape): - out[f"{name}/{i}{suffix}"] = [ + if self.zarr_version == 3: + key = f"data/root/{name}/c{i}{suffix}" + else: + key = f"{name}/{i}{suffix}" + + self.out[key] = [ self.filename, int(offset + i * dt.itemsize), int(dtype.itemsize), ] offset += dtype.itemsize - z.attrs.update( + zroot.attrs.update( { k: v.decode() if isinstance(v, bytes) else str(v) for k, v in self._attributes.items() diff --git a/kerchunk/tests/test_netcdf.py b/kerchunk/tests/test_netcdf.py index 43b6021b..b79793a7 100644 --- a/kerchunk/tests/test_netcdf.py +++ b/kerchunk/tests/test_netcdf.py @@ -24,15 +24,17 @@ ) -def test_one(m): +@pytest.mark.parametrize("zarr_version", [2, 3]) +def test_one(m, zarr_version): m.pipe("data.nc3", bdata) - h = netCDF3.netcdf_recording_file("memory://data.nc3") + h = netCDF3.netcdf_recording_file("memory://data.nc3", zarr_version=zarr_version) out = h.translate() ds = xr.open_dataset( "reference://", engine="zarr", backend_kwargs={ "consolidated": False, + "zarr_version": zarr_version, "storage_options": {"fo": out, "remote_protocol": "memory"}, }, ) @@ -76,16 +78,18 @@ def unlimited_dataset(tmpdir): return fn -def test_unlimited(unlimited_dataset): +@pytest.mark.parametrize("zarr_version", [2, 3]) +def test_unlimited(unlimited_dataset, zarr_version): fn = unlimited_dataset expected = xr.open_dataset(fn, engine="scipy") - h = netCDF3.NetCDF3ToZarr(fn) + h = netCDF3.NetCDF3ToZarr(fn, zarr_version=zarr_version) out = h.translate() ds = xr.open_dataset( "reference://", engine="zarr", backend_kwargs={ "consolidated": False, + "zarr_version": zarr_version, "storage_options": {"fo": out}, }, )