Skip to content

Commit 15595bf

Browse files
committed
enhance 3d-party devices in mix-precision
1 parent 06a8d5b commit 15595bf

File tree

25 files changed

+187
-24
lines changed

25 files changed

+187
-24
lines changed

docs/source-pytorch/extensions/accelerator.rst

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,45 +36,79 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc
3636

3737
.. code-block:: python
3838
39+
import torch
3940
import xpulib
4041
42+
from functools import lru_cache
43+
from typing import Any, Dict, Union
44+
from lightning.pytorch.accelerators.accelerator import Accelerator
45+
46+
from typing_extensions import override
47+
4148
4249
class XPUAccelerator(Accelerator):
4350
"""Support for a hypothetical XPU, optimized for large-scale machine learning."""
4451
52+
@override
53+
def setup_device(self, device: torch.device) -> None:
54+
"""
55+
Raises:
56+
ValueError:
57+
If the selected device is not of type hypothetical XPU.
58+
"""
59+
if device.type != "xpu":
60+
raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.")
61+
if device.index is None:
62+
device = torch.device("xpu", 0)
63+
xpulib.set_device(device.index)
64+
65+
@override
66+
def teardown(self) -> None:
67+
xpulib.empty_cache()
68+
4569
@staticmethod
70+
@override
4671
def parse_devices(devices: Any) -> Any:
4772
# Put parsing logic here how devices can be passed into the Trainer
4873
# via the `devices` argument
4974
return devices
5075
5176
@staticmethod
77+
@override
5278
def get_parallel_devices(devices: Any) -> Any:
5379
# Here, convert the device indices to actual device objects
5480
return [torch.device("xpu", idx) for idx in devices]
5581
5682
@staticmethod
83+
@override
5784
def auto_device_count() -> int:
5885
# Return a value for auto-device selection when `Trainer(devices="auto")`
5986
return xpulib.available_devices()
6087
6188
@staticmethod
89+
@override
6290
def is_available() -> bool:
6391
return xpulib.is_available()
6492
6593
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
6694
# Return optional device statistics for loggers
6795
return {}
6896
97+
@staticmethod
98+
@override
99+
def get_device() -> str:
100+
return "xpu"
101+
69102
70103
Finally, add the XPUAccelerator to the Trainer:
71104

72105
.. code-block:: python
73106
74107
from lightning.pytorch import Trainer
75-
108+
from lightning.pytorch.strategies import DDPStrategy
76109
accelerator = XPUAccelerator()
77-
trainer = Trainer(accelerator=accelerator, devices=2)
110+
strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2))
111+
trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2)
78112
79113
80114
:doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator.
@@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes,
93127
...
94128
95129
@classmethod
130+
@override
96131
def register_accelerators(cls, accelerator_registry):
97132
accelerator_registry.register(
98133
"xpu",

src/lightning/fabric/accelerators/accelerator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any:
4646
def get_parallel_devices(devices: Any) -> Any:
4747
"""Gets parallel devices for the Accelerator."""
4848

49+
@staticmethod
50+
@abstractmethod
51+
def get_device() -> Any:
52+
"""Get the device for the current Accelerator."""
53+
4954
@staticmethod
5055
@abstractmethod
5156
def auto_device_count() -> int:

src/lightning/fabric/accelerators/cpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi
4949
"""Gets parallel devices for the Accelerator."""
5050
devices = _parse_cpu_cores(devices)
5151
return [torch.device("cpu")] * devices
52+
53+
@staticmethod
54+
@override
55+
def get_device() -> str:
56+
return "cpu"
5257

5358
@staticmethod
5459
@override

src/lightning/fabric/accelerators/cuda.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]:
5555
"""Gets parallel devices for the Accelerator."""
5656
return [torch.device("cuda", i) for i in devices]
5757

58+
@staticmethod
59+
@override
60+
def get_device() -> str:
61+
return "cuda"
62+
5863
@staticmethod
5964
@override
6065
def auto_device_count() -> int:

src/lightning/fabric/accelerators/mps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi
6060
assert parsed_devices is not None
6161
return [torch.device("mps", i) for i in range(len(parsed_devices))]
6262

63+
@staticmethod
64+
@override
65+
def get_device() -> str:
66+
return "mps"
67+
6368
@staticmethod
6469
@override
6570
def auto_device_count() -> int:

src/lightning/fabric/accelerators/xla.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
6464
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
6565
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
6666

67+
@staticmethod
68+
@override
69+
def get_device() -> str:
70+
return "xla"
71+
6772
@staticmethod
6873
@override
6974
# XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it

src/lightning/fabric/connector.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def __init__(
141141
self._accelerator_flag = self._choose_auto_accelerator()
142142
elif self._accelerator_flag == "gpu":
143143
self._accelerator_flag = self._choose_gpu_accelerator_backend()
144+
elif isinstance(self._accelerator_flag, Accelerator):
145+
pass # for 3rd party accelerator, just do nothing
144146

145147
self._set_parallel_devices_and_init_accelerator()
146148

@@ -461,7 +463,10 @@ def _check_and_init_precision(self) -> Precision:
461463
if isinstance(self.strategy, DeepSpeedStrategy):
462464
return DeepSpeedPrecision(self._precision_input) # type: ignore
463465
if isinstance(self.strategy, FSDPStrategy):
464-
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
466+
return FSDPPrecision(
467+
precision=self._precision_input, # type: ignore[arg-type]
468+
device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None,
469+
)
465470
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
466471
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
467472
raise ValueError(
@@ -493,6 +498,8 @@ def _check_and_init_precision(self) -> Precision:
493498
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494499
)
495500
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
501+
if isinstance(self._accelerator_flag, Accelerator):
502+
device = self._accelerator_flag.get_device()
496503
return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type]
497504

498505
raise RuntimeError("No precision set")

src/lightning/fabric/plugins/precision/amp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,15 @@ def __init__(
5050

5151
self.precision = precision
5252
if scaler is None and self.precision == "16-mixed":
53-
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
53+
scaler = (
54+
torch.amp.GradScaler(device=device)
55+
if _TORCH_GREATER_EQUAL_2_4
56+
else getattr(
57+
torch,
58+
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu"
59+
else device.split(":")[0]
60+
).amp.GradScaler()
61+
)
5462
if scaler is not None and self.precision == "bf16-mixed":
5563
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5664
self.device = device

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,16 @@ class FSDPPrecision(Precision):
4848
4949
"""
5050

51-
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
51+
def __init__(
52+
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None
53+
) -> None:
5254
supported_precision = get_args(_PRECISION_INPUT)
5355
if precision not in supported_precision:
5456
raise ValueError(
5557
f"`precision={precision!r})` is not supported in FSDP."
5658
f" `precision` must be one of: {supported_precision}."
5759
)
60+
self.device = device if device is not None else "cuda"
5861

5962
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
6063

@@ -110,7 +113,9 @@ def module_init_context(self) -> ContextManager:
110113
@override
111114
def forward_context(self) -> ContextManager:
112115
if "mixed" in self.precision:
113-
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
116+
return torch.autocast(
117+
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
118+
)
114119
return self.tensor_init_context()
115120

116121
@override

src/lightning/fabric/strategies/ddp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,13 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
124124
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
125125
device_ids = self._determine_ddp_device_ids()
126126
# https://pytorch.org/docs/stable/notes/cuda.html#id5
127-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
127+
ctx = (
128+
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
129+
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
130+
)
131+
if device_ids is not None
132+
else nullcontext()
133+
)
128134
with ctx:
129135
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
130136

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,9 @@ def load_checkpoint(
506506

507507
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
508508

509-
torch.cuda.empty_cache()
509+
getattr(
510+
torch, f"{self.root_device.type.split(':')[0]}"
511+
).empty_cache() if self.accelerator.get_device() != "cpu" else None # type: ignore[union-attr]
510512
_, client_state = engine.load_checkpoint(
511513
path,
512514
tag="checkpoint",
@@ -616,10 +618,12 @@ def _initialize_engine(
616618

617619
@override
618620
def setup_environment(self) -> None:
619-
if not isinstance(self.accelerator, CUDAAccelerator):
621+
from deepspeed.runtime.utils import get_accelerator
622+
if (not isinstance(self.accelerator, CUDAAccelerator)) and \
623+
self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr]
620624
raise RuntimeError(
621-
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
622-
" is used."
625+
f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs,"
626+
f"but `{self.accelerator.__class__.__name__}` is used."
623627
)
624628
super().setup_environment()
625629

src/lightning/fabric/strategies/strategy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ def load_checkpoint(
325325
given, the full checkpoint will be returned.
326326
327327
"""
328-
torch.cuda.empty_cache()
328+
getattr(
329+
torch, f"{self.root_device.type.split(':')[0]}"
330+
).empty_cache() if self.root_device.type != "cpu" else None
329331
checkpoint = self.checkpoint_io.load_checkpoint(path)
330332
if not state:
331333
return checkpoint

src/lightning/pytorch/accelerators/accelerator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
4545
4646
"""
4747
raise NotImplementedError
48+
49+
@staticmethod
50+
def get_device() -> str:
51+
"""Get the device for the current process."""
52+
raise NotImplementedError

src/lightning/pytorch/accelerators/cpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8080
description=cls.__name__,
8181
)
8282

83+
@staticmethod
84+
@override
85+
def get_device() -> str:
86+
return "cpu"
87+
8388

8489
# CPU device metrics
8590
_CPU_VM_PERCENT = "cpu_vm_percent"

src/lightning/pytorch/accelerators/cuda.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
113113
description=cls.__name__,
114114
)
115115

116+
@staticmethod
117+
@override
118+
def get_device() -> str:
119+
return "cuda"
120+
116121

117122
def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
118123
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

src/lightning/pytorch/accelerators/mps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
8787
description=cls.__name__,
8888
)
8989

90+
@staticmethod
91+
@override
92+
def get_device() -> str:
93+
return "mps"
94+
9095

9196
# device metrics
9297
_VM_PERCENT = "M1_vm_percent"

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,15 @@ def __init__(
5050

5151
self.precision = precision
5252
if scaler is None and self.precision == "16-mixed":
53-
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
53+
scaler = (
54+
torch.amp.GradScaler(device=device)
55+
if _TORCH_GREATER_EQUAL_2_4
56+
else getattr(
57+
torch,
58+
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu"
59+
else device.split(":")[0]
60+
).amp.GradScaler()
61+
)
5462
if scaler is not None and self.precision == "bf16-mixed":
5563
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5664
self.device = device

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ class FSDPPrecision(Precision):
4747
4848
"""
4949

50-
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
50+
def __init__(
51+
self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None
52+
) -> None:
5153
supported_precision = get_args(_PRECISION_INPUT)
5254
if precision not in supported_precision:
5355
raise ValueError(
5456
f"`precision={precision!r})` is not supported in FSDP."
5557
f" `precision` must be one of: {supported_precision}."
5658
)
59+
self.device = device if device is not None else "cuda"
5760

5861
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
5962

@@ -119,7 +122,9 @@ def module_init_context(self) -> ContextManager:
119122
@override
120123
def forward_context(self) -> ContextManager:
121124
if "mixed" in self.precision:
122-
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
125+
return torch.autocast(
126+
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)
127+
)
123128
return _DtypeContextManager(self._desired_input_dtype)
124129

125130
@override

src/lightning/pytorch/strategies/ddp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,13 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
190190
device_ids = self.determine_ddp_device_ids()
191191
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
192192
# https://pytorch.org/docs/stable/notes/cuda.html#id5
193-
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
193+
ctx = (
194+
getattr(torch, f"{self.root_device.type.split(':')[0]}").stream(
195+
getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream()
196+
)
197+
if device_ids is not None
198+
else nullcontext()
199+
)
194200
with ctx:
195201
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
196202

0 commit comments

Comments
 (0)