diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index b41f9862e424994854467306d9d7058bba1ba4dd..767d7e1935dc5716aba4b2492348aa2cd881c80e 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -666,67 +666,54 @@ static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, #error "No optimizations in vec_avx.h. This should never happen. " #endif -static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) -{ - int i, j; - for (i=0;i<rows;i+=16) - { - float *y; - __m256 vy0, vy8; - y = &out[i]; - vy0 = _mm256_setzero_ps(); - vy8 = _mm256_setzero_ps(); - for (j=0;j<cols;j++) - { - __m256 vxj; - __m256 vw; - vxj = _mm256_broadcast_ss(&x[j]); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i]); - vy0 = _mm256_fmadd_ps(vw, vxj, vy0); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]); - vy8 = _mm256_fmadd_ps(vw, vxj, vy8); - } - _mm256_storeu_ps (&y[0], vy0); - _mm256_storeu_ps (&y[8], vy8); - } -} - -static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) -{ - int i, j; - for (i=0;i<rows;i+=8) - { - float *y; - __m256 vy0; - y = &out[i]; - vy0 = _mm256_setzero_ps(); - for (j=0;j<cols;j++) - { - __m256 vxj; - __m256 vw; - vxj = _mm256_broadcast_ss(&x[j]); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i]); - vy0 = _mm256_fmadd_ps(vw, vxj, vy0); - } - _mm256_storeu_ps (&y[0], vy0); - } -} - static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { - if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); - else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); - else { - int i, j; - for (i=0;i<rows;i++) - { - out[i] = 0; - for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; - } - } + int i, j; + i=0; + for (;i<rows-15;i+=16) + { + float *y; + __m256 vy0, vy8; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + vy8 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]); + vy8 = _mm256_fmadd_ps(vw, vxj, vy8); + } + _mm256_storeu_ps (&y[0], vy0); + _mm256_storeu_ps (&y[8], vy8); + } + for (;i<rows-7;i+=8) + { + float *y; + __m256 vy0; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + } + _mm256_storeu_ps (&y[0], vy0); + } + for (;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } } static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)