Skip to content

Commit 37fb41f

Browse files
TFLM-botveblush
andauthored
Automated sync from github.com/tensorflow/tensorflow (#3214)
* Sync from upstream TF. * Manual change from cl/819821231 * Fix * Copyright --------- Co-authored-by: Esun Kim <[email protected]>
1 parent 7e1eaa0 commit 37fb41f

File tree

13 files changed

+142
-32
lines changed

13 files changed

+142
-32
lines changed

python/tflite_micro/numpy_utils.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
5858
case kTfLiteInt4:
5959
// TODO(b/246806634): NPY_INT4 currently doesn't exist
6060
return NPY_BYTE;
61+
case kTfLiteInt2:
62+
// TODO(b/246806634): NPY_INT2 currently doesn't exist
63+
return NPY_BYTE;
6164
case kTfLiteInt8:
6265
return NPY_INT8;
6366
case kTfLiteInt64:

tensorflow/compiler/mlir/lite/core/c/tflite_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ typedef enum {
6464
kTfLiteUInt16 = 17,
6565
kTfLiteInt4 = 18,
6666
kTfLiteBFloat16 = 19,
67+
kTfLiteInt2 = 20,
6768
} TfLiteType;
6869
// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType)
6970

tensorflow/compiler/mlir/lite/schema/schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ enum TensorType : byte {
5959
UINT16 = 16,
6060
INT4 = 17,
6161
BFLOAT16 = 18,
62+
INT2 = 19,
6263
}
6364

6465
// Custom quantization parameters for experimenting with new quantization

tensorflow/lite/core/api/flatbuffer_conversions.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
10881088
case TensorType_INT4:
10891089
*type = kTfLiteInt4;
10901090
return kTfLiteOk;
1091+
case TensorType_INT2:
1092+
*type = kTfLiteInt2;
1093+
return kTfLiteOk;
10911094
default:
10921095
*type = kTfLiteNoType;
10931096
TF_LITE_REPORT_ERROR(error_reporter,

tensorflow/lite/core/c/common.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,8 @@ const char* TfLiteTypeGetName(TfLiteType type) {
509509
return "VARIANT";
510510
case kTfLiteInt4:
511511
return "INT4";
512+
case kTfLiteInt2:
513+
return "INT2";
512514
}
513515
return "Unknown type";
514516
}

tensorflow/lite/kernels/internal/portable_tensor_utils.cc

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
1919

2020
#include <algorithm>
21+
#include <cassert>
2122
#include <cmath>
2223
#include <cstdint>
2324

@@ -92,23 +93,90 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
9293
}
9394
}
9495

95-
void PackInt8IntoDenseInt4(const int8_t* src_buffer, int num_elements,
96-
int8_t* dst_buffer) {
97-
// num_elements means the number of elements regardless of packed or unpacked.
98-
// For example, 3 elements means both
99-
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
100-
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
101-
// 2) Unpacked: 3 int8's = 3 bytes.
102-
// stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
103-
for (int i = 0; i < num_elements - 1; i += 2) {
104-
dst_buffer[i / 2] = src_buffer[i] & 0x0F;
105-
dst_buffer[i / 2] |= src_buffer[i + 1] << 4;
96+
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
97+
int bit_width, int8_t* dst_buffer) {
98+
assert(bit_width == 2 || bit_width == 4);
99+
if (bit_width == 4) {
100+
// num_elements means the number of elements regardless of packed or
101+
// unpacked. For example, 3 elements means both
102+
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
103+
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
104+
// 2) Unpacked: 3 int8's = 3 bytes.
105+
//. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
106+
for (int i = 0; i < num_elements / 2; i++) {
107+
int8_t byte = src_buffer[i];
108+
// Shift left first so that sign is properly extended when shifted right
109+
int8_t lower = static_cast<int8_t>(byte << 4) >> 4;
110+
int8_t higher = byte >> 4;
111+
dst_buffer[2 * i] = lower;
112+
dst_buffer[2 * i + 1] = higher;
113+
}
114+
115+
// If the buffer size is odd, extract the final lower nibble.
116+
if (num_elements % 2 != 0) {
117+
dst_buffer[num_elements - 1] =
118+
static_cast<int8_t>(src_buffer[num_elements / 2] << 4) >> 4;
119+
}
120+
} else if (bit_width == 2) {
121+
for (int i = 0; i < num_elements / 4; i++) {
122+
int8_t byte = src_buffer[i];
123+
// Shift left first so that sign is properly extended when shifted right
124+
int8_t val1 = static_cast<int8_t>(byte << 6) >> 6;
125+
int8_t val2 = static_cast<int8_t>((byte << 4) & 0xFF) >> 6;
126+
int8_t val3 = static_cast<int8_t>((byte << 2) & 0xFF) >> 6;
127+
int8_t val4 = byte >> 6;
128+
dst_buffer[4 * i] = val1;
129+
dst_buffer[4 * i + 1] = val2;
130+
dst_buffer[4 * i + 2] = val3;
131+
dst_buffer[4 * i + 3] = val4;
132+
}
133+
134+
// Handle the remaining elements.
135+
int remaining_elements = num_elements % 4;
136+
if (remaining_elements > 0) {
137+
int8_t byte = src_buffer[num_elements / 4];
138+
for (int i = 0; i < remaining_elements; i++) {
139+
dst_buffer[num_elements - remaining_elements + i] =
140+
static_cast<int8_t>((byte << (6 - 2 * i)) & 0xFF) >> 6;
141+
}
142+
}
106143
}
107-
auto packed_size = (num_elements + 1) / 2;
144+
}
108145

109-
// Copy the final nibble if the buffer is odd-lengthed
110-
if (num_elements % 2 != 0) {
111-
dst_buffer[packed_size - 1] = src_buffer[num_elements - 1] & 0x0F;
146+
void PackInt8IntoDenseInt(const int8_t* src_buffer, int num_elements,
147+
int bit_width, int8_t* dst_buffer) {
148+
assert(bit_width == 2 || bit_width == 4);
149+
if (bit_width == 4) {
150+
// num_elements means the number of elements regardless of packed or
151+
// unpacked. For example, 3 elements means both
152+
// 1) Unpacked: 3 int8's = 3 bytes.
153+
// stored in src_buffer[0], src_buffer[1] and src_buffer[2] (j = 0..2)
154+
// 2) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
155+
// stored in dst_buffer[0] and dst_buffer[1] (i = 0..1)
156+
for (int i = 0; i < num_elements / 2; ++i) {
157+
dst_buffer[i] = (src_buffer[2 * i] & 0x0F) | (src_buffer[2 * i + 1] << 4);
158+
}
159+
// If the buffer size is odd, pack the final nibble.
160+
if (num_elements % 2 != 0) {
161+
dst_buffer[num_elements / 2] = src_buffer[num_elements - 1] & 0x0F;
162+
}
163+
} else if (bit_width == 2) {
164+
for (int i = 0; i < num_elements / 4; ++i) {
165+
dst_buffer[i] = (src_buffer[4 * i] & 0x03) |
166+
((src_buffer[4 * i + 1] & 0x03) << 2) |
167+
((src_buffer[4 * i + 2] & 0x03) << 4) |
168+
((src_buffer[4 * i + 3] & 0x03) << 6);
169+
}
170+
// Handle the remaining elements.
171+
int remaining_elements = num_elements % 4;
172+
if (remaining_elements > 0) {
173+
int8_t packed_val = 0;
174+
for (int i = 0; i < remaining_elements; ++i) {
175+
packed_val |= (src_buffer[num_elements - remaining_elements + i] & 0x03)
176+
<< (i * 2);
177+
}
178+
dst_buffer[num_elements / 4] = packed_val;
179+
}
112180
}
113181
}
114182

tensorflow/lite/kernels/internal/portable_tensor_utils.h

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,20 +618,41 @@ void ApplySignbitToVector(const float* __restrict__ vector, int v_size,
618618
void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
619619
int8_t* dst_buffer);
620620

621-
// Pack `src_buffer` into a densely packed buffer of int4 values.
621+
// Unpack or inflate `src_buffer` by taking each byte and splitting it into
622+
// multiple elements into `dst_buffer`. Supports 2-bit and 4-bit packed integers
622623
// Parameters:
623-
// src_buffer : Buffer containing int4 values stored in int8 memory.
624+
// src_buffer : Densely packed buffer containing int2 or int4 values.
625+
// num_elements : Number of unpacked elements to be read from the buffer.
626+
// This should be equal to the size of `dst_buffer`.
627+
// bit_width : The bit width of the packed elements (either 2 or 4).
628+
// dst_buffer : Buffer to unpack into. Should be allocated by the caller.
629+
// Size should be at least `num_elements`.
630+
// Notes:
631+
// For 4-bit unpacking: e.g., `src_buffer = {0x12, 0x34};` (num_elements = 4)
632+
// will return `dst_buffer = {0x02, 0x01, 0x04, 0x03}`.
633+
// For 2-bit unpacking: e.g., `src_buffer = {0x12};` (num_elements = 4)
634+
// will return `dst_buffer = {0x02, 0x00, 0x01, 0x00}` (sign extended).
635+
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
636+
int bit_width, int8_t* dst_buffer);
637+
638+
// Pack `src_buffer` into a densely packed buffer of int2 or int4 values.
639+
// Parameters:
640+
// src_buffer : Buffer containing int2 or int4 values stored in int8
641+
// memory.
624642
// num_elements : Number of elements stored in the buffer. Note that this can
625643
// be smaller than the size of `src_buffer` by 1 if it's odd,
626644
// in which case the last nibble in `src_buffer` is ignored.
627645
// This should be equal to the size of `dst_buffer`.
646+
// bit_width : The bit width of the packed elements (either 2 or 4).
628647
// dst_buffer : Buffer to pack into. Should be allocated by the caller.
629648
// Size should be at least `num_elements`.
630649
// Notes:
631-
// For example, given `src_buffer = {0x02, 0x01, 0x04, 0x03}`, calling this
632-
// function will return `dst_buffer = {0x12, 0x34}`.
633-
void PackInt8IntoDenseInt4(const int8_t* src_buffer, int num_elements,
634-
int8_t* dst_buffer);
650+
// For 4-bit packing: e.g., given `src_buffer = {0x02, 0x01, 0x04, 0x03}`,
651+
// calling this function will return `dst_buffer = {0x12, 0x34}`.
652+
// For 2-bit packing: e.g., given `src_buffer = {0x00, 0x01, 0x00, 0x02}`,
653+
// calling this function will return `dst_buffer = {0x84}`.
654+
void PackInt8IntoDenseInt(const int8_t* src_buffer, int num_elements,
655+
int bit_width, int8_t* dst_buffer);
635656
} // namespace tensor_utils
636657

637658
} // namespace tflite

tensorflow/lite/micro/tools/layer_by_layer.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ TfLiteStatus ConvertTensorType(TfLiteType type, TensorTypes& tensor_type) {
120120
case kTfLiteInt4:
121121
tensor_type = TensorTypes_INT4;
122122
return kTfLiteOk;
123+
case kTfLiteInt2:
124+
tensor_type = TensorTypes_INT2;
125+
return kTfLiteOk;
123126
case kTfLiteNoType:
124127
MicroPrintf("Unsupported data type %d in tensor\n", tensor_type);
125128
return kTfLiteError;

tensorflow/lite/micro/tools/layer_by_layer_schema.fbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ enum TensorTypes : byte {
3535
UINT16 = 16,
3636
INT4 = 17,
3737
BFLOAT16 = 18,
38+
INT2 = 19,
3839
}
3940

4041
table TensorData {

tensorflow/lite/micro/tools/layer_by_layer_schema_generated.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,12 @@ enum TensorTypes : int8_t {
5959
TensorTypes_UINT16 = 16,
6060
TensorTypes_INT4 = 17,
6161
TensorTypes_BFLOAT16 = 18,
62+
TensorTypes_INT2 = 19,
6263
TensorTypes_MIN = TensorTypes_FLOAT32,
63-
TensorTypes_MAX = TensorTypes_BFLOAT16
64+
TensorTypes_MAX = TensorTypes_INT2
6465
};
6566

66-
inline const TensorTypes (&EnumValuesTensorTypes())[19] {
67+
inline const TensorTypes (&EnumValuesTensorTypes())[20] {
6768
static const TensorTypes values[] = {
6869
TensorTypes_FLOAT32,
6970
TensorTypes_FLOAT16,
@@ -83,13 +84,14 @@ inline const TensorTypes (&EnumValuesTensorTypes())[19] {
8384
TensorTypes_UINT32,
8485
TensorTypes_UINT16,
8586
TensorTypes_INT4,
86-
TensorTypes_BFLOAT16
87+
TensorTypes_BFLOAT16,
88+
TensorTypes_INT2
8789
};
8890
return values;
8991
}
9092

9193
inline const char * const *EnumNamesTensorTypes() {
92-
static const char * const names[20] = {
94+
static const char * const names[21] = {
9395
"FLOAT32",
9496
"FLOAT16",
9597
"INT32",
@@ -109,13 +111,14 @@ inline const char * const *EnumNamesTensorTypes() {
109111
"UINT16",
110112
"INT4",
111113
"BFLOAT16",
114+
"INT2",
112115
nullptr
113116
};
114117
return names;
115118
}
116119

117120
inline const char *EnumNameTensorTypes(TensorTypes e) {
118-
if (::flatbuffers::IsOutRange(e, TensorTypes_FLOAT32, TensorTypes_BFLOAT16)) return "";
121+
if (::flatbuffers::IsOutRange(e, TensorTypes_FLOAT32, TensorTypes_INT2)) return "";
119122
const size_t index = static_cast<size_t>(e);
120123
return EnumNamesTensorTypes()[index];
121124
}

0 commit comments

Comments
 (0)