Skip to content

Commit 0e0a3b9

Browse files
committed
Sync from upstream TF.
1 parent 70955d0 commit 0e0a3b9

File tree

3 files changed

+111
-22
lines changed

3 files changed

+111
-22
lines changed

tensorflow/lite/kernels/internal/portable_tensor_utils.cc

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
1919

2020
#include <algorithm>
21+
#include <cassert>
2122
#include <cmath>
2223
#include <cstdint>
2324

@@ -92,23 +93,90 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
9293
}
9394
}
9495

95-
void PackInt8IntoDenseInt4(const int8_t* src_buffer, int num_elements,
96-
int8_t* dst_buffer) {
97-
// num_elements means the number of elements regardless of packed or unpacked.
98-
// For example, 3 elements means both
99-
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
100-
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
101-
// 2) Unpacked: 3 int8's = 3 bytes.
102-
// stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
103-
for (int i = 0; i < num_elements - 1; i += 2) {
104-
dst_buffer[i / 2] = src_buffer[i] & 0x0F;
105-
dst_buffer[i / 2] |= src_buffer[i + 1] << 4;
96+
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
97+
int bit_width, int8_t* dst_buffer) {
98+
assert(bit_width == 2 || bit_width == 4);
99+
if (bit_width == 4) {
100+
// num_elements means the number of elements regardless of packed or
101+
// unpacked. For example, 3 elements means both
102+
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
103+
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
104+
// 2) Unpacked: 3 int8's = 3 bytes.
105+
//. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
106+
for (int i = 0; i < num_elements / 2; i++) {
107+
int8_t byte = src_buffer[i];
108+
// Shift left first so that sign is properly extended when shifted right
109+
int8_t lower = static_cast<int8_t>(byte << 4) >> 4;
110+
int8_t higher = byte >> 4;
111+
dst_buffer[2 * i] = lower;
112+
dst_buffer[2 * i + 1] = higher;
113+
}
114+
115+
// If the buffer size is odd, extract the final lower nibble.
116+
if (num_elements % 2 != 0) {
117+
dst_buffer[num_elements - 1] =
118+
static_cast<int8_t>(src_buffer[num_elements / 2] << 4) >> 4;
119+
}
120+
} else if (bit_width == 2) {
121+
for (int i = 0; i < num_elements / 4; i++) {
122+
int8_t byte = src_buffer[i];
123+
// Shift left first so that sign is properly extended when shifted right
124+
int8_t val1 = static_cast<int8_t>(byte << 6) >> 6;
125+
int8_t val2 = static_cast<int8_t>((byte << 4) & 0xFF) >> 6;
126+
int8_t val3 = static_cast<int8_t>((byte << 2) & 0xFF) >> 6;
127+
int8_t val4 = byte >> 6;
128+
dst_buffer[4 * i] = val1;
129+
dst_buffer[4 * i + 1] = val2;
130+
dst_buffer[4 * i + 2] = val3;
131+
dst_buffer[4 * i + 3] = val4;
132+
}
133+
134+
// Handle the remaining elements.
135+
int remaining_elements = num_elements % 4;
136+
if (remaining_elements > 0) {
137+
int8_t byte = src_buffer[num_elements / 4];
138+
for (int i = 0; i < remaining_elements; i++) {
139+
dst_buffer[num_elements - remaining_elements + i] =
140+
static_cast<int8_t>((byte << (6 - 2 * i)) & 0xFF) >> 6;
141+
}
142+
}
106143
}
107-
auto packed_size = (num_elements + 1) / 2;
144+
}
108145

109-
// Copy the final nibble if the buffer is odd-lengthed
110-
if (num_elements % 2 != 0) {
111-
dst_buffer[packed_size - 1] = src_buffer[num_elements - 1] & 0x0F;
146+
void PackInt8IntoDenseInt(const int8_t* src_buffer, int num_elements,
147+
int bit_width, int8_t* dst_buffer) {
148+
assert(bit_width == 2 || bit_width == 4);
149+
if (bit_width == 4) {
150+
// num_elements means the number of elements regardless of packed or
151+
// unpacked. For example, 3 elements means both
152+
// 1) Unpacked: 3 int8's = 3 bytes.
153+
// stored in src_buffer[0], src_buffer[1] and src_buffer[2] (j = 0..2)
154+
// 2) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
155+
// stored in dst_buffer[0] and dst_buffer[1] (i = 0..1)
156+
for (int i = 0; i < num_elements / 2; ++i) {
157+
dst_buffer[i] = (src_buffer[2 * i] & 0x0F) | (src_buffer[2 * i + 1] << 4);
158+
}
159+
// If the buffer size is odd, pack the final nibble.
160+
if (num_elements % 2 != 0) {
161+
dst_buffer[num_elements / 2] = src_buffer[num_elements - 1] & 0x0F;
162+
}
163+
} else if (bit_width == 2) {
164+
for (int i = 0; i < num_elements / 4; ++i) {
165+
dst_buffer[i] = (src_buffer[4 * i] & 0x03) |
166+
((src_buffer[4 * i + 1] & 0x03) << 2) |
167+
((src_buffer[4 * i + 2] & 0x03) << 4) |
168+
((src_buffer[4 * i + 3] & 0x03) << 6);
169+
}
170+
// Handle the remaining elements.
171+
int remaining_elements = num_elements % 4;
172+
if (remaining_elements > 0) {
173+
int8_t packed_val = 0;
174+
for (int i = 0; i < remaining_elements; ++i) {
175+
packed_val |= (src_buffer[num_elements - remaining_elements + i] & 0x03)
176+
<< (i * 2);
177+
}
178+
dst_buffer[num_elements / 4] = packed_val;
179+
}
112180
}
113181
}
114182

tensorflow/lite/kernels/internal/portable_tensor_utils.h

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,20 +618,41 @@ void ApplySignbitToVector(const float* __restrict__ vector, int v_size,
618618
void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
619619
int8_t* dst_buffer);
620620

621-
// Pack `src_buffer` into a densely packed buffer of int4 values.
621+
// Unpack or inflate `src_buffer` by taking each byte and splitting it into
622+
// multiple elements into `dst_buffer`. Supports 2-bit and 4-bit packed integers
622623
// Parameters:
623-
// src_buffer : Buffer containing int4 values stored in int8 memory.
624+
// src_buffer : Densely packed buffer containing int2 or int4 values.
625+
// num_elements : Number of unpacked elements to be read from the buffer.
626+
// This should be equal to the size of `dst_buffer`.
627+
// bit_width : The bit width of the packed elements (either 2 or 4).
628+
// dst_buffer : Buffer to unpack into. Should be allocated by the caller.
629+
// Size should be at least `num_elements`.
630+
// Notes:
631+
// For 4-bit unpacking: e.g., `src_buffer = {0x12, 0x34};` (num_elements = 4)
632+
// will return `dst_buffer = {0x02, 0x01, 0x04, 0x03}`.
633+
// For 2-bit unpacking: e.g., `src_buffer = {0x12};` (num_elements = 4)
634+
// will return `dst_buffer = {0x02, 0x00, 0x01, 0x00}` (sign extended).
635+
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
636+
int bit_width, int8_t* dst_buffer);
637+
638+
// Pack `src_buffer` into a densely packed buffer of int2 or int4 values.
639+
// Parameters:
640+
// src_buffer : Buffer containing int2 or int4 values stored in int8
641+
// memory.
624642
// num_elements : Number of elements stored in the buffer. Note that this can
625643
// be smaller than the size of `src_buffer` by 1 if it's odd,
626644
// in which case the last nibble in `src_buffer` is ignored.
627645
// This should be equal to the size of `dst_buffer`.
646+
// bit_width : The bit width of the packed elements (either 2 or 4).
628647
// dst_buffer : Buffer to pack into. Should be allocated by the caller.
629648
// Size should be at least `num_elements`.
630649
// Notes:
631-
// For example, given `src_buffer = {0x02, 0x01, 0x04, 0x03}`, calling this
632-
// function will return `dst_buffer = {0x12, 0x34}`.
633-
void PackInt8IntoDenseInt4(const int8_t* src_buffer, int num_elements,
634-
int8_t* dst_buffer);
650+
// For 4-bit packing: e.g., given `src_buffer = {0x02, 0x01, 0x04, 0x03}`,
651+
// calling this function will return `dst_buffer = {0x12, 0x34}`.
652+
// For 2-bit packing: e.g., given `src_buffer = {0x00, 0x01, 0x00, 0x02}`,
653+
// calling this function will return `dst_buffer = {0x84}`.
654+
void PackInt8IntoDenseInt(const int8_t* src_buffer, int num_elements,
655+
int bit_width, int8_t* dst_buffer);
635656
} // namespace tensor_utils
636657

637658
} // namespace tflite

tensorflow/lite/tools/visualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
3434
else:
3535
# This file is part of tflite_runtime package.
36-
from tflite_runtime import schema_py_generated as schema_fb
36+
from tflite_micro.tensorflow.lite_runtime import schema_py_generated as schema_fb
3737

3838
# A CSS description for making the visualizer
3939
_CSS = """

0 commit comments

Comments
 (0)