From 62b546436fc07035802eb998f61702ee2716db60 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Mon, 30 Oct 2023 00:08:53 -0400
Subject: [PATCH] Speed up general case for float matrix multiply

---
 dnn/vec_avx.h | 105 ++++++++++++++++++++++----------------------------
 1 file changed, 46 insertions(+), 59 deletions(-)

diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h
index b41f9862e..767d7e193 100644
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -666,67 +666,54 @@ static inline mm256i_emu opus_mm256_dpbusds_epi32(mm256i_emu src, mm256i_emu a,
 #error "No optimizations in vec_avx.h. This should never happen. "
 #endif
 
-static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
-   int i, j;
-   for (i=0;i<rows;i+=16)
-   {
-      float *y;
-      __m256 vy0, vy8;
-      y = &out[i];
-      vy0 = _mm256_setzero_ps();
-      vy8 = _mm256_setzero_ps();
-      for (j=0;j<cols;j++)
-      {
-         __m256 vxj;
-         __m256 vw;
-         vxj = _mm256_broadcast_ss(&x[j]);
-
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
-         vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
-      }
-      _mm256_storeu_ps (&y[0], vy0);
-      _mm256_storeu_ps (&y[8], vy8);
-   }
-}
-
-static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
-   int i, j;
-   for (i=0;i<rows;i+=8)
-   {
-      float *y;
-      __m256 vy0;
-      y = &out[i];
-      vy0 = _mm256_setzero_ps();
-      for (j=0;j<cols;j++)
-      {
-         __m256 vxj;
-         __m256 vw;
-         vxj = _mm256_broadcast_ss(&x[j]);
-
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-      }
-      _mm256_storeu_ps (&y[0], vy0);
-   }
-}
-
 static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
 {
-   if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
-   else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
-   else {
-      int i, j;
-      for (i=0;i<rows;i++)
-      {
-         out[i] = 0;
-         for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
-      }
-   }
+  int i, j;
+  i=0;
+  for (;i<rows-15;i+=16)
+  {
+     float *y;
+     __m256 vy0, vy8;
+     y = &out[i];
+     vy0 = _mm256_setzero_ps();
+     vy8 = _mm256_setzero_ps();
+     for (j=0;j<cols;j++)
+     {
+        __m256 vxj;
+        __m256 vw;
+        vxj = _mm256_broadcast_ss(&x[j]);
+
+        vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+        vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+
+        vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
+        vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
+     }
+     _mm256_storeu_ps (&y[0], vy0);
+     _mm256_storeu_ps (&y[8], vy8);
+  }
+  for (;i<rows-7;i+=8)
+  {
+     float *y;
+     __m256 vy0;
+     y = &out[i];
+     vy0 = _mm256_setzero_ps();
+     for (j=0;j<cols;j++)
+     {
+        __m256 vxj;
+        __m256 vw;
+        vxj = _mm256_broadcast_ss(&x[j]);
+
+        vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+        vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+     }
+     _mm256_storeu_ps (&y[0], vy0);
+  }
+  for (;i<rows;i++)
+  {
+    out[i] = 0;
+    for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+  }
 }
 
 static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
-- 
GitLab