Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX splitter should send no_grad regions to Thunder #1420

Open
IvanYashchuk opened this issue Nov 11, 2024 · 1 comment · May be fixed by #1463
Open

ThunderFX splitter should send no_grad regions to Thunder #1420

IvanYashchuk opened this issue Nov 11, 2024 · 1 comment · May be fixed by #1463
Assignees
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Nov 11, 2024

🐛 Bug

Splitter for Thunder as a Dynamo backend should send regions of code under a "no_grad" context manager to Thunder. Currently, it chooses to send these computations to Inductor.

from thunder.dynamo import ThunderCompiler
import thunder
import torch

def f(x):
    with torch.no_grad():
        return x * x

jit_f = thunder.jit(f)
backend = ThunderCompiler()
compile_f = torch.compile(backend=backend)(f)

x = torch.randn(3, 3, requires_grad=True)

out = jit_f(x) # Works with thunder.jit with a warning
out_1 = compile_f(x) # Works with torch.compile but sends the computation to the Inductor instead of Thunder

print(backend.subgraph_infos[0].split_graph_module.print_readable())

prints:

class GraphModule(torch.nn.Module):
    def forward(self, l_x_: "f32[3, 3]"):
        # No stacktrace found for following nodes
        inductor_1 = self.inductor_1(l_x_);  l_x_ = None
        return (inductor_1,)
        
    class inductor_1(torch.nn.Module):
        def forward(self, l_x_: "f32[3, 3]"):
            # No stacktrace found for following nodes
            _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
            
             # File: <ipython-input-3-f5a278fc59c9>:7 in f, code: return x * x
            mul: "f32[3, 3]" = l_x_ * l_x_;  l_x_ = None
            
            # No stacktrace found for following nodes
            _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
            return mul

HF's Qwen 2 model added in #1406 creates a small Inductor region because Qwen2RotaryEmbedding.forward is decorated with a torch.no_grad.

@IvanYashchuk IvanYashchuk added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Nov 11, 2024
@IvanYashchuk
Copy link
Collaborator Author

Sending no_grad to Thunder was disabled in #1282 for a good reason (#1219). Maybe it's time to properly support the no_grad context in Thunder?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants