From 62b546436fc07035802eb998f61702ee2716db60 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Mon, 30 Oct 2023 00:08:53 -0400 Subject: [PATCH] Speed up general case for float matrix multiply --- dnn/vec_avx.h | 105 ++++++++++++++++++++++---------------------------- 1 file changed, 46 insertions(+), 59 deletions(-) diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index b41f9862e..767d7e193 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -666,67 +666,54 @@ static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a, #error "No optimizations in vec_avx.h. This should never happen. " #endif -static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) -{ - int i, j; - for (i=0;i<rows;i+=16) - { - float *y; - __m256 vy0, vy8; - y = &out[i]; - vy0 = _mm256_setzero_ps(); - vy8 = _mm256_setzero_ps(); - for (j=0;j<cols;j++) - { - __m256 vxj; - __m256 vw; - vxj = _mm256_broadcast_ss(&x[j]); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i]); - vy0 = _mm256_fmadd_ps(vw, vxj, vy0); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]); - vy8 = _mm256_fmadd_ps(vw, vxj, vy8); - } - _mm256_storeu_ps (&y[0], vy0); - _mm256_storeu_ps (&y[8], vy8); - } -} - -static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) -{ - int i, j; - for (i=0;i<rows;i+=8) - { - float *y; - __m256 vy0; - y = &out[i]; - vy0 = _mm256_setzero_ps(); - for (j=0;j<cols;j++) - { - __m256 vxj; - __m256 vw; - vxj = _mm256_broadcast_ss(&x[j]); - - vw = _mm256_loadu_ps(&weights[j*col_stride + i]); - vy0 = _mm256_fmadd_ps(vw, vxj, vy0); - } - _mm256_storeu_ps (&y[0], vy0); - } -} - static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { - if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); - else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); - else { - int i, j; - for (i=0;i<rows;i++) - { - out[i] = 0; - for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; - } - } + int i, j; + i=0; + for (;i<rows-15;i+=16) + { + float *y; + __m256 vy0, vy8; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + vy8 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]); + vy8 = _mm256_fmadd_ps(vw, vxj, vy8); + } + _mm256_storeu_ps (&y[0], vy0); + _mm256_storeu_ps (&y[8], vy8); + } + for (;i<rows-7;i+=8) + { + float *y; + __m256 vy0; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + } + _mm256_storeu_ps (&y[0], vy0); + } + for (;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } } static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x) -- GitLab