You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Splitter for Thunder as a Dynamo backend should send regions of code under a "no_grad" context manager to Thunder. Currently, it chooses to send these computations to Inductor.
fromthunder.dynamoimportThunderCompilerimportthunderimporttorchdeff(x):
withtorch.no_grad():
returnx*xjit_f=thunder.jit(f)
backend=ThunderCompiler()
compile_f=torch.compile(backend=backend)(f)
x=torch.randn(3, 3, requires_grad=True)
out=jit_f(x) # Works with thunder.jit with a warningout_1=compile_f(x) # Works with torch.compile but sends the computation to the Inductor instead of Thunderprint(backend.subgraph_infos[0].split_graph_module.print_readable())
prints:
classGraphModule(torch.nn.Module):
defforward(self, l_x_: "f32[3, 3]"):
# No stacktrace found for following nodesinductor_1=self.inductor_1(l_x_); l_x_=Nonereturn (inductor_1,)
classinductor_1(torch.nn.Module):
defforward(self, l_x_: "f32[3, 3]"):
# No stacktrace found for following nodes_set_grad_enabled=torch._C._set_grad_enabled(False); _set_grad_enabled=None# File: <ipython-input-3-f5a278fc59c9>:7 in f, code: return x * xmul: "f32[3, 3]"=l_x_*l_x_; l_x_=None# No stacktrace found for following nodes_set_grad_enabled_1=torch._C._set_grad_enabled(True); _set_grad_enabled_1=Nonereturnmul
🐛 Bug
Splitter for Thunder as a Dynamo backend should send regions of code under a "no_grad" context manager to Thunder. Currently, it chooses to send these computations to Inductor.
prints:
HF's Qwen 2 model added in #1406 creates a small Inductor region because
Qwen2RotaryEmbedding.forward
is decorated with atorch.no_grad
.The text was updated successfully, but these errors were encountered: