From e9f8402a7122ca03e894d161c50706053bf4fb83 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Mon, 31 Jul 2023 03:03:37 -0400
Subject: [PATCH] Handle float matrices with multiple of 8 rows

---
 dnn/nnet.c     |  2 +-
 dnn/vec.h      | 40 ++++++++++++++++++++++++++++++++++++++++
 dnn/vec_avx.h  | 36 ++++++++++++++++++++++++++++++++++++
 dnn/vec_neon.h | 49 +++++++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 126 insertions(+), 1 deletion(-)

diff --git a/dnn/nnet.c b/dnn/nnet.c
index 1c0035d08..05b0ea909 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -78,7 +78,7 @@ void compute_linear(const LinearLayer *linear, float *out, const float *in)
    N = linear->nb_outputs;
    if (linear->float_weights != NULL) {
      if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in);
-     else sgemv16x1(out, linear->float_weights, N, M, N, in);
+     else sgemv(out, linear->float_weights, N, M, N, in);
    } else if (linear->weights != NULL) {
      if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in);
      else cgemv8x4(out, linear->weights, linear->scale, N, M, in);
diff --git a/dnn/vec.h b/dnn/vec.h
index f6085cee9..5b6951bbb 100644
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -92,6 +92,46 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col
    }
 }
 
+static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   int i, j;
+   OPUS_CLEAR(out, rows);
+   for (i=0;i<rows;i+=8)
+   {
+      for (j=0;j<cols;j++)
+      {
+         const float * restrict w;
+         float * restrict y;
+         float xj;
+         w = &weights[j*col_stride + i];
+         xj = x[j];
+         y = &out[i];
+         y[0] += w[0]*xj;
+         y[1] += w[1]*xj;
+         y[2] += w[2]*xj;
+         y[3] += w[3]*xj;
+         y[4] += w[4]*xj;
+         y[5] += w[5]*xj;
+         y[6] += w[6]*xj;
+         y[7] += w[7]*xj;
+      }
+   }
+}
+
+static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
+   else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
+   else {
+      int i, j;
+      for (i=0;i<rows;i++)
+      {
+         out[i] = 0;
+         for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+      }
+   }
+}
+
 static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
 {
    int i, j;
diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h
index 77b3a0e0f..4747bb41f 100644
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -701,6 +701,42 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col
    }
 }
 
+static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   int i, j;
+   for (i=0;i<rows;i+=8)
+   {
+      float *y;
+      __m256 vy0;
+      y = &out[i];
+      vy0 = _mm256_setzero_ps();
+      for (j=0;j<cols;j++)
+      {
+         __m256 vxj;
+         __m256 vw;
+         vxj = _mm256_broadcast_ss(&x[j]);
+
+         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+      }
+      _mm256_storeu_ps (&y[0], vy0);
+   }
+}
+
+static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
+   else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
+   else {
+      int i, j;
+      for (i=0;i<rows;i++)
+      {
+         out[i] = 0;
+         for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+      }
+   }
+}
+
 static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
 {
    int i, j;
diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h
index 38c20d7bc..48e3eaa1a 100644
--- a/dnn/vec_neon.h
+++ b/dnn/vec_neon.h
@@ -239,6 +239,55 @@ static inline void sgemv16x1(float *out, const float *weights, int rows, int col
     }
 }
 
+static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+    int i, j;
+    for (i=0;i<rows;i+=8)
+    {
+    float * restrict y = &out[i];
+
+    /* keep y[0..15] in registers for duration of inner loop */
+
+    float32x4_t y0_3 = vdupq_n_f32(0);
+    float32x4_t y4_7 = vdupq_n_f32(0);
+
+    for (j=0;j<cols;j++)
+    {
+        const float * restrict w;
+        float32x4_t wvec0_3, wvec4_7;
+        float32x4_t xj;
+
+        w = &weights[j*col_stride + i];
+        wvec0_3 = vld1q_f32(&w[0]);
+        wvec4_7 = vld1q_f32(&w[4]);
+
+        xj = vld1q_dup_f32(&x[j]);
+
+        y0_3 = vmlaq_f32(y0_3, wvec0_3, xj);
+        y4_7 = vmlaq_f32(y4_7, wvec4_7, xj);
+    }
+
+    /* save y[0..15] back to memory */
+
+    vst1q_f32(&y[0], y0_3);
+    vst1q_f32(&y[4], y4_7);
+    }
+}
+
+static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+{
+   if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
+   else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
+   else {
+      int i, j;
+      for (i=0;i<rows;i++)
+      {
+         out[i] = 0;
+         for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+      }
+   }
+}
+
 /* Temporarily use unoptimized version */
 static inline void sparse_sgemv8x4(float *out, const float *w, const int *idx, int rows, const float *x)
 {
-- 
GitLab