diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index 9b2edcd9..63d62490 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -742,7 +742,6 @@ def _resolve_sharding(self, sharding_plan): local_hf = h5py.File(local_output, "r") local_hdf_labels = local_hf.get(self.DEFAULT_MASK_NAME)[:] - print(type(local_hdf_labels)) shifted_map, edge_labels = shift_labels( local_hdf_labels, class_id_shift, @@ -902,8 +901,9 @@ def _resolve_sharding(self, sharding_plan): if not self.deep_debug: self._cleanup_shards(sharding_plan) - def _initializer_function(self, gpu_id_list): + def _initializer_function(self, gpu_id_list, n_processes): current_process().gpu_id_list = gpu_id_list + current_process().n_processes = n_processes def _perform_segmentation(self, shard_list): # get GPU status @@ -921,7 +921,7 @@ def _perform_segmentation(self, shard_list): with mp.get_context(self.context).Pool( processes=self.n_processes, initializer=self._initializer_function, - initargs=[self.gpu_id_list], + initargs=[self.gpu_id_list, self.n_processes], ) as pool: list( tqdm( diff --git a/src/scportrait/pipeline/segmentation/workflows.py b/src/scportrait/pipeline/segmentation/workflows.py index 94ed8c44..2ad5d0cb 100644 --- a/src/scportrait/pipeline/segmentation/workflows.py +++ b/src/scportrait/pipeline/segmentation/workflows.py @@ -15,6 +15,7 @@ from skimage.filters import median from skimage.morphology import binary_erosion, dilation, disk, erosion from skimage.segmentation import watershed +import _multiprocessing from scportrait.pipeline._utils.segmentation import ( contact_filter, @@ -1353,6 +1354,9 @@ def _check_gpu_status(self): gpu_id_list = current.gpu_id_list cpu_id = int(cpu_name[cpu_name.find("-") + 1 :]) - 1 + if cpu_id >= len(gpu_id_list): + cpu_id = cpu_id%current.n_processes + # track gpu_id and update GPU status self.gpu_id = gpu_id_list[cpu_id] self.status = "multi_GPU"