Skip to content

Commit

Permalink
netcdf3
Browse files Browse the repository at this point in the history
  • Loading branch information
ghidalgo3 committed Jul 15, 2024
1 parent 572f62d commit 653a260
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
54 changes: 41 additions & 13 deletions kerchunk/netCDF3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fsspec.implementations.reference import LazyReferenceMapper
import fsspec

from kerchunk.utils import _encode_for_JSON, inline_array
from kerchunk.utils import _encode_for_JSON, inline_array, zarr_init_group_and_store

try:
from scipy.io._netcdf import ZERO, NC_VARIABLE, netcdf_file, netcdf_variable
Expand All @@ -31,6 +31,7 @@ def __init__(
inline_threshold=100,
max_chunk_size=0,
out=None,
zarr_version=None,
**kwargs,
):
"""
Expand All @@ -52,6 +53,9 @@ def __init__(
This allows you to supply an fsspec.implementations.reference.LazyReferenceMapper
to write out parquet as the references get filled, or some other dictionary-like class
to customise how references get stored
zarr_version: int
The desired zarr spec version to target (currently 2 or 3). The default
of None will use the default zarr version.
args, kwargs: passed to scipy superclass ``scipy.io.netcdf.netcdf_file``
"""
assert kwargs.pop("mmap", False) is False
Expand All @@ -63,6 +67,7 @@ def __init__(
self.chunks = {}
self.threshold = inline_threshold
self.max_chunk_size = max_chunk_size
self.zarr_version = zarr_version
self.out = out or {}
self.storage_options = storage_options
self.fp = fsspec.open(filename, **(storage_options or {})).open()
Expand Down Expand Up @@ -164,10 +169,9 @@ def translate(self):
Parameters
----------
"""
import zarr

out = self.out
z = zarr.open(out, mode="w")
zroot, out = zarr_init_group_and_store(out, self.zarr_version)
for dim, var in self.variables.items():
if dim in self.chunks:
shape = self.chunks[dim][-1]
Expand All @@ -191,17 +195,25 @@ def translate(self):
fill = float(fill)
if fill is not None and var.data.dtype.kind == "i":
fill = int(fill)
arr = z.create_dataset(
arr = zroot.create_dataset(
name=dim,
shape=shape,
dtype=var.data.dtype,
fill_value=fill,
chunks=shape,
compression=None,
overwrite=True,
)
part = ".".join(["0"] * len(shape)) or "0"
k = f"{dim}/{part}"
out[k] = [

if self.zarr_version == 3:
part = "/".join(["0"] * len(shape)) or "0"
key = f"data/root/{dim}/c{part}"
else:
part = ".".join(["0"] * len(shape)) or "0"

key = f"{dim}/{part}"

self.out[key] = [self.filename] + [
self.filename,
int(self.chunks[dim][0]),
int(self.chunks[dim][1]),
Expand Down Expand Up @@ -245,13 +257,14 @@ def translate(self):
fill = float(fill)
if fill is not None and base.kind == "i":
fill = int(fill)
arr = z.create_dataset(
arr = zroot.create_dataset(
name=name,
shape=shape,
dtype=base,
fill_value=fill,
chunks=(1,) + dtype.shape,
compression=None,
overwrite=True,
)
arr.attrs.update(
{
Expand All @@ -266,18 +279,33 @@ def translate(self):

arr.attrs["_ARRAY_DIMENSIONS"] = list(var.dimensions)

suffix = (
("." + ".".join("0" for _ in dtype.shape)) if dtype.shape else ""
)
if self.zarr_version == 3:
suffix = (
("/" + "/".join("0" for _ in dtype.shape))
if dtype.shape
else ""
)
else:
suffix = (
("." + ".".join("0" for _ in dtype.shape))
if dtype.shape
else ""
)

for i in range(outer_shape):
out[f"{name}/{i}{suffix}"] = [
if self.zarr_version == 3:
key = f"data/root/{name}/c{i}{suffix}"
else:
key = f"{name}/{i}{suffix}"

self.out[key] = [
self.filename,
int(offset + i * dt.itemsize),
int(dtype.itemsize),
]

offset += dtype.itemsize
z.attrs.update(
zroot.attrs.update(
{
k: v.decode() if isinstance(v, bytes) else str(v)
for k, v in self._attributes.items()
Expand Down
12 changes: 8 additions & 4 deletions kerchunk/tests/test_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@
)


def test_one(m):
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_one(m, zarr_version):
m.pipe("data.nc3", bdata)
h = netCDF3.netcdf_recording_file("memory://data.nc3")
h = netCDF3.netcdf_recording_file("memory://data.nc3", zarr_version=zarr_version)
out = h.translate()
ds = xr.open_dataset(
"reference://",
engine="zarr",
backend_kwargs={
"consolidated": False,
"zarr_version": zarr_version,
"storage_options": {"fo": out, "remote_protocol": "memory"},
},
)
Expand Down Expand Up @@ -76,16 +78,18 @@ def unlimited_dataset(tmpdir):
return fn


def test_unlimited(unlimited_dataset):
@pytest.mark.parametrize("zarr_version", [2, 3])
def test_unlimited(unlimited_dataset, zarr_version):
fn = unlimited_dataset
expected = xr.open_dataset(fn, engine="scipy")
h = netCDF3.NetCDF3ToZarr(fn)
h = netCDF3.NetCDF3ToZarr(fn, zarr_version=zarr_version)
out = h.translate()
ds = xr.open_dataset(
"reference://",
engine="zarr",
backend_kwargs={
"consolidated": False,
"zarr_version": zarr_version,
"storage_options": {"fo": out},
},
)
Expand Down

0 comments on commit 653a260

Please sign in to comment.