Skip to content

Commit

Permalink
Tweak Dask based Reader interface
Browse files Browse the repository at this point in the history
`.open` now gets an extra argument `idx` that
uniquely identifies `RasterSource` for a given
load context. This can be used to label file open
stages in Dask graph.
  • Loading branch information
Kirill888 committed Jun 18, 2024
1 parent f00b769 commit dbf97a1
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _task_futures(
keys_out: list[Key] = []
for i_src, src in enumerate(layer):
idx = (i_time, *task.idx[1:], i_src)
rdr = dask_reader.open(src, ctx, layer_name=layer_name)
rdr = dask_reader.open(src, ctx, layer_name=layer_name, idx=i_src)
fut = rdr.read(cfg, dst_gbox, selection=task.selection, idx=idx)
keys_out.append(fut.key)
dsk.update(fut.dask)
Expand Down
7 changes: 6 additions & 1 deletion odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
ctx: Any | None = None,
src: RasterSource | None = None,
layer_name: str = "",
idx: int = -1,
) -> None:
if env is None:
env = driver.capture_env()
Expand All @@ -63,6 +64,7 @@ def __init__(
self._ctx = ctx
self._src = src
self._layer_name = layer_name
self._src_idx = idx

def read(
self,
Expand All @@ -88,13 +90,16 @@ def read(
dask_key_name=(self._layer_name, *idx),
)

def open(self, src: RasterSource, ctx: Any, layer_name: str) -> "ReaderDaskAdaptor":
def open(
self, src: RasterSource, ctx: Any, layer_name: str, idx: int
) -> "ReaderDaskAdaptor":
return ReaderDaskAdaptor(
self._driver,
self._env,
ctx,
src,
layer_name=layer_name,
idx=idx,
)


Expand Down
2 changes: 1 addition & 1 deletion odc/loader/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_dask_reader_adaptor(dtype: str):
ctx = base_driver.new_load(gbox, chunks={"x": 64, "y": 64})

src = RasterSource("mem://", meta=meta)
rdr = driver.open(src, ctx, layer_name="aa")
rdr = driver.open(src, ctx, layer_name="aa", idx=0)

assert isinstance(rdr, ReaderDaskAdaptor)

Expand Down
1 change: 1 addition & 0 deletions odc/loader/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def open(
ctx: Any,
*,
layer_name: str,
idx: int,
) -> "DaskRasterReader": ...


Expand Down

0 comments on commit dbf97a1

Please sign in to comment.