Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 22, 2024
1 parent 4ac2dd4 commit 1c83154
Show file tree
Hide file tree
Showing 12 changed files with 19 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices

@staticmethod
@override
def get_device() -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_parallel_devices(devices: List[int]) -> List[torch.device]:
@override
def get_device() -> str:
return "cuda"

@staticmethod
@override
def auto_device_count() -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.devi
@override
def get_device() -> str:
return "mps"

@staticmethod
@override
def auto_device_count() -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]:
@override
def get_device() -> str:
return "xla"

@staticmethod
@override
# XLA's multiprocessing will pop the TPU_NUM_DEVICES key, so we need to cache it
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ def _check_and_init_precision(self) -> Precision:
return DeepSpeedPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(
precision=self._precision_input, # type: ignore[arg-type]
precision=self._precision_input, # type: ignore[arg-type]
device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None,
)
)
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
raise ValueError(
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def __init__(
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu"
else device.split(":")[0]
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,10 @@ def _initialize_engine(
@override
def setup_environment(self) -> None:
from deepspeed.runtime.utils import get_accelerator
if (not isinstance(self.accelerator, CUDAAccelerator)) and \
self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr]

if (
not isinstance(self.accelerator, CUDAAccelerator)
) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr]
raise RuntimeError(
f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs,"
f"but `{self.accelerator.__class__.__name__}` is used."
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def __init__(
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu"
else device.split(":")[0]
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,10 @@ def __init__(
@override
def setup_environment(self) -> None:
from deepspeed.runtime.utils import get_accelerator
if (not isinstance(self.accelerator, CUDAAccelerator)) and \
self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr]

if (
not isinstance(self.accelerator, CUDAAccelerator)
) and self.accelerator.get_device() != get_accelerator().device_name(): # type: ignore[union-attr]
raise RuntimeError(
f"The DeepSpeed strategy is only supported on {get_accelerator().device_name()} GPUs,"
f"but `{self.accelerator.__class__.__name__}` is used."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _check_and_init_precision(self) -> Precision:
return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type]
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(
precision=self._precision_flag, # type: ignore[arg-type]
precision=self._precision_flag, # type: ignore[arg-type]
device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None,
)
if self._precision_flag in ("16-true", "bf16-true"):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def parse_devices(devices):
@staticmethod
def get_parallel_devices(devices):
return [torch.device("cpu")] * devices

@staticmethod
def get_device() -> str:
return "cpu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def parse_devices(devices):
@staticmethod
def get_parallel_devices(devices):
return [torch.device("cpu")] * devices

@staticmethod
def get_device() -> str:
return "cpu"
Expand Down

0 comments on commit 1c83154

Please sign in to comment.