Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

open_virtual_dataset with and without indexes #52

Merged
merged 16 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions virtualizarr/tests/test_xarray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Mapping

import numpy as np
import pytest
import xarray as xr
from xarray.core.indexes import Index

from virtualizarr import open_virtual_dataset
from virtualizarr.manifests import ChunkManifest, ManifestArray
from virtualizarr.zarr import ZArray

Expand Down Expand Up @@ -104,6 +109,7 @@ def test_concat_along_existing_dim(self):
ds2 = xr.Dataset({"a": (["x", "y"], marr2)})

result = xr.concat([ds1, ds2], dim="x")["a"]
assert result.indexes == {}

assert result.shape == (2, 20)
assert result.chunks == (1, 10)
Expand Down Expand Up @@ -150,6 +156,7 @@ def test_concat_along_new_dim(self):
ds2 = xr.Dataset({"a": (["x", "y"], marr2)})

result = xr.concat([ds1, ds2], dim="z")["a"]
assert result.indexes == {}

# xarray.concat adds new dimensions along axis=0
assert result.shape == (2, 5, 20)
Expand Down Expand Up @@ -201,6 +208,7 @@ def test_concat_dim_coords_along_existing_dim(self):
ds2 = xr.Dataset(coords=coords)

result = xr.concat([ds1, ds2], dim="t")["t"]
assert result.indexes == {}

assert result.shape == (40,)
assert result.chunks == (10,)
Expand All @@ -215,3 +223,45 @@ def test_concat_dim_coords_along_existing_dim(self):
assert result.data.zarray.fill_value == zarray.fill_value
assert result.data.zarray.order == zarray.order
assert result.data.zarray.zarr_format == zarray.zarr_format


@pytest.fixture
def netcdf4_file(tmpdir):
# Set up example xarray dataset
ds = xr.tutorial.open_dataset("air_temperature")

# Save it to disk as netCDF (in temporary directory)
filepath = f"{tmpdir}/air.nc"
ds.to_netcdf(filepath)

return filepath


class TestOpenVirtualDataseIndexes:
def test_no_indexes(self, netcdf4_file):
vds = open_virtual_dataset(netcdf4_file, indexes={})
assert vds.indexes == {}

def test_create_default_indexes(self, netcdf4_file):
vds = open_virtual_dataset(netcdf4_file, indexes=None)
ds = xr.open_dataset(netcdf4_file)
print(vds.indexes)
print(ds.indexes)
# TODO use xr.testing.assert_identical(vds.indexes, ds.indexes) instead once class supported by assertion comparison, see https://github.com/pydata/xarray/issues/5812
assert index_mappings_equal(vds.xindexes, ds.xindexes)


def index_mappings_equal(indexes1: Mapping[str, Index], indexes2: Mapping[str, Index]):
# Check if the mappings have the same keys
if set(indexes1.keys()) != set(indexes2.keys()):
return False

# Check if the values for each key are identical
for key in indexes1.keys():
index1 = indexes1[key]
index2 = indexes2[key]

if not index1.equals(index2):
return False

return True
63 changes: 38 additions & 25 deletions virtualizarr/xarray.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Literal, Optional, Union, overload
from typing import List, Literal, Mapping, Optional, Union, overload

import ujson # type: ignore
import xarray as xr
from xarray import register_dataset_accessor
from xarray.backends import BackendArray
from xarray.core.indexes import Index

import virtualizarr.kerchunk as kerchunk
from virtualizarr.kerchunk import KerchunkStoreRefs
Expand All @@ -20,8 +21,8 @@ def open_virtual_dataset(
filepath: str,
filetype: Optional[str] = None,
drop_variables: Optional[List[str]] = None,
indexes: Mapping[str, Index] | None = None,
virtual_array_class=ManifestArray,
indexes={},
) -> xr.Dataset:
"""
Open a file or store as an xarray Dataset wrapping virtualized zarr arrays.
Expand All @@ -38,27 +39,38 @@ def open_virtual_dataset(
If not provided will attempt to automatically infer the correct filetype from the the filepath's extension.
drop_variables: list[str], default is None
Variables in the file to drop before returning.
indexes : Mapping[str, Index], default is None
Default is None, which will read any 1D coordinate data to create in-memory Pandas indexes.
To avoid creating any indexes, pass indexes={}.
virtual_array_class
Virtual array class to use to represent the references to the chunks in each on-disk array.
Currently can only be ManifestArray, but once VirtualZarrArray is implemented the default should be changed to that.
"""

# this is the only place we actually always need to use kerchunk directly
ds_refs = kerchunk.read_kerchunk_references_from_file(
vds_refs = kerchunk.read_kerchunk_references_from_file(
filepath=filepath,
filetype=filetype,
)

ds = dataset_from_kerchunk_refs(
ds_refs,
if indexes is None:
# add default indexes by reading data from file
# TODO we are reading a bunch of stuff we know we won't need here, e.g. all of the data variables...
# TODO it would also be nice if we could somehow consolidate this with the reading of the kerchunk references
ds = xr.open_dataset(filepath)
indexes = ds.xindexes
ds.close()

vds = dataset_from_kerchunk_refs(
vds_refs,
drop_variables=drop_variables,
virtual_array_class=virtual_array_class,
indexes=indexes,
)

Comment on lines +54 to 72
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're literally opening the file twice here - once with kerchunk to read all the byte ranges, and then optionally once again to read in the values to use to create defaut pandas indexes with xarray.

Wondering if you have any thoughts on how hard it might be to consolidate these @jhamman ?

# TODO we should probably also use ds.set_close() to tell xarray how to close the file we opened
# TODO we should probably also use vds.set_close() to tell xarray how to close the file we opened

return ds
return vds


def dataset_from_kerchunk_refs(
Expand Down Expand Up @@ -86,14 +98,9 @@ def dataset_from_kerchunk_refs(

vars = {}
for var_name in var_names_to_keep:
# TODO abstract all this parsing into a function/method?
arr_refs = kerchunk.extract_array_refs(refs, var_name)
chunk_dict, zarray, zattrs = kerchunk.parse_array_refs(arr_refs)
manifest = ChunkManifest.from_kerchunk_chunk_dict(chunk_dict)
dims = zattrs["_ARRAY_DIMENSIONS"]

varr = virtual_array_class(zarray=zarray, chunkmanifest=manifest)
vars[var_name] = xr.Variable(data=varr, dims=dims, attrs=zattrs)
vars[var_name] = variable_from_kerchunk_refs(
refs, var_name, virtual_array_class
)

data_vars, coords = separate_coords(vars, indexes)

Expand All @@ -109,6 +116,20 @@ def dataset_from_kerchunk_refs(
return ds


def variable_from_kerchunk_refs(
refs: KerchunkStoreRefs, var_name: str, virtual_array_class
) -> xr.Variable:
"""Create a single xarray Variable by reading specific keys of a kerchunk references dict."""

arr_refs = kerchunk.extract_array_refs(refs, var_name)
chunk_dict, zarray, zattrs = kerchunk.parse_array_refs(arr_refs)
manifest = ChunkManifest.from_kerchunk_chunk_dict(chunk_dict)
dims = zattrs["_ARRAY_DIMENSIONS"]
varr = virtual_array_class(zarray=zarray, chunkmanifest=manifest)

return xr.Variable(data=varr, dims=dims, attrs=zattrs)


def separate_coords(
vars: dict[str, xr.Variable],
indexes={},
Expand All @@ -121,6 +142,7 @@ def separate_coords(

# this would normally come from CF decoding, let's hope the fact we're skipping that doesn't cause any problems...
coord_names: List[str] = []

# split data and coordinate variables (promote dimension coordinates)
data_vars = {}
coord_vars = {}
Expand All @@ -135,16 +157,7 @@ def separate_coords(
else:
data_vars[name] = var

# this is stolen from https://github.com/pydata/xarray/pull/8051
# needed otherwise xarray errors whilst trying to turn the KerchunkArrays for the 1D coordinate variables into indexes
# but it doesn't appear to work with `main` since #8107, which is why the workaround above is needed
# EDIT: actually even the workaround doesn't work - to avoid creating indexes I had to checkout xarray v2023.08.0, the last one before #8107 was merged
set_indexes = False
if set_indexes:
coords = coord_vars
else:
# explict Coordinates object with no index passed
coords = xr.Coordinates(coord_vars, indexes=indexes)
coords = xr.Coordinates(coord_vars, indexes=indexes)

return data_vars, coords

Expand Down
Loading