diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py index e150d9466cb..06da5bbab22 100644 --- a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -49,12 +49,8 @@ def call_operator(self, op, args, kwargs, meta): ) # convolution with bias and activation is int16 - # The bias is assumed to be quantized with the same quantization parameters as - # as the output of the convolution bias = args[2] - assert ( - meta.data["output_qparams"][0].dtype == bias.data.dtype - ), "Bias needs to have same type as quantized output type" + no_bias_args = list(args) no_bias_args[2] = None # split up to convolution + bias @@ -79,46 +75,30 @@ def call_operator(self, op, args, kwargs, meta): # The conv will get the output int48 scaled to int32 in serialization step. # To be able to add the bias we need to first scale (cast?) the output to int32. # The resulting i32 sum will then need to be scaled back to the output dtype. - - # calculate common rescale factor from convolution output and bias quantization output_qparams = cast(QuantArgs, meta.data["output_qparams"][0]) conv_output_scale = output_qparams.scale - bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2]) - bias_scale = bias_qparams.scale - common_scale = max(bias_scale, conv_output_scale) - - # calculate how we can rescale bias and conv to a common scale and maximize the output range - bias_rescale_factor = bias_scale / common_scale - conv_rescale_factor = conv_output_scale / common_scale + bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2]) + per_channel_quant = bias_qparams.per_channel - # Either of conv output or bias now covers the full int16 range and the other one a smaller range. - # Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range. - # Worst case here is that both bias and conv output covers the full int16 range so we leave one bit - # and then one for the sign bit. - bits_left_to_shift = 14 + if per_channel_quant: + bias_scale = bias_qparams.get_scale_per_channel() + else: + bias_scale = [bias_qparams.get_scale_per_tensor()] - # update rescale factors - bias_rescale_factor *= 1 << bits_left_to_shift - conv_rescale_factor *= 1 << bits_left_to_shift + conv_rescale_factors = [1.0] * len(bias_scale) + final_output_scale = [b / conv_output_scale for b in bias_scale] conv_output = super().call_operator( exir_ops.backend.tosa.RESCALE.default, - (convolution, torch.int32, [conv_rescale_factor], 0, 0), - {}, - new_meta, - ) - - bias_rescaled = super().call_operator( - exir_ops.backend.tosa.RESCALE.default, - (channel_bias, torch.int32, [bias_rescale_factor], 0, 0), + (convolution, torch.int32, conv_rescale_factors, 0, 0), {}, new_meta, ) add = super().call_operator( exir_ops.edge.aten.add.Tensor, - (conv_output, bias_rescaled), + (conv_output, channel_bias), {}, new_meta, ) @@ -128,7 +108,7 @@ def call_operator(self, op, args, kwargs, meta): ( add, output_dtype, - [(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))], + final_output_scale, 0, 0, ), diff --git a/backends/arm/_passes/rewrite_conv2d_pass.py b/backends/arm/_passes/rewrite_conv2d_pass.py index c46cfb3b205..52feba5f8b9 100644 --- a/backends/arm/_passes/rewrite_conv2d_pass.py +++ b/backends/arm/_passes/rewrite_conv2d_pass.py @@ -237,8 +237,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: pad[3], dilation[1], ) - - if bias is None: + has_bias = bias is not None + if not has_bias: bias = self._add_bias(graph_module, node, weight) if self._is_depthwise_conv2d(node): @@ -278,14 +278,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if ( tosa_node_fake_tensor.dtype == torch.int32 and input_fake_tensor.dtype == torch.int8 - ) or ( - tosa_node_fake_tensor.dtype == torch.int32 - and input_fake_tensor.dtype == torch.int16 ): output_rescale = self.insert_output_rescale(graph_module, tosa_op) node.replace_all_uses_with(output_rescale) - if input_fake_tensor.dtype == torch.int16: - tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + elif ( + tosa_node_fake_tensor.dtype == torch.int32 + and input_fake_tensor.dtype == torch.int16 + ): + has_bias = len(node.meta["input_qparams"]) > 2 + if not has_bias: + output_rescale = self.insert_output_rescale(graph_module, tosa_op) + node.replace_all_uses_with(output_rescale) + else: + node.replace_all_uses_with(tosa_op) + tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 else: node.replace_all_uses_with(tosa_op) diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index bb1c8ec51cd..36ab233bdb6 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -182,9 +182,12 @@ def _derive_qparams_fn( raise ValueError( "Input activation and weight QuantizationConfig must be specified." ) - if self.input_activation.dtype == self.weight.dtype == torch.int8: - # This is the default int8 quantization which uses the derived quantization - # calculated from the activation and weight scale + + if (self.input_activation.dtype == self.weight.dtype == torch.int8) or ( + self.input_activation.dtype == torch.int16 + and self.weight.dtype == torch.int8 + ): + input_act = node.args[0] weight = node.args[1] @@ -209,13 +212,6 @@ def _derive_qparams_fn( ch_axis=ch_axis, ) return quantization_spec # type: ignore[return-value] - elif ( - self.input_activation.dtype == torch.int16 - and self.weight.dtype == torch.int8 - ): - # In case the activation is quantized to int16, the bias needs to be - # added after the convolution, so use the output quantization for this case. - return self.output_activation else: raise NotImplementedError( f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented" diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 4029fcef54e..952befeeffa 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -274,10 +274,6 @@ def get_symmetric_a16w8_linear_quantizer( test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT -# TODO: Remove large rand test as they are flaky until sorted out why: MLETORCH-1377 -for k in list(test_data_all_16a8w.keys()): - if "large_rand" in k: - test_data_all_16a8w.pop(k) @common.parametrize("test_data", test_data_all_16a8w) @@ -311,7 +307,19 @@ def test_linear_16a8w_tosa_INT(test_data: torch.Tensor): pipeline.run() -@common.parametrize("test_data", test_data_all_16a8w) +x_fails = {} +for test_name in [ + "model_linear_rank4_zeros", + "model_linear_rank4_negative_ones", + "model_linear_rank4_negative_large_rand", +]: + for set_per_chan in ["True", "False"]: + x_fails[test_name + ",per_channel_quant={}".format(set_per_chan)] = ( + "MLETORCH-1452: AssertionError: Output 0 does not match reference output." + ) + + +@common.parametrize("test_data", test_data_all_16a8w, x_fails) @common.XfailIfNoCorstone300 def test_linear_16a8w_u55_INT16(test_data: torch.Tensor): """Test linear operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""