Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use joblib for parallel processing #50

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ classifiers = [
]
dependencies = [
"dask==2024.6.2",
"joblib==1.4.2",
"loguru==0.7.2",
"numpy==1.26.4",
"scikit-image==0.24.0",
"tqdm",
"zarr==2.18.2",
]
description = "Convert stacks of images to chunked datasets"
Expand Down
3 changes: 3 additions & 0 deletions src/stack_to_chunk/_array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import numpy as np
import skimage.measure
import zarr
from joblib import delayed
from loguru import logger


@delayed # type: ignore[misc]
def _copy_slab(arr_zarr: zarr.Array, slab: da.Array, zstart: int, zend: int) -> None:
"""
Copy a single slab of data to a zarr array.
Expand All @@ -31,6 +33,7 @@ def _copy_slab(arr_zarr: zarr.Array, slab: da.Array, zstart: int, zend: int) ->
logger.info(f"Finished copying z={zstart} -> {zend-1}")


@delayed # type: ignore[misc]
def _downsample_block(
arr_in: zarr.Array, arr_out: zarr.Array, block_idx: tuple[int, int, int]
) -> None:
Expand Down
25 changes: 5 additions & 20 deletions src/stack_to_chunk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
Main code for converting stacks to chunks.
"""

from multiprocessing import Pool
from pathlib import Path
from typing import Any, Literal

import numpy as np
import tqdm
import zarr
from dask.array.core import Array
from joblib import Parallel
from loguru import logger
from numcodecs import blosc
from numcodecs.abc import Codec
Expand Down Expand Up @@ -225,17 +224,8 @@ def add_full_res_data(
blosc_use_threads = blosc.use_threads
blosc.use_threads = 0

# Use try/finally pattern to allow code coverage to be collected
if n_processes == 1:
for args in all_args:
_copy_slab(*args)
else:
p = Pool(n_processes)
try:
p.starmap(_copy_slab, all_args)
finally:
p.close()
p.join()
jobs = [_copy_slab(*args) for args in all_args]
Parallel(n_jobs=n_processes)(jobs)

blosc.use_threads = blosc_use_threads
logger.info("Finished full resolution copy to zarr.")
Expand Down Expand Up @@ -299,13 +289,8 @@ def add_downsample_level(self, level: int, *, n_processes: int) -> None:
blosc_use_threads = blosc.use_threads
blosc.use_threads = 0

# Use try/finally pattern to allow code coverage to be collected
p = Pool(n_processes)
try:
p.starmap(_downsample_block, tqdm.tqdm(all_args, total=len(all_args)))
finally:
p.close()
p.join()
jobs = [_downsample_block(*args) for args in all_args]
Parallel(n_jobs=n_processes, verbose=10)(jobs)

self._add_level_metadata(level)
blosc.use_threads = blosc_use_threads
Expand Down
Loading