-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FP8 + FSDP2 + torch.compile examples for PyTorch Lightning and Fabric (…
…#20440) * Minimal transformer examples * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add tests for compile after fsdp2/tp * Add README's * Add docs * Rename folder, add cross-reference * Fix link * Newline after code-block directive * Update section name * Fix reference * Half standalone tests batch size * Fix integration tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
333d1cf
commit 87565cb
Showing
12 changed files
with
592 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
## Distributed, Low-Precision Transformer Example | ||
|
||
This example shows how to use `ModelParallelStrategy` in `Fabric` to train a Transformer model minimizing memory usage, maximizing throughput, and distributing load across multiple GPUs. | ||
|
||
### Training Large Models and Memory Requirements | ||
|
||
One of the main challenges when training large models, like large language models (LLMs), is dealing with their memory footprint. LLMs can be so large that weights, activations, gradients and optimizer state don't fit a single GPU, so that they need to be distributed across multiple GPUs, and across multiple machines. There are multiple ways of distributing computations, among which fully-sharded data parallelism (FSDP) and tensor parallelism (TP). | ||
|
||
An additional way of reducing memory requirements is representing floating point numbers in weights and activations in low numerical precision, such as 16-bit (`bfloat16`), or 8-bit (`fp8`). This leads to savings in memory usage, as well as memory bandwidth usage (fewer bytes transferred from device memory to GPU cores in unit time). | ||
|
||
Roughly, reducing precision to `fp8` for linear layers can lead to 2x reduction in memory requirements and 1.6x improvement in throughput. Support for `fp8` weights and activations requires recent GPUs - Hopper, Ada Lovelace and above (e.g. H100, L4, L40). | ||
|
||
The introduction of tensor subclasses in PyTorch brought two new APIs that can be used to achieve memory savings and distributed training (as well as inference) in combination: | ||
|
||
- [torch ao](https://github.com/pytorch/ao) to execute linear layers in low numerical precision (`fp8` and other quantized formats) | ||
- [dtensors](https://pytorch.org/docs/stable/distributed.tensor.html) to distribute models across GPUs, by combining TP and FSDP (referred to FSDP2 in PyTorch) | ||
|
||
Notably, `torch ao` introduces quantization and dequantization operations in the model that may result in slow-downs if not optimized. Using `torch.compile` after `torch ao` recovers performance by generating optimized kernels for those operations. | ||
|
||
### Vanilla Transformer Example | ||
|
||
This example shows how to train a vanilla Transformer model using `fp8` precision and the FSDP2 distributed strategy, and then optimize the resulting model through `torch.compile`. | ||
|
||
Specifically, we employ the `ModelParallelStrategy`, and use the `configure_model` hook to distribute the model using the PyTorch DTensor API. | ||
In the same hook we also pass the model through the `torch ao` API (prior to FSDP2), as well as `torch.compile` (after FSDP2). | ||
|
||
The resulting code follows the PyTorch API closely, while also taking advantage of the rest of PyTorch Lightning. | ||
|
||
To execute the code directly just run: | ||
|
||
```bash | ||
python train.py | ||
``` | ||
|
||
### A Note on torch.compile | ||
|
||
Note that PyTorch Lightning also supports calling `torch.compile` on a `LightningModule` and passing it to the `Trainer`. | ||
|
||
While this works for simple cases, in order to get the most out of the combination of the latest distributed, quantization, and compile PyTorch API's, we recommend invoking `torch.compile` at the end of the `configure_model` hook, as shown in this example. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
torchao>=0.7.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import lightning as L | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from lightning.fabric.strategies import ModelParallelStrategy | ||
from lightning.pytorch.demos import Transformer, WikiText2 | ||
from torch.distributed._composable.fsdp.fully_shard import fully_shard | ||
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.utils.data import DataLoader | ||
from torchao.float8 import Float8LinearConfig, convert_to_float8_training | ||
from tqdm import tqdm | ||
|
||
|
||
def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
float8_config = Float8LinearConfig( | ||
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly # noqa | ||
pad_inner_dim=True, | ||
) | ||
|
||
def module_filter_fn(mod: torch.nn.Module, fqn: str): | ||
# we skip the decoder because it typically vocabulary size | ||
# is not divisible by 16 as required by float8 | ||
return fqn != "decoder" | ||
|
||
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) | ||
|
||
for module in model.modules(): | ||
if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): | ||
fully_shard(module, mesh=device_mesh) | ||
|
||
fully_shard(model, mesh=device_mesh) | ||
|
||
return torch.compile(model) | ||
|
||
|
||
def train(): | ||
L.seed_everything(42) | ||
|
||
batch_size = 8 | ||
micro_batch_size = 1 | ||
|
||
max_steps = 100 | ||
|
||
dataset = WikiText2() | ||
dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size) | ||
|
||
with torch.device("meta"): | ||
model = Transformer( | ||
vocab_size=dataset.vocab_size, | ||
nlayers=16, | ||
nhid=4096, | ||
ninp=1024, | ||
nhead=32, | ||
) | ||
|
||
strategy = ModelParallelStrategy(data_parallel_size=4, tensor_parallel_size=1, parallelize_fn=configure_model) | ||
|
||
fabric = L.Fabric(precision="bf16-true", strategy=strategy) | ||
fabric.launch() | ||
|
||
model = fabric.setup(model) | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | ||
optimizer = fabric.setup_optimizers(optimizer) | ||
|
||
dataloader = fabric.setup_dataloaders(dataloader) | ||
|
||
iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader) | ||
|
||
steps = 0 | ||
|
||
for i, batch in iterable: | ||
input, target = batch | ||
|
||
is_accumulating = i % (batch_size // micro_batch_size) != 0 | ||
|
||
with fabric.no_backward_sync(model, enabled=is_accumulating): | ||
output = model(input, target) | ||
loss = F.nll_loss(output, target.view(-1)) | ||
fabric.backward(loss) | ||
|
||
if not is_accumulating: | ||
fabric.clip_gradients(model, optimizer, max_norm=1.0) | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
steps += 1 | ||
|
||
if fabric.is_global_zero: | ||
iterable.set_postfix_str(f"train_loss={loss.item():.2f}") | ||
|
||
if steps == max_steps: | ||
break | ||
|
||
fabric.print(torch.cuda.memory_summary()) | ||
|
||
|
||
if __name__ == "__main__": | ||
torch.set_float32_matmul_precision("high") | ||
|
||
train() |
Oops, something went wrong.