Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX fails with FP8 and Activation Checkpointing #1424

Open
mpatel31415 opened this issue Nov 12, 2024 · 3 comments · May be fixed by #1473
Open

ThunderFX fails with FP8 and Activation Checkpointing #1424

mpatel31415 opened this issue Nov 12, 2024 · 3 comments · May be fixed by #1473
Assignees
Labels
mixology Issues that the mixology team has surfaced thunderfx for things that could be applicable to the dynamo+thunder frontend TransformerEngine

Comments

@mpatel31415
Copy link
Contributor

mpatel31415 commented Nov 12, 2024

🐛 Bug

When training models: 'vicuna-7b-v1.5-16k', 'longchat-13b-16k', 'Mistral-7B-v0.2', 'falcon-180B', 'Llama-3-70B', 'CodeLlama-34b-hf' with FSDP and FP8 we get KeyError: 'scaling_fwd'. This might be also issue with Transformer Engine,, so I'm happy to move this issue to TE if needed.

Full traceback:

[rank7]: Traceback (most recent call last):
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 974, in
7: [rank7]: CLI(benchmark_main)
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 96, in CLI
7: [rank7]: return _run_component(components, init)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/jsonargparse/_cli.py", line 204, in _run_component
7: [rank7]: return component(**cfg)
7: [rank7]: ^^^^^^^^^^^^^^^^
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 871, in benchmark_main
7: [rank7]: benchmark.train()
7: [rank7]: File "/workspace/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 765, in train
7: [rank7]: loss.backward()
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 624, in backward
7: [rank7]: torch.autograd.backward(
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/init.py", line 347, in backward
7: [rank7]: _engine_run_backward(
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
7: [rank7]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 307, in apply
7: [rank7]: return user_fn(self, *args)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 600, in wrapper
7: [rank7]: outputs = fn(ctx, *args)
7: [rank7]: ^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 115, in backward
7: [rank7]: grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
7: [rank7]: return func(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "thunder.backward_fn_13", line 28, in backward_fn
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
7: [rank7]: return self.call_impl(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in call_impl
7: [rank7]: return forward_call(*args, **kwargs)
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 205, in forward
7: [rank7]: weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat(
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 273, in get_fp8_weight_version_compat
7: [rank7]: weight_fp8 = self.get_fp8_workspace(
7: [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/module/base.py", line 1086, in get_fp8_workspace
7: [rank7]: out.quantize
(tensor, noop_flag=skip_update_flag)
7: [rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/tensor/float8_tensor.py", line 642, in quantize

7: [rank7]: fp8_meta = dst._fp8_meta[fp8_meta_key]
7: [rank7]: ~~~~~~~~~~~~~^^^^^^^^^^^^^^
7: [rank7]: KeyError: 'scaling_fwd'

To Reproduce

Please use:
1 node(s), each with 8 GPUs.
Image "INTERNAL_IMAGE:pjnl_20241107"
Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mistral-7B-v0.2
--distributed_mode fsdp
--shard_mode zero2
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode fp8-delayed-te
--micro_batch_size 1

Environment

system.device_product_name DGXH100
system.gpu_driver_version 535.129.03
libraries.cuda 12.6.98.001
libraries.pip.lightning 2.4.0.dev20240728
libraries.pip.lightning-thunder 0.2.0.dev0
libraries.pip.lightning-utilities 0.11.8
libraries.pip.litgpt 0.4.11
libraries.pip.nvfuser 0.2.22+gitba4f7d4
libraries.pip.pytorch-lightning 2.4.0
libraries.pip.torch 2.6.0a0+gita9b4989
libraries.pip.torchao 0.6.1
libraries.pip.torchmetrics 1.5.1
libraries.pip.torchvision 0.19.0a0+d23a6e1

@mpatel31415 mpatel31415 changed the title Dynamo + Thunder fails with FP8 ThunderFX fails with FP8 Nov 12, 2024
@IvanYashchuk IvanYashchuk added TransformerEngine mixology Issues that the mixology team has surfaced labels Nov 12, 2024
@tfogal tfogal added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Nov 15, 2024
@kshitij12345
Copy link
Collaborator

kshitij12345 commented Nov 19, 2024

This seems to be happening due to interaction of TransformerEngine and checkpointing.

Minimal Repro

import torch
import torch.utils.checkpoint

def checkpointed_fn(x):
    y = x.cos()
    return torch.nn.functional.linear(x, y)

def fn(x):
    return torch.utils.checkpoint.checkpoint(checkpointed_fn, x)

from thunder.dynamo import ThunderCompiler
from thunder.executors.transformer_engineex import transformer_engine_ex
import thunder

backend = ThunderCompiler(executors=[transformer_engine_ex,])
x = torch.randn(16, 16, device='cuda', requires_grad=True)
o = torch.compile(fn, backend=backend)(x)

assert len(backend.subgraph_infos) == 1
subgraph_info = backend.subgraph_infos[0]
tfn = subgraph_info.thunder_compiled_fns[0]
print(thunder.last_traces(tfn)[-1])
print(thunder.last_backward_traces(tfn)[-1])

o.sum().backward()  # KeyError: 'scaling_fwd'

This happens because in the forward we are calling torch.nn.functional.linear but in the backward, we are calling te_functional_linear_backward (without ever calling the TE's forward).

Forward Graph

def computation(l_x_):
  # l_x_: "cuda:0 f32[16, 16]"
  t4 = torch.cos(l_x_)  # t4: "cuda:0 f32[16, 16]"
    # t4 = ltorch.cos(l_x_)  # t4: "cuda:0 f32[16, 16]"
      # t4 = prims.cos(l_x_)  # t4: "cuda:0 f32[16, 16]"
  getitem = torch.nn.functional.linear(l_x_, t4, None)  # getitem: "cuda:0 f32[16, 16]"
    # getitem = ltorch.linear(l_x_, t4, None)  # getitem: "cuda:0 f32[16, 16]"
      # getitem = prims.linear(l_x_, t4, None)  # getitem: "cuda:0 f32[16, 16]"
  del t4
  return {'output': getitem, 'flat_args': [l_x_], 'flat_output': (getitem,)}, ((l_x_,), ())

Backward Graph

def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  l_x_, = C0
  clear_mutable_collection(C0)
  del C0
  t6 = torch.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
    # t6 = ltorch.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
      # t6 = prims.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
  (_, (t10, t11, t12, t13, t14, _), ctx_te_1) = te_linear_0(l_x_, t6, None)
  del t6
  (t19, t20, _) = te_functional_linear_backward((16, 16), (16, 16), None, ctx_te_1, (t10, t11, t12, t13, t14, None), t0)
  del ctx_te_1, t10, t11, t12, t13, t14, t0
  t21 = torch.sin(l_x_)  # t21: "cuda:0 f32[16, 16]"
    # t21 = ltorch.sin(l_x_)  # t21: "cuda:0 f32[16, 16]"
      # t21 = prims.sin(l_x_)  # t21: "cuda:0 f32[16, 16]"
  del l_x_
  t22 = torch.neg(t21)  # t22: "cuda:0 f32[16, 16]"
    # t22 = ltorch.neg(t21)  # t22: "cuda:0 f32[16, 16]"
      # t22 = prims.neg(t21)  # t22: "cuda:0 f32[16, 16]"
  del t21
  t23 = torch.mul(t20, t22)  # t23: "cuda:0 f32[16, 16]"
    # t23 = ltorch.mul(t20, t22)  # t23: "cuda:0 f32[16, 16]"
      # t23 = prims.mul(t20, t22)  # t23: "cuda:0 f32[16, 16]"
  del t20, t22
  t24 = torch.add(t19, t23)  # t24: "cuda:0 f32[16, 16]"
    # t24 = ltorch.add(t19, t23, alpha=1)  # t24: "cuda:0 f32[16, 16]"
      # t24 = prims.add(t19, t23)  # t24: "cuda:0 f32[16, 16]"
  del t19, t23
  te_sync_fp8_meta_bwd()
  return (t24,)

@kiya00 do you know why this could be happening? Thanks!

@kiya00
Copy link
Collaborator

kiya00 commented Nov 19, 2024

This happens because in the forward we are calling torch.nn.functional.linear but in the backward, we are calling te_functional_linear_backward (without ever calling the TE's forward).

@register_backward(
"activation_checkpoint",
)
def _backward_checkpoint(
function,
args,
kwargs,
*grad_outputs,
) -> tuple[None | TensorLike, ...]:
from thunder.core.transforms import vjp
_, grads = vjp(function)(args, grad_outputs, **kwargs)
return grads

checkpointing uses vjp, is the te_linear_0 in the backward trace the original torch.nn.functional.linear ?

@kiya00
Copy link
Collaborator

kiya00 commented Nov 19, 2024

result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)

the input trace is:

@torch.no_grad()
@no_autocast
def flat_func(*flat_args):
  # flat_args: "Collection"
  t0, = flat_args
  t1 = ltorch.cos(t0)  # t1: "cuda:0 f32[16, 16]"
    # t1 = prims.cos(t0)  # t1: "cuda:0 f32[16, 16]"
  t2 = ltorch.linear(t0, t1, None)  # t2: "cuda:0 f32[16, 16]"
    # t2 = prims.linear(t0, t1, None)  # t2: "cuda:0 f32[16, 16]"
  return (t2,)

and after L2819, it seems the linear becomes te_linear_0 in the backward_fn:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, = saved_for_backward
  t0, = cotangents
  _torch_fx_graph_module_GraphModule___new___<locals>_GraphModuleImpl_0, C1, C2, \
  = C0
  l_x_, = C1
  # C2 (empty dict)
  t6 = prims.cos(l_x_)  # t6: "cuda:0 f32[16, 16]"
  t7 = ltorch.view(t0, (-1, 16))  # t7: "cuda:0 f32[16, 16]"
    # t7 = ltorch.reshape(t0, (-1, 16))  # t7: "cuda:0 f32[16, 16]"
      # t7 = prims.reshape(t0, (16, 16))  # t7: "cuda:0 f32[16, 16]"
  _ = ltorch.dim(t7)
  (_, _) = prims.shape(t7)
  (_, _) = prims.shape(t7)
  _ = ltorch.dim(t1)
  (_, _) = prims.shape(t1)
  (_, _) = prims.shape(t1)
  t8 = ltorch.view(a, (-1, 16))  # t8: "cuda:0 f32[16, 16]"
    # t8 = ltorch.reshape(a, (-1, 16))  # t8: "cuda:0 f32[16, 16]"
      # t8 = prims.reshape(a, (16, 16))  # t8: "cuda:0 f32[16, 16]"
  _ = ltorch.dim(t8)
  (_, _) = prims.shape(t8)
  (_, _) = prims.shape(t8)
  _ = ltorch.dim(w)
  (_, _) = prims.shape(w)
  (_, _) = prims.shape(w)
  t9 = ltorch.view(a, (-1, 16))  # t9: "cuda:0 f32[16, 16]"
    # t9 = ltorch.reshape(a, (-1, 16))  # t9: "cuda:0 f32[16, 16]"
      # t9 = prims.reshape(a, (16, 16))  # t9: "cuda:0 f32[16, 16]"
  _ = ltorch.dim(t9)
  (_, _) = prims.shape(t9)
  (_, _) = prims.shape(t9)
  _ = ltorch.dim(w)
  (_, _) = prims.shape(w)
  (_, _) = prims.shape(w)
  (t15, (t10, t11, t12, t13, t14, _), ctx_te_1) = te_linear_0(l_x_, t6, None)

@kshitij12345 kshitij12345 changed the title ThunderFX fails with FP8 ThunderFX fails with FP8 and Activation Checkpointing Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mixology Issues that the mixology team has surfaced thunderfx for things that could be applicable to the dynamo+thunder frontend TransformerEngine
Projects
None yet
5 participants