Skip to content

Commit

Permalink
use more executors
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 28, 2024
1 parent ec05978 commit f582143
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
from torch.utils import _pytree as pytree

import thunder
from thunder.core.proxies import SubclassTensorProxy
from thunder.tests.framework import instantiate
from thunder.tests.framework import instantiate, TorchExecutor, TorchCompileCatExecutor, nvFuserExecutor, DynamoThunderExecutor
from thunder.tests.make_tensor import make_tensor

TORCHAO_AVAILABLE = package_available("torchao")

if TYPE_CHECKING:
from typing import Any
from thunder.core.symbol import BoundSymbol


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -243,6 +241,7 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
@instantiate(
dtypes=(thunder.core.dtypes.float32,),
devicetypes=(thunder.core.devices.DeviceType.CUDA,),
executors=(TorchExecutor, TorchCompileCatExecutor, nvFuserExecutor, DynamoThunderExecutor),
decorators=(
pytest.mark.skipif(
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
Expand All @@ -267,4 +266,8 @@ def test_torchao_float8_linear(executor, device, _):
jitted = executor.make_callable(fp8_model)
actual = jitted(x)

torch.testing.assert_close(actual, expected)
if executor == DynamoThunderExecutor:
with pytest.raises(AssertionError):
torch.testing.assert_close(actual, expected)
else:
torch.testing.assert_close(actual, expected)

0 comments on commit f582143

Please sign in to comment.