Skip to content

🐛 Fix Multi-GPU Support with torch.compile #923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,48 @@ def test_full_inference(
# ! else the output values will not exactly be the same (still < 1.0e-4
# ! of epsilon though)
assert np.mean(np.abs(features[:4] - _features)) < 1.0e-1


@pytest.mark.skipif(
toolbox_env.running_on_ci() or not ON_GPU,
reason="Local test on machine with GPU.",
)
def test_multi_gpu_feature_extraction(remote_sample: Callable, tmp_path: Path) -> None:
"""Local functionality test for feature extraction using multiple GPUs."""
save_dir = tmp_path / "output"
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
shutil.rmtree(save_dir, ignore_errors=True)

# Use multiple GPUs
device = select_device(on_gpu=ON_GPU)

wsi_ioconfig = IOSegmentorConfig(
input_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_input_shape=[224, 224],
output_resolutions=[{"units": "mpp", "resolution": 0.5}],
patch_output_shape=[224, 224],
stride_shape=[224, 224],
)

model = TimmBackbone(backbone="UNI", pretrained=True)
extractor = DeepFeatureExtractor(
model=model,
auto_generate_mask=True,
batch_size=32,
num_loader_workers=4,
num_postproc_workers=4,
)

output_list = extractor.predict(
[mini_wsi_svs],
mode="wsi",
device=device,
ioconfig=wsi_ioconfig,
crash_on_exception=True,
save_dir=save_dir,
)
wsi_0_root_path = output_list[0][1]
positions = np.load(f"{wsi_0_root_path}.position.npy")
features = np.load(f"{wsi_0_root_path}.features.0.npy")
assert len(positions.shape) == 2
assert len(features.shape) == 2
9 changes: 9 additions & 0 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import joblib
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as torch_mp
import torch.utils.data as torch_data
import tqdm
Expand Down Expand Up @@ -1421,6 +1422,14 @@ def predict( # noqa: PLR0913
logger.warning("Unable to remove %s", self._cache_dir)

self._memory_cleanup()
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

if (
device == "cuda"
and torch.cuda.device_count() > 1
and is_torch_compile_compatible()
): # pragma: no cover
dist.destroy_process_group()
Copy link
Preview

Copilot AI Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Destroying the process group without verifying initialization may error if no group exists. Add if dist.is_initialized(): before calling destroy_process_group().

Suggested change
dist.destroy_process_group()
if dist.is_initialized():
dist.destroy_process_group()

Copilot uses AI. Check for mistakes.


return self._outputs

Expand Down
29 changes: 26 additions & 3 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable

import torch
import torch._dynamo
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001

Expand Down Expand Up @@ -51,12 +56,30 @@ def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
The model after being moved to specified device.

"""
if device != "cpu":
torch_device = torch.device(device)

# Use DDP if multiple GPUs and not on CPU
if (
device == "cuda"
and torch.cuda.device_count() > 1
and is_torch_compile_compatible()
): # pragma: no cover
# This assumes a single-process DDP setup for inference
model = model.to(torch_device)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="nccl", rank=0, world_size=1)
model = DistributedDataParallel(model, device_ids=[torch_device.index])

elif device != "cpu":
# DataParallel work only for cuda
model = torch.nn.DataParallel(model)
model = model.to(torch_device)

torch_device = torch.device(device)
return model.to(torch_device)
else:
model = model.to(torch_device)

return model


class ModelABC(ABC, torch.nn.Module):
Expand Down