diff --git a/tests/models/test_feature_extractor.py b/tests/models/test_feature_extractor.py index cd33f0a5a..9ceb549be 100644 --- a/tests/models/test_feature_extractor.py +++ b/tests/models/test_feature_extractor.py @@ -115,3 +115,48 @@ def test_full_inference( # ! else the output values will not exactly be the same (still < 1.0e-4 # ! of epsilon though) assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1 + + +@pytest.mark.skipif( + toolbox_env.running_on_ci() or not ON_GPU, + reason="Local test on machine with GPU.", +) +def test_multi_gpu_feature_extraction(remote_sample: Callable, tmp_path: Path) -> None: + """Local functionality test for feature extraction using multiple GPUs.""" + save_dir = tmp_path / "output" + mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs")) + shutil.rmtree(save_dir, ignore_errors=True) + + # Use multiple GPUs + device = select_device(on_gpu=ON_GPU) + + wsi_ioconfig = IOSegmentorConfig( + input_resolutions=[{"units": "mpp", "resolution": 0.5}], + patch_input_shape=[224, 224], + output_resolutions=[{"units": "mpp", "resolution": 0.5}], + patch_output_shape=[224, 224], + stride_shape=[224, 224], + ) + + model = TimmBackbone(backbone="UNI", pretrained=True) + extractor = DeepFeatureExtractor( + model=model, + auto_generate_mask=True, + batch_size=32, + num_loader_workers=4, + num_postproc_workers=4, + ) + + output_list = extractor.predict( + [mini_wsi_svs], + mode="wsi", + device=device, + ioconfig=wsi_ioconfig, + crash_on_exception=True, + save_dir=save_dir, + ) + wsi_0_root_path = output_list[0][1] + positions = np.load(f"{wsi_0_root_path}.position.npy") + features = np.load(f"{wsi_0_root_path}.features.0.npy") + assert len(positions.shape) == 2 + assert len(features.shape) == 2 diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index fe0c3e02b..2befcf0b1 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -13,6 +13,7 @@ import joblib import numpy as np import torch +import torch.distributed as dist import torch.multiprocessing as torch_mp import torch.utils.data as torch_data import tqdm @@ -1421,6 +1422,14 @@ def predict( # noqa: PLR0913 logger.warning("Unable to remove %s", self._cache_dir) self._memory_cleanup() + from tiatoolbox.models.architecture.utils import is_torch_compile_compatible + + if ( + device == "cuda" + and torch.cuda.device_count() > 1 + and is_torch_compile_compatible() + ): + dist.destroy_process_group() return self._outputs diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index a8a8f7262..cf640e96f 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -2,11 +2,16 @@ from __future__ import annotations +import os from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable import torch import torch._dynamo +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + +from tiatoolbox.models.architecture.utils import is_torch_compile_compatible torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001 @@ -51,12 +56,32 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module: The model after being moved to specified device. """ - if device != "cpu": + torch_device = torch.device(device) + + # Use DDP if multiple GPUs and not on CPU + if ( + device == "cuda" + and torch.cuda.device_count() > 1 + and is_torch_compile_compatible() + ): + # This assumes a single-process DDP setup for inference + model = model.to(torch_device) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group( + backend="nccl", rank=0, world_size=torch.cuda.device_count() + ) + model = DistributedDataParallel(model, device_ids=[torch_device.index]) + + elif device != "cpu": # DataParallel work only for cuda model = torch.nn.DataParallel(model) + model = model.to(torch_device) - torch_device = torch.device(device) - return model.to(torch_device) + else: + model = model.to(torch_device) + + return model class ModelABC(ABC, torch.nn.Module):