Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 30, 2024
1 parent 8475ff7 commit 4a029a2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
6 changes: 4 additions & 2 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
reason="Requires capability >= 8.9 and torchao",
),
pytest.mark.parametrize("bias", (True, False))
pytest.mark.parametrize("bias", (True, False)),
),
)
def test_torchao_float8_linear(executor, device, dtype, bias):
Expand Down Expand Up @@ -294,7 +294,9 @@ def test_torchao_float8_linear(executor, device, dtype, bias):
torch.testing.assert_close(actual, expected)
return

if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor):
if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (
not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor
):
pytest.xfail("numerical error")
torch.testing.assert_close(actual, expected)

Expand Down
33 changes: 19 additions & 14 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,23 +254,26 @@ def __post_init__(self) -> None:
if len(self.computation_trace.bound_symbols) > 6:
maybe_unpack_C0_bsym = self.computation_trace.bound_symbols[4]
maybe_unpack_C1_bsym = self.computation_trace.bound_symbols[5]
is_backward_trace = maybe_unpack_C0_bsym.args and maybe_unpack_C1_bsym.args and (
maybe_unpack_C0_bsym.sym.id,
maybe_unpack_C1_bsym.sym.id,
getattr(maybe_unpack_C0_bsym.args[0], "name", ""),
getattr(maybe_unpack_C1_bsym.args[0], "name", ""),
) == (
prims.PrimIDs.UNPACK_SEQUENCE,
prims.PrimIDs.UNPACK_SEQUENCE,
"C0",
"C1",
is_backward_trace = (
maybe_unpack_C0_bsym.args
and maybe_unpack_C1_bsym.args
and (
maybe_unpack_C0_bsym.sym.id,
maybe_unpack_C1_bsym.sym.id,
getattr(maybe_unpack_C0_bsym.args[0], "name", ""),
getattr(maybe_unpack_C1_bsym.args[0], "name", ""),
)
== (
prims.PrimIDs.UNPACK_SEQUENCE,
prims.PrimIDs.UNPACK_SEQUENCE,
"C0",
"C1",
)
)
if is_backward_trace:
self.flat_trace_args, _ = tree_flatten((maybe_unpack_C0_bsym.output, maybe_unpack_C1_bsym.output))
if not is_backward_trace:
self.flat_trace_args, _ = tree_flatten(
(self.computation_trace.args, self.computation_trace.kwargs)
)
self.flat_trace_args, _ = tree_flatten((self.computation_trace.args, self.computation_trace.kwargs))
for arg in self.flat_trace_args:
if isinstance(arg, SubclassTensorProxy):
self.subclass_proxy_to_flatten.add(variableify(arg))
Expand Down Expand Up @@ -679,6 +682,8 @@ def flatten_tensor_subclasses(trace: TraceCtx) -> TraceCtx:

computation_trace_with_subclass_tensor_proxy_output = from_trace(trace)
computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms)
computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance(f"tensor subclasses desugared (took {elapsed_time_millis} milliseconds)"))
computation_trace_with_subclass_tensor_proxy_output.set_provenance(
TraceProvenance(f"tensor subclasses desugared (took {elapsed_time_millis} milliseconds)")
)
warn_tensor_subclass_support()
return computation_trace_with_subclass_tensor_proxy_output

0 comments on commit 4a029a2

Please sign in to comment.