From e9f8402a7122ca03e894d161c50706053bf4fb83 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Mon, 31 Jul 2023 03:03:37 -0400 Subject: [PATCH] Handle float matrices with multiple of 8 rows --- dnn/nnet.c | 2 +- dnn/vec.h | 40 ++++++++++++++++++++++++++++++++++++++++ dnn/vec_avx.h | 36 ++++++++++++++++++++++++++++++++++++ dnn/vec_neon.h | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 126 insertions(+), 1 deletion(-) diff --git a/dnn/nnet.c b/dnn/nnet.c index 1c0035d08..05b0ea909 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -78,7 +78,7 @@ void compute_linear(const LinearLayer *linear, float *out, const float *in) N = linear->nb_outputs; if (linear->float_weights != NULL) { if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in); - else sgemv16x1(out, linear->float_weights, N, M, N, in); + else sgemv(out, linear->float_weights, N, M, N, in); } else if (linear->weights != NULL) { if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in); else cgemv8x4(out, linear->weights, linear->scale, N, M, in); diff --git a/dnn/vec.h b/dnn/vec.h index f6085cee9..5b6951bbb 100644 --- a/dnn/vec.h +++ b/dnn/vec.h @@ -92,6 +92,46 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col } } +static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + int i, j; + OPUS_CLEAR(out, rows); + for (i=0;i<rows;i+=8) + { + for (j=0;j<cols;j++) + { + const float * restrict w; + float * restrict y; + float xj; + w = &weights[j*col_stride + i]; + xj = x[j]; + y = &out[i]; + y[0] += w[0]*xj; + y[1] += w[1]*xj; + y[2] += w[2]*xj; + y[3] += w[3]*xj; + y[4] += w[4]*xj; + y[5] += w[5]*xj; + y[6] += w[6]*xj; + y[7] += w[7]*xj; + } + } +} + +static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); + else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); + else { + int i, j; + for (i=0;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } + } +} + static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x) { int i, j; diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index 77b3a0e0f..4747bb41f 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -701,6 +701,42 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col } } +static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + int i, j; + for (i=0;i<rows;i+=8) + { + float *y; + __m256 vy0; + y = &out[i]; + vy0 = _mm256_setzero_ps(); + for (j=0;j<cols;j++) + { + __m256 vxj; + __m256 vw; + vxj = _mm256_broadcast_ss(&x[j]); + + vw = _mm256_loadu_ps(&weights[j*col_stride + i]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + } + _mm256_storeu_ps (&y[0], vy0); + } +} + +static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); + else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); + else { + int i, j; + for (i=0;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } + } +} + static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x) { int i, j; diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h index 38c20d7bc..48e3eaa1a 100644 --- a/dnn/vec_neon.h +++ b/dnn/vec_neon.h @@ -239,6 +239,55 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col } } +static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + int i, j; + for (i=0;i<rows;i+=8) + { + float * restrict y = &out[i]; + + /* keep y[0..15] in registers for duration of inner loop */ + + float32x4_t y0_3 = vdupq_n_f32(0); + float32x4_t y4_7 = vdupq_n_f32(0); + + for (j=0;j<cols;j++) + { + const float * restrict w; + float32x4_t wvec0_3, wvec4_7; + float32x4_t xj; + + w = &weights[j*col_stride + i]; + wvec0_3 = vld1q_f32(&w[0]); + wvec4_7 = vld1q_f32(&w[4]); + + xj = vld1q_dup_f32(&x[j]); + + y0_3 = vmlaq_f32(y0_3, wvec0_3, xj); + y4_7 = vmlaq_f32(y4_7, wvec4_7, xj); + } + + /* save y[0..15] back to memory */ + + vst1q_f32(&y[0], y0_3); + vst1q_f32(&y[4], y4_7); + } +} + +static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +{ + if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x); + else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x); + else { + int i, j; + for (i=0;i<rows;i++) + { + out[i] = 0; + for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j]; + } + } +} + /* Temporarily use unoptimized version */ static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x) { -- GitLab