Skip to content

🐛 Fix Multi-GPU Support with torch.compile #923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
45 changes: 45 additions & 0 deletions tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,48 @@ def test_full_inference(
# ! else the output values will not exactly be the same (still < 1.0e-4
# ! of epsilon though)
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1


@pytest.mark.skipif(
toolbox_env.running_on_ci() or not ON_GPU,
reason="Local test on machine with GPU.",
)
def test_multi_gpu_feature_extraction(remote_sample: Callable, tmp_path: Path) -> None:
"""Local functionality test for feature extraction using multiple GPUs."""
save_dir = tmp_path / "output"
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
shutil.rmtree(save_dir, ignore_errors=True)

# Use multiple GPUs
device = select_device(on_gpu=ON_GPU)

wsi_ioconfig = IOSegmentorConfig(
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_input_shape=[224, 224],
output_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_output_shape=[224, 224],
stride_shape=[224, 224],
)

model = TimmBackbone(backbone="UNI", pretrained=True)
extractor = DeepFeatureExtractor(
model=model,
auto_generate_mask=True,
batch_size=32,
num_loader_workers=4,
num_postproc_workers=4,
)

output_list = extractor.predict(
[mini_wsi_svs],
mode="wsi",
device=device,
ioconfig=wsi_ioconfig,
crash_on_exception=True,
save_dir=save_dir,
)
wsi_0_root_path = output_list[0][1]
positions = np.load(f"{wsi_0_root_path}.position.npy")
features = np.load(f"{wsi_0_root_path}.features.0.npy")
assert len(positions.shape) == 2
assert len(features.shape) == 2
9 changes: 9 additions & 0 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import joblib
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as torch_mp
import torch.utils.data as torch_data
import tqdm
Expand Down Expand Up @@ -1421,6 +1422,14 @@ def predict( # noqa: PLR0913
logger.warning("Unable to remove %s", self._cache_dir)

self._memory_cleanup()
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

if (
device == "cuda"
and torch.cuda.device_count() > 1
and is_torch_compile_compatible()
):
dist.destroy_process_group()

return self._outputs

Expand Down
31 changes: 28 additions & 3 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable

import torch
import torch._dynamo
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001

Expand Down Expand Up @@ -51,12 +56,32 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
The model after being moved to specified device.

"""
if device != "cpu":
torch_device = torch.device(device)

# Use DDP if multiple GPUs and not on CPU
if (
device == "cuda"
and torch.cuda.device_count() > 1
and is_torch_compile_compatible()
):
# This assumes a single-process DDP setup for inference
model = model.to(torch_device)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(
backend="nccl", rank=0, world_size=torch.cuda.device_count()
)
model = DistributedDataParallel(model, device_ids=[torch_device.index])

elif device != "cpu":
# DataParallel work only for cuda
model = torch.nn.DataParallel(model)
model = model.to(torch_device)

torch_device = torch.device(device)
return model.to(torch_device)
else:
model = model.to(torch_device)

return model


class ModelABC(ABC, torch.nn.Module):
Expand Down
Loading