Skip to content

[Backend]: Support device backend registration for a wide range of third-party hardware #20349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,79 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc

.. code-block:: python

import torch
import xpulib

from functools import lru_cache
from typing import Any, Dict, Union
from lightning.pytorch.accelerators.accelerator import Accelerator

from typing_extensions import override


class XPUAccelerator(Accelerator):
"""Support for a hypothetical XPU, optimized for large-scale machine learning."""

@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
ValueError:
If the selected device is not of type hypothetical XPU.
"""
if device.type != "xpu":
raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.")
if device.index is None:
device = torch.device("xpu", 0)
xpulib.set_device(device.index)

@override
def teardown(self) -> None:
xpulib.empty_cache()

@staticmethod
@override
def parse_devices(devices: Any) -> Any:
# Put parsing logic here how devices can be passed into the Trainer
# via the `devices` argument
return devices

@staticmethod
@override
def get_parallel_devices(devices: Any) -> Any:
# Here, convert the device indices to actual device objects
return [torch.device("xpu", idx) for idx in devices]

@staticmethod
@override
def auto_device_count() -> int:
# Return a value for auto-device selection when `Trainer(devices="auto")`
return xpulib.available_devices()

@staticmethod
@override
def is_available() -> bool:
return xpulib.is_available()

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
# Return optional device statistics for loggers
return {}

@staticmethod
@override
def get_device_type() -> str:
return "xpu"


Finally, add the XPUAccelerator to the Trainer:

.. code-block:: python

from lightning.pytorch import Trainer

from lightning.pytorch.strategies import DDPStrategy
accelerator = XPUAccelerator()
trainer = Trainer(accelerator=accelerator, devices=2)
strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2))
trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2)


:doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator.
Expand All @@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes,
...

@classmethod
@override
def register_accelerators(cls, accelerator_registry):
accelerator_registry.register(
"xpu",
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any:
def get_parallel_devices(devices: Any) -> Any:
"""Gets parallel devices for the Accelerator."""

@staticmethod
@abstractmethod
def get_device_type() -> Any:
"""Get the device_type for the current Accelerator."""

@staticmethod
@abstractmethod
def auto_device_count() -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]:
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices

@staticmethod
@override
def get_device_type() -> str:
return "cpu"

@staticmethod
@override
def auto_device_count() -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def get_parallel_devices(devices: list[int]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

@staticmethod
@override
def get_device_type() -> str:
return "cuda"

@staticmethod
@override
def auto_device_count() -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.devi
assert parsed_devices is not None
return [torch.device("mps", i) for i in range(len(parsed_devices))]

@staticmethod
@override
def get_device_type() -> str:
return "mps"

@staticmethod
@override
def auto_device_count() -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy

@staticmethod
@override
def get_device_type() -> str:
return "xla"

@staticmethod
@override
# XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it
Expand Down
13 changes: 12 additions & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(
self._accelerator_flag = self._choose_auto_accelerator()
elif self._accelerator_flag == "gpu":
self._accelerator_flag = self._choose_gpu_accelerator_backend()
elif isinstance(self._accelerator_flag, Accelerator):
pass # for 3rd party accelerator, just do nothing

self._set_parallel_devices_and_init_accelerator()

Expand Down Expand Up @@ -461,7 +463,12 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
return FSDPPrecision(
precision=self._precision_input, # type: ignore[arg-type]
device_type=self._accelerator_flag.get_device_type()
if isinstance(self._accelerator_flag, Accelerator)
else None,
)
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
raise ValueError(
Expand Down Expand Up @@ -492,7 +499,11 @@ def _check_and_init_precision(self) -> Precision:
if self._precision_input == "16-mixed"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)

device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
if isinstance(self._accelerator_flag, Accelerator):
device = self._accelerator_flag.get_device_type()

return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]

raise RuntimeError("No precision set")
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
scaler = (
torch.amp.GradScaler(device=device)
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
12 changes: 10 additions & 2 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,19 @@ class FSDPPrecision(Precision):

"""

def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
def __init__(
self,
precision: _PRECISION_INPUT,
scaler: Optional["ShardedGradScaler"] = None,
device_type: Optional[str] = None,
) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in FSDP."
f" `precision` must be one of: {supported_precision}."
)
self.device_type = device_type if device_type is not None else "cuda"

from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

Expand Down Expand Up @@ -117,7 +123,9 @@ def module_init_context(self) -> AbstractContextManager:
@override
def forward_context(self) -> AbstractContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return torch.autocast(
self.device_type, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
)
return self.tensor_init_context()

@override
Expand Down
14 changes: 13 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
device_ids = self._determine_ddp_device_ids()
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
ctx = self._create_stream_context(device_ids=device_ids)
with ctx:
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)

Expand Down Expand Up @@ -228,6 +228,18 @@ def _set_world_ranks(self) -> None:
def _determine_ddp_device_ids(self) -> Optional[list[int]]:
return None if self.root_device.type == "cpu" else [self.root_device.index]

def _create_stream_context(self, device_ids=None):
"""Create a stream context for the current device, if supported."""

torch_lib = getattr(torch, self.root_device.type)
# Check if the device type supports streams and has the necessary attributes.
if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None:
stream = torch_lib.Stream()
ctx = torch_lib.stream(stream)
else:
ctx = nullcontext()
return ctx


class _DDPBackwardSyncControl(_BackwardSyncControl):
@override
Expand Down
20 changes: 16 additions & 4 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ def __init__(

self._deepspeed_engine: Optional[DeepSpeedEngine] = None

if isinstance(self.accelerator, Accelerator):
self.device_type = self.accelerator.get_device_type()
else:
self.device_type = "cuda"
self.torch_lib = getattr(torch, self.device_type)

@property
def zero_stage_3(self) -> bool:
assert isinstance(self.config, dict)
Expand Down Expand Up @@ -510,7 +516,9 @@ def load_checkpoint(

optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())

torch.cuda.empty_cache()
if hasattr(torch, self.device_type) and callable(self.torch_lib.empty_cache):
self.torch_lib.empty_cache()

_, client_state = engine.load_checkpoint(
path,
tag="checkpoint",
Expand Down Expand Up @@ -620,10 +628,14 @@ def _initialize_engine(

@override
def setup_environment(self) -> None:
if not isinstance(self.accelerator, CUDAAccelerator):
from deepspeed.runtime.utils import get_accelerator

if (
not isinstance(self.accelerator, CUDAAccelerator)
) and self.accelerator.get_device_type() != get_accelerator().device_name(): # type: ignore[union-attr]
raise RuntimeError(
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
" is used."
f"The DeepSpeed strategy is only supported on {get_accelerator().device_name().upper()} GPUs, "
f"but `{self.accelerator.__class__.__name__}` is used."
)
super().setup_environment()

Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def load_checkpoint(
given, the full checkpoint will be returned.

"""
torch.cuda.empty_cache()
if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device_type() != "cpu":
getattr(torch, self.root_device.type.split(":")[0]).empty_cache()
else:
torch.cuda.empty_cache()
checkpoint = self.checkpoint_io.load_checkpoint(path)
if not state:
return checkpoint
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:

"""
raise NotImplementedError

@staticmethod
def get_device_type() -> str:
"""Get the device for the current process."""
raise NotImplementedError
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device_type() -> str:
return "cpu"


# CPU device metrics
_CPU_VM_PERCENT = "cpu_vm_percent"
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device_type() -> str:
return "cuda"


def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
description=cls.__name__,
)

@staticmethod
@override
def get_device_type() -> str:
return "mps"


# device metrics
_VM_PERCENT = "M1_vm_percent"
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def __init__(

self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
scaler = (
torch.amp.GradScaler(device=device)
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
Expand Down
Loading
Loading