Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test optimizer to device #20062

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions src/lightning/fabric/utilities/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from typing import Iterable

from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.optim import Optimizer

from lightning.fabric.utilities.apply_func import move_data_to_device
Expand All @@ -28,7 +27,42 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N
_optimizer_to_device(opt, device)


def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
def _optimizer_to_device_deprecated(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device."""
"""Deprecated because it seems this is unnecessary."""
# Note special logic for 'step' parameter
# The 'step' parameter needs to remain unmoved (possibly on the CPU) since that is where the optimizer needs it.
# See https://github.com/pytorch/pytorch/issues/74424 and
# _process_value_according_to_param_policy in torch/optim/optimizer.py:618
fused = False
with contextlib.suppress(Exception):
fused = optimizer.param_groups[0]["fused"]

for p, v in optimizer.state.items():
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
for key, val in v.items():
if key != "step" or fused:
v[key] = move_data_to_device(val, device)


def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device.

In fact, it looks like we dont need this function but can rely on optimizer.load_state_dict to do the right thing
after given a correct prototype on the target device. For now we do nothing and assume that we don't care about
transferring the optimizer back to the CPU on teardown. See details below.

"""
pass

# To test for correct behaviour here we have created two tests:
# 1. tests/tests_fabric/utilities/test_optimizer.py to test Optimizer.load_state_dict with a prototype
# 2. tests/tests_pytorch/checkpointing/test_trainer_move_device.py to test higher level checkpointing on
# one device and resuming on a different device

# Details on how this function is called.
# 1st call is in Strategy.setup(), to initialize empty optimizer. src/lightning/pytorch/strategies/strategy.py: 158
# Note: Strategy.setup() first calls Strategy.setup_optimizers which eventually invokes Model.configure_optimizers()
# based on a model that has been moved to the device. Thus it essentially creates a prototype optimizer on the
# target device and then, eventually, relies on Optimizer.load_state_dict() to transfer the state.
# 2nd call when restoring checkpoint, as part of Strategy.load_optimizer_state_dict(). Source strategy.py: 377
# Final call in Strategy.teardown(), move optimizer back to CPU. Source strategy.py: 525
163 changes: 139 additions & 24 deletions tests/tests_fabric/utilities/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,151 @@
import collections
import copy
import dataclasses
from typing import Tuple

import pytest
import torch
from lightning.fabric.utilities.optimizer import _optimizer_to_device
import torch.nn as nn
from lightning.fabric.utilities.optimizer import _optimizer_to_device_deprecated
from torch import Tensor
from torch.utils.data import DataLoader


def test_optimizer_to_device():
def create_optimizer_on_devices():
@dataclasses.dataclass(frozen=True)
class FooState:
bar: int

class TestOptimizer(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state["dummy"] = torch.tensor(0)
self.state["frozen"] = FooState(0)

layer = torch.nn.Linear(32, 2)
opt = TestOptimizer(layer.parameters(), lr=0.1)
_optimizer_to_device(opt, "cpu")
if torch.cuda.is_available():
_optimizer_to_device(opt, "cuda")
assert_opt_parameters_on_device(opt, "cuda")


def assert_opt_parameters_on_device(opt, device: str):
for param in opt.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, Tensor):
assert param.data.device.type == device
class CachedRandomTensorDataset(torch.utils.data.Dataset):
"""Very low overhead torch dataset for training for a given number of steps."""

def __init__(self, batch_size: int, num_features: int, num_responses: int, length: int, device: str) -> None:
self.x = torch.randn((batch_size, num_features), device=torch.device(device))
self.y = torch.randn((batch_size, num_responses), device=torch.device(device))
self.length = length

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.x.clone(), self.y.clone()

def __len__(self) -> int:
return self.length

def simple_training(optimizer, model, dataset, loss_fn):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()

gpu_device = "cuda"
devices = ["cpu", gpu_device]
num_features = 32
num_responses = 2
batch_size = 16
optimizer_on_device = {}
for device in devices:
dataset = CachedRandomTensorDataset(batch_size, num_features, num_responses, batch_size * 16, device)
dataloader = DataLoader(dataset, batch_size=None, shuffle=False)
model = torch.nn.Linear(num_features, num_responses)
model.to(device=device)
fused_vals = [False] if device == "cpu" else [False, True]

for fused in fused_vals:
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, fused=fused)
simple_training(optimizer, model, dataloader, loss_fn=nn.MSELoss())
if fused:
optimizer_on_device[device + "_fused_" + str(fused)] = optimizer
else:
optimizer_on_device[device] = optimizer
return gpu_device, optimizer_on_device


def test_optimizer_to_device_match_locations():
"""Test the _optimizer_to_device_deprecated function by ensuring that moving the internal state matches what is
expected."""

if not torch.cuda.is_available():
return

gpu_device, optimizer_on_device = create_optimizer_on_devices()

# Test _optimizer_to_device function
# Test cpu-->gpu, fused = False from CPU
opt_to_gpu = copy.deepcopy(optimizer_on_device["cpu"])
_optimizer_to_device_deprecated(opt_to_gpu, gpu_device)
assert_opt_state_in_expected_location(opt_to_gpu, optimizer_on_device[gpu_device])

# Test gpu-->cpu, fused = False
opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device])
_optimizer_to_device_deprecated(opt_to_cpu, "cpu")
assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"])

# Test gpu-->cpu, fused = True
opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"])
_optimizer_to_device_deprecated(opt_to_cpu, "cpu")
assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"])


def test_load_state_dict():
"""Test that optimizer.load_state_dict() with a model prototype on the target device is equivalent to moving the
optimizer onto the device manually."""

if not torch.cuda.is_available():
return

gpu_device, optimizer_on_device = create_optimizer_on_devices()

######################################################################################
# GPU prototypes
######################################################################################

# CPU -> GPU, fused=False
# Use from_dict with gpu prototype, fused = False
opt_cpu_dict = optimizer_on_device["cpu"].state_dict()
gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device])
gpu_prototype.load_state_dict(opt_cpu_dict)
print(opt_cpu_dict)
assert_opt_state_in_expected_location(gpu_prototype, optimizer_on_device[gpu_device])

# CPU -> GPU, fused=True
# Use from_dict with gpu prototype, fused = True
opt_cpu_dict = optimizer_on_device["cpu"].state_dict()
gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"])
gpu_prototype.load_state_dict(opt_cpu_dict)
assert_opt_state_in_expected_location(
gpu_prototype, optimizer_on_device[gpu_device]
) # fused=False from CPU, overrides prototype

######################################################################################
# CPU prototypes
######################################################################################

# GPU, fused=False -> CPU
# Use from_dict with cpu prototype, fused = False
opt_gpu_dict = optimizer_on_device[gpu_device].state_dict()
cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"])
cpu_prototype.load_state_dict(opt_gpu_dict)
assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"])

# GPU, fused=True -> CPU
# Use from_dict with cpu prototype, fused = True
opt_gpu_dict = optimizer_on_device[gpu_device + "_fused_True"].state_dict()
cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"])
cpu_prototype.load_state_dict(opt_gpu_dict) # !!!!! This should give an error / refuse to allow fused = True
assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"])


def assert_opt_state_in_expected_location(opt, expected_opt):
opt_dict = opt.state_dict()
expected_opt_dict = expected_opt.state_dict()
for key, param in opt_dict["state"].items():
if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type:
pytest.fail(f"Optimizer device mismatch for state[{key}]")
elif isinstance(param, collections.abc.Mapping):
for subparam in param.values():
if isinstance(subparam, Tensor):
assert param.data.device.type == device
for subkey, subparam in param.items():
if (
isinstance(subparam, Tensor)
and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type
):
pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]")
143 changes: 143 additions & 0 deletions tests/tests_pytorch/checkpointing/test_trainer_move_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
from typing import Any

import lightning.pytorch as pl
import pytest
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from torch import Tensor


class TrainerStateChecker(Callback):
def __init__(
self,
optimizer_dict: dict,
model_dict: dict,
capture: bool,
target_device: str,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.optimizer_dict = optimizer_dict
self.model_dict = model_dict
self.capture = capture
self.target_device = target_device

def on_train_start(self, trainer, pl_module):
if not self.capture:
# Check model and optimizer device locations
assert trainer.model.device == self.model_dict[self.target_device].device
assert_opt_state_in_expected_location(trainer.optimizers[0], self.optimizer_dict[self.target_device])

def on_train_end(self, trainer, pl_module):
if self.capture:
# Capture the optimizer state before it is transferred back to the cpu
self.optimizer_dict[self.target_device] = copy.deepcopy(trainer.optimizers[0])
self.model_dict[self.target_device] = copy.deepcopy(trainer.model)


def assert_opt_state_in_expected_location(opt, expected_opt):
opt_dict = opt.state_dict()
expected_opt_dict = expected_opt.state_dict()
for key, param in opt_dict["state"].items():
if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type:
pytest.fail(f"Optimizer device mismatch for state[{key}]")
elif isinstance(param, collections.abc.Mapping):
for subkey, subparam in param.items():
if (
isinstance(subparam, Tensor)
and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type
):
pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]")


def test_change_device(tmpdir):
"""This test validates that a generated ModelCheckpoint can be moved to a different device."""

class ExtendedBoringModel(BoringModel):
def __init__(
self,
target_device: str,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.target_device = target_device

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.01)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def validation_step(self, batch, batch_idx):
loss = self.step(batch)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)

# Train on different devices to create profile of where the state Tensors are located for each device
devices = ["cpu", "gpu"]
optimizer_dict = {}
model_dict = {}
checkpoint_path = {}
for device in devices:
device_path = tmpdir.mkdir(device)
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", dirpath=device_path, filename="{epoch:02d}", save_top_k=-1
)

tsc = TrainerStateChecker(
optimizer_dict=optimizer_dict, model_dict=model_dict, capture=True, target_device=device
)
trainer = Trainer(
accelerator=device,
devices=1,
default_root_dir=device_path,
max_epochs=1,
limit_train_batches=12,
limit_val_batches=6,
limit_test_batches=12,
callbacks=[checkpoint_callback, tsc],
logger=False,
)
model = ExtendedBoringModel(device)
trainer.fit(model)
checkpoint_path[device] = checkpoint_callback.best_model_path

# Cross load from checkpoint
# That is, load CPU checkpoint, but target continuation on GPU, and vice versa
# Expected state is checked via TrainerStateChecker using the above trainers created on GPU and CPU devices
trainer_resume_dict = {}
for device_idx, device in enumerate(devices):
cross_device = devices[(device_idx + 1) % len(devices)]
tsc = TrainerStateChecker(
optimizer_dict=optimizer_dict, model_dict=model_dict, capture=False, target_device=cross_device
)
trainer = pl.Trainer(
accelerator=cross_device,
devices=1,
default_root_dir=tmpdir,
max_epochs=3,
limit_train_batches=12,
limit_val_batches=12,
limit_test_batches=12,
enable_progress_bar=False,
callbacks=tsc,
)
model = ExtendedBoringModel(cross_device)
trainer.fit(model, ckpt_path=checkpoint_path[device]) # Load checkpoint from original device
trainer.test()
trainer_resume_dict[cross_device] = trainer
Loading