Skip to content

Commit

Permalink
feat: Allowing to use zcollection without any dask cluster.
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Zilio committed Nov 20, 2024
1 parent 7faa803 commit d4662f8
Show file tree
Hide file tree
Showing 11 changed files with 711 additions and 333 deletions.
250 changes: 179 additions & 71 deletions zcollection/collection/__init__.py

Large diffs are not rendered by default.

85 changes: 61 additions & 24 deletions zcollection/collection/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def load(
filters: PartitionFilter = None,
indexer: Indexer | None = None,
selected_variables: Iterable[str] | None = None,
distributed: bool = True,
) -> dataset.Dataset | None:
"""Load the selected partitions.
Expand All @@ -564,6 +565,7 @@ def load(
indexer: The indexer to apply.
selected_variables: A list of variables to retain from the
collection. If None, all variables are kept.
distributed: Whether to use dask or not. Default To True.
Returns:
The dataset containing the selected partitions, or None if no
Expand All @@ -582,22 +584,42 @@ def load(
... filters=lambda keys: keys["year"] == 2019 and
... keys["month"] == 3 and keys["day"] % 2 == 0)
"""
client: dask.distributed.Client = dask_utils.get_client()
# Delayed has to be True of dask is disabled
if not distributed:
delayed = False

arrays: list[dataset.Dataset]
client: dask.distributed.Client

if indexer is None:
# No indexer, so the dataset is loaded directly for each
# selected partition.
selected_partitions = tuple(self.partitions(filters=filters))
if len(selected_partitions) == 0:
return None

# No indexer, so the dataset is loaded directly for each
# selected partition.
bag: dask.bag.core.Bag = dask.bag.core.from_sequence(
self.partitions(filters=filters),
npartitions=dask_utils.dask_workers(client, cores_only=True))
arrays = bag.map(storage.open_zarr_group,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables).compute()
partitions = self.partitions(filters=filters)

if distributed:
client = dask_utils.get_client()
bag: dask.bag.core.Bag = dask.bag.core.from_sequence(
partitions,
npartitions=dask_utils.dask_workers(client,
cores_only=True))
arrays = bag.map(
storage.open_zarr_group,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables).compute()
else:
arrays = [
storage.open_zarr_group(
dirname=partition,
delayed=delayed,
fs=self.fs,
selected_variables=selected_variables)
for partition in partitions
]
else:
# We're going to reuse the indexer variable, so ensure it is
# an iterable not a generator.
Expand All @@ -617,21 +639,36 @@ def load(
if len(args) == 0:
return None

bag = dask.bag.core.from_sequence(
args,
npartitions=dask_utils.dask_workers(client, cores_only=True))

# Finally, load the selected partitions and apply the indexer.
arrays = list(
itertools.chain.from_iterable(
bag.map(
_load_and_apply_indexer,
delayed=delayed,
fs=self.fs,
partition_handler=self.partitioning,
partition_properties=self.partition_properties,
selected_variables=selected_variables,
).compute()))
if distributed:
client = dask_utils.get_client()
bag = dask.bag.core.from_sequence(
args,
npartitions=dask_utils.dask_workers(client,
cores_only=True))

arrays = list(
itertools.chain.from_iterable(
bag.map(
_load_and_apply_indexer,
delayed=delayed,
fs=self.fs,
partition_handler=self.partitioning,
partition_properties=self.partition_properties,
selected_variables=selected_variables,
).compute()))
else:
arrays = list(
itertools.chain.from_iterable([
_load_and_apply_indexer(
args=a,
delayed=delayed,
fs=self.fs,
partition_handler=self.partitioning,
partition_properties=self.partition_properties,
selected_variables=selected_variables)
for a in args
]))

array: dataset.Dataset = arrays.pop(0)
if arrays:
Expand Down
13 changes: 10 additions & 3 deletions zcollection/collection/detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def _insert(
fs: fsspec.AbstractFileSystem,
merge_callable: merging.MergeCallable | None,
partitioning_properties: PartitioningProperties,
distributed: bool = True,
**kwargs,
) -> None:
"""Insert or update a partition in the collection.
Expand All @@ -405,6 +406,7 @@ def _insert(
fs: The file system that the partition is stored on.
merge_callable: The merge callable.
partitioning_properties: The partitioning properties.
distributed: Whether to use dask or not. Default To True.
**kwargs: Additional keyword arguments to pass to the merge callable.
"""
partition: tuple[str, ...]
Expand All @@ -423,7 +425,8 @@ def _insert(
axis,
fs,
partitioning_properties.dim,
delayed=zds.delayed,
delayed=zds.delayed if distributed else False,
distributed=distributed,
merge_callable=merge_callable,
**kwargs)
return
Expand All @@ -434,7 +437,11 @@ def _insert(
zarr.storage.init_group(store=fs.get_mapper(dirname))

# The synchronization is done by the caller.
write_zarr_group(zds.isel(indexer), dirname, fs, sync.NoSync())
write_zarr_group(zds.isel(indexer),
dirname,
fs,
sync.NoSync(),
distributed=distributed)
except: # noqa: E722
# If the construction of the new dataset fails, the created
# partition is deleted, to guarantee the integrity of the
Expand All @@ -459,7 +466,7 @@ def _load_and_apply_indexer(
fs: The file system that the partition is stored on.
partition_handler: The partitioning handler.
partition_properties: The partitioning properties.
selected_variable: The selected variables to load.
selected_variables: The selected variables to load.
Returns:
The list of loaded datasets.
Expand Down
Loading

0 comments on commit d4662f8

Please sign in to comment.