-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test optimizer to device #20062
Draft
corwinjoy
wants to merge
4
commits into
Lightning-AI:master
Choose a base branch
from
corwinjoy:test_optimizer_to_device
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Test optimizer to device #20062
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
17bded8
Add tests showing GPU to CPU copies
corwinjoy cdd1556
Merge branch 'Lightning-AI:master' into test_optimizer_to_device
corwinjoy b700bc1
Add explicit test for loading checkpoint and running on new device.
corwinjoy 34c5d97
Add further checkpoint tests and replace _optimizer_to_device with pass
corwinjoy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,131 @@ | ||
import collections | ||
import copy | ||
import dataclasses | ||
from typing import Tuple | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from lightning.fabric.utilities.optimizer import _optimizer_to_device | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def test_optimizer_to_device(): | ||
def test_optimizer_to_device_match_locations(): | ||
@dataclasses.dataclass(frozen=True) | ||
class FooState: | ||
bar: int | ||
|
||
class TestOptimizer(torch.optim.SGD): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.state["dummy"] = torch.tensor(0) | ||
self.state["frozen"] = FooState(0) | ||
|
||
layer = torch.nn.Linear(32, 2) | ||
opt = TestOptimizer(layer.parameters(), lr=0.1) | ||
_optimizer_to_device(opt, "cpu") | ||
if torch.cuda.is_available(): | ||
_optimizer_to_device(opt, "cuda") | ||
assert_opt_parameters_on_device(opt, "cuda") | ||
|
||
|
||
def assert_opt_parameters_on_device(opt, device: str): | ||
for param in opt.state.values(): | ||
# Not sure there are any global tensors in the state dict | ||
if isinstance(param, Tensor): | ||
assert param.data.device.type == device | ||
class CachedRandomTensorDataset(torch.utils.data.Dataset): | ||
"""Very low overhead torch dataset for training for a given number of steps.""" | ||
|
||
def __init__(self, batch_size: int, num_features: int, num_responses: int, length: int, device: str) -> None: | ||
self.x = torch.randn((batch_size, num_features), device=torch.device(device)) | ||
self.y = torch.randn((batch_size, num_responses), device=torch.device(device)) | ||
self.length = length | ||
|
||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
return self.x.clone(), self.y.clone() | ||
|
||
def __len__(self) -> int: | ||
return self.length | ||
|
||
def simple_training(optimizer, model, dataset, loss_fn): | ||
for input, target in dataset: | ||
optimizer.zero_grad() | ||
output = model(input) | ||
loss = loss_fn(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if not torch.cuda.is_available(): | ||
return | ||
|
||
gpu_device = "cuda" | ||
devices = ["cpu", gpu_device] | ||
|
||
num_features = 32 | ||
num_responses = 2 | ||
batch_size = 16 | ||
|
||
optimizer_on_device = {} | ||
for device in devices: | ||
dataset = CachedRandomTensorDataset(batch_size, num_features, num_responses, batch_size * 16, device) | ||
dataloader = DataLoader(dataset, batch_size=None, shuffle=False) | ||
model = torch.nn.Linear(num_features, num_responses) | ||
model.to(device=device) | ||
fused_vals = [False] if device == "cpu" else [False, True] | ||
|
||
for fused in fused_vals: | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, fused=fused) | ||
simple_training(optimizer, model, dataloader, loss_fn=nn.MSELoss()) | ||
if fused: | ||
optimizer_on_device[device + "_fused_" + str(fused)] = optimizer | ||
else: | ||
optimizer_on_device[device] = optimizer | ||
|
||
# Test _optimizer_to_device function | ||
# Test cpu-->gpu, fused = False from CPU | ||
opt_to_gpu = copy.deepcopy(optimizer_on_device["cpu"]) | ||
_optimizer_to_device(opt_to_gpu, gpu_device) | ||
assert_opt_state_in_expected_location(opt_to_gpu, optimizer_on_device[gpu_device]) | ||
|
||
# Test gpu-->cpu, fused = False | ||
opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device]) | ||
_optimizer_to_device(opt_to_cpu, "cpu") | ||
assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"]) | ||
|
||
# Test gpu-->cpu, fused = True | ||
opt_to_cpu = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"]) | ||
_optimizer_to_device(opt_to_cpu, "cpu") | ||
assert_opt_state_in_expected_location(opt_to_cpu, optimizer_on_device["cpu"]) | ||
|
||
# Try from_dict | ||
# These all pretend that we have an appropriate prototype, I don't think we can actually do this since | ||
# all we may have is a CPU pickle | ||
# For now, this is a future idea for _optimizer_to_device | ||
|
||
if False: | ||
# GPU prototypes | ||
# Use from_dict with gpu prototype, fused = False | ||
opt_cpu_dict = optimizer_on_device["cpu"].state_dict() | ||
gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device]) | ||
gpu_prototype.load_state_dict(opt_cpu_dict) | ||
print(opt_cpu_dict) | ||
assert_opt_state_in_expected_location(gpu_prototype, optimizer_on_device[gpu_device]) | ||
|
||
# Use from_dict with gpu prototype, fused = True | ||
opt_cpu_dict = optimizer_on_device["cpu"].state_dict() | ||
gpu_prototype = copy.deepcopy(optimizer_on_device[gpu_device + "_fused_True"]) | ||
gpu_prototype.load_state_dict(opt_cpu_dict) | ||
assert_opt_state_in_expected_location( | ||
gpu_prototype, optimizer_on_device[gpu_device] | ||
) # fused=False from CPU, overrides prototype | ||
|
||
# CPU prototypes | ||
# Use from_dict with cpu prototype, fused = False | ||
opt_gpu_dict = optimizer_on_device[gpu_device].state_dict() | ||
cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"]) | ||
cpu_prototype.load_state_dict(opt_gpu_dict) | ||
assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"]) | ||
|
||
# Use from_dict with cpu prototype, fused = True | ||
opt_gpu_dict = optimizer_on_device[gpu_device + "_fused_True"].state_dict() | ||
cpu_prototype = copy.deepcopy(optimizer_on_device["cpu"]) | ||
cpu_prototype.load_state_dict(opt_gpu_dict) # This should give an error / refuse to allow fused = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, for older versions of torch this should indeed be not allowed/would break. But since torch 2.4, there is a fused CPU Adam(W)/SGD/Adagrad, so fused=True on CPU for these optimizers would be valid. |
||
assert_opt_state_in_expected_location(cpu_prototype, optimizer_on_device["cpu"]) | ||
|
||
|
||
def assert_opt_state_in_expected_location(opt, expected_opt): | ||
opt_dict = opt.state_dict() | ||
expected_opt_dict = expected_opt.state_dict() | ||
for key, param in opt_dict["state"].items(): | ||
if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type: | ||
pytest.fail(f"Optimizer device mismatch for state[{key}]") | ||
elif isinstance(param, collections.abc.Mapping): | ||
for subparam in param.values(): | ||
if isinstance(subparam, Tensor): | ||
assert param.data.device.type == device | ||
for subkey, subparam in param.items(): | ||
if ( | ||
isinstance(subparam, Tensor) | ||
and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type | ||
): | ||
pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]") |
140 changes: 140 additions & 0 deletions
140
tests/tests_pytorch/checkpointing/test_trainer_move_device.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import collections | ||
import copy | ||
from typing import Any | ||
|
||
import lightning.pytorch as pl | ||
import pytest | ||
import torch | ||
from lightning.pytorch import Trainer | ||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint | ||
from lightning.pytorch.demos.boring_classes import BoringModel | ||
from torch import Tensor | ||
|
||
|
||
class TrainerStateChecker(Callback): | ||
def __init__( | ||
self, | ||
optimizer_dict: dict, | ||
model_dict: dict, | ||
capture: bool, | ||
target_device: str, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.optimizer_dict = optimizer_dict | ||
self.model_dict = model_dict | ||
self.capture = capture | ||
self.target_device = target_device | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
if not self.capture: | ||
# Check model and optimizer device locations | ||
assert trainer.model.device == self.model_dict[self.target_device].device | ||
assert_opt_state_in_expected_location(trainer.optimizers[0], self.optimizer_dict[self.target_device]) | ||
|
||
def on_train_end(self, trainer, pl_module): | ||
if self.capture: | ||
# Capture the optimizer state before it is transferred back to the cpu | ||
self.optimizer_dict[self.target_device] = copy.deepcopy(trainer.optimizers[0]) | ||
self.model_dict[self.target_device] = copy.deepcopy(trainer.model) | ||
|
||
|
||
def assert_opt_state_in_expected_location(opt, expected_opt): | ||
opt_dict = opt.state_dict() | ||
expected_opt_dict = expected_opt.state_dict() | ||
for key, param in opt_dict["state"].items(): | ||
if isinstance(param, Tensor) and param.data.device.type != expected_opt_dict["state"][key].device.type: | ||
pytest.fail(f"Optimizer device mismatch for state[{key}]") | ||
elif isinstance(param, collections.abc.Mapping): | ||
for subkey, subparam in param.items(): | ||
if ( | ||
isinstance(subparam, Tensor) | ||
and subparam.data.device.type != expected_opt_dict["state"][key][subkey].device.type | ||
): | ||
pytest.fail(f"Optimizer device mismatch for state[{key}][{subkey}]") | ||
|
||
|
||
def test_change_device(tmpdir): | ||
"""This test validates that a generated ModelCheckpoint can be moved to a different device.""" | ||
|
||
class ExtendedBoringModel(BoringModel): | ||
def __init__( | ||
self, | ||
target_device: str, | ||
**kwargs: Any, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
self.target_device = target_device | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.01) | ||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) | ||
return [optimizer], [lr_scheduler] | ||
|
||
def validation_step(self, batch, batch_idx): | ||
loss = self.step(batch) | ||
self.log("val_loss", loss, on_epoch=True, prog_bar=True) | ||
|
||
devices = ["cpu", "gpu"] | ||
optimizer_dict = {} | ||
model_dict = {} | ||
checkpoint_path = {} | ||
for device in devices: | ||
device_path = tmpdir.mkdir(device) | ||
checkpoint_callback = ModelCheckpoint( | ||
monitor="val_loss", dirpath=device_path, filename="{epoch:02d}", save_top_k=-1 | ||
) | ||
|
||
tsc = TrainerStateChecker( | ||
optimizer_dict=optimizer_dict, model_dict=model_dict, capture=True, target_device=device | ||
) | ||
trainer = Trainer( | ||
accelerator=device, | ||
devices=1, | ||
default_root_dir=device_path, | ||
max_epochs=1, | ||
limit_train_batches=12, | ||
limit_val_batches=6, | ||
limit_test_batches=12, | ||
callbacks=[checkpoint_callback, tsc], | ||
logger=False, | ||
) | ||
model = ExtendedBoringModel(device) | ||
trainer.fit(model) | ||
checkpoint_path[device] = checkpoint_callback.best_model_path | ||
|
||
# cross load from checkpoint | ||
trainer_resume_dict = {} | ||
for device_idx, device in enumerate(devices): | ||
cross_device = devices[(device_idx + 1) % len(devices)] | ||
tsc = TrainerStateChecker( | ||
optimizer_dict=optimizer_dict, model_dict=model_dict, capture=False, target_device=cross_device | ||
) | ||
trainer = pl.Trainer( | ||
accelerator=cross_device, | ||
devices=1, | ||
default_root_dir=tmpdir, | ||
max_epochs=3, | ||
limit_train_batches=12, | ||
limit_val_batches=12, | ||
limit_test_batches=12, | ||
enable_progress_bar=False, | ||
callbacks=tsc, | ||
) | ||
model = ExtendedBoringModel(cross_device) | ||
trainer.fit(model, ckpt_path=checkpoint_path[device]) # Load checkpoint from original device | ||
trainer.test() | ||
trainer_resume_dict[cross_device] = trainer |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test we have for pytorch
load_state_dict
being able to read a CPU checkpoint into an appropriate GPU optimizer is here: https://github.com/pytorch/pytorch/blob/main/test/test_optim.py#L1545-L1574The code above is also how I expect checkpointing to happen, without the need of an explicit move to device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated this test to be more explicit about what is going on, please take a look and see if it makes sense since the test you linked doesn't look at thorough as far as I can tell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, the case we test is moving from CPU to GPU, and I see you test more combinations.