@@ -18,6 +18,7 @@ limitations under the License.
1818#include " tensorflow/lite/kernels/internal/portable_tensor_utils.h"
1919
2020#include < algorithm>
21+ #include < cassert>
2122#include < cmath>
2223#include < cstdint>
2324
@@ -92,23 +93,90 @@ void UnpackDenseInt4IntoInt8(const int8_t* src_buffer, int num_elements,
9293 }
9394}
9495
95- void PackInt8IntoDenseInt4 (const int8_t * src_buffer, int num_elements,
96- int8_t * dst_buffer) {
97- // num_elements means the number of elements regardless of packed or unpacked.
98- // For example, 3 elements means both
99- // 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
100- // stored in src_buffer[0] and src_buffer[1] (i = 0..1)
101- // 2) Unpacked: 3 int8's = 3 bytes.
102- // stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
103- for (int i = 0 ; i < num_elements - 1 ; i += 2 ) {
104- dst_buffer[i / 2 ] = src_buffer[i] & 0x0F ;
105- dst_buffer[i / 2 ] |= src_buffer[i + 1 ] << 4 ;
96+ void UnpackPackedIntToInt8 (const int8_t * src_buffer, int num_elements,
97+ int bit_width, int8_t * dst_buffer) {
98+ assert (bit_width == 2 || bit_width == 4 );
99+ if (bit_width == 4 ) {
100+ // num_elements means the number of elements regardless of packed or
101+ // unpacked. For example, 3 elements means both
102+ // 1) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
103+ // stored in src_buffer[0] and src_buffer[1] (i = 0..1)
104+ // 2) Unpacked: 3 int8's = 3 bytes.
105+ // . stored in dst_buffer[0], dst_buffer[1] and dst_buffer[2] (j = 0..2)
106+ for (int i = 0 ; i < num_elements / 2 ; i++) {
107+ int8_t byte = src_buffer[i];
108+ // Shift left first so that sign is properly extended when shifted right
109+ int8_t lower = static_cast <int8_t >(byte << 4 ) >> 4 ;
110+ int8_t higher = byte >> 4 ;
111+ dst_buffer[2 * i] = lower;
112+ dst_buffer[2 * i + 1 ] = higher;
113+ }
114+
115+ // If the buffer size is odd, extract the final lower nibble.
116+ if (num_elements % 2 != 0 ) {
117+ dst_buffer[num_elements - 1 ] =
118+ static_cast <int8_t >(src_buffer[num_elements / 2 ] << 4 ) >> 4 ;
119+ }
120+ } else if (bit_width == 2 ) {
121+ for (int i = 0 ; i < num_elements / 4 ; i++) {
122+ int8_t byte = src_buffer[i];
123+ // Shift left first so that sign is properly extended when shifted right
124+ int8_t val1 = static_cast <int8_t >(byte << 6 ) >> 6 ;
125+ int8_t val2 = static_cast <int8_t >((byte << 4 ) & 0xFF ) >> 6 ;
126+ int8_t val3 = static_cast <int8_t >((byte << 2 ) & 0xFF ) >> 6 ;
127+ int8_t val4 = byte >> 6 ;
128+ dst_buffer[4 * i] = val1;
129+ dst_buffer[4 * i + 1 ] = val2;
130+ dst_buffer[4 * i + 2 ] = val3;
131+ dst_buffer[4 * i + 3 ] = val4;
132+ }
133+
134+ // Handle the remaining elements.
135+ int remaining_elements = num_elements % 4 ;
136+ if (remaining_elements > 0 ) {
137+ int8_t byte = src_buffer[num_elements / 4 ];
138+ for (int i = 0 ; i < remaining_elements; i++) {
139+ dst_buffer[num_elements - remaining_elements + i] =
140+ static_cast <int8_t >((byte << (6 - 2 * i)) & 0xFF ) >> 6 ;
141+ }
142+ }
106143 }
107- auto packed_size = (num_elements + 1 ) / 2 ;
144+ }
108145
109- // Copy the final nibble if the buffer is odd-lengthed
110- if (num_elements % 2 != 0 ) {
111- dst_buffer[packed_size - 1 ] = src_buffer[num_elements - 1 ] & 0x0F ;
146+ void PackInt8IntoDenseInt (const int8_t * src_buffer, int num_elements,
147+ int bit_width, int8_t * dst_buffer) {
148+ assert (bit_width == 2 || bit_width == 4 );
149+ if (bit_width == 4 ) {
150+ // num_elements means the number of elements regardless of packed or
151+ // unpacked. For example, 3 elements means both
152+ // 1) Unpacked: 3 int8's = 3 bytes.
153+ // stored in src_buffer[0], src_buffer[1] and src_buffer[2] (j = 0..2)
154+ // 2) Packed: 3 int4's = 12 bit -> 16 bits (padded) = 2 bytes.
155+ // stored in dst_buffer[0] and dst_buffer[1] (i = 0..1)
156+ for (int i = 0 ; i < num_elements / 2 ; ++i) {
157+ dst_buffer[i] = (src_buffer[2 * i] & 0x0F ) | (src_buffer[2 * i + 1 ] << 4 );
158+ }
159+ // If the buffer size is odd, pack the final nibble.
160+ if (num_elements % 2 != 0 ) {
161+ dst_buffer[num_elements / 2 ] = src_buffer[num_elements - 1 ] & 0x0F ;
162+ }
163+ } else if (bit_width == 2 ) {
164+ for (int i = 0 ; i < num_elements / 4 ; ++i) {
165+ dst_buffer[i] = (src_buffer[4 * i] & 0x03 ) |
166+ ((src_buffer[4 * i + 1 ] & 0x03 ) << 2 ) |
167+ ((src_buffer[4 * i + 2 ] & 0x03 ) << 4 ) |
168+ ((src_buffer[4 * i + 3 ] & 0x03 ) << 6 );
169+ }
170+ // Handle the remaining elements.
171+ int remaining_elements = num_elements % 4 ;
172+ if (remaining_elements > 0 ) {
173+ int8_t packed_val = 0 ;
174+ for (int i = 0 ; i < remaining_elements; ++i) {
175+ packed_val |= (src_buffer[num_elements - remaining_elements + i] & 0x03 )
176+ << (i * 2 );
177+ }
178+ dst_buffer[num_elements / 4 ] = packed_val;
179+ }
112180 }
113181}
114182
0 commit comments