diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index 4747bb41fbcf1e8618c5ef79b774253a45683b57..bc93313013b4ef57381e397aa4854cec3b5756d7 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -621,27 +621,27 @@ static inline void vec_sigmoid(float *y, const float *x, int N) #if defined(__AVXVNNI__) || defined(__AVX512VNNI__) +#define opus_mm256_dpbusds_epi32(src, a, b) _mm256_dpbusds_epi32(src, a, b) + #elif defined(__AVX2__) -static inline __m256i mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) { +static inline __m256i opus_mm256_dpbusds_epi32(__m256i src, __m256i a, __m256i b) { __m256i ones, tmp; ones = _mm256_set1_epi16(1); tmp = _mm256_maddubs_epi16(a, b); tmp = _mm256_madd_epi16(tmp, ones); return _mm256_add_epi32(src, tmp); } -#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b) #elif defined(__SSSE3__) -static inline mm256i_emu mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) { +static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) { mm256i_emu ones, tmp; ones = _mm256_set1_epi16(1); tmp = _mm256_maddubs_epi16(a, b); tmp = _mm256_madd_epi16(tmp, ones); return _mm256_add_epi32(src, tmp); } -#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b) #elif defined(__SSE2__) @@ -655,13 +655,12 @@ static inline __m128i mm_dpbusds_epi32(__m128i src, __m128i a, __m128i b) { return _mm_add_epi32(src, tmp); } -static inline mm256i_emu mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) { +static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, mm256i_emu b) { mm256i_emu res; res.hi = mm_dpbusds_epi32(src.hi, a.hi, b.hi); res.lo = mm_dpbusds_epi32(src.lo, a.lo, b.lo); return res; } -#define _mm256_dpbusds_epi32(src, a, b) mm256_dpbusds_epi32(src, a, b) #if defined(_MSC_VER) #pragma message ("Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 to get better performance") @@ -797,19 +796,19 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i __m256i vw; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[*idx++]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } #endif @@ -821,7 +820,7 @@ static inline void sparse_cgemv8x4(float *_out, const opus_int8 *w, const int *i pos = (*idx++); vxj = _mm256_set1_epi32(*(int*)&x[pos]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } vout = _mm256_cvtepi32_ps(vy0); @@ -848,19 +847,19 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, __m256i vw; vxj = _mm256_set1_epi32(*(int*)&x[j]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[j+4]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[j+8]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; vxj = _mm256_set1_epi32(*(int*)&x[j+12]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } #endif @@ -870,7 +869,7 @@ static inline void cgemv8x4(float *_out, const opus_int8 *w, const float *scale, __m256i vw; vxj = _mm256_set1_epi32(*(int*)&x[j]); vw = _mm256_loadu_si256((const __m256i *)w); - vy0 = _mm256_dpbusds_epi32(vy0, vxj, vw); + vy0 = opus_mm256_dpbusds_epi32(vy0, vxj, vw); w += 32; } vout = _mm256_cvtepi32_ps(vy0);