Skip to content

Commit

Permalink
sqme: keep track of optional base fs in driver
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 25, 2024
1 parent a5887e7 commit f34ddca
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions odc/loader/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def _from_zarr_spec(
if fs is not None:
chunk_store = fs.get_mapper(target)

# TODO: deal with coordinates being loaded at open time.
#
# When chunk store is supplied xarray will try to load index coords (i.e.
# name == dim, coords)

md_store = fsspec.filesystem("reference", fo=spec_doc).get_mapper("")
xx = xr.open_dataset(
md_store,
Expand Down Expand Up @@ -173,21 +178,33 @@ def __init__(

def with_env(self, env: dict[str, Any]) -> "Context":
assert isinstance(env, dict)
return Context(self.geobox, self.chunks)
return Context(self.geobox, self.chunks, driver=self.driver)

@property
def fs(self) -> fsspec.AbstractFileSystem | None:
if self.driver is None:
return None
return self.driver.fs


class XrSource:
"""
RasterSource -> xr.DataArray|xr.Dataset
"""

def __init__(self, src: RasterSource, chunks: Any | None = None) -> None:
def __init__(
self,
src: RasterSource,
chunks: Any | None = None,
fs: fsspec.AbstractFileSystem | None = None,
) -> None:
driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data
self._spec: ZarrSpecFs | None = None
self._ds: xr.Dataset | None = None
self._xx: xr.DataArray | None = None
self._src = src
self._chunks = chunks
self._fs = fs

if isinstance(driver_data, xr.DataArray):
self._xx = driver_data
Expand All @@ -207,14 +224,22 @@ def __init__(self, src: RasterSource, chunks: Any | None = None) -> None:
def spec(self) -> ZarrSpecFs | None:
return self._spec

def base(self, regen_coords: bool = False) -> xr.Dataset | None:
def base(
self,
regen_coords: bool = False,
refresh: bool = False,
) -> xr.Dataset | None:
if refresh and self._spec:
self._ds = None

if self._ds is not None:
return self._ds
if self._spec is None:
return None
self._ds = _from_zarr_spec(
self._spec,
regen_coords=regen_coords,
fs=self._fs,
target=self._src.uri,
chunks=self._chunks,
)
Expand All @@ -223,11 +248,12 @@ def base(self, regen_coords: bool = False) -> xr.Dataset | None:
def resolve(
self,
regen_coords: bool = False,
refresh: bool = False,
) -> xr.DataArray:
if self._xx is not None:
return self._xx

src_ds = self.base(regen_coords=regen_coords)
src_ds = self.base(regen_coords=regen_coords, refresh=refresh)
if src_ds is None:
raise ValueError("Failed to interpret driver data")

Expand Down Expand Up @@ -264,8 +290,7 @@ class XrMemReader:
"""

def __init__(self, src: RasterSource, ctx: Context) -> None:
self._src = XrSource(src, chunks=None)
self._ctx = ctx
self._src = XrSource(src, chunks=None, fs=ctx.fs)

def read(
self,
Expand Down Expand Up @@ -306,8 +331,13 @@ def __init__(
layer_name: str = "",
idx: int = -1,
) -> None:
self._src = XrSource(src, chunks={}) if src is not None else None
self._ctx = ctx
if src is not None:
assert ctx is not None
_src = XrSource(src, chunks={}, fs=ctx.fs)
else:
_src = None

self._src = _src
self._layer_name = layer_name
self._idx = idx

Expand Down Expand Up @@ -354,13 +384,15 @@ def __init__(
self,
src: xr.Dataset | None = None,
template: RasterGroupMetadata | None = None,
fs: fsspec.AbstractFileSystem | None = None,
) -> None:
if src is not None and template is None:
template = raster_group_md(src)
if template is None:
template = RasterGroupMetadata({}, {}, {}, [])
self.src = src
self.template = template
self.fs = fs

def new_load(
self,
Expand Down

0 comments on commit f34ddca

Please sign in to comment.