Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.fx.node import Argument, Node, Target
from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt._features import needs_not_tensorrt_rtx
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand Down Expand Up @@ -429,7 +429,7 @@ def index_nonbool_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
# for thor and tensorrt_rtx, we don't support boolean indices, due to nonzero op not supported
if is_thor() or ENABLED_FEATURES.tensorrt_rtx:
if ENABLED_FEATURES.tensorrt_rtx:
index = node.args[1]
for ind in index:
if ind is not None:
Expand Down Expand Up @@ -3621,18 +3621,10 @@ def aten_ops_full(
)


def nonzero_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
return not is_thor()


# currently nonzero is not supported for tensorrt_rtx
# TODO: lan to add the nonzero support once tensorrt_rtx team has added the support
# TODO: apbose to remove the capability validator once thor bug resolve in NGC
@dynamo_tensorrt_converter(
torch.ops.aten.nonzero.default,
capability_validator=nonzero_validator,
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
Expand Down
5 changes: 0 additions & 5 deletions tests/py/dynamo/conversion/test_arange_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,10 @@
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt._utils import is_tegra_platform, is_thor

from .harness import DispatchTestCase


@unittest.skipIf(
is_thor() or is_tegra_platform(),
"Skipped on Thor and Tegra platforms",
)
class TestArangeConverter(DispatchTestCase):
@parameterized.expand(
[
Expand Down
5 changes: 0 additions & 5 deletions tests/py/dynamo/conversion/test_cumsum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,10 @@
import torch_tensorrt
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt._utils import is_tegra_platform, is_thor

from .harness import DispatchTestCase


@unittest.skipIf(
is_thor() or is_tegra_platform(),
"Skipped on Thor and Tegra platforms",
)
class TestCumsumConverter(DispatchTestCase):
@parameterized.expand(
[
Expand Down
19 changes: 7 additions & 12 deletions tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import ENABLED_FEATURES, Input
from torch_tensorrt._utils import is_tegra_platform, is_thor

from .harness import DispatchTestCase

Expand Down Expand Up @@ -114,8 +113,8 @@ def forward(self, input):
]
)
@unittest.skipIf(
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
ENABLED_FEATURES.tensorrt_rtx,
"Skipped on tensorrt_rtx due to nonzero not supported",
)
def test_index_constant_bool_mask(self, _, index, input):
class TestModule(torch.nn.Module):
Expand Down Expand Up @@ -149,8 +148,8 @@ def forward(self, x, index0):
)

@unittest.skipIf(
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
ENABLED_FEATURES.tensorrt_rtx,
"Skipped on tensorrt_rtx due to nonzero not supported",
)
def test_index_zero_two_dim_ITensor_mask(self):
class TestModule(nn.Module):
Expand All @@ -163,10 +162,6 @@ def forward(self, x, index0):
index0 = torch.tensor([True, False])
self.run_test(TestModule(), [input, index0], enable_passes=True)

@unittest.skipIf(
is_thor(),
"Skipped on Thor due to nonzero not supported",
)
def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
Expand All @@ -180,8 +175,8 @@ def forward(self, x, index0):
self.run_test(TestModule(), [input, index0])

@unittest.skipIf(
is_thor() or ENABLED_FEATURES.tensorrt_rtx,
"Skipped on Thor or tensorrt_rtx due to nonzero not supported",
ENABLED_FEATURES.tensorrt_rtx,
"Skipped on tensorrt_rtx due to nonzero not supported",
)
def test_index_zero_index_three_dim_mask_ITensor(self):
class TestModule(nn.Module):
Expand Down Expand Up @@ -252,7 +247,7 @@ def forward(self, input):


@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_thor() or is_tegra_platform(),
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"nonzero is not supported for tensorrt_rtx",
)
class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/py/dynamo/conversion/test_nonzero_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from torch_tensorrt._utils import is_tegra_platform, is_thor

from .harness import DispatchTestCase


@unittest.skipIf(
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx or is_thor() or is_tegra_platform(),
torch_tensorrt.ENABLED_FEATURES.tensorrt_rtx,
"nonzero is not supported for tensorrt_rtx",
)
class TestNonZeroConverter(DispatchTestCase):
Expand Down
5 changes: 0 additions & 5 deletions tests/py/dynamo/conversion/test_sym_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,10 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt._utils import is_thor

from .harness import DispatchTestCase


@unittest.skipIf(
is_thor(),
"Skipped on Thor",
)
class TestSymSizeConverter(DispatchTestCase):
@parameterized.expand(
[
Expand Down
21 changes: 10 additions & 11 deletions tests/py/dynamo/models/test_export_kwargs_serde.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# type: ignore
import os
import tempfile
import unittest

import pytest
Expand All @@ -22,7 +21,7 @@

@pytest.mark.unit
@pytest.mark.critical
def test_custom_model():
def test_custom_model(tmpdir):
class net(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -75,15 +74,15 @@ def forward(self, x, b=5, c=None, d=None):
)

# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
@pytest.mark.critical
def test_custom_model_with_dynamo_trace():
def test_custom_model_with_dynamo_trace(tmpdir):
class net(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -137,15 +136,15 @@ def forward(self, x, b=5, c=None, d=None):
)

# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
@pytest.mark.critical
def test_custom_model_with_dynamo_trace_dynamic():
def test_custom_model_with_dynamo_trace_dynamic(tmpdir):
class net(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -208,15 +207,15 @@ def forward(self, x, b=5, c=None, d=None):
)

# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
@pytest.mark.critical
def test_custom_model_with_dynamo_trace_kwarg_dynamic():
def test_custom_model_with_dynamo_trace_kwarg_dynamic(tmpdir):
ir = "dynamo"

class net(nn.Module):
Expand Down Expand Up @@ -298,15 +297,15 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
@pytest.mark.critical
def test_custom_model_with_dynamo_trace_kwarg_dynamic():
def test_custom_model_with_dynamo_trace_kwarg_dynamic(tmpdir):
ir = "dynamo"

class net(nn.Module):
Expand Down Expand Up @@ -388,7 +387,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]):
msg=f"CustomKwargs Module TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Save the module
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
trt_ep_path = os.path.join(tmpdir, "compiled.ep")
torchtrt.save(trt_gm, trt_ep_path, retrace=False)
# Clean up model env
torch._dynamo.reset()
Expand Down
Loading