From d9178c643c12e0e785df8b7056bd78f48938ea08 Mon Sep 17 00:00:00 2001 From: Ran Tao Date: Mon, 17 Jun 2024 15:30:44 +0200 Subject: [PATCH] input of batchnorm1d can also be (N,C) --- tests/converter_tests/test_converters.py | 5 +++++ torch2trt/converters/native_converters.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/converter_tests/test_converters.py b/tests/converter_tests/test_converters.py index de9e0a55..c67fe260 100644 --- a/tests/converter_tests/test_converters.py +++ b/tests/converter_tests/test_converters.py @@ -203,6 +203,11 @@ def test_batch_norm_nd(nd, with_conv): inputs = [torch.randn(*input_size).cuda()] cross_validate(module, inputs, fp16_mode=False, tol=1e-1) + if nd == 1: + input_size = [2, 3] + inputs = [torch.randn(*input_size).cuda()] + cross_validate(module, inputs, fp16_mode=False, tol=1e-1) + @pytest.mark.parametrize("dim", [1, -1]) def test_cat(dim): diff --git a/torch2trt/converters/native_converters.py b/torch2trt/converters/native_converters.py index e1dcd145..660fedbd 100644 --- a/torch2trt/converters/native_converters.py +++ b/torch2trt/converters/native_converters.py @@ -156,7 +156,7 @@ def convert_batch_norm(ctx): bias = bias.detach().cpu().numpy() - running_mean.detach().cpu().numpy() * scale power = np.ones_like(scale) - if ndim == 1: + if ndim == 1 or ndim == 0: # reshape to 2D layer = ctx.network.add_shuffle(input_trt) @@ -171,7 +171,7 @@ def convert_batch_norm(ctx): layer = ctx.network.add_scale_nd(scale_input, trt.ScaleMode.CHANNEL, bias, scale, power, 1) - if ndim == 1: + if ndim == 1 or ndim == 0: # reshape back to 1D layer = ctx.network.add_shuffle(layer.get_output(0)) if len(input.shape) == 2: