@@ -18,6 +18,7 @@ limitations under the License.
1818#include " tensorflow/lite/kernels/internal/portable_tensor_utils.h"
1919
2020#include < algorithm>
21+ #include < cassert>
2122#include < cmath>
2223#include < cstdint>
2324
@@ -92,6 +93,56 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
9293 }
9394}
9495
96+ void UnpackPackedIntToInt8 (const int8_t * src_buffer, int num_elements,
97+ int bit_width, int8_t * dst_buffer) {
98+ assert (bit_width == 2 || bit_width == 4 );
99+ if (bit_width == 4 ) {
100+ // num_elements means the number of elements regardless of packed or
101+ // unpacked. For example, 3 elements means both
102+ // 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
103+ // stored in src_buffer[0] and src_buffer[1] (i = 0..1)
104+ // 2) Unpacked: 3 int8's = 3 bytes.
105+ // . stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
106+ for (int i = 0 ; i < num_elements / 2 ; i++) {
107+ int8_t byte = src_buffer[i];
108+ // Shift left first so that sign is properly extended when shifted right
109+ int8_t lower = static_cast <int8_t >(byte << 4 ) >> 4 ;
110+ int8_t higher = byte >> 4 ;
111+ dst_buffer[2 * i] = lower;
112+ dst_buffer[2 * i + 1 ] = higher;
113+ }
114+
115+ // If the buffer size is odd, extract the final lower nibble.
116+ if (num_elements % 2 != 0 ) {
117+ dst_buffer[num_elements - 1 ] =
118+ static_cast <int8_t >(src_buffer[num_elements / 2 ] << 4 ) >> 4 ;
119+ }
120+ } else if (bit_width == 2 ) {
121+ for (int i = 0 ; i < num_elements / 4 ; i++) {
122+ int8_t byte = src_buffer[i];
123+ // Shift left first so that sign is properly extended when shifted right
124+ int8_t val1 = static_cast <int8_t >(byte << 6 ) >> 6 ;
125+ int8_t val2 = static_cast <int8_t >((byte << 4 ) & 0xFF ) >> 6 ;
126+ int8_t val3 = static_cast <int8_t >((byte << 2 ) & 0xFF ) >> 6 ;
127+ int8_t val4 = byte >> 6 ;
128+ dst_buffer[4 * i] = val1;
129+ dst_buffer[4 * i + 1 ] = val2;
130+ dst_buffer[4 * i + 2 ] = val3;
131+ dst_buffer[4 * i + 3 ] = val4;
132+ }
133+
134+ // Handle the remaining elements.
135+ int remaining_elements = num_elements % 4 ;
136+ if (remaining_elements > 0 ) {
137+ int8_t byte = src_buffer[num_elements / 4 ];
138+ for (int i = 0 ; i < remaining_elements; i++) {
139+ dst_buffer[num_elements - remaining_elements + i] =
140+ static_cast <int8_t >((byte << (6 - 2 * i)) & 0xFF ) >> 6 ;
141+ }
142+ }
143+ }
144+ }
145+
95146void PackInt8IntoDenseInt4 (const int8_t * src_buffer, int num_elements,
96147 int8_t * dst_buffer) {
97148 // num_elements means the number of elements regardless of packed or unpacked.
0 commit comments