diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 9d26e508c61..c0fd4759362 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -79,6 +79,20 @@ tflm_cc_library( ], ) +tflm_cc_library( + name = "decode_test_helpers", + hdrs = [ + "decode_test_helpers.h", + ], + deps = [ + ":kernel_runner", + ":micro_ops", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflm_cc_library( name = "decompress", srcs = [ @@ -239,6 +253,7 @@ tflm_kernel_cc_library( "decode.cc", "decode_state.cc", "decode_state_lut.cc", + "decode_state_prune.cc", "depth_to_space.cc", "depthwise_conv.cc", "depthwise_conv_common.cc", @@ -332,6 +347,7 @@ tflm_kernel_cc_library( "conv.h", "decode_state.h", "decode_state_lut.h", + "decode_state_prune.h", "depthwise_conv.h", "dequantize.h", "ethosu.h", @@ -648,12 +664,29 @@ tflm_cc_test( ], ) +tflm_cc_test( + name = "decode_state_prune_test", + srcs = [ + "decode_state_prune_test.cc", + ], + deps = [ + ":decode_test_helpers", + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:debug_log", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflm_cc_test( name = "decode_test", srcs = [ "decode_test.cc", ], deps = [ + ":decode_test_helpers", ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 62e9324995e..d9d9d728eb8 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ diff --git a/tensorflow/lite/micro/kernels/decode.cc b/tensorflow/lite/micro/kernels/decode.cc index 778a516224c..1e06a3390ef 100644 --- a/tensorflow/lite/micro/kernels/decode.cc +++ b/tensorflow/lite/micro/kernels/decode.cc @@ -63,6 +63,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { break; } + TF_LITE_ENSURE(context, IsConstantTensor(input)); + TF_LITE_ENSURE(context, IsConstantTensor(ancillary)); + if (DecodeState::Version(*ancillary) != 1) { MicroPrintf("version %u != 1", DecodeState::Version(*ancillary)); status = kTfLiteError; @@ -75,6 +78,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { dsp = DecodeState::CreateDecodeStateLUT( context, micro_context->GetAlternateProfiler()); break; + case DecodeState::kDcmTypePrune: + dsp = DecodeState::CreateDecodeStatePrune( + context, micro_context->GetAlternateProfiler()); + break; case DecodeState::kDcmTypeCustom: MicroPrintf("Custom decode type not yet supported"); break; diff --git a/tensorflow/lite/micro/kernels/decode_state.cc b/tensorflow/lite/micro/kernels/decode_state.cc index a55b4b4148b..adcdf913be8 100644 --- a/tensorflow/lite/micro/kernels/decode_state.cc +++ b/tensorflow/lite/micro/kernels/decode_state.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/decode_state.h" #include "tensorflow/lite/micro/kernels/decode_state_lut.h" +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" #include "tensorflow/lite/micro/micro_context.h" namespace tflite { @@ -33,4 +34,17 @@ DecodeState* DecodeState::CreateDecodeStateLUT( return dsp; } +DecodeState* DecodeState::CreateDecodeStatePrune( + const TfLiteContext* context, MicroProfilerInterface* profiler) { + MicroContext* const micro_context = GetMicroContext(context); + void* buffer = + micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune)); + if (buffer == nullptr) { + return nullptr; + } + DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler); + + return dsp; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state.h b/tensorflow/lite/micro/kernels/decode_state.h index 3818781b9dc..baebfb5ea63 100644 --- a/tensorflow/lite/micro/kernels/decode_state.h +++ b/tensorflow/lite/micro/kernels/decode_state.h @@ -43,6 +43,8 @@ class DecodeState { static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context, MicroProfilerInterface* profiler); + static DecodeState* CreateDecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler); static uint8_t Type(const TfLiteTensor& ancillary) { return GetTensorData(&ancillary)[kDcmDecodeTypeOffset]; @@ -66,6 +68,7 @@ class DecodeState { // Decode Common Metadata constants public: static constexpr uint8_t kDcmTypeLUT = 0; + static constexpr uint8_t kDcmTypePrune = 2; static constexpr uint8_t kDcmTypeCustom = 127; static constexpr size_t kDcmSizeInBytes = 16; diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.cc b/tensorflow/lite/micro/kernels/decode_state_prune.cc new file mode 100644 index 00000000000..aadfd8445ee --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.cc @@ -0,0 +1,206 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" + +#include +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_profiler.h" + +namespace tflite { + +TfLiteStatus DecodeStatePrune::Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) { + const uint8_t* const ancillary_data = GetTensorData(&ancillary); + if (ancillary_data[kDcmVersionOffset] != 1) { + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); + return kTfLiteError; + } + + // resolve num_channels_, use_alternate_axis_, and zero points + if (output.quantization.type == kTfLiteAffineQuantization && + output.quantization.params != nullptr) { + const TfLiteAffineQuantization* quantization = + reinterpret_cast(output.quantization.params); + num_channels_ = quantization->scale->size; + if ((quantization->quantized_dimension == output.dims->size - 1) && + num_channels_ > 1) { + use_alternate_axis_ = true; + } else if (quantization->quantized_dimension != 0) { + MicroPrintf("unsupported quantization axis %u", + quantization->quantized_dimension); + return kTfLiteError; + } + + TFLITE_DCHECK(num_channels_ == + static_cast(quantization->zero_point->size)); + bool has_non_zero_zp = + std::any_of(quantization->zero_point->data, + quantization->zero_point->data + num_channels_, + [](int zp) { return zp != 0; }); + + if (output.type != kTfLiteInt8) { + // make sure all zero points are 0 (zero) + TF_LITE_ENSURE_MSG(const_cast(context_), + has_non_zero_zp == false, + "All zero-points must be zero"); + } + + if (num_channels_ > 1 && has_non_zero_zp) { + // copy zero points + MicroContext* micro_context = GetMicroContext(context_); + const size_t bufsize = num_channels_ * sizeof(*zero_points_); + zero_points_ = static_cast( + micro_context->AllocatePersistentBuffer(bufsize)); + if (zero_points_ == nullptr) { + MicroPrintf("unable to allocate zero_points_"); + return kTfLiteError; + } + std::copy_n(quantization->zero_point->data, num_channels_, zero_points_); + } else { + single_zero_point_ = quantization->zero_point->data[0]; + } + } + + compressed_indices_ = GetTensorData(&input); + count_indices_ = NumElements(&output); + elements_per_channel_ = + use_alternate_axis_ ? 1 : count_indices_ / num_channels_; + value_table_ = &ancillary_data[kDcmSizeInBytes]; + + return kTfLiteOk; +} + +TfLiteStatus DecodeStatePrune::Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) { + void* const buffer = const_cast(micro::GetTensorData(&output)); + TFLITE_DCHECK(buffer != nullptr); + + switch (output.type) { + case kTfLiteBool: + DecompressToBuffer(buffer); + break; + case kTfLiteFloat32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt8: + if (num_channels_ > 1 && zero_points_ != nullptr) { + DecompressToBufferPerChannelInt8(buffer); + } else { + DecompressToBuffer(buffer); + } + break; + case kTfLiteInt16: + DecompressToBuffer(buffer); + break; + case kTfLiteInt32: + DecompressToBuffer(buffer); + break; + case kTfLiteInt64: + DecompressToBuffer(buffer); + break; + default: + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); + return kTfLiteError; + } + + return kTfLiteOk; +} + +template +void DecodeStatePrune::DecompressToBuffer(void* vp) { + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + T* buffer = static_cast(vp); + const T* value_table = static_cast(value_table_); + const size_t max_count = count_indices_; + const uint8_t* const indices = compressed_indices_; + + for (size_t index = 0; index < max_count; index++) { + size_t shift = ~index & 0b111; + size_t is_not_zp = (indices[index >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = single_zero_point_; + } + } +} + +void DecodeStatePrune::DecompressToBufferPerChannelInt8(void* vp) { + TFLITE_DCHECK(zero_points_ != nullptr); + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); + + int8_t* buffer = static_cast(vp); + size_t current_offset = 0; + const uint8_t* const indices = compressed_indices_; + const int8_t* value_table = static_cast(value_table_); + + if (use_alternate_axis_) { + const size_t max_channels = num_channels_; + size_t count = count_indices_; + + while (count > 0) { + for (size_t channel = 0; channel < max_channels; channel++) { + const int8_t zp = zero_points_[channel]; + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + count -= max_channels; + } + } else { + const size_t max_count = elements_per_channel_; + + for (size_t channel = 0; channel < num_channels_; channel++) { + size_t count = max_count; + const int8_t zp = zero_points_[channel]; + + while (count-- > 0) { + size_t shift = ~current_offset & 0b111; + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; + + if (is_not_zp) { + *buffer++ = *value_table++; + } else { + *buffer++ = zp; + } + current_offset++; + } + } + } +} + +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); +template void DecodeStatePrune::DecompressToBuffer(void*); + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/decode_state_prune.h b/tensorflow/lite/micro/kernels/decode_state_prune.h new file mode 100644 index 00000000000..f95df725578 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune.h @@ -0,0 +1,70 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ + +#include + +#include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" + +namespace tflite { + +class DecodeStatePrune : public DecodeState { + public: + DecodeStatePrune() = delete; + + DecodeStatePrune(const TfLiteContext* context, + MicroProfilerInterface* profiler) + : DecodeState(context, profiler) {} + + virtual TfLiteStatus Setup(const TfLiteTensor& input, + const TfLiteTensor& ancillary, + const TfLiteTensor& output) override; + virtual TfLiteStatus Decode(const TfLiteEvalTensor& input, + const TfLiteEvalTensor& ancillary, + const TfLiteEvalTensor& output) override; + + private: + // Prune Decode Common Metadata constants + static constexpr size_t kDcmVersionOffset = 4; + + protected: + virtual ~DecodeStatePrune() = default; + + template + void DecompressToBuffer(void* buffer); + + void DecompressToBufferPerChannelInt8(void* buffer); + + protected: + const uint8_t* compressed_indices_ = nullptr; + size_t count_indices_ = 0; + size_t num_channels_ = 1; + size_t elements_per_channel_ = 0; // computed from use_alternate_axis_ + const void* value_table_ = nullptr; // original non-pruned values + int8_t* zero_points_ = nullptr; // quantized per-channel zero points + int8_t single_zero_point_ = 0; // single channel zero point + bool use_alternate_axis_ = false; // shape channel axis: + // false = first, true = last + + private: + TF_LITE_REMOVE_VIRTUAL_DELETE +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_PRUNE_H_ diff --git a/tensorflow/lite/micro/kernels/decode_state_prune_test.cc b/tensorflow/lite/micro/kernels/decode_state_prune_test.cc new file mode 100644 index 00000000000..8d87e97017b --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_state_prune_test.cc @@ -0,0 +1,581 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" +#include "tensorflow/lite/micro/kernels/decode_test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace { + +// +// Prune test data +// +constexpr int8_t kAncillaryDataPrune0[] = { + 1, 2, 3, 4, 1, // chan 0 + 2, 3, 4, 1, 2, // chan 0 + 3, 4, 1, 2, 3, // chan 0 + 4, 1, 2, 3, 4, // chan 0 + 11, 12, 13, 14, 11, // chan 1 + 12, 13, 14, 11, 12, // chan 1 + 13, 14, 11, 12, 13, // chan 1 + 14, 11, 12, 13, 14 // chan 1 +}; +constexpr int16_t kAncillaryDataPrune1[] = { + 5, 6, 7, 8, 5, // chan 0 + 6, 7, 8, 5, 6, // chan 0 + 7, 8, 5, 6, 7, // chan 0 + 8, 5, 6, 7, 8, // chan 0 + 15, 16, 17, 18, 15, // chan 1 + 16, 17, 18, 15, 16, // chan 1 + 17, 18, 15, 16, 17, // chan 1 + 18, 15, 16, 17, 18 // chan 1 +}; +constexpr float kAncillaryDataPrune2[] = { + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 0 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 1 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 2 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 3 + 9.0f, 10.0f, 11.0f, 12.0f, // encoded byte 4 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 5 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 6 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 7 + 19.0f, 20.0f, 21.0f, 22.0f, // encoded byte 8 + 19.0f, 20.0f, 21.0f, 22.0f // encoded byte 9 +}; +constexpr int8_t kAncillaryDataPrune3[] = { + 13, 14, 15, 16, 13, // chan 0 + 14, 15, 16, 13, 14, // chan 0 + 15, 16, 13, 14, 15, // chan 0 + 16, 13, 14, 15, 16, // chan 0 + 113, 114, 115, 116, 113, // chan 1 + 114, 115, 116, 113, 114, // chan 1 + 115, 116, 113, 114, 115, // chan 1 + 116, 113, 114, 115, 116 // chan 1 +}; +constexpr int8_t kAncillaryDataPrune4[] = { + 17, 18, 19, 20, 17, 18, 19, 20, 17, 18, // group 0 + 19, 20, 17, 18, 19, 20, 17, 18, 19, 20, // group 0 + 21, 22, 23, 24, 21, 22, 23, 24, 21, 22, // group 1 + 23, 24, 21, 22, 23, 24, 21, 22, 23, 24, // group 1 +}; +constexpr int8_t kAncillaryDataPrune5[] = { + 13, 14, 15, 16, 13, // chan 0 + 14, 15, 16, 13, 14, // chan 0 + 15, 16, 13, 14, 15, // chan 0 + 16, 13, 14, 15, 16, // chan 0 + 23, 24, 25, 26, 23, // chan 0 + 24, 25, 26, 23, 24, // chan 0 + 25, 26, 23, 24, 25, // chan 0 + 26, 23, 24, 25, 26 // chan 0 +}; + +constexpr uint8_t kDcmPrune[tflite::DecodeState::kDcmSizeInBytes] = { + tflite::DecodeState::kDcmTypePrune, // type: Prune + 1, // DCM version: 1 + 0, // reserved + 0, // reserved + 1, // Prune version: 1 +}; + +// Align the tensor data the same as a Buffer in the TfLite schema. +// Use 0x5A in byte 1 to check byte ordering in the low-level code. +alignas(16) const uint8_t kEncodedPrune[] = {0xA5, 0x5A, 0xA5, 0xA5, 0xA5, + 0xA5, 0xA5, 0xA5, 0xA5, 0xA5}; + +// Tensor shapes as TfLiteIntArray +constexpr int kEncodedShapePrune[] = {1, sizeof(kEncodedPrune)}; +constexpr int kOutputShapePrune[] = {4, 2, 5, 8, 1}; // 2 channels +constexpr int kOutputShapePrune4[] = {4, 1, 2, 1, 40}; // 40 channels, alt-axis +constexpr int kOutputShapePrune5[] = {4, 1, 8, 10, 1}; // 1 channel + +// Quantization datum as TfLiteIntArray. +constexpr int kZeroPointsPrune0[] = {2, -128, 0}; +constexpr int kZeroPointsPrune1[] = {2, 0, 0}; +constexpr int kZeroPointsPrune1_Invalid[] = {2, 0, -1}; +constexpr int kZeroPointsPrune3[] = {2, 0, 0}; +constexpr int kZeroPointsPrune4[] = { + 40, // size + 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, + -14, -15, -16, -17, -18, -19, 0, -1, -2, -3, -4, -5, -6, -7, + -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, +}; +constexpr int kZeroPointsPrune5[] = {1, -44}; + +constexpr int8_t kExpectPrune0[] = { + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + -128, 1, -128, 2, 3, -128, 4, -128, // chan 0 + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + 1, -128, 2, -128, -128, 3, -128, 4, // chan 0 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14, // chan 1 + 11, 0, 12, 0, 0, 13, 0, 14 // chan 1 +}; +constexpr int16_t kExpectPrune1[] = { + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 0, 5, 0, 6, 7, 0, 8, 0, // chan 0 + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 5, 0, 6, 0, 0, 7, 0, 8, // chan 0 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18, // chan 1 + 15, 0, 16, 0, 0, 17, 0, 18 // chan 1 +}; +constexpr float kExpectPrune2[] = { + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 0 + 0.0f, 9.0f, 0.0f, 10.0f, 11.0f, 0.0f, 12.0f, 0.0f, // encode byte 1 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 2 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 3 + 9.0f, 0.0f, 10.0f, 0.0f, 0.0f, 11.0f, 0.0f, 12.0f, // encode byte 4 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 5 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 6 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 7 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f, // encode byte 8 + 19.0f, 0.0f, 20.0f, 0.0f, 0.0f, 21.0f, 0.0f, 22.0f // encode byte 9 +}; +constexpr int8_t kExpectPrune3[] = { + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 0, 13, 0, 14, 15, 0, 16, 0, // chan 0 + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 13, 0, 14, 0, 0, 15, 0, 16, // chan 0 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116, // chan 1 + 113, 0, 114, 0, 0, 115, 0, 116 // chan 1 +}; +constexpr int8_t kExpectPrune4[] = { + 17, -1, 18, -3, -4, 19, -6, 20, -8, 17, // group 0 + -10, 18, 19, -13, 20, -15, 17, -17, 18, -19, // group 0 + 0, 19, -2, 20, 17, -5, 18, -7, -8, 19, // group 0 + -10, 20, 17, -13, 18, -15, -16, 19, -18, 20, // group 0 + 21, -1, 22, -3, -4, 23, -6, 24, 21, -9, // group 1 + 22, -11, -12, 23, -14, 24, 21, -17, 22, -19, // group 1 + 0, 23, -2, 24, 21, -5, 22, -7, -8, 23, // group 1 + -10, 24, 21, -13, 22, -15, -16, 23, -18, 24 // group 1 +}; +constexpr int8_t kExpectPrune5[] = { + 13, -44, 14, -44, -44, 15, -44, 16, -44, 13, // chan 0 + -44, 14, 15, -44, 16, -44, 13, -44, 14, -44, // chan 0 + -44, 15, -44, 16, 13, -44, 14, -44, -44, 15, // chan 0 + -44, 16, 13, -44, 14, -44, -44, 15, -44, 16, // chan 0 + 23, -44, 24, -44, -44, 25, -44, 26, 23, -44, // chan 0 + 24, -44, -44, 25, -44, 26, 23, -44, 24, -44, // chan 0 + -44, 25, -44, 26, 23, -44, 24, -44, -44, 25, // chan 0 + -44, 26, 23, -44, 24, -44, -44, 25, -44, 26 // chan 0 +}; + +} // namespace + +TF_LITE_MICRO_TESTS_BEGIN + +using tflite::testing::AncillaryData; +using tflite::testing::TensorInDatum; +using tflite::testing::TensorOutDatum; + +TF_LITE_MICRO_TEST(DecodePruneFloat) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) float output_data[std::size(kExpectPrune2)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune2}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteFloat32, + kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune2}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune3)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune3}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + constexpr int kOutputZeroPointsData[] = {0}; + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kOutputZeroPointsData); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune3}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune3)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune3}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune3); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune3}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedMixedZeroPointInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune0)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune0}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune0); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune0}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedSingleChannelInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune5)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune5}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune5); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune5); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, *kOutputDims, kTfLiteInt8, kOutputScales, *kOutputZeroPoints, + 0, {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune5}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedAltAxisInt8) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int8_t output_data[std::size(kExpectPrune4)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune4}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune4); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune4); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt8, + kOutputScales, + *kOutputZeroPoints, + (kOutputDims->size - 1), + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune4}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune1); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE()); +} + +TF_LITE_MICRO_TEST(DecodePruneQuantizedInvalidZeroPointInt16) { + // Align the tensor data the same as a Buffer in the TfLite schema + alignas(16) int16_t output_data[std::size(kExpectPrune1)] = {}; + alignas(16) const AncillaryData + kAncillaryData = {{kDcmPrune}, {kAncillaryDataPrune1}}; + + const TfLiteIntArray* const kEncodedDims = + tflite::testing::IntArrayFromInts(kEncodedShapePrune); + static const TensorInDatum kEncodeTID = { + kEncodedPrune, + *kEncodedDims, + }; + static constexpr std::initializer_list kEncodes = { + &kEncodeTID, + }; + + constexpr int kAncillaryShape[] = {1, sizeof(kAncillaryData)}; + const TfLiteIntArray* const kAncillaryDims = + tflite::testing::IntArrayFromInts(kAncillaryShape); + static const TensorInDatum kAncillaryTID = { + &kAncillaryData, + *kAncillaryDims, + }; + static constexpr std::initializer_list kAncillaries = { + &kAncillaryTID}; + + const TfLiteIntArray* const kOutputDims = + tflite::testing::IntArrayFromInts(kOutputShapePrune); + const TfLiteIntArray* const kOutputZeroPoints = + tflite::testing::IntArrayFromInts(kZeroPointsPrune1_Invalid); + const TfLiteFloatArray kOutputScales = {kOutputZeroPoints->size}; + static const TensorOutDatum kTOD = { + output_data, + *kOutputDims, + kTfLiteInt16, + kOutputScales, + *kOutputZeroPoints, + 0, + {}, + }; + static constexpr std::initializer_list kOutputs = { + &kTOD}; + + const std::initializer_list kExpected = {kExpectPrune1}; + + tflite::testing::TestDecode( + kEncodes, kAncillaries, kOutputs, kExpected, tflite::Register_DECODE(), + kTfLiteError); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/decode_test.cc b/tensorflow/lite/micro/kernels/decode_test.cc index 593b608e2ec..9645193bd68 100644 --- a/tensorflow/lite/micro/kernels/decode_test.cc +++ b/tensorflow/lite/micro/kernels/decode_test.cc @@ -14,51 +14,19 @@ limitations under the License. ==============================================================================*/ #include +#include #include -#include #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/decode_state.h" -#include "tensorflow/lite/micro/kernels/kernel_runner.h" -#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/kernels/decode_test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -namespace tflite { -namespace testing { namespace { -struct TensorInDatum { - const void* const data; - const TfLiteIntArray& dims; -}; - -struct TensorOutDatum { - void* const data; - const TfLiteIntArray& dims; - const TfLiteType type; - const TfLiteFloatArray& scales; - const TfLiteIntArray& zero_points; - const int quantized_dimension; - - // initialized by CreatePerChannelQuantizedTensor - const TfLiteAffineQuantization affine_quantization; -}; - -template -struct AncillaryData { - AncillaryData() = delete; - AncillaryData(const uint8_t (&dcm)[tflite::DecodeState::kDcmSizeInBytes], - const T (&values)[N]) { - std::copy(std::begin(dcm), std::end(dcm), std::begin(dcm_)); - std::copy(std::begin(values), std::end(values), std::begin(value_table_)); - } - - private: - uint8_t dcm_[tflite::DecodeState::kDcmSizeInBytes]; - T value_table_[N > 0 ? N : 1]; // assure not zero length -}; - +// +// LUT test data +// constexpr int kBitWidthLUT = 2; constexpr int8_t kAncillaryDataLUT0[] = {1, 2, 3, 4}; @@ -98,119 +66,11 @@ constexpr int kEncodedShapeLUT[] = {1, sizeof(kEncodedLUT)}; constexpr int8_t kExpectLUT0[] = {1, 2, 3, 4, 4, 3, 2, 1}; constexpr int16_t kExpectLUT1[] = {5, 6, 7, 8, 8, 7, 6, 5}; -template -TfLiteStatus CheckOutput(const TfLiteTensor& output, - const void* const expected) { - const T* const expected_data = reinterpret_cast(expected); - const T* const output_data = tflite::GetTensorData(&output); - - constexpr float kTolerance = 1e-5; - const size_t kOutputCount = tflite::NumElements(&output); - for (size_t i = 0; i < kOutputCount; i++) { - TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTolerance); - TF_LITE_MICRO_CHECK_FAIL(); - } - - return kTfLiteOk; -} - -template -TfLiteStatus ExecuteDecodeTest( - TfLiteTensor* tensors, const TFLMRegistration& registration, - const std::initializer_list& expected) { - int kInputArrayData[kNumInputs + 1] = {kNumInputs}; - for (size_t i = 0; i < kNumInputs; i++) { - kInputArrayData[i + 1] = i; - } - TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData); - - int kOutputArrayData[kNumOutputs + 1] = {kNumOutputs}; - for (size_t i = 0; i < kNumOutputs; i++) { - kOutputArrayData[i + 1] = i + kNumInputs; - } - TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData); - - micro::KernelRunner runner(registration, tensors, kNumInputs + kNumOutputs, - inputs_array, outputs_array, nullptr); - - if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) { - return kTfLiteError; - } - - const TfLiteTensor* const output_tensors = &tensors[kNumInputs]; - TfLiteStatus status = kTfLiteError; - for (size_t i = 0; i < kNumOutputs; i++) { - switch (output_tensors[i].type) { - case kTfLiteInt8: - status = CheckOutput(output_tensors[i], expected.begin()[i]); - break; - case kTfLiteInt16: - status = CheckOutput(output_tensors[i], expected.begin()[i]); - break; - default: - TF_LITE_MICRO_FAIL("unsupported tensor type in test"); - break; - } - } - - return status; -} - -template -void TestDecode(const std::initializer_list& encodes, - const std::initializer_list& ancillaries, - const std::initializer_list& outputs, - const std::initializer_list& expected, - const TFLMRegistration& registration, - const TfLiteStatus expected_status = kTfLiteOk) { - TfLiteTensor tensors[kNumInputs + kNumOutputs] = {}; - - for (size_t i = 0; i < kNumInputs; i += 2) { - const TensorInDatum& tid_encode = *encodes.begin()[i / 2]; - tensors[i] = CreateTensor(tid_encode.data, - const_cast(&tid_encode.dims), - false, kTfLiteUInt8); - const TensorInDatum& tid_ancillary = *ancillaries.begin()[i / 2]; - tensors[i + 1] = CreateTensor( - tid_ancillary.data, const_cast(&tid_ancillary.dims), - false, kTfLiteUInt8); - } - for (size_t i = 0; i < kNumOutputs; i++) { - const TensorOutDatum& tod = *outputs.begin()[i]; - if (tod.scales.size == 0) { - tensors[i + kNumInputs] = CreateTensor( - tod.data, const_cast(&tod.dims), false, tod.type); - } else { - tensors[i + kNumInputs] = CreatePerChannelQuantizedTensor( - tod.data, const_cast(&tod.dims), - const_cast(&tod.scales), - const_cast(&tod.zero_points), - const_cast(&tod.affine_quantization), - tod.quantized_dimension, false, tod.type); - } - } - - TfLiteStatus s = ExecuteDecodeTest( - tensors, registration, expected); - TF_LITE_MICRO_EXPECT_EQ(s, expected_status); -} - } // namespace -} // namespace testing -} // namespace tflite TF_LITE_MICRO_TESTS_BEGIN using tflite::testing::AncillaryData; -using tflite::testing::kAncillaryDataLUT0; -using tflite::testing::kAncillaryDataLUT1; -using tflite::testing::kDcmLUT0; -using tflite::testing::kDcmLUT1; -using tflite::testing::kEncodedLUT; -using tflite::testing::kEncodedShapeLUT; -using tflite::testing::kExpectLUT0; -using tflite::testing::kExpectLUT1; -using tflite::testing::kOutputShapeLUT; using tflite::testing::TensorInDatum; using tflite::testing::TensorOutDatum; diff --git a/tensorflow/lite/micro/kernels/decode_test_helpers.h b/tensorflow/lite/micro/kernels/decode_test_helpers.h new file mode 100644 index 00000000000..96cc27cdecf --- /dev/null +++ b/tensorflow/lite/micro/kernels/decode_test_helpers.h @@ -0,0 +1,175 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_TEST_HELPERS_H_ +#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_TEST_HELPERS_H_ + +#include +#include +#include +#include + +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/decode_state.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/micro_common.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +struct TensorInDatum { + const void* const data; + const TfLiteIntArray& dims; +}; + +struct TensorOutDatum { + void* const data; + const TfLiteIntArray& dims; + const TfLiteType type; + const TfLiteFloatArray& scales; + const TfLiteIntArray& zero_points; + const int quantized_dimension; + + // initialized by CreatePerChannelQuantizedTensor + const TfLiteAffineQuantization affine_quantization; +}; + +template +struct AncillaryData { + AncillaryData() = delete; + AncillaryData(const uint8_t (&dcm)[tflite::DecodeState::kDcmSizeInBytes], + const T (&values)[N]) { + std::copy(std::begin(dcm), std::end(dcm), std::begin(dcm_)); + std::copy(std::begin(values), std::end(values), std::begin(value_table_)); + } + + private: + uint8_t dcm_[tflite::DecodeState::kDcmSizeInBytes]; + T value_table_[N > 0 ? N : 1]; // assure not zero length +}; + +template +TfLiteStatus CheckOutput(const TfLiteTensor& output, + const void* const expected) { + const T* const expected_data = reinterpret_cast(expected); + const T* const output_data = tflite::GetTensorData(&output); + + constexpr float kTolerance = 1e-5; + const size_t kOutputCount = tflite::NumElements(&output); + for (size_t i = 0; i < kOutputCount; i++) { + TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTolerance); + TF_LITE_MICRO_CHECK_FAIL(); + } + + return kTfLiteOk; +} + +template +TfLiteStatus ExecuteDecodeTest( + TfLiteTensor* tensors, const TFLMRegistration& registration, + const std::initializer_list& expected) { + int kInputArrayData[kNumInputs + 1] = {kNumInputs}; + for (size_t i = 0; i < kNumInputs; i++) { + kInputArrayData[i + 1] = i; + } + TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData); + + int kOutputArrayData[kNumOutputs + 1] = {kNumOutputs}; + for (size_t i = 0; i < kNumOutputs; i++) { + kOutputArrayData[i + 1] = i + kNumInputs; + } + TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData); + + micro::KernelRunner runner(registration, tensors, kNumInputs + kNumOutputs, + inputs_array, outputs_array, nullptr); + + if (runner.InitAndPrepare() != kTfLiteOk || runner.Invoke() != kTfLiteOk) { + return kTfLiteError; + } + + const TfLiteTensor* const output_tensors = &tensors[kNumInputs]; + TfLiteStatus status = kTfLiteError; + for (size_t i = 0; i < kNumOutputs; i++) { + switch (output_tensors[i].type) { + case kTfLiteInt8: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + case kTfLiteInt16: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + case kTfLiteFloat32: + status = CheckOutput(output_tensors[i], expected.begin()[i]); + break; + default: + TF_LITE_MICRO_FAIL("unsupported tensor type in test"); + break; + } + } + + return status; +} + +template +void TestDecode(const std::initializer_list& encodes, + const std::initializer_list& ancillaries, + const std::initializer_list& outputs, + const std::initializer_list& expected, + const TFLMRegistration& registration, + const TfLiteStatus expected_status = kTfLiteOk) { + TfLiteTensor tensors[kNumInputs + kNumOutputs] = {}; + + for (size_t i = 0; i < kNumInputs; i += 2) { + const TensorInDatum& tid_encode = *encodes.begin()[i / 2]; + tensors[i] = CreateTensor(tid_encode.data, + const_cast(&tid_encode.dims), + false, kTfLiteUInt8); + // must be a const tensor + tensors[i].allocation_type = kTfLiteMmapRo; + const TensorInDatum& tid_ancillary = *ancillaries.begin()[i / 2]; + tensors[i + 1] = CreateTensor( + tid_ancillary.data, const_cast(&tid_ancillary.dims), + false, kTfLiteUInt8); + // must be a const tensor + tensors[i + 1].allocation_type = kTfLiteMmapRo; + } + for (size_t i = 0; i < kNumOutputs; i++) { + const TensorOutDatum& tod = *outputs.begin()[i]; + if (tod.scales.size == 0) { + tensors[i + kNumInputs] = CreateTensor( + tod.data, const_cast(&tod.dims), false, tod.type); + } else { + tensors[i + kNumInputs] = CreatePerChannelQuantizedTensor( + tod.data, const_cast(&tod.dims), + const_cast(&tod.scales), + const_cast(&tod.zero_points), + const_cast(&tod.affine_quantization), + tod.quantized_dimension, false, tod.type); + } + } + + TfLiteStatus s = ExecuteDecodeTest( + tensors, registration, expected); + TF_LITE_MICRO_EXPECT_EQ(s, expected_status); +} + +} // namespace +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_TEST_HELPERS_H_ diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 46621194601..c3e1bbab3bf 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -389,6 +389,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_lut.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space.cc \