Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam optimizer is slower after loading model from checkpoint #19955

Closed
radomirgr opened this issue Jun 7, 2024 · 27 comments · Fixed by #20019 · May be fixed by #20062
Closed

Adam optimizer is slower after loading model from checkpoint #19955

radomirgr opened this issue Jun 7, 2024 · 27 comments · Fixed by #20019 · May be fixed by #20062
Labels
bug Something isn't working help wanted Open to be worked on optimization performance
Milestone

Comments

@radomirgr
Copy link

radomirgr commented Jun 7, 2024

Bug description

When i was resuming my model from training from checkpoint i notice slowness in gpu utilization. I have found problem that adam is doing cuda sync after restoring from checkpoint. It is a problem if you have a lot of optimziers in your network.

Adam implementation is assuming that step component of the state is a cpu tensor. It is assumed here which is executed in adam here

Problem is that lightning is putting all optimizer state to the gpu here

My current workaround is:

    def training_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        print("training_step")
        optimizer = self.optimizers()
        for _, vv in optimizer.state.items():
            if "step" in vv and vv["step"].device.type == "cuda":
                vv["step"] = vv["step"].cpu()

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os
from typing import Any, Tuple

import lightning.pytorch as plight
import lightning.pytorch as pl
import torch
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

num_features = 6875
num_responses = 7
batch_size = 32768


class CachedRandomTensorDataset(torch.utils.data.Dataset):
    """Very low overhead torch dataset for training for a given number of steps"""

    def __init__(self, batch_size: int, num_features: int, num_responses: int, length: int) -> None:
        self.x = torch.randn((batch_size, num_features))
        self.y = torch.randn((batch_size, num_responses))
        self.length = length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.x.clone(), self.y.clone()

    def __len__(self) -> int:
        return self.length


dataset = CachedRandomTensorDataset(
    num_features=num_features,
    num_responses=num_responses,
    length=1013,
    batch_size=batch_size,
)

train_dataloader = DataLoader(dataset, batch_size=None, pin_memory=False, num_workers=0, shuffle=False)


class MLP(nn.Module):

    def __init__(
        self,
        in_dim,
        hidden_dim,
        out_dim,
    ):
        super().__init__()
        self.layers = len(hidden_dim)
        self.LinearClass = nn.Linear
        self.activation_fn = nn.ReLU()
        module_dict = {}
        for i in range(self.layers):
            layer_input_size = in_dim if i == 0 else hidden_dim[i - 1]
            module_dict[f"layer_{i}"] = nn.Linear(layer_input_size, hidden_dim[i])
        module_dict["last_linear"] = nn.Linear(hidden_dim[-1], out_dim)
        self.module_dict = nn.ModuleDict(module_dict)

    def forward(self, x):
        for i in range(self.layers):
            x = self.module_dict[f"layer_{i}"](x)
            x = self.activation_fn(x)
        yhat = self.module_dict["last_linear"](x)
        return yhat


class TestNetwork(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        num_it: int,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.automatic_optimization = False
        self.model = model
        self.mse = nn.MSELoss()
        self.num_it = num_it

    def configure_optimizers(self, name=None):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def training_step(
        self,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        print("training_step")
        optimizer = self.optimizers()

        for _ in range(self.num_it):
            torch.cuda.nvtx.range_push("it step")
            x, y = batch
            yhat = self.model.forward(x)
            loss = self.mse(yhat, y)

            optimizer.zero_grad()
            self.manual_backward(loss)
            torch.cuda.nvtx.range_push("optimizer")
            optimizer.step()
            torch.cuda.nvtx.range_pop()

            torch.cuda.nvtx.range_pop()


train_model = TestNetwork(
    MLP(
        num_features,
        [2048, 1024, 512, 256],
        num_responses,
    ),
    200,
)

trainer_max_steps = 200
checkpoint_name = "debug3"
checkpoint_dir = "./model_checkpoint"
ckpt_path = f"{checkpoint_dir}/{checkpoint_name}-step={trainer_max_steps}.ckpt"

if os.path.isfile(ckpt_path):
    print("training from checkpoint")
    trainer_max_steps = trainer_max_steps + 1
else:
    print("training new model")
    ckpt_path = None


checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_dir,
    save_top_k=10,
    monitor="step",
    mode="max",
    filename=checkpoint_name + "-{step:02d}",
    every_n_train_steps=100,
)


# TRAINER CREATION
trainer = plight.Trainer(
    accelerator="gpu",
    devices=1,
    num_nodes=1,
    max_steps=trainer_max_steps,
    max_epochs=1,
    log_every_n_steps=50,
    logger=[],
    enable_progress_bar=True,
    enable_checkpointing=True,
    enable_model_summary=True,
    num_sanity_val_steps=0,
    check_val_every_n_epoch=None,
    callbacks=[checkpoint_callback],
)

torch.cuda.set_sync_debug_mode(1)

trainer.fit(
    train_model,
    train_dataloader,
    ckpt_path=ckpt_path,
)

Error messages and logs

# Error messages and logs here please

below some nsys traces
image
image

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA A100-SXM4-80GB
    • available: True
    • version: 12.1
  • Lightning:
    • gpytorch: 1.11
    • lightning: 2.2.5
    • lightning-utilities: 0.11.2
    • pytorch-lightning: 2.2.5
    • torch: 2.3.1
    • torchinfo: 1.8.0
    • torchmetrics: 1.3.1
    • torchtyping: 0.1.4
    • torchvision: 0.18.0
    • torchviz: 0.0.2
  • System:

More info

No response

cc @Borda

@radomirgr radomirgr added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 7, 2024
@awaelchli
Copy link
Contributor

Hey @radomirgr
Thanks for the investigation.

Adam implementation is assuming that step component of the state is a cpu tensor. It is assumed here which is executed in adam here

These links might have pointed to an earlier version but now they don't seem to show the place that you meant. Could you show me where in the PyTorch code this assumption is made?

I don't remember exactly why we needed the optimizer_to_device function.

@radomirgr
Copy link
Author

Here are screen screenshots:

image
image

optimizer_to_device is needed as torch don't have .to(device) method and you need to put optimizer state in the gpu. There is an issue for that here: pytorch/pytorch#8741

It might be maybe solved if you add if param._grad is not None: into the code, but not sure

@awaelchli awaelchli added optimization performance help wanted Open to be worked on and removed needs triage Waiting to be triaged by maintainers labels Jun 21, 2024
@janeyx99
Copy link
Contributor

PyTorch intentionally places the scalar Tensors on CPU unless compile/capturable is needed for performance reasons. Executing Python math is faster and more precise than calling into a kernel, and here we want the calculations with step to be fast.

Is there a reason lightning moves everything to GPU?

@corwinjoy
Copy link
Contributor

corwinjoy commented Jun 21, 2024

I can confirm this issue. What happens during a checkpoint is that the optimizer param state is stored (including CPU or GPU location). But then, when lightning reloads the param it forces everything onto the GPU:
pytorch-lightning/src/lightning/fabric/utilities/optimizer.py:32

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    for p, v in optimizer.state.items():
        optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)

This causes a problem because the Adam optimizer explicitly expects 'step' to be on the cpu @janeyx99 :
torch/optim/adam.py:103

                if len(state) == 0:
                    # note(crcrpar): [special device hosting for step]
                    # Deliberately host `step` on CPU if both capturable and fused are off.
                    # This is because kernel launches are costly on CUDA and XLA.
                    state['step'] = (
                        torch.zeros((), dtype=_get_scalar_dtype(is_fused=group['fused']), device=p.device)
                        if group['capturable'] or group['fused']
                        else torch.tensor(0.0, dtype=_get_scalar_dtype())
                    )

When I run the above example code (to resume after a checkpoint) under nvidia nsight I can see that it forces many copies of step from the GPU to the CPU where the algorithm expects it:

'step' parameter on GPU:
nsys profile --stats=true /home/cjoy/src/adam_gpu/.venv/bin/python /home/cjoy/src/adam_gpu/src/test.py

 Time (%)  Total Time (ns)  Count   Avg (ns)    Med (ns)  Min (ns)   Max (ns)   StdDev (ns)            Operation          
 --------  ---------------  -----  -----------  --------  --------  ----------  ------------  ----------------------------
     61.1      133,934,385  4,094     32,714.8   1,344.0       992  18,373,539     698,332.9  [CUDA memcpy Device-to-Host]
     38.0       83,249,648     44  1,892,037.5     607.5       415  67,803,752  10,226,160.4  [CUDA memcpy Host-to-Device]
      0.9        1,964,303  2,000        982.2     991.0       416       1,857         169.2  [CUDA memset]        

I see a total of 4094 copies from the device to the host. In contrast, if after a checkpoint restore we leave 'step' on the CPU we get only 74 copies:

'step' parameter on CPU:
nsys profile --stats=true /home/cjoy/src/adam_gpu/.venv/bin/python /home/cjoy/src/adam_gpu/src/test.py

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count   Avg (ns)    Med (ns)  Min (ns)   Max (ns)   StdDev (ns)            Operation          
 --------  ---------------  -----  -----------  --------  --------  ----------  ------------  ----------------------------
     60.7      131,468,535     74  1,776,601.8   1,488.0       992  18,964,694   5,054,946.2  [CUDA memcpy Device-to-Host]
     38.4       83,193,746     34  2,446,874.9     815.5       416  67,839,898  11,619,734.7  [CUDA memcpy Host-to-Device]
      0.9        1,935,397  2,000        967.7     991.0       415       4,704         186.0  [CUDA memset]               

This large number of transfers doesn't take a long time if you have a monopoly on the device. But, if you are sharing a device all these transfers can be a bottleneck. (These copies are forcing stream synchronization events). Tracing via tensorboard the underlying operation that is forcing this transfer is aten::_local_scalar_dense but I was having trouble getting stack tracing to work to see where this happens in the Adam algorithm. (I guess this is happening during _get_value(step) as mentioned above: https://github.com/pytorch/pytorch/blob/1c75ddff3576a0bd7ed664476c49020c35875ab5/torch/optim/adam.py#L417)

Basically, the pytorch lightning logic that blindly forces params onto the device is incorrect. Different algorithms may have different needs. Essentially, pytorch lightning is messing with the internal state of the model and making incorrect assumptions.

@corwinjoy
Copy link
Contributor

corwinjoy commented Jun 21, 2024

One idea for a fix would be to add special handling based on the optimizer class, but it's a bit ugly.
Replace:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:

With:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    if isinstance(optimizer, Adam):
        # Special logic for Adam optimizer
        # The 'step' parameter needs to remain on the CPU since that is where the optimizer needs it.
        for p, v in optimizer.state.items():
            for key, val in v.items():
                if key != 'step':
                    v[key] = move_data_to_device(val, device)
    else:
        for p, v in optimizer.state.items():
            optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)

A better idea could be to push for optimizers to have a 'to' method to map them onto the device. This has been discussed in torch before along with how awkward it is to map optimizers to devices but the request doesn't seem to have much traction.
pytorch/pytorch#8741

Maybe there is a way to copy construct the optimizer and get the correct device assignments? But, I don't see how to do it.

A third idea could be to look at the params dictionary and see whether the tensor was on the CPU or GPU, but I think that would get very flaky for remappings. E.g. you might start training on a CPU but then resume on a GPU.

As an aside, @radomirgr , a third solution for the Adam optimizer might be to use the Adam parameter fused=True. Then it expects all the params to be on the GPU. In theory I think this idea could work, but when I tried it I still saw a bunch of forced copies from the GPU to CPU and I'm not sure why.

@janeyx99
Copy link
Contributor

janeyx99 commented Jun 24, 2024

I'm coming into this naively, but it looks like an equivalent to the _optimizer_to_device function = calling load_state_dict(..) on a new optimizer where the parameters are on the device. More concretely, to do the following while checkpointing:

...
model.load_state_dict(checkpointed_model)
model.to(device="cuda")  # device could be anything here

# so now all params are on the desired target device

optimizer = torch.optim.AdamW(params, ...)
optimizer.load_state_dict(checkpointed_optim)

# this should correctly set up step on CPU and move the proper state to CUDA

Is there a reason the above would not be viable?

Tangentially, using fused=True would bypass this problem as it expects the step to be on CUDA, so @corwinjoy I am surprised to find that there are still forced copies from GPU to CPU. Are you on the latest torch nightly or an older version? Maybe these syncs have to do with the LRScheduler/lr.

@corwinjoy
Copy link
Contributor

@janeyx99 So, as I understand it, the reason for the function _optimizer_to_device is that after checkpointing we may need to resume on a different device. So, we may start training on the CPU but then want to resume on the GPU. Or, we might start training on GPU0 but then need to resume on GPU1. So, this function supports remapping the device, as I understand it. In the main load from checkpoint function I actually think it does optimizer.load_state_dict(checkpointed_optim) but then later does this remapping. (The remapping is needed because the tensor locations in the checkpoint may not be where we need the tensors to be.)

In addition, I also agree with you that fused=True should bypass this problem, but it doesn't in the version of torch I am using. Here I am using the most recent from PyPI, torch==2.3.1. I'm not quite sure why the extra copies are happening since tensorboard stack generation seems to be broken in the latest version of Torch so I am not quite sure how to trace it.

Anyway, so that's why _optimizer_to_device exists, as I understand it. Therefore, it needs to be able to do device remapping more intelligently.

@janeyx99
Copy link
Contributor

Yes, I understand the need to load on distinct devices, but my code snippet should still work for that. As long as one creates an optimizer referencing parameters that are on the desired device (CUDA1 or CUDA or even CPU), load_state_dict should automatically move that state to the corresponding device. The code that does that is https://github.com/pytorch/pytorch/blob/main/torch/optim/optimizer.py#L727.

It feels that doing both a checkpoint + then a move is redundant.

For the fused=True still having copies--once you get more details, please feel free to open an issue in pytorch/pytorch!

@corwinjoy
Copy link
Contributor

@janeyx99 Thanks! That's actually an interesting idea. I think my caveat here is that we cannot create the optimizer directly since we (generically) have only the base Optimizer class (and the detailed class is loaded via pickle). But I think we could use your idea (something) like this:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    sd = optimizer.state_dict()
    for p, v in sd.items():
        sd[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True)
        
    # Special logic, use load_state_dict method which can correctly migrate the desired tensor device state
   optimizer.load_state_dict(sd)

What do you think? Unless you had some other way to do this? If so, I would like to see that code since I don't understand how that would work generically.

@corwinjoy
Copy link
Contributor

OK. Doing further testing, unfortunately, the idea of using load_state_dict does not work. The special logic in there merely leaves the 'step' parameter as-is if we are not using 'fused=True'. So, no matter what, it seems we have to add special logic for the 'step' parameter to this routine. I have put in a PR to do this (#20019) in the simplest way I could and added a link to the related PyTorch issue.

@janeyx99
Copy link
Contributor

Hm, maybe I am not understanding the use case correctly. I thought the optimizer_to_device function attempts to move all the states of the optimizer to the device that the parameters are. So if the desired device is demarcated as DEVICE, what I would expect when calling _optimizer_to_device(optimizer, DEVICE) is that every state in optimizer except step should go on DEVICE. step will be left on the previous device, which, in your use case, should be CPU.

Here is an explicit way to rewrite the optimizer_to_device function, but I am confused how the input optimizer is already but incorrectly populated:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device."""
    mismatching_sd = optimizer.state_dict()
    params = mismatching_sd.keys()    # is it correct to assume that these are already on DEVICE?
    optimizer_with_matching_state = optimizer.__class__(params)
    optimizer_with_matching_state.load_state_dict(mismatching_sd)    # this should move the mismatching state to DEVICE without touching step
     
    # load state back into the original optimizer
    optimizer.load_state_dict(optimizer_with_matching_state.state_dict())

So the above should work with any optimizer generically, but it is very roundabout because it is confusing to me why there is an optimizer input with mismatching state in the first place.

Instead, what I would expect in a use case is for the optimizer to be correctly loaded during checkpointing through load_state_dict, without needing this move to device function at all. The code for that would look more like my previous comment.

@corwinjoy
Copy link
Contributor

@janeyx99 I'm still a bit new to all this, but here is what I see in the stack trace when debugging a restore from checkpoint (as per the above code). You have to look at the second call to _optimizer_to_device because the first is not used.

Stack:
_optimizer_to_device, optimizer.py:32
load_optimizer_state_dict, strategy.py:377
restore_optimizers, checkpoint_connector.py:383
restore_optimizers_and_schedulers, checkpoint_connector.py:368
restore_training_state, checkpoint_connector.py:298
_run, trainer.py:977
_fit_impl, trainer.py:579
_call_and_handle_interrupt, call.py:47
fit, trainer.py:543
<module>, test.py:163

...
    def restore_optimizers(self) -> None:
        """Restores the optimizer states from the pre-loaded checkpoint."""
        if not self._loaded_checkpoint:
            return

        # restore the optimizers
        self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)

    def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        optimizer_states = checkpoint["optimizer_states"]
        for optimizer, opt_state in zip(self.optimizers, optimizer_states):
            optimizer.load_state_dict(opt_state)
            _optimizer_to_device(optimizer, self.root_device)

    def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    ...

Looking at the tensors from the checkpoint they do have the right locations before _optimizer_to_device is called. That is, from the pickle, step is on the CPU and the other entries in optimizer.state.parameter are on the CPU. But, looking at the function load_optimizer_state_dict in strategy.py there is a potential remapping that could happen based on the device strategy (here seen as root.device). So, e.g., training may be started on a CPU only machine but then we may want to resume on a GPU enabled device. So, I believe that the point of _optimizer_to_device is to be able to move an optimizer onto a device. For the code you give above, I don't understand it. For example, I don't see how it would be able to move an optimizer that originally has all CPU tensors to an optimizer with (some) GPU tensors.

Also, knowing that optimizer.load_state_dict has some move capability, maybe the correct thing to do here is rewrite things at the higher level. That is, rewrite load_optimizer_state_dict.

@awaelchli
Copy link
Contributor

Hey everyone
Great discussion!
I also want to leave a couple remarks.

  1. Not sure if you've found this, but here is the original PR where I added this function (it was named differently 3 years ago): Fix Adagrad optimizer not working with DDP/GPU #7277. It's entirely possible that it was just a naive thing to do in the first place. But this should give some context as to why we thought it was needed.

  2. Beyond that first point, if we're not 100% convinced what this function is there for, a simple approach could be to remove all calls to it in the code base, submit a PR to the repo and then we'll let the entire test suite run. This will make certain tests fail and then we can understand the edge cases.

  3. I quickly looked and I found that Fabric (under src/lightning/fabric) does not use this function at all. One thing to try would be to replicate your minimal repro (nice that you provided this, thanks!) using Lightning Fabric instead of Trainer to show that loading is happening as expected without performance regression. And that resuming a cpu-trained checkpoint on GPU or vice versa should work as expected.

@janeyx99
Copy link
Contributor

Ah, thanks @awaelchli and @corwinjoy for the context.

I see the original problem this function sought to solve was that the model parameters shifted under a created optimizer, causing the mismatch in devices for parameter and optimizer state. Here, the solution should not be to move the optimizer, but to wait til the model has been moved to its final location and then to create the optimizer. If that's not possible, reloading the state dict into a new optimizer with the final parameters would also work. I would suggest the cleaner solution of maintaining the invariant that the optimizer should be created after the model is done being modified, to ensure that the latest parameters are what get optimized. Without this invariant, it's easy to get into a wild goose chase of problems like this that crop up due to mismatch.

@corwinjoy The reason it works is because load_state_dict will move state to match the parameter that is passed into the optimizer--there is already code in there to cast/move state appropriately for each optimizer, so the work should not need to be duplicated. Feel free to follow up if you have more questions--I am increasingly convinced that the spot-solution of patching the function for this issue is at best only a temporary one.

@corwinjoy
Copy link
Contributor

In order to move the discussion forward, I have created a PR where this function is simply disabled to see what tests fail. It is at #20036

Before we move forward, I believe we should agree what the behavior here should be. I think that the test for the function in tests/tests_fabric/utilities/test_optimizer.py should be changed to create the optimizer on the CPU and then on the GPU. Then the parameters should be compared to check that move CPU->GPU gives the same locations as creating the optimizer on the GPU.

@janeyx99
I would love to see a fix using load_state_dict but I don't see how to do it for two reasons:

  1. load_state_dict matches the device that is in the parameters section of the optimizer, not the model location. So just reloading the model onto the GPU does not help us, load_state_dict will continue to use the last location the optimizer (not the model) was running.
  2. I feel that rebuilding the optimizer from the latest parameter state extracted from the model will also not work. The optimizer contains a history of the parameters tried. For many optimizers (such as a Bayesian optimizer) this history is quite important and only using the latest state will not work correctly.

But, I would be happy to be proven wrong. If there is a clean way to use load_state_dict I would like to see it. But, I think it will be tricky. Maybe use the model as a prototype?

@janeyx99
Copy link
Contributor

janeyx99 commented Jul 2, 2024

I don't understand point 2 -> why would using the latest state not work correctly/be different from the current implementation? when one reloads from checkpointing, one has to start with a fresh optimizer instance, no? I still think the strongest sturdiest solution to push for is to ensure that the optimizer state is generated with the latest parameters.

@corwinjoy
Copy link
Contributor

Because the optimizer doesn't just hold a single set of parameters. Instead, it holds an array of parameters indicating model parameters that were tried. So, the optimizer restore holds an array of parameters that need to be converted.

@janeyx99
Copy link
Contributor

janeyx99 commented Jul 8, 2024

Is there any reason the optimizer needs to hold these parameters for some time before moving them? Regardless the parameters are only referenced in the optimizer as the optimizer object only holds Tensors for the state. In this use case, it looks like the old state on the checkpointed device is never used, so it should really never be created. Instead, my understanding is that state should only be created for the parameters that matter, which would be the latest set of parameters.

Regardless, I think your approach of removing all accesses to the _optimizer_to_device function is a good place to start--then we can talk about actual problems we don't expect in the code.

@corwinjoy
Copy link
Contributor

OK. To answer these questions I have submitted the following PR with an improved test for _optimizer_to_device.
#20062

  1. Looking at the tests from simply disabling _optimizer_to_device we don't see a lot of errors. But, I think that is incorrect. I think existing tests don't properly test restoring a checkpoint and resuming on a different device.
  2. In the above PR I have added test cases for going from CPU-->GPU and GPU-->CPU both with and without the fused=True flag to more carefully delineate the cases.
  3. I feel like the best approach may be to specialize _optimizer_to_device even if some special logic is required. This is what is shown in the above PR.
  4. Using Optimizer.load_state_dict() could work in theory, and I have some test cases for this, but I see two problems:
    a. To use it, you would need a prototype of the Optimizer on the target device. This prototype needs to contain params on the correct target device as previously noted in _process_value_according_to_param_policy. (https://github.com/pytorch/pytorch/blob/e836ee19554bec725810ab682e51a64d8869fcf0/torch/optim/optimizer.py#L727).
    I don't see how you create this cleanly since all you may have is a pickle from the source device.
    b. It seems that Optimzer.load_state_dict() may be broken when loading the fused=True case, see the above PR where it tries to allow fused=True on the CPU which is incorrect.

I'm hoping this code makes it clearer what is going on and where the issue lies.
Also, @janeyx99, in terms of the optimizer, the history I am talking about in the optimizer is the parameter 'state' history. So, e.g. from optimizer.state_dict() in the above PR:

{'state': 
   {0: {'step': tensor(256.), 'exp_avg': tensor(...), 'exp_avg_sq': tensor(...)}, 
    1: {'step': tensor(256.), 'exp_avg': tensor(...), 'exp_avg_sq': tensor(...)}},
 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, '...}]} 

Also see the docs for state_dict() (https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.state_dict.html#torch.optim.Optimizer.state_dict)
We should assume that the state may contain history over a number of steps that the optimizer may need. So, it is not just a matter of only using the latest state / set of parameters. Instead, every element in the Optimizer.state array needs to go to the correct device to facilitate optimizer calculations. Here the problem shown in this bug report is that the current function places Optimizer.state[0..1]['step'] on the GPU rather than the CPU which causes problems with Adam and other algorithms.

@awaelchli
Copy link
Contributor

I think existing tests don't properly test restoring a checkpoint and resuming on a different device.

We should have them, it's just probably hard to find them haha. But we can also add such a test.

So you'd suggest not removing the function? That's ok with me if the PyTorch's load_state_dict is not enough. After this, we should probably try to push for the optimizer.to() feature request (or something similar) on the PyTorch side as mentioned before.

@awaelchli awaelchli added this to the 2.4 milestone Jul 16, 2024
@janeyx99
Copy link
Contributor

  1. In the above PR I have added test cases for going from CPU-->GPU and GPU-->CPU both with and without the fused=True flag to more carefully delineate the cases.

I commented on the PR above as well, but here's how we test this use case in PyTorch, in case it helps: https://github.com/pytorch/pytorch/blob/main/test/test_optim.py#L1545-L1574

  1. Using Optimizer.load_state_dict() could work in theory, and I have some test cases for this, but I see two problems:
    a. To use it, you would need a prototype of the Optimizer on the target device. This prototype needs to contain params on the correct target device as previously noted in _process_value_according_to_param_policy. (pytorch/pytorch@e836ee1/torch/optim/optimizer.py#L727).
    I don't see how you create this cleanly since all you may have is a pickle from the source device.

Here is how I imagine checkpointing should go:

  1. Load the pickle for the nn module (on CPU)
  2. Load the model state dict into a model.
  3. Move the model to desired device (say GPU)
  4. Now create an optimizer with model.parameters(), no earlier!
  5. Load the pickle for the optim (on CPU)
  6. Call optim.load_state_dict(optim_pickled_dict) => this should move the CPU state (what you called "history") properly to GPU.

b. It seems that Optimzer.load_state_dict() may be broken when loading the fused=True case, see the above PR where it tries to allow fused=True on the CPU which is incorrect.

Could you link me to the specific part that is broken? is it a test failure or something in the code? By the way, we have fused adam, adamw, sgd, and adagrad now on CPU! So that could be related.

I'm hoping this code makes it clearer what is going on and where the issue lies. Also, @janeyx99, in terms of the optimizer, the history I am talking about in the optimizer is the parameter 'state' history. So, e.g. from optimizer.state_dict() in the above PR:

{'state': 
   {0: {'step': tensor(256.), 'exp_avg': tensor(...), 'exp_avg_sq': tensor(...)}, 
    1: {'step': tensor(256.), 'exp_avg': tensor(...), 'exp_avg_sq': tensor(...)}},
 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, '...}]} 

Also see the docs for state_dict() (pytorch.org/docs/stable/generated/torch.optim.Optimizer.state_dict.html) We should assume that the state may contain history over a number of steps that the optimizer may need. So, it is not just a matter of only using the latest state / set of parameters. Instead, every element in the Optimizer.state array needs to go to the correct device to facilitate optimizer calculations. Here the problem shown in this bug report is that the current function places Optimizer.state[0..1]['step'] on the GPU rather than the CPU which causes problems with Adam and other algorithms.

Ah yes, thank you for clarifying that by history you mean the param state! I agree! This is precisely what load_state_dict should handle.

@janeyx99
Copy link
Contributor

@corwinjoy @awaelchli I still believe load_state_dict should be sufficient for this use case. I've tried addressing the concerns above--please point me to where specialization is needed beyond the use case I delineated above. Thank you both for the detailed discussion.

@corwinjoy
Copy link
Contributor

@awaelchli @janeyx99 OK. I have done further investigation and added additional comments + tests to #20062
These include specific tests for _optimizer_to_device behavior as well as confirming that checkpoints are correctly moved between devices. Although, as @janeyx99 points out some of these may be redundant, but I'm not really sure about that.
I think it is now closer to a PR we can use. It took me a while to trace and understand the pytorch-lightning logic but I now see that:

  1. The way that checkpoint restores work is they do in fact create an optimizer prototype on the device and then call Optimizer.load_state_dict() to transfer the state to the device. So, for construction, a manual transfer may be unnecessary.

But, just eliminating this function does create a couple potential problems that I am not sure I understand and would like a review on.

  1. This function is also called by Strategy.teardown(). The idea is to transfer the optimizer back from the GPU to the CPU. Maybe we don't really care about this since when the fit is complete either the optimizer will be checkpointed or discarded. The questions is if users really depend on this final transfer behavior but I would guess not.
  2. In the tests I added, test_load_state_dict, it is possible to use a [GPU, fused=True] state dictionary to populate a CPU optimizer with the value of fused=True. This is technically an invalid object state and the constructor will reject this but load_state_dict allows it. This is probably a bug but I guess it works in our favor since then checkpoints with fused=True can more easily be moved CPU<-->GPU. But, it could potentially cause erroneous behavior.

Anyway, I guess we could replace this function with a no-op with some comments explaining the behavior because I think it is rather non-obvious. Here is what I have in the PR right now for _optimizer_to_device. This explains what I said above in more detail:

def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
    """Moves the state of a single optimizer to the device.

    In fact, it looks like we dont need this function but can rely on optimizer.load_state_dict to do the right thing
    after given a correct prototype on the target device. For now we do nothing and assume that we don't care about
    transferring the optimizer back to the CPU on teardown. See details below.

    """
    pass

    # To test for correct behaviour here we have created two tests:
    # 1. tests/tests_fabric/utilities/test_optimizer.py to test Optimizer.load_state_dict with a prototype
    # 2. tests/tests_pytorch/checkpointing/test_trainer_move_device.py to test higher level checkpointing on
    #    one device and resuming on a different device

    # Details on how this function is called.
    # 1st call is in Strategy.setup(), to initialize empty optimizer. src/lightning/pytorch/strategies/strategy.py: 158
    # Note: Strategy.setup() first calls Strategy.setup_optimizers which eventually invokes Model.configure_optimizers()
    # based on a model that has been moved to the device. Thus it essentially creates a prototype optimizer on the
    # target device and then, eventually, relies on Optimizer.load_state_dict() to transfer the state.
    # 2nd call when restoring checkpoint, as part of Strategy.load_optimizer_state_dict(). Source strategy.py: 377
    # Final call in Strategy.teardown(), move optimizer back to CPU. Source strategy.py: 525

@corwinjoy
Copy link
Contributor

Eventually, maybe we can rip out _optimizer_to_device but I would favor a more conservative approach at first.

@awaelchli
Copy link
Contributor

awaelchli commented Jul 17, 2024

This function is also called by Strategy.teardown(). The idea is to transfer the optimizer back from the GPU to the CPU. Maybe we don't really care about this since when the fit is complete either the optimizer will be checkpointed or discarded. The questions is if users really depend on this final transfer behavior but I would guess not.

There was a desire in the past to have trainer leave no memory behind and shut down cleanly. Such that other workflows after fit() could use all memory. I think we can still keep that separate from what you are fixing.

I'm ok with the plan of first adding the step-specific fix, and then prefer also a removal of this long-term.
I'll set aside time to review your PR.

A big no from me about leaving the optimizer_to_device as a no-op just for documentation sake. These comments are great but could go elsewhere and be encapsulated by tests :)

@awaelchli
Copy link
Contributor

I merged the fix and opened #20165 so we can work on removing the function in the future.

@corwinjoy
Copy link
Contributor

@awaelchli - thanks so much for the improved and very nice tests! I think this helps clarify the behavior we want. Also, thanks for merging the interim fix so we can see improved performance as we work to removing the function. Also FYI @radomirgr .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on optimization performance
Projects
None yet
4 participants