diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index e2605ceca4670..a4ce4a006bfc1 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from typing import Iterable -from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor from torch.optim import Optimizer from lightning.fabric.utilities.apply_func import move_data_to_device @@ -28,7 +27,42 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N _optimizer_to_device(opt, device) -def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: +def _optimizer_to_device_deprecated(optimizer: Optimizer, device: _DEVICE) -> None: """Moves the state of a single optimizer to the device.""" + """Deprecated because it seems this is unnecessary.""" + # Note special logic for 'step' parameter + # The 'step' parameter needs to remain unmoved (possibly on the CPU) since that is where the optimizer needs it. + # See https://github.com/pytorch/pytorch/issues/74424 and + # _process_value_according_to_param_policy in torch/optim/optimizer.py:618 + fused = False + with contextlib.suppress(Exception): + fused = optimizer.param_groups[0]["fused"] + for p, v in optimizer.state.items(): - optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True) + for key, val in v.items(): + if key != "step" or fused: + v[key] = move_data_to_device(val, device) + + +def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: + """Moves the state of a single optimizer to the device. + + In fact, it looks like we dont need this function but can rely on optimizer.load_state_dict to do the right thing + after given a correct prototype on the target device. For now we do nothing and assume that we don't care about + transferring the optimizer back to the CPU on teardown. See details below. + + """ + pass + + # To test for correct behaviour here we have created two tests: + # 1. tests/tests_fabric/utilities/test_optimizer.py to test Optimizer.load_state_dict with a prototype + # 2. tests/tests_pytorch/checkpointing/test_trainer_move_device.py to test higher level checkpointing on + # one device and resuming on a different device + + # Details on how this function is called. + # 1st call is in Strategy.setup(), to initialize empty optimizer. src/lightning/pytorch/strategies/strategy.py: 158 + # Note: Strategy.setup() first calls Strategy.setup_optimizers which eventually invokes Model.configure_optimizers() + # based on a model that has been moved to the device. Thus it essentially creates a prototype optimizer on the + # target device and then, eventually, relies on Optimizer.load_state_dict() to transfer the state. + # 2nd call when restoring checkpoint, as part of Strategy.load_optimizer_state_dict(). Source strategy.py: 377 + # Final call in Strategy.teardown(), move optimizer back to CPU. Source strategy.py: 525 diff --git a/tests/tests_fabric/utilities/test_optimizer.py b/tests/tests_fabric/utilities/test_optimizer.py index 3aa78d507c346..d2273fe55fe1c 100644 --- a/tests/tests_fabric/utilities/test_optimizer.py +++ b/tests/tests_fabric/utilities/test_optimizer.py @@ -1,36 +1,151 @@ import collections +import copy import dataclasses +from typing import Tuple +import pytest import torch -from lightning.fabric.utilities.optimizer import _optimizer_to_device +import torch.nn as nn +from lightning.fabric.utilities.optimizer import _optimizer_to_device_deprecated from torch import Tensor +from torch.utils.data import DataLoader -def test_optimizer_to_device(): +def create_optimizer_on_devices(): @dataclasses.dataclass(frozen=True) class FooState: bar: int - class TestOptimizer(torch.optim.SGD): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.state["dummy"] = torch.tensor(0) - self.state["frozen"] = FooState(0) - - layer = torch.nn.Linear(32, 2) - opt = TestOptimizer(layer.parameters(), lr=0.1) - _optimizer_to_device(opt, "cpu") - if torch.cuda.is_available(): - _optimizer_to_device(opt, "cuda") - assert_opt_parameters_on_device(opt, "cuda") - - -def assert_opt_parameters_on_device(opt, device: str): - for param in opt.state.values(): - # Not sure there are any global tensors in the state dict - if isinstance(param, Tensor): - assert param.data.device.type == device + class CachedRandomTensorDataset(torch.utils.data.Dataset): + """Very low overhead torch dataset for training for a given number of steps.""" + + def __init__(self, batch_size: int, num_features: int, num_responses: int, length: int, device: str) -> None: + self.x = torch.randn((batch_size, num_features), device=torch.device(device)) + self.y = torch.randn((batch_size, num_responses), device=torch.device(device)) + self.length = length + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.x.clone(), self.y.clone() + + def __len__(self) -> int: + return self.length + + def simple_training(optimizer, model, dataset, loss_fn): + for input, target in dataset: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + + gpu_device = "cuda" + devices = ["cpu", gpu_device] + num_features = 32 + num_responses = 2 + batch_size = 16 + optimizer_on_device = {} + for device in devices: + dataset = CachedRandomTensorDataset(batch_size, num_features, num_responses, batch_size * 16, device) + dataloader = DataLoader(dataset, batch_size=None, shuffle=False) + model = torch.nn.Linear(num_features, num_responses) + model.to(device=device) + fused_vals = [False] if device == "cpu" else [False, True] + + for fused in fused_vals: + optimizer = torch.optim.Adam(model.parameters(), lr=0.1, fused=fused) + simple_training(optimizer, model, dataloader, loss_fn=nn.MSELoss()) + if fused: + optimizer_on_device[device + "_fused_" + str(fused)] = optimizer + else: + optimizer_on_device[device] = optimizer + return gpu_device, optimizer_on_device + + +def test_optimizer_to_device_match_locations(): + """Test the _optimizer_to_device_deprecated function by ensuring that moving the internal state matches what is + expected.""" + + if not torch.cuda.is_available(): + return + + gpu_device, optimizer_on_device = create_optimizer_on_devices() + + # Test _optimizer_to_device function + # Test cpu-->gpu, fused = False from CPU + opt_to_gpu = copy.deepcopy(optimizer_on_device["cpu"]) + _optimizer_to_device_deprecated(opt_to_gpu, gpu_device) + assert_opt_state_in_expected_location(opt_to_gpu, optimizer_on_device[gpu_device]) + + # Test gpu-->cpu, fused = False + opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device]) + _optimizer_to_device_deprecated(opt_to_cpu, "cpu") + assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"]) + + # Test gpu-->cpu, fused = True + opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"]) + _optimizer_to_device_deprecated(opt_to_cpu, "cpu") + assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"]) + + +def test_load_state_dict(): + """Test that optimizer.load_state_dict() with a model prototype on the target device is equivalent to moving the + optimizer onto the device manually.""" + + if not torch.cuda.is_available(): + return + + gpu_device, optimizer_on_device = create_optimizer_on_devices() + + ###################################################################################### + # GPU prototypes + ###################################################################################### + + # CPU -> GPU, fused=False + # Use from_dict with gpu prototype, fused = False + opt_cpu_dict = optimizer_on_device["cpu"].state_dict() + gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device]) + gpu_prototype.load_state_dict(opt_cpu_dict) + print(opt_cpu_dict) + assert_opt_state_in_expected_location(gpu_prototype, optimizer_on_device[gpu_device]) + + # CPU -> GPU, fused=True + # Use from_dict with gpu prototype, fused = True + opt_cpu_dict = optimizer_on_device["cpu"].state_dict() + gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"]) + gpu_prototype.load_state_dict(opt_cpu_dict) + assert_opt_state_in_expected_location( + gpu_prototype, optimizer_on_device[gpu_device] + ) # fused=False from CPU, overrides prototype + + ###################################################################################### + # CPU prototypes + ###################################################################################### + + # GPU, fused=False -> CPU + # Use from_dict with cpu prototype, fused = False + opt_gpu_dict = optimizer_on_device[gpu_device].state_dict() + cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"]) + cpu_prototype.load_state_dict(opt_gpu_dict) + assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"]) + + # GPU, fused=True -> CPU + # Use from_dict with cpu prototype, fused = True + opt_gpu_dict = optimizer_on_device[gpu_device + "_fused_True"].state_dict() + cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"]) + cpu_prototype.load_state_dict(opt_gpu_dict) # !!!!! This should give an error / refuse to allow fused = True + assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"]) + + +def assert_opt_state_in_expected_location(opt, expected_opt): + opt_dict = opt.state_dict() + expected_opt_dict = expected_opt.state_dict() + for key, param in opt_dict["state"].items(): + if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type: + pytest.fail(f"Optimizer device mismatch for state[{key}]") elif isinstance(param, collections.abc.Mapping): - for subparam in param.values(): - if isinstance(subparam, Tensor): - assert param.data.device.type == device + for subkey, subparam in param.items(): + if ( + isinstance(subparam, Tensor) + and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type + ): + pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]") diff --git a/tests/tests_pytorch/checkpointing/test_trainer_move_device.py b/tests/tests_pytorch/checkpointing/test_trainer_move_device.py new file mode 100644 index 0000000000000..c34522ce733b7 --- /dev/null +++ b/tests/tests_pytorch/checkpointing/test_trainer_move_device.py @@ -0,0 +1,143 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import copy +from typing import Any + +import lightning.pytorch as pl +import pytest +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import Callback, ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel +from torch import Tensor + + +class TrainerStateChecker(Callback): + def __init__( + self, + optimizer_dict: dict, + model_dict: dict, + capture: bool, + target_device: str, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.optimizer_dict = optimizer_dict + self.model_dict = model_dict + self.capture = capture + self.target_device = target_device + + def on_train_start(self, trainer, pl_module): + if not self.capture: + # Check model and optimizer device locations + assert trainer.model.device == self.model_dict[self.target_device].device + assert_opt_state_in_expected_location(trainer.optimizers[0], self.optimizer_dict[self.target_device]) + + def on_train_end(self, trainer, pl_module): + if self.capture: + # Capture the optimizer state before it is transferred back to the cpu + self.optimizer_dict[self.target_device] = copy.deepcopy(trainer.optimizers[0]) + self.model_dict[self.target_device] = copy.deepcopy(trainer.model) + + +def assert_opt_state_in_expected_location(opt, expected_opt): + opt_dict = opt.state_dict() + expected_opt_dict = expected_opt.state_dict() + for key, param in opt_dict["state"].items(): + if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type: + pytest.fail(f"Optimizer device mismatch for state[{key}]") + elif isinstance(param, collections.abc.Mapping): + for subkey, subparam in param.items(): + if ( + isinstance(subparam, Tensor) + and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type + ): + pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]") + + +def test_change_device(tmpdir): + """This test validates that a generated ModelCheckpoint can be moved to a different device.""" + + class ExtendedBoringModel(BoringModel): + def __init__( + self, + target_device: str, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.target_device = target_device + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.01) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def validation_step(self, batch, batch_idx): + loss = self.step(batch) + self.log("val_loss", loss, on_epoch=True, prog_bar=True) + + # Train on different devices to create profile of where the state Tensors are located for each device + devices = ["cpu", "gpu"] + optimizer_dict = {} + model_dict = {} + checkpoint_path = {} + for device in devices: + device_path = tmpdir.mkdir(device) + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", dirpath=device_path, filename="{epoch:02d}", save_top_k=-1 + ) + + tsc = TrainerStateChecker( + optimizer_dict=optimizer_dict, model_dict=model_dict, capture=True, target_device=device + ) + trainer = Trainer( + accelerator=device, + devices=1, + default_root_dir=device_path, + max_epochs=1, + limit_train_batches=12, + limit_val_batches=6, + limit_test_batches=12, + callbacks=[checkpoint_callback, tsc], + logger=False, + ) + model = ExtendedBoringModel(device) + trainer.fit(model) + checkpoint_path[device] = checkpoint_callback.best_model_path + + # Cross load from checkpoint + # That is, load CPU checkpoint, but target continuation on GPU, and vice versa + # Expected state is checked via TrainerStateChecker using the above trainers created on GPU and CPU devices + trainer_resume_dict = {} + for device_idx, device in enumerate(devices): + cross_device = devices[(device_idx + 1) % len(devices)] + tsc = TrainerStateChecker( + optimizer_dict=optimizer_dict, model_dict=model_dict, capture=False, target_device=cross_device + ) + trainer = pl.Trainer( + accelerator=cross_device, + devices=1, + default_root_dir=tmpdir, + max_epochs=3, + limit_train_batches=12, + limit_val_batches=12, + limit_test_batches=12, + enable_progress_bar=False, + callbacks=tsc, + ) + model = ExtendedBoringModel(cross_device) + trainer.fit(model, ckpt_path=checkpoint_path[device]) # Load checkpoint from original device + trainer.test() + trainer_resume_dict[cross_device] = trainer