diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index 93dc467b02921..4ea3b639600a9 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -36,29 +36,57 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc .. code-block:: python + import torch import xpulib + from functools import lru_cache + from typing import Any, Dict, Union + from lightning.pytorch.accelerators.accelerator import Accelerator + + from typing_extensions import override + class XPUAccelerator(Accelerator): """Support for a hypothetical XPU, optimized for large-scale machine learning.""" + @override + def setup_device(self, device: torch.device) -> None: + """ + Raises: + ValueError: + If the selected device is not of type hypothetical XPU. + """ + if device.type != "xpu": + raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.") + if device.index is None: + device = torch.device("xpu", 0) + xpulib.set_device(device.index) + + @override + def teardown(self) -> None: + xpulib.empty_cache() + @staticmethod + @override def parse_devices(devices: Any) -> Any: # Put parsing logic here how devices can be passed into the Trainer # via the `devices` argument return devices @staticmethod + @override def get_parallel_devices(devices: Any) -> Any: # Here, convert the device indices to actual device objects return [torch.device("xpu", idx) for idx in devices] @staticmethod + @override def auto_device_count() -> int: # Return a value for auto-device selection when `Trainer(devices="auto")` return xpulib.available_devices() @staticmethod + @override def is_available() -> bool: return xpulib.is_available() @@ -66,15 +94,21 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc # Return optional device statistics for loggers return {} + @staticmethod + @override + def get_device_type() -> str: + return "xpu" + Finally, add the XPUAccelerator to the Trainer: .. code-block:: python from lightning.pytorch import Trainer - + from lightning.pytorch.strategies import DDPStrategy accelerator = XPUAccelerator() - trainer = Trainer(accelerator=accelerator, devices=2) + strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2)) + trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2) :doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator. @@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes, ... @classmethod + @override def register_accelerators(cls, accelerator_registry): accelerator_registry.register( "xpu", diff --git a/src/lightning/fabric/accelerators/accelerator.py b/src/lightning/fabric/accelerators/accelerator.py index 3a8aa85ad041d..1017fd5c368ba 100644 --- a/src/lightning/fabric/accelerators/accelerator.py +++ b/src/lightning/fabric/accelerators/accelerator.py @@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any: def get_parallel_devices(devices: Any) -> Any: """Gets parallel devices for the Accelerator.""" + @staticmethod + @abstractmethod + def get_device_type() -> Any: + """Get the device_type for the current Accelerator.""" + @staticmethod @abstractmethod def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 2997d1ada3352..ba9ff07c85d09 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -50,6 +50,11 @@ def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices + @staticmethod + @override + def get_device_type() -> str: + return "cpu" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 5b8a4c2f80bed..d09d0a3fc0097 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -55,6 +55,11 @@ def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] + @staticmethod + @override + def get_device_type() -> str: + return "cuda" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index b535ba57ed4cb..f8a4b68543dee 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -60,6 +60,11 @@ def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.devi assert parsed_devices is not None return [torch.device("mps", i) for i in range(len(parsed_devices))] + @staticmethod + @override + def get_device_type() -> str: + return "mps" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index d438197329939..6a74e207edaa3 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -64,6 +64,11 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`. # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy + @staticmethod + @override + def get_device_type() -> str: + return "xla" + @staticmethod @override # XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index b3289debbd522..2974c48c6fc19 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -142,6 +142,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # for 3rd party accelerator, just do nothing self._set_parallel_devices_and_init_accelerator() @@ -461,7 +463,12 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type] + return FSDPPrecision( + precision=self._precision_input, # type: ignore[arg-type] + device_type=self._accelerator_flag.get_device_type() + if isinstance(self._accelerator_flag, Accelerator) + else None, + ) mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: raise ValueError( @@ -492,7 +499,11 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) + device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device_type() + return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index d5fc1f0c1cc2a..d89ff1597e4bc 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -51,7 +51,14 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = ( + torch.amp.GradScaler(device=device) + if _TORCH_GREATER_EQUAL_2_4 + else getattr( + torch, + "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0], + ).amp.GradScaler() + ) if scaler is not None and self.precision == "bf16-mixed": raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 270a67e3a2338..3f1cbe8fa3e8b 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -49,13 +49,19 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, + precision: _PRECISION_INPUT, + scaler: Optional["ShardedGradScaler"] = None, + device_type: Optional[str] = None, + ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device_type = device_type if device_type is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -117,7 +123,9 @@ def module_init_context(self) -> AbstractContextManager: @override def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast( + self.device_type, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + ) return self.tensor_init_context() @override diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index ce47e4e403c34..9faa07b8b2f56 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,7 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = self._create_stream_context(device_ids=device_ids) with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) @@ -228,6 +228,18 @@ def _set_world_ranks(self) -> None: def _determine_ddp_device_ids(self) -> Optional[list[int]]: return None if self.root_device.type == "cpu" else [self.root_device.index] + def _create_stream_context(self, device_ids=None): + """Create a stream context for the current device, if supported.""" + + torch_lib = getattr(torch, self.root_device.type) + # Check if the device type supports streams and has the necessary attributes. + if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None: + stream = torch_lib.Stream() + ctx = torch_lib.stream(stream) + else: + ctx = nullcontext() + return ctx + class _DDPBackwardSyncControl(_BackwardSyncControl): @override diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 41820c1cc433f..355e55b576894 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -299,6 +299,12 @@ def __init__( self._deepspeed_engine: Optional[DeepSpeedEngine] = None + if isinstance(self.accelerator, Accelerator): + self.device_type = self.accelerator.get_device_type() + else: + self.device_type = "cuda" + self.torch_lib = getattr(torch, self.device_type) + @property def zero_stage_3(self) -> bool: assert isinstance(self.config, dict) @@ -510,7 +516,9 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) - torch.cuda.empty_cache() + if hasattr(torch, self.device_type) and callable(self.torch_lib.empty_cache): + self.torch_lib.empty_cache() + _, client_state = engine.load_checkpoint( path, tag="checkpoint", @@ -620,10 +628,14 @@ def _initialize_engine( @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + from deepspeed.runtime.utils import get_accelerator + + if ( + not isinstance(self.accelerator, CUDAAccelerator) + ) and self.accelerator.get_device_type() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + f"The DeepSpeed strategy is only supported on {get_accelerator().device_name().upper()} GPUs, " + f"but `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index e0100bb148dd3..17e589ccb9d71 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -326,7 +326,10 @@ def load_checkpoint( given, the full checkpoint will be returned. """ - torch.cuda.empty_cache() + if isinstance(self.accelerator, Accelerator) and self.accelerator.get_device_type() != "cpu": + getattr(torch, self.root_device.type.split(":")[0]).empty_cache() + else: + torch.cuda.empty_cache() checkpoint = self.checkpoint_io.load_checkpoint(path) if not state: return checkpoint diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 9238071178a80..48e7bcf160834 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """ raise NotImplementedError + + @staticmethod + def get_device_type() -> str: + """Get the device for the current process.""" + raise NotImplementedError diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 525071cbb377f..bfe4551311aa4 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device_type() -> str: + return "cpu" + # CPU device metrics _CPU_VM_PERCENT = "cpu_vm_percent" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index a00b12a85a8dd..2ec1d34fa9e04 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device_type() -> str: + return "cuda" + def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index f7674989cc721..a82eb1df6e439 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device_type() -> str: + return "mps" + # device metrics _VM_PERCENT = "M1_vm_percent" diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..b6a6fd12771ca 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -51,7 +51,14 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = ( + torch.amp.GradScaler(device=device) + if _TORCH_GREATER_EQUAL_2_4 + else getattr( + torch, + "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0], + ).amp.GradScaler() + ) if scaler is not None and self.precision == "bf16-mixed": raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index f3bab3e915e91..1f5b92c8285ce 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -49,13 +49,19 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, + precision: _PRECISION_INPUT, + scaler: Optional["ShardedGradScaler"] = None, + device_type: Optional[str] = None, + ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device_type = device_type if device_type is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -127,7 +133,9 @@ def module_init_context(self) -> AbstractContextManager: @override def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast( + self.device_type, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + ) return _DtypeContextManager(self._desired_input_dtype) @override diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66ef42471..f64d7b761b36d 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -190,7 +190,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = self._create_stream_context(device_ids=device_ids) with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) @@ -228,7 +228,7 @@ def _register_ddp_hooks(self) -> None: def _enable_model_averaging(self) -> None: log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") - if self._model_averaging_period is None: + if self._model_averaging_period is None: # type: ignore[no-untyped-def] raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." ) @@ -418,6 +418,18 @@ def teardown(self) -> None: super().teardown() + def _create_stream_context(self, device_ids=None): # type: ignore[no-untyped-def] + """Create a stream context for the current device, if supported.""" + + torch_lib = getattr(torch, self.root_device.type) + # Check if the device type supports streams and has the necessary attributes. + if hasattr(torch_lib, "Stream") and hasattr(torch_lib, "stream") and device_ids is not None: + stream = torch_lib.Stream() + ctx = torch_lib.stream(stream) + else: + ctx = nullcontext() + return ctx + class _DDPForwardRedirection(_ForwardRedirection): @override diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index dabfde70242b9..45ca752b045e3 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -319,12 +319,22 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale + try: + self.device_type = self.accelerator.get_device_type() # type: ignore[union-attr] + except Exception: + self.device_type = "cuda" + self.torch_lib = getattr(torch, self.device_type) + @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + from deepspeed.runtime.utils import get_accelerator + + if ( + not isinstance(self.accelerator, CUDAAccelerator) + ) and self.accelerator.get_device_type() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + f"The DeepSpeed strategy is only supported on {get_accelerator().device_name().upper()} GPUs, " + f"but `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() @@ -667,6 +677,9 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING + if hasattr(torch, self.device_type) and callable(self.torch_lib.empty_cache): + self.torch_lib.empty_cache() + _, client_state = self.deepspeed_engine.load_checkpoint( checkpoint_path, load_optimizer_states=is_fitting, diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 16b16a4927513..f205ff93cb3fb 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -364,7 +364,10 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: return self._lightning_module def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: - torch.cuda.empty_cache() + if isinstance(self.accelerator, pl.accelerators.Accelerator) and self.accelerator.get_device_type() != "cpu": + getattr(torch, self.root_device.type.split(":")[0]).empty_cache() + else: + torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 1423c1aeeafe4..c4da6010883d8 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -142,6 +142,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # for 3rd party accelerator, just do nothing self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes) self._set_parallel_devices_and_init_accelerator() @@ -302,13 +304,15 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" - if self._strategy_flag.parallel_devices[0].type == "cuda": + elif self._strategy_flag.parallel_devices[0].type == "cuda": if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + else: + pass # 3rd party accelerator self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: @@ -458,12 +462,19 @@ def _check_strategy_and_fallback(self) -> None: strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag if ( - strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and self._accelerator_flag not in ("cuda", "gpu"): + (strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy) + and self._accelerator_flag not in ("cuda", "gpu") + and isinstance(self._accelerator_flag, str) + ): raise ValueError( f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:" f" {self._accelerator_flag}" ) + if isinstance(self._accelerator_flag, Accelerator): + Warning( + f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`." + f" Please ensure it is compatible with the selected strategy `{strategy_flag}`." + ) if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this" @@ -497,7 +508,12 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(self._precision_flag) # type: ignore[arg-type] + return FSDPPrecision( + precision=self._precision_flag, # type: ignore[arg-type] + device_type=self._accelerator_flag.get_device_type() + if isinstance(self._accelerator_flag, Accelerator) + else None, + ) if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore if self._precision_flag == "32-true": @@ -521,6 +537,8 @@ def _check_and_init_precision(self) -> Precision: f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device_type() return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 8036a6f45b8a0..e7e6e6b790e57 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -45,6 +45,10 @@ def parse_devices(devices): def get_parallel_devices(devices): return ["foo"] * devices + @staticmethod + def get_device_type(): + return "foo" + @staticmethod def auto_device_count(): return 3 diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 1074789e71055..d73758e4ad939 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -179,6 +179,10 @@ def parse_devices(devices): def get_parallel_devices(devices): return [torch.device("cpu")] * devices + @staticmethod + def get_device_type() -> str: + return "cpu" + @staticmethod def auto_device_count() -> int: return 1 diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 3877d6c051017..57f022b978765 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -192,6 +192,10 @@ def parse_devices(devices): def get_parallel_devices(devices): return [torch.device("cpu")] * devices + @staticmethod + def get_device_type() -> str: + return "cpu" + @staticmethod def auto_device_count() -> int: return 1