Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test optimizer to device #20062

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/lightning/fabric/utilities/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from typing import Iterable

from lightning_utilities.core.apply_func import apply_to_collection
Expand All @@ -28,7 +29,27 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N
_optimizer_to_device(opt, device)


def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
def _optimizer_to_device_old(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device."""
for p, v in optimizer.state.items():
optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)


def _optimizer_to_device_fancy(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device."""
# Note special logic for 'step' parameter
# The 'step' parameter needs to remain unmoved (possibly on the CPU) since that is where the optimizer needs it.
# See https://github.com/pytorch/pytorch/issues/74424 and
# _process_value_according_to_param_policy in torch/optim/optimizer.py:618
fused = False
with contextlib.suppress(Exception):
fused = optimizer.param_groups[0]["fused"]

for p, v in optimizer.state.items():
for key, val in v.items():
if key != "step" or fused:
v[key] = move_data_to_device(val, device)


def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
pass # Rely on optimizer.load_state_dict to do the right thing
141 changes: 118 additions & 23 deletions tests/tests_fabric/utilities/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,131 @@
import collections
import copy
import dataclasses
from typing import Tuple

import pytest
import torch
import torch.nn as nn
from lightning.fabric.utilities.optimizer import _optimizer_to_device
from torch import Tensor
from torch.utils.data import DataLoader


def test_optimizer_to_device():
def test_optimizer_to_device_match_locations():
@dataclasses.dataclass(frozen=True)
class FooState:
bar: int

class TestOptimizer(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state["dummy"] = torch.tensor(0)
self.state["frozen"] = FooState(0)

layer = torch.nn.Linear(32, 2)
opt = TestOptimizer(layer.parameters(), lr=0.1)
_optimizer_to_device(opt, "cpu")
if torch.cuda.is_available():
_optimizer_to_device(opt, "cuda")
assert_opt_parameters_on_device(opt, "cuda")


def assert_opt_parameters_on_device(opt, device: str):
for param in opt.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, Tensor):
assert param.data.device.type == device
class CachedRandomTensorDataset(torch.utils.data.Dataset):
"""Very low overhead torch dataset for training for a given number of steps."""

def __init__(self, batch_size: int, num_features: int, num_responses: int, length: int, device: str) -> None:
self.x = torch.randn((batch_size, num_features), device=torch.device(device))
self.y = torch.randn((batch_size, num_responses), device=torch.device(device))
self.length = length

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.x.clone(), self.y.clone()

def __len__(self) -> int:
return self.length

def simple_training(optimizer, model, dataset, loss_fn):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()

if not torch.cuda.is_available():
return

gpu_device = "cuda"
devices = ["cpu", gpu_device]

num_features = 32
num_responses = 2
batch_size = 16

optimizer_on_device = {}
for device in devices:
dataset = CachedRandomTensorDataset(batch_size, num_features, num_responses, batch_size * 16, device)
dataloader = DataLoader(dataset, batch_size=None, shuffle=False)
model = torch.nn.Linear(num_features, num_responses)
model.to(device=device)
fused_vals = [False] if device == "cpu" else [False, True]

for fused in fused_vals:
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, fused=fused)
simple_training(optimizer, model, dataloader, loss_fn=nn.MSELoss())
if fused:
optimizer_on_device[device + "_fused_" + str(fused)] = optimizer
else:
optimizer_on_device[device] = optimizer

# Test _optimizer_to_device function
# Test cpu-->gpu, fused = False from CPU
opt_to_gpu = copy.deepcopy(optimizer_on_device["cpu"])
_optimizer_to_device(opt_to_gpu, gpu_device)
assert_opt_state_in_expected_location(opt_to_gpu, optimizer_on_device[gpu_device])

# Test gpu-->cpu, fused = False
opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device])
_optimizer_to_device(opt_to_cpu, "cpu")
assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"])

# Test gpu-->cpu, fused = True
opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"])
_optimizer_to_device(opt_to_cpu, "cpu")
assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"])

# Try from_dict
# These all pretend that we have an appropriate prototype, I don't think we can actually do this since
# all we may have is a CPU pickle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test we have for pytorch load_state_dict being able to read a CPU checkpoint into an appropriate GPU optimizer is here: https://github.com/pytorch/pytorch/blob/main/test/test_optim.py#L1545-L1574

The code above is also how I expect checkpointing to happen, without the need of an explicit move to device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated this test to be more explicit about what is going on, please take a look and see if it makes sense since the test you linked doesn't look at thorough as far as I can tell.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the case we test is moving from CPU to GPU, and I see you test more combinations.

# For now, this is a future idea for _optimizer_to_device

if False:
# GPU prototypes
# Use from_dict with gpu prototype, fused = False
opt_cpu_dict = optimizer_on_device["cpu"].state_dict()
gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device])
gpu_prototype.load_state_dict(opt_cpu_dict)
print(opt_cpu_dict)
assert_opt_state_in_expected_location(gpu_prototype, optimizer_on_device[gpu_device])

# Use from_dict with gpu prototype, fused = True
opt_cpu_dict = optimizer_on_device["cpu"].state_dict()
gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"])
gpu_prototype.load_state_dict(opt_cpu_dict)
assert_opt_state_in_expected_location(
gpu_prototype, optimizer_on_device[gpu_device]
) # fused=False from CPU, overrides prototype

# CPU prototypes
# Use from_dict with cpu prototype, fused = False
opt_gpu_dict = optimizer_on_device[gpu_device].state_dict()
cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"])
cpu_prototype.load_state_dict(opt_gpu_dict)
assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"])

# Use from_dict with cpu prototype, fused = True
opt_gpu_dict = optimizer_on_device[gpu_device + "_fused_True"].state_dict()
cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"])
cpu_prototype.load_state_dict(opt_gpu_dict) # This should give an error / refuse to allow fused = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, for older versions of torch this should indeed be not allowed/would break. But since torch 2.4, there is a fused CPU Adam(W)/SGD/Adagrad, so fused=True on CPU for these optimizers would be valid.

assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"])


def assert_opt_state_in_expected_location(opt, expected_opt):
opt_dict = opt.state_dict()
expected_opt_dict = expected_opt.state_dict()
for key, param in opt_dict["state"].items():
if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type:
pytest.fail(f"Optimizer device mismatch for state[{key}]")
elif isinstance(param, collections.abc.Mapping):
for subparam in param.values():
if isinstance(subparam, Tensor):
assert param.data.device.type == device
for subkey, subparam in param.items():
if (
isinstance(subparam, Tensor)
and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type
):
pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]")
140 changes: 140 additions & 0 deletions tests/tests_pytorch/checkpointing/test_trainer_move_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
from typing import Any

import lightning.pytorch as pl
import pytest
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from torch import Tensor


class TrainerStateChecker(Callback):
def __init__(
self,
optimizer_dict: dict,
model_dict: dict,
capture: bool,
target_device: str,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.optimizer_dict = optimizer_dict
self.model_dict = model_dict
self.capture = capture
self.target_device = target_device

def on_train_start(self, trainer, pl_module):
if not self.capture:
# Check model and optimizer device locations
assert trainer.model.device == self.model_dict[self.target_device].device
assert_opt_state_in_expected_location(trainer.optimizers[0], self.optimizer_dict[self.target_device])

def on_train_end(self, trainer, pl_module):
if self.capture:
# Capture the optimizer state before it is transferred back to the cpu
self.optimizer_dict[self.target_device] = copy.deepcopy(trainer.optimizers[0])
self.model_dict[self.target_device] = copy.deepcopy(trainer.model)


def assert_opt_state_in_expected_location(opt, expected_opt):
opt_dict = opt.state_dict()
expected_opt_dict = expected_opt.state_dict()
for key, param in opt_dict["state"].items():
if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type:
pytest.fail(f"Optimizer device mismatch for state[{key}]")
elif isinstance(param, collections.abc.Mapping):
for subkey, subparam in param.items():
if (
isinstance(subparam, Tensor)
and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type
):
pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]")


def test_change_device(tmpdir):
"""This test validates that a generated ModelCheckpoint can be moved to a different device."""

class ExtendedBoringModel(BoringModel):
def __init__(
self,
target_device: str,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.target_device = target_device

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.01)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def validation_step(self, batch, batch_idx):
loss = self.step(batch)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)

devices = ["cpu", "gpu"]
optimizer_dict = {}
model_dict = {}
checkpoint_path = {}
for device in devices:
device_path = tmpdir.mkdir(device)
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", dirpath=device_path, filename="{epoch:02d}", save_top_k=-1
)

tsc = TrainerStateChecker(
optimizer_dict=optimizer_dict, model_dict=model_dict, capture=True, target_device=device
)
trainer = Trainer(
accelerator=device,
devices=1,
default_root_dir=device_path,
max_epochs=1,
limit_train_batches=12,
limit_val_batches=6,
limit_test_batches=12,
callbacks=[checkpoint_callback, tsc],
logger=False,
)
model = ExtendedBoringModel(device)
trainer.fit(model)
checkpoint_path[device] = checkpoint_callback.best_model_path

# cross load from checkpoint
trainer_resume_dict = {}
for device_idx, device in enumerate(devices):
cross_device = devices[(device_idx + 1) % len(devices)]
tsc = TrainerStateChecker(
optimizer_dict=optimizer_dict, model_dict=model_dict, capture=False, target_device=cross_device
)
trainer = pl.Trainer(
accelerator=cross_device,
devices=1,
default_root_dir=tmpdir,
max_epochs=3,
limit_train_batches=12,
limit_val_batches=12,
limit_test_batches=12,
enable_progress_bar=False,
callbacks=tsc,
)
model = ExtendedBoringModel(cross_device)
trainer.fit(model, ckpt_path=checkpoint_path[device]) # Load checkpoint from original device
trainer.test()
trainer_resume_dict[cross_device] = trainer
Loading