Skip to content

Commit 5227dba

Browse files
committed
Sync from upstream TF.
1 parent 70955d0 commit 5227dba

File tree

3 files changed

+69
-1
lines changed

3 files changed

+69
-1
lines changed

tensorflow/lite/kernels/internal/portable_tensor_utils.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
1919

2020
#include <algorithm>
21+
#include <cassert>
2122
#include <cmath>
2223
#include <cstdint>
2324

@@ -92,6 +93,56 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
9293
}
9394
}
9495

96+
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
97+
int bit_width, int8_t* dst_buffer) {
98+
assert(bit_width == 2 || bit_width == 4);
99+
if (bit_width == 4) {
100+
// num_elements means the number of elements regardless of packed or
101+
// unpacked. For example, 3 elements means both
102+
// 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
103+
// stored in src_buffer[0] and src_buffer[1] (i = 0..1)
104+
// 2) Unpacked: 3 int8's = 3 bytes.
105+
//. stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
106+
for (int i = 0; i < num_elements / 2; i++) {
107+
int8_t byte = src_buffer[i];
108+
// Shift left first so that sign is properly extended when shifted right
109+
int8_t lower = static_cast<int8_t>(byte << 4) >> 4;
110+
int8_t higher = byte >> 4;
111+
dst_buffer[2 * i] = lower;
112+
dst_buffer[2 * i + 1] = higher;
113+
}
114+
115+
// If the buffer size is odd, extract the final lower nibble.
116+
if (num_elements % 2 != 0) {
117+
dst_buffer[num_elements - 1] =
118+
static_cast<int8_t>(src_buffer[num_elements / 2] << 4) >> 4;
119+
}
120+
} else if (bit_width == 2) {
121+
for (int i = 0; i < num_elements / 4; i++) {
122+
int8_t byte = src_buffer[i];
123+
// Shift left first so that sign is properly extended when shifted right
124+
int8_t val1 = static_cast<int8_t>(byte << 6) >> 6;
125+
int8_t val2 = static_cast<int8_t>((byte << 4) & 0xFF) >> 6;
126+
int8_t val3 = static_cast<int8_t>((byte << 2) & 0xFF) >> 6;
127+
int8_t val4 = byte >> 6;
128+
dst_buffer[4 * i] = val1;
129+
dst_buffer[4 * i + 1] = val2;
130+
dst_buffer[4 * i + 2] = val3;
131+
dst_buffer[4 * i + 3] = val4;
132+
}
133+
134+
// Handle the remaining elements.
135+
int remaining_elements = num_elements % 4;
136+
if (remaining_elements > 0) {
137+
int8_t byte = src_buffer[num_elements / 4];
138+
for (int i = 0; i < remaining_elements; i++) {
139+
dst_buffer[num_elements - remaining_elements + i] =
140+
static_cast<int8_t>((byte << (6 - 2 * i)) & 0xFF) >> 6;
141+
}
142+
}
143+
}
144+
}
145+
95146
void PackInt8IntoDenseInt4(const int8_t* src_buffer, int num_elements,
96147
int8_t* dst_buffer) {
97148
// num_elements means the number of elements regardless of packed or unpacked.

tensorflow/lite/kernels/internal/portable_tensor_utils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,23 @@ void ApplySignbitToVector(const float* __restrict__ vector, int v_size,
618618
void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
619619
int8_t* dst_buffer);
620620

621+
// Unpack or inflate `src_buffer` by taking each byte and splitting it into
622+
// multiple elements into `dst_buffer`. Supports 2-bit and 4-bit packed integers
623+
// Parameters:
624+
// src_buffer : Densely packed buffer containing int2 or int4 values.
625+
// num_elements : Number of unpacked elements to be read from the buffer.
626+
// This should be equal to the size of `dst_buffer`.
627+
// bit_width : The bit width of the packed elements (either 2 or 4).
628+
// dst_buffer : Buffer to unpack into. Should be allocated by the caller.
629+
// Size should be at least `num_elements`.
630+
// Notes:
631+
// For 4-bit unpacking: e.g., `src_buffer = {0x12, 0x34};` (num_elements = 4)
632+
// will return `dst_buffer = {0x02, 0x01, 0x04, 0x03}`.
633+
// For 2-bit unpacking: e.g., `src_buffer = {0x12};` (num_elements = 4)
634+
// will return `dst_buffer = {0x02, 0x00, 0x01, 0x00}` (sign extended).
635+
void UnpackPackedIntToInt8(const int8_t* src_buffer, int num_elements,
636+
int bit_width, int8_t* dst_buffer);
637+
621638
// Pack `src_buffer` into a densely packed buffer of int4 values.
622639
// Parameters:
623640
// src_buffer : Buffer containing int4 values stored in int8 memory.

tensorflow/lite/tools/visualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
3434
else:
3535
# This file is part of tflite_runtime package.
36-
from tflite_runtime import schema_py_generated as schema_fb
36+
from tflite_micro.tensorflow.lite_runtime import schema_py_generated as schema_fb
3737

3838
# A CSS description for making the visualizer
3939
_CSS = """

0 commit comments

Comments
 (0)