diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index 63690825f2b8c2259ffe03f8c8734820d0eefa40..e4b8f04305e2049fbe336a5cdd134120d891599c 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -627,6 +627,32 @@ static inline void vec_sigmoid(float *y, const float *x, int N) #endif +#if defined(__AVXVNNI__) || defined(__AVX512VNNI__) + +#elif defined(__AVX2__) + +static inline __m256i mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) { + __m256i ones, tmp; + ones = _mm256_set1_epi16(1); + tmp = _mm256_maddubs_epi16(a, b); + tmp = _mm256_madd_epi16(tmp, ones); + return _mm256_add_epi32(src, tmp); +} +#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b) + +#elif defined(__SSSE3__) + +static inline mm256i_emu mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) { + mm256i_emu ones, tmp; + ones = _mm256_set1_epi16(1); + tmp = _mm256_maddubs_epi16(a, b); + tmp = _mm256_madd_epi16(tmp, ones); + return _mm256_add_epi32(src, tmp); +} +#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b) + +#elif defined(__SSE2__) +#endif static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { @@ -696,10 +722,8 @@ static inline void sparse_sgemv8x4(float *out, const float *weights, const int * static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *idx, const float *scale, int rows, int cols, const float *_x) { - __m256i ones; int i, j; unsigned char x[MAX_INPUTS]; - ones = _mm256_set1_epi16(1); /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/ vector_ps_to_epi8(x, _x, cols); for (i=0;i<rows;i+=8) @@ -713,47 +737,35 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */ for (;j<colblocks-3;j+=4) { - __m256i tmp; __m256i vxj; __m256i vw; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } #endif for (;j<colblocks;j++) { - __m256i tmp; __m256i vxj; __m256i vw; int pos; pos = (*idx++); vxj = _mm256_set1_epi32(*(int*)&x[pos]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } vout = _mm256_cvtepi32_ps(vy0); @@ -763,10 +775,8 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i } static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, int rows, int cols, const float *_x) { - __m256i ones; int i, j; unsigned char x[MAX_INPUTS]; - ones = _mm256_set1_epi16(1); /*for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);*/ vector_ps_to_epi8(x, _x, cols); for (i=0;i<rows;i+=8) @@ -778,45 +788,33 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */ for (;j<cols-12;j+=16) { - __m256i tmp; __m256i vxj; __m256i vw; vxj = _mm256_set1_epi32(*(int*)&x[j]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[j+4]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[j+8]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[j+12]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } #endif for (;j<cols;j+=4) { - __m256i tmp; __m256i vxj; __m256i vw; vxj = _mm256_set1_epi32(*(int*)&x[j]); vw = _mm256_loadu_si256((const __m256i *)w); - tmp = _mm256_maddubs_epi16(vxj, vw); - tmp = _mm256_madd_epi16(tmp, ones); - vy0 = _mm256_add_epi32(vy0, tmp); + vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } vout = _mm256_cvtepi32_ps(vy0);