5353
5454#include "datatypes.h"
5555
56+ #if defined(__SSE2__ )
57+ #include <tmmintrin.h>
58+ #endif
59+
60+ #if defined(_MSC_VER )
61+ #define ALIGNMENT (N ) __declspec(align(N))
62+ #else
63+ #define ALIGNMENT (N ) __attribute__((aligned(N)))
64+ #endif
65+
5666/*
5767 * bit_string is a buffer that is used to hold output strings, e.g.
5868 * for printing.
@@ -123,6 +133,9 @@ char *v128_bit_string(v128_t *x)
123133
124134void v128_copy_octet_string (v128_t * x , const uint8_t s [16 ])
125135{
136+ #if defined(__SSE2__ )
137+ _mm_storeu_si128 ((__m128i * )(x ), _mm_loadu_si128 ((const __m128i * )(s )));
138+ #else
126139#ifdef ALIGNMENT_32BIT_REQUIRED
127140 if ((((uint32_t )& s [0 ]) & 0x3 ) != 0 )
128141#endif
@@ -151,8 +164,67 @@ void v128_copy_octet_string(v128_t *x, const uint8_t s[16])
151164 v128_copy (x , v );
152165 }
153166#endif
167+ #endif /* defined(__SSE2__) */
154168}
155169
170+ #if defined(__SSSE3__ )
171+
172+ /* clang-format off */
173+
174+ ALIGNMENT (16 )
175+ static const uint8_t right_shift_masks [5 ][16 ] = {
176+ { 0u , 1u , 2u , 3u , 4u , 5u , 6u , 7u ,
177+ 8u , 9u , 10u , 11u , 12u , 13u , 14u , 15u },
178+ { 0x80 , 0x80 , 0x80 , 0x80 , 0u , 1u , 2u , 3u ,
179+ 4u , 5u , 6u , 7u , 8u , 9u , 10u , 11u },
180+ { 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 ,
181+ 0u , 1u , 2u , 3u , 4u , 5u , 6u , 7u },
182+ { 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 ,
183+ 0x80 , 0x80 , 0x80 , 0x80 , 0u , 1u , 2u , 3u },
184+ /* needed for bitvector_left_shift */
185+ { 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 ,
186+ 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 }
187+ };
188+
189+ ALIGNMENT (16 )
190+ static const uint8_t left_shift_masks [4 ][16 ] = {
191+ { 0u , 1u , 2u , 3u , 4u , 5u , 6u , 7u ,
192+ 8u , 9u , 10u , 11u , 12u , 13u , 14u , 15u },
193+ { 4u , 5u , 6u , 7u , 8u , 9u , 10u , 11u ,
194+ 12u , 13u , 14u , 15u , 0x80 , 0x80 , 0x80 , 0x80 },
195+ { 8u , 9u , 10u , 11u , 12u , 13u , 14u , 15u ,
196+ 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 },
197+ { 12u , 13u , 14u , 15u , 0x80 , 0x80 , 0x80 , 0x80 ,
198+ 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 , 0x80 }
199+ };
200+
201+ /* clang-format on */
202+
203+ void v128_left_shift (v128_t * x , int shift )
204+ {
205+ if (shift > 127 ) {
206+ v128_set_to_zero (x );
207+ return ;
208+ }
209+
210+ const int base_index = shift >> 5 ;
211+ const int bit_index = shift & 31 ;
212+
213+ __m128i mm = _mm_loadu_si128 ((const __m128i * )x );
214+ __m128i mm_shift_right = _mm_cvtsi32_si128 (bit_index );
215+ __m128i mm_shift_left = _mm_cvtsi32_si128 (32 - bit_index );
216+ mm = _mm_shuffle_epi8 (mm , ((const __m128i * )left_shift_masks )[base_index ]);
217+
218+ __m128i mm1 = _mm_srl_epi32 (mm , mm_shift_right );
219+ __m128i mm2 = _mm_sll_epi32 (mm , mm_shift_left );
220+ mm2 = _mm_srli_si128 (mm2 , 4 );
221+ mm1 = _mm_or_si128 (mm1 , mm2 );
222+
223+ _mm_storeu_si128 ((__m128i * )x , mm1 );
224+ }
225+
226+ #else /* defined(__SSSE3__) */
227+
156228void v128_left_shift (v128_t * x , int shift )
157229{
158230 int i ;
@@ -179,6 +251,8 @@ void v128_left_shift(v128_t *x, int shift)
179251 x -> v32 [i ] = 0 ;
180252}
181253
254+ #endif /* defined(__SSSE3__) */
255+
182256/* functions manipulating bitvector_t */
183257
184258int bitvector_alloc (bitvector_t * v , unsigned long length )
@@ -190,6 +264,7 @@ int bitvector_alloc(bitvector_t *v, unsigned long length)
190264 (length + bits_per_word - 1 ) & ~(unsigned long )((bits_per_word - 1 ));
191265
192266 l = length / bits_per_word * bytes_per_word ;
267+ l = (l + 15ul ) & ~15ul ;
193268
194269 /* allocate memory, then set parameters */
195270 if (l == 0 ) {
@@ -225,6 +300,73 @@ void bitvector_set_to_zero(bitvector_t *x)
225300 memset (x -> word , 0 , x -> length >> 3 );
226301}
227302
303+ #if defined(__SSSE3__ )
304+
305+ void bitvector_left_shift (bitvector_t * x , int shift )
306+ {
307+ if ((uint32_t )shift >= x -> length ) {
308+ bitvector_set_to_zero (x );
309+ return ;
310+ }
311+
312+ const int base_index = shift >> 5 ;
313+ const int bit_index = shift & 31 ;
314+ const int vec_length = (x -> length + 127u ) >> 7 ;
315+ const __m128i * from = ((const __m128i * )x -> word ) + (base_index >> 2 );
316+ __m128i * to = (__m128i * )x -> word ;
317+ __m128i * const end = to + vec_length ;
318+
319+ __m128i mm_right_shift_mask =
320+ ((const __m128i * )right_shift_masks )[4u - (base_index & 3u )];
321+ __m128i mm_left_shift_mask =
322+ ((const __m128i * )left_shift_masks )[base_index & 3u ];
323+ __m128i mm_shift_right = _mm_cvtsi32_si128 (bit_index );
324+ __m128i mm_shift_left = _mm_cvtsi32_si128 (32 - bit_index );
325+
326+ __m128i mm_current = _mm_loadu_si128 (from );
327+ __m128i mm_current_r = _mm_srl_epi32 (mm_current , mm_shift_right );
328+ __m128i mm_current_l = _mm_sll_epi32 (mm_current , mm_shift_left );
329+
330+ while ((end - from ) >= 2 ) {
331+ ++ from ;
332+ __m128i mm_next = _mm_loadu_si128 (from );
333+
334+ __m128i mm_next_r = _mm_srl_epi32 (mm_next , mm_shift_right );
335+ __m128i mm_next_l = _mm_sll_epi32 (mm_next , mm_shift_left );
336+ mm_current_l = _mm_alignr_epi8 (mm_next_l , mm_current_l , 4 );
337+ mm_current = _mm_or_si128 (mm_current_r , mm_current_l );
338+
339+ mm_current = _mm_shuffle_epi8 (mm_current , mm_left_shift_mask );
340+
341+ __m128i mm_temp_next = _mm_srli_si128 (mm_next_l , 4 );
342+ mm_temp_next = _mm_or_si128 (mm_next_r , mm_temp_next );
343+
344+ mm_temp_next = _mm_shuffle_epi8 (mm_temp_next , mm_right_shift_mask );
345+ mm_current = _mm_or_si128 (mm_temp_next , mm_current );
346+
347+ _mm_storeu_si128 (to , mm_current );
348+ ++ to ;
349+
350+ mm_current_r = mm_next_r ;
351+ mm_current_l = mm_next_l ;
352+ }
353+
354+ mm_current_l = _mm_srli_si128 (mm_current_l , 4 );
355+ mm_current = _mm_or_si128 (mm_current_r , mm_current_l );
356+
357+ mm_current = _mm_shuffle_epi8 (mm_current , mm_left_shift_mask );
358+
359+ _mm_storeu_si128 (to , mm_current );
360+ ++ to ;
361+
362+ while (to < end ) {
363+ _mm_storeu_si128 (to , _mm_setzero_si128 ());
364+ ++ to ;
365+ }
366+ }
367+
368+ #else /* defined(__SSSE3__) */
369+
228370void bitvector_left_shift (bitvector_t * x , int shift )
229371{
230372 int i ;
@@ -253,16 +395,82 @@ void bitvector_left_shift(bitvector_t *x, int shift)
253395 x -> word [i ] = 0 ;
254396}
255397
398+ #endif /* defined(__SSSE3__) */
399+
256400int srtp_octet_string_is_eq (uint8_t * a , uint8_t * b , int len )
257401{
258- uint8_t * end = b + len ;
259- uint8_t accumulator = 0 ;
260-
261402 /*
262403 * We use this somewhat obscure implementation to try to ensure the running
263404 * time only depends on len, even accounting for compiler optimizations.
264405 * The accumulator ends up zero iff the strings are equal.
265406 */
407+ uint8_t * end = b + len ;
408+ uint32_t accumulator = 0 ;
409+
410+ #if defined(__SSE2__ )
411+ __m128i mm_accumulator1 = _mm_setzero_si128 ();
412+ __m128i mm_accumulator2 = _mm_setzero_si128 ();
413+ for (int i = 0 , n = len >> 5 ; i < n ; ++ i , a += 32 , b += 32 ) {
414+ __m128i mm_a1 = _mm_loadu_si128 ((const __m128i * )a );
415+ __m128i mm_b1 = _mm_loadu_si128 ((const __m128i * )b );
416+ __m128i mm_a2 = _mm_loadu_si128 ((const __m128i * )(a + 16 ));
417+ __m128i mm_b2 = _mm_loadu_si128 ((const __m128i * )(b + 16 ));
418+ mm_a1 = _mm_xor_si128 (mm_a1 , mm_b1 );
419+ mm_a2 = _mm_xor_si128 (mm_a2 , mm_b2 );
420+ mm_accumulator1 = _mm_or_si128 (mm_accumulator1 , mm_a1 );
421+ mm_accumulator2 = _mm_or_si128 (mm_accumulator2 , mm_a2 );
422+ }
423+
424+ mm_accumulator1 = _mm_or_si128 (mm_accumulator1 , mm_accumulator2 );
425+
426+ if ((end - b ) >= 16 ) {
427+ __m128i mm_a1 = _mm_loadu_si128 ((const __m128i * )a );
428+ __m128i mm_b1 = _mm_loadu_si128 ((const __m128i * )b );
429+ mm_a1 = _mm_xor_si128 (mm_a1 , mm_b1 );
430+ mm_accumulator1 = _mm_or_si128 (mm_accumulator1 , mm_a1 );
431+ a += 16 ;
432+ b += 16 ;
433+ }
434+
435+ if ((end - b ) >= 8 ) {
436+ __m128i mm_a1 = _mm_loadl_epi64 ((const __m128i * )a );
437+ __m128i mm_b1 = _mm_loadl_epi64 ((const __m128i * )b );
438+ mm_a1 = _mm_xor_si128 (mm_a1 , mm_b1 );
439+ mm_accumulator1 = _mm_or_si128 (mm_accumulator1 , mm_a1 );
440+ a += 8 ;
441+ b += 8 ;
442+ }
443+
444+ mm_accumulator1 = _mm_or_si128 (
445+ mm_accumulator1 , _mm_unpackhi_epi64 (mm_accumulator1 , mm_accumulator1 ));
446+ mm_accumulator1 =
447+ _mm_or_si128 (mm_accumulator1 , _mm_srli_si128 (mm_accumulator1 , 4 ));
448+ accumulator = _mm_cvtsi128_si32 (mm_accumulator1 );
449+ #else
450+ uint32_t accumulator2 = 0 ;
451+ for (int i = 0 , n = len >> 3 ; i < n ; ++ i , a += 8 , b += 8 ) {
452+ uint32_t a_val1 , b_val1 ;
453+ uint32_t a_val2 , b_val2 ;
454+ memcpy (& a_val1 , a , sizeof (a_val1 ));
455+ memcpy (& b_val1 , b , sizeof (b_val1 ));
456+ memcpy (& a_val2 , a + 4 , sizeof (a_val2 ));
457+ memcpy (& b_val2 , b + 4 , sizeof (b_val2 ));
458+ accumulator |= a_val1 ^ b_val1 ;
459+ accumulator2 |= a_val2 ^ b_val2 ;
460+ }
461+
462+ accumulator |= accumulator2 ;
463+
464+ if ((end - b ) >= 4 ) {
465+ uint32_t a_val , b_val ;
466+ memcpy (& a_val , a , sizeof (a_val ));
467+ memcpy (& b_val , b , sizeof (b_val ));
468+ accumulator |= a_val ^ b_val ;
469+ a += 4 ;
470+ b += 4 ;
471+ }
472+ #endif
473+
266474 while (b < end )
267475 accumulator |= (* a ++ ^ * b ++ );
268476
@@ -272,9 +480,14 @@ int srtp_octet_string_is_eq(uint8_t *a, uint8_t *b, int len)
272480
273481void srtp_cleanse (void * s , size_t len )
274482{
483+ #if defined(__GNUC__ )
484+ memset (s , 0 , len );
485+ __asm__ __volatile__("" : : "r" (s ) : "memory" );
486+ #else
275487 volatile unsigned char * p = (volatile unsigned char * )s ;
276488 while (len -- )
277489 * p ++ = 0 ;
490+ #endif
278491}
279492
280493void octet_string_set_to_zero (void * s , size_t len )
0 commit comments