From 4ac2dd4eb2c8ac531f5b448a4cc9f5b193024aba Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Sat, 19 Oct 2024 02:00:03 +1100 Subject: [PATCH] enhance 3d-party devices in mix-precision --- .../source-pytorch/extensions/accelerator.rst | 39 ++++++++++++++++++- .../fabric/accelerators/accelerator.py | 5 +++ src/lightning/fabric/accelerators/cpu.py | 5 +++ src/lightning/fabric/accelerators/cuda.py | 5 +++ src/lightning/fabric/accelerators/mps.py | 5 +++ src/lightning/fabric/accelerators/xla.py | 5 +++ src/lightning/fabric/connector.py | 9 ++++- src/lightning/fabric/plugins/precision/amp.py | 10 ++++- .../fabric/plugins/precision/fsdp.py | 9 ++++- src/lightning/fabric/strategies/ddp.py | 8 +++- src/lightning/fabric/strategies/deepspeed.py | 12 ++++-- src/lightning/fabric/strategies/strategy.py | 4 +- .../pytorch/accelerators/accelerator.py | 5 +++ src/lightning/pytorch/accelerators/cpu.py | 5 +++ src/lightning/pytorch/accelerators/cuda.py | 5 +++ src/lightning/pytorch/accelerators/mps.py | 5 +++ .../pytorch/plugins/precision/amp.py | 10 ++++- .../pytorch/plugins/precision/fsdp.py | 9 ++++- src/lightning/pytorch/strategies/ddp.py | 8 +++- src/lightning/pytorch/strategies/deepspeed.py | 8 ++-- src/lightning/pytorch/strategies/strategy.py | 4 +- .../connectors/accelerator_connector.py | 24 ++++++++++-- .../accelerators/test_registry.py | 4 ++ tests/tests_fabric/test_connector.py | 4 ++ .../connectors/test_accelerator_connector.py | 4 ++ 25 files changed, 187 insertions(+), 24 deletions(-) diff --git a/docs/source-pytorch/extensions/accelerator.rst b/docs/source-pytorch/extensions/accelerator.rst index 93dc467b02921..dcedde8c6905c 100644 --- a/docs/source-pytorch/extensions/accelerator.rst +++ b/docs/source-pytorch/extensions/accelerator.rst @@ -36,29 +36,57 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc .. code-block:: python + import torch import xpulib + from functools import lru_cache + from typing import Any, Dict, Union + from lightning.pytorch.accelerators.accelerator import Accelerator + + from typing_extensions import override + class XPUAccelerator(Accelerator): """Support for a hypothetical XPU, optimized for large-scale machine learning.""" + @override + def setup_device(self, device: torch.device) -> None: + """ + Raises: + ValueError: + If the selected device is not of type hypothetical XPU. + """ + if device.type != "xpu": + raise ValueError(f"Device should be of type 'xpu', got '{device.type}' instead.") + if device.index is None: + device = torch.device("xpu", 0) + xpulib.set_device(device.index) + + @override + def teardown(self) -> None: + xpulib.empty_cache() + @staticmethod + @override def parse_devices(devices: Any) -> Any: # Put parsing logic here how devices can be passed into the Trainer # via the `devices` argument return devices @staticmethod + @override def get_parallel_devices(devices: Any) -> Any: # Here, convert the device indices to actual device objects return [torch.device("xpu", idx) for idx in devices] @staticmethod + @override def auto_device_count() -> int: # Return a value for auto-device selection when `Trainer(devices="auto")` return xpulib.available_devices() @staticmethod + @override def is_available() -> bool: return xpulib.is_available() @@ -66,15 +94,21 @@ Let's pretend we want to integrate the fictional XPU accelerator and we have acc # Return optional device statistics for loggers return {} + @staticmethod + @override + def get_device() -> str: + return "xpu" + Finally, add the XPUAccelerator to the Trainer: .. code-block:: python from lightning.pytorch import Trainer - + from lightning.pytorch.strategies import DDPStrategy accelerator = XPUAccelerator() - trainer = Trainer(accelerator=accelerator, devices=2) + strategy = DDPStrategy(parallel_devices=accelerator.get_parallel_devices(2)) + trainer = Trainer(accelerator=accelerator, strategy=strategy, devices=2) :doc:`Learn more about Strategies <../extensions/strategy>` and how they interact with the Accelerator. @@ -93,6 +127,7 @@ If you wish to switch to a custom accelerator from the CLI without code changes, ... @classmethod + @override def register_accelerators(cls, accelerator_registry): accelerator_registry.register( "xpu", diff --git a/src/lightning/fabric/accelerators/accelerator.py b/src/lightning/fabric/accelerators/accelerator.py index 3a8aa85ad041d..84ef97e514bbc 100644 --- a/src/lightning/fabric/accelerators/accelerator.py +++ b/src/lightning/fabric/accelerators/accelerator.py @@ -46,6 +46,11 @@ def parse_devices(devices: Any) -> Any: def get_parallel_devices(devices: Any) -> Any: """Gets parallel devices for the Accelerator.""" + @staticmethod + @abstractmethod + def get_device() -> Any: + """Get the device for the current Accelerator.""" + @staticmethod @abstractmethod def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 1bcec1b2ac278..8a0681d860be9 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -49,6 +49,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices + + @staticmethod + @override + def get_device() -> str: + return "cpu" @staticmethod @override diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 4afc9be723fc2..6e7a92b2a0eb0 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -55,6 +55,11 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] + @staticmethod + @override + def get_device() -> str: + return "cuda" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index 75497169cda0f..1840e39586250 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -60,6 +60,11 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi assert parsed_devices is not None return [torch.device("mps", i) for i in range(len(parsed_devices))] + @staticmethod + @override + def get_device() -> str: + return "mps" + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 38d7380dc7905..3826178599ea2 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -64,6 +64,11 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: # accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`. # it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy + @staticmethod + @override + def get_device() -> str: + return "xla" + @staticmethod @override # XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 9fb66255830c6..4b14734951a9d 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -141,6 +141,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # for 3rd party accelerator, just do nothing self._set_parallel_devices_and_init_accelerator() @@ -461,7 +463,10 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type] + return FSDPPrecision( + precision=self._precision_input, # type: ignore[arg-type] + device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, + ) mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: raise ValueError( @@ -493,6 +498,8 @@ def _check_and_init_precision(self) -> Precision: else "Using bfloat16 Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device() return MixedPrecision(precision=self._precision_input, device=device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index c624e821af28c..9c671cb0a4310 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -50,7 +50,15 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = ( + torch.amp.GradScaler(device=device) + if _TORCH_GREATER_EQUAL_2_4 + else getattr( + torch, + "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" + else device.split(":")[0] + ).amp.GradScaler() + ) if scaler is not None and self.precision == "bf16-mixed": raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 179fc21cdd90d..43570373a39b1 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -48,13 +48,16 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None + ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device = device if device is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -110,7 +113,9 @@ def module_init_context(self) -> ContextManager: @override def forward_context(self) -> ContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast( + self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + ) return self.tensor_init_context() @override diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index c38780655ce6e..e7456fd6a8ca5 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -124,7 +124,13 @@ def setup_module(self, module: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self._determine_ddp_device_ids() # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = ( + getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( + getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() + ) + if device_ids is not None + else nullcontext() + ) with ctx: return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index e71b8e2db3d58..8956a73375de1 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -506,7 +506,9 @@ def load_checkpoint( optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values()) - torch.cuda.empty_cache() + getattr( + torch, f"{self.root_device.type.split(':')[0]}" + ).empty_cache() if self.accelerator.get_device() != "cpu" else None _, client_state = engine.load_checkpoint( path, tag="checkpoint", @@ -616,10 +618,12 @@ def _initialize_engine( @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + from deepspeed.runtime.utils import get_accelerator + if (not isinstance(self.accelerator, CUDAAccelerator)) and \ + self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs," + f"but `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b68..244588d5e4124 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -325,7 +325,9 @@ def load_checkpoint( given, the full checkpoint will be returned. """ - torch.cuda.empty_cache() + getattr( + torch, f"{self.root_device.type.split(':')[0]}" + ).empty_cache() if self.root_device.type != "cpu" else None checkpoint = self.checkpoint_io.load_checkpoint(path) if not state: return checkpoint diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 0490c2d86431c..96a3941af97f3 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: """ raise NotImplementedError + + @staticmethod + def get_device() -> str: + """Get the device for the current process.""" + raise NotImplementedError diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index 735312b363d11..ab6304053f314 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -80,6 +80,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "cpu" + # CPU device metrics _CPU_VM_PERCENT = "cpu_vm_percent" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 6df3bc6b468ee..cfb85cb2c2990 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -113,6 +113,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "cuda" + def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 6efe6292de624..d8bda9dae8087 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -87,6 +87,11 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @staticmethod + @override + def get_device() -> str: + return "mps" + # device metrics _VM_PERCENT = "M1_vm_percent" diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index e63ccd6912b63..421b6776b8f5c 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -50,7 +50,15 @@ def __init__( self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() + scaler = ( + torch.amp.GradScaler(device=device) + if _TORCH_GREATER_EQUAL_2_4 + else getattr( + torch, + "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" + else device.split(":")[0] + ).amp.GradScaler() + ) if scaler is not None and self.precision == "bf16-mixed": raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index e6c684967ed40..280defe04ff44 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -47,13 +47,16 @@ class FSDPPrecision(Precision): """ - def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None, device: Optional[str] = None + ) -> None: supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in FSDP." f" `precision` must be one of: {supported_precision}." ) + self.device = device if device is not None else "cuda" from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler @@ -119,7 +122,9 @@ def module_init_context(self) -> ContextManager: @override def forward_context(self) -> ContextManager: if "mixed" in self.precision: - return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) + return torch.autocast( + self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16) + ) return _DtypeContextManager(self._desired_input_dtype) @override diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9031b6ee177f3..c16310cd65245 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -190,7 +190,13 @@ def _setup_model(self, model: Module) -> DistributedDataParallel: device_ids = self.determine_ddp_device_ids() log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") # https://pytorch.org/docs/stable/notes/cuda.html#id5 - ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + ctx = ( + getattr(torch, f"{self.root_device.type.split(':')[0]}").stream( + getattr(torch, f"{self.root_device.type.split(':')[0]}").Stream() + ) + if device_ids is not None + else nullcontext() + ) with ctx: return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 1eaa5bab75fbe..6183948d689c7 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -316,10 +316,12 @@ def __init__( @override def setup_environment(self) -> None: - if not isinstance(self.accelerator, CUDAAccelerator): + from deepspeed.runtime.utils import get_accelerator + if (not isinstance(self.accelerator, CUDAAccelerator)) and \ + self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( - f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" - " is used." + f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs," + f"but `{self.accelerator.__class__.__name__}` is used." ) super().setup_environment() diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 314007f497f59..6327170f31c46 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -363,7 +363,9 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: return self._lightning_module def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: - torch.cuda.empty_cache() + getattr( + torch, f"{self.root_device.type.split(':')[0]}" + ).empty_cache() if self.root_device.type != "cpu" else None return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 06f3ee366bcaa..1e1bb9982301d 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -141,6 +141,8 @@ def __init__( self._accelerator_flag = self._choose_auto_accelerator() elif self._accelerator_flag == "gpu": self._accelerator_flag = self._choose_gpu_accelerator_backend() + elif isinstance(self._accelerator_flag, Accelerator): + pass # for 3rd party accelerator, just do nothing self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes) self._set_parallel_devices_and_init_accelerator() @@ -301,13 +303,15 @@ def _check_config_and_set_final_flags( f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cpu" - if self._strategy_flag.parallel_devices[0].type == "cuda": + elif self._strategy_flag.parallel_devices[0].type == "cuda": if self._accelerator_flag and self._accelerator_flag not in ("auto", "cuda", "gpu"): raise MisconfigurationException( f"GPU parallel_devices set through {self._strategy_flag.__class__.__name__} class," f" but accelerator set to {self._accelerator_flag}, please choose one device type" ) self._accelerator_flag = "cuda" + else: + pass # 3rd party accelerator self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: @@ -457,12 +461,19 @@ def _check_strategy_and_fallback(self) -> None: strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag if ( - strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and self._accelerator_flag not in ("cuda", "gpu"): + (strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy) + and self._accelerator_flag not in ("cuda", "gpu") + and isinstance(self._accelerator_flag, str) + ): raise ValueError( f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:" f" {self._accelerator_flag}" ) + if isinstance(self._accelerator_flag, Accelerator): + Warning( + f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`." + f" Please ensure it is compatible with the selected strategy `{strategy_flag}`." + ) if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this" @@ -496,7 +507,10 @@ def _check_and_init_precision(self) -> Precision: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): - return FSDPPrecision(self._precision_flag) # type: ignore[arg-type] + return FSDPPrecision( + precision=self._precision_flag, # type: ignore[arg-type] + device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, + ) if self._precision_flag in ("16-true", "bf16-true"): return HalfPrecision(self._precision_flag) # type: ignore if self._precision_flag == "32-true": @@ -520,6 +534,8 @@ def _check_and_init_precision(self) -> Precision: f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + if isinstance(self._accelerator_flag, Accelerator): + device = self._accelerator_flag.get_device() return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index e8f39b6e83406..2540bde18ce7d 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -44,6 +44,10 @@ def parse_devices(devices): def get_parallel_devices(devices): return ["foo"] * devices + @staticmethod + def get_device(): + return "foo" + @staticmethod def auto_device_count(): return 3 diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 08d6dbb45ed91..fee4000cd904f 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -178,6 +178,10 @@ def parse_devices(devices): @staticmethod def get_parallel_devices(devices): return [torch.device("cpu")] * devices + + @staticmethod + def get_device() -> str: + return "cpu" @staticmethod def auto_device_count() -> int: diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 65c5777e28fed..5a2190c692723 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -191,6 +191,10 @@ def parse_devices(devices): @staticmethod def get_parallel_devices(devices): return [torch.device("cpu")] * devices + + @staticmethod + def get_device() -> str: + return "cpu" @staticmethod def auto_device_count() -> int: