From 1c83154679abc2f08eac7005350b7086f357e53e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Oct 2024 07:52:48 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/accelerators/cpu.py | 2 +- src/lightning/fabric/accelerators/cuda.py | 2 +- src/lightning/fabric/accelerators/mps.py | 2 +- src/lightning/fabric/accelerators/xla.py | 2 +- src/lightning/fabric/connector.py | 4 ++-- src/lightning/fabric/plugins/precision/amp.py | 3 +-- src/lightning/fabric/strategies/deepspeed.py | 6 ++++-- src/lightning/pytorch/plugins/precision/amp.py | 3 +-- src/lightning/pytorch/strategies/deepspeed.py | 6 ++++-- .../pytorch/trainer/connectors/accelerator_connector.py | 2 +- tests/tests_fabric/test_connector.py | 2 +- .../trainer/connectors/test_accelerator_connector.py | 2 +- 12 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 8a0681d860be9..e019ea100ee8b 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -49,7 +49,7 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices - + @staticmethod @override def get_device() -> str: diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 6e7a92b2a0eb0..420f645dc9cb6 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -59,7 +59,7 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]: @override def get_device() -> str: return "cuda" - + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index 1840e39586250..29beb97fc9c9a 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -64,7 +64,7 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi @override def get_device() -> str: return "mps" - + @staticmethod @override def auto_device_count() -> int: diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 3826178599ea2..5055a29398f60 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -68,7 +68,7 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: @override def get_device() -> str: return "xla" - + @staticmethod @override # XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 4b14734951a9d..3cf9ed681fe6b 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -464,9 +464,9 @@ def _check_and_init_precision(self) -> Precision: return DeepSpeedPrecision(self._precision_input) # type: ignore if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision( - precision=self._precision_input, # type: ignore[arg-type] + precision=self._precision_input, # type: ignore[arg-type] device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, - ) + ) mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true") if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported: raise ValueError( diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index 9c671cb0a4310..1e2f54e2c7270 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -55,8 +55,7 @@ def __init__( if _TORCH_GREATER_EQUAL_2_4 else getattr( torch, - "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" - else device.split(":")[0] + "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], ).amp.GradScaler() ) if scaler is not None and self.precision == "bf16-mixed": diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 8956a73375de1..4e0493711bf70 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -619,8 +619,10 @@ def _initialize_engine( @override def setup_environment(self) -> None: from deepspeed.runtime.utils import get_accelerator - if (not isinstance(self.accelerator, CUDAAccelerator)) and \ - self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] + + if ( + not isinstance(self.accelerator, CUDAAccelerator) + ) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs," f"but `{self.accelerator.__class__.__name__}` is used." diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 421b6776b8f5c..eb17c33a902de 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -55,8 +55,7 @@ def __init__( if _TORCH_GREATER_EQUAL_2_4 else getattr( torch, - "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" - else device.split(":")[0] + "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], ).amp.GradScaler() ) if scaler is not None and self.precision == "bf16-mixed": diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 6183948d689c7..76fa649666a1e 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -317,8 +317,10 @@ def __init__( @override def setup_environment(self) -> None: from deepspeed.runtime.utils import get_accelerator - if (not isinstance(self.accelerator, CUDAAccelerator)) and \ - self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] + + if ( + not isinstance(self.accelerator, CUDAAccelerator) + ) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr] raise RuntimeError( f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs," f"but `{self.accelerator.__class__.__name__}` is used." diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 1e1bb9982301d..c4d7526b85691 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -508,7 +508,7 @@ def _check_and_init_precision(self) -> Precision: return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type] if isinstance(self.strategy, FSDPStrategy): return FSDPPrecision( - precision=self._precision_flag, # type: ignore[arg-type] + precision=self._precision_flag, # type: ignore[arg-type] device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None, ) if self._precision_flag in ("16-true", "bf16-true"): diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index fee4000cd904f..22a998962141b 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -178,7 +178,7 @@ def parse_devices(devices): @staticmethod def get_parallel_devices(devices): return [torch.device("cpu")] * devices - + @staticmethod def get_device() -> str: return "cpu" diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 5a2190c692723..621fd9106019b 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -191,7 +191,7 @@ def parse_devices(devices): @staticmethod def get_parallel_devices(devices): return [torch.device("cpu")] * devices - + @staticmethod def get_device() -> str: return "cpu"