Skip to content

Commit

Permalink
Merge pull request #50 from HiPCTProject/joblib
Browse files Browse the repository at this point in the history
Use joblib for parallel processing
  • Loading branch information
dstansby authored Aug 28, 2024
2 parents efcc2dc + b12150c commit d9fc9ae
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ classifiers = [
]
dependencies = [
"dask==2024.6.2",
"joblib==1.4.2",
"loguru==0.7.2",
"numpy==1.26.4",
"scikit-image==0.24.0",
"tqdm",
"zarr==2.18.2",
]
description = "Convert stacks of images to chunked datasets"
Expand Down
3 changes: 3 additions & 0 deletions src/stack_to_chunk/_array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import numpy as np
import skimage.measure
import zarr
from joblib import delayed
from loguru import logger


@delayed # type: ignore[misc]
def _copy_slab(arr_zarr: zarr.Array, slab: da.Array, zstart: int, zend: int) -> None:
"""
Copy a single slab of data to a zarr array.
Expand All @@ -31,6 +33,7 @@ def _copy_slab(arr_zarr: zarr.Array, slab: da.Array, zstart: int, zend: int) ->
logger.info(f"Finished copying z={zstart} -> {zend-1}")


@delayed # type: ignore[misc]
def _downsample_block(
arr_in: zarr.Array, arr_out: zarr.Array, block_idx: tuple[int, int, int]
) -> None:
Expand Down
25 changes: 5 additions & 20 deletions src/stack_to_chunk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
Main code for converting stacks to chunks.
"""

from multiprocessing import Pool
from pathlib import Path
from typing import Any, Literal

import numpy as np
import tqdm
import zarr
from dask.array.core import Array
from joblib import Parallel
from loguru import logger
from numcodecs import blosc
from numcodecs.abc import Codec
Expand Down Expand Up @@ -225,17 +224,8 @@ def add_full_res_data(
blosc_use_threads = blosc.use_threads
blosc.use_threads = 0

# Use try/finally pattern to allow code coverage to be collected
if n_processes == 1:
for args in all_args:
_copy_slab(*args)
else:
p = Pool(n_processes)
try:
p.starmap(_copy_slab, all_args)
finally:
p.close()
p.join()
jobs = [_copy_slab(*args) for args in all_args]
Parallel(n_jobs=n_processes)(jobs)

blosc.use_threads = blosc_use_threads
logger.info("Finished full resolution copy to zarr.")
Expand Down Expand Up @@ -299,13 +289,8 @@ def add_downsample_level(self, level: int, *, n_processes: int) -> None:
blosc_use_threads = blosc.use_threads
blosc.use_threads = 0

# Use try/finally pattern to allow code coverage to be collected
p = Pool(n_processes)
try:
p.starmap(_downsample_block, tqdm.tqdm(all_args, total=len(all_args)))
finally:
p.close()
p.join()
jobs = [_downsample_block(*args) for args in all_args]
Parallel(n_jobs=n_processes, verbose=10)(jobs)

self._add_level_metadata(level)
blosc.use_threads = blosc_use_threads
Expand Down

0 comments on commit d9fc9ae

Please sign in to comment.