From 2b4652f9f6b59b3c95a31da3b4348f5d4eea068e Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@jmvalin.ca>
Date: Sat, 26 Dec 2020 03:20:20 -0500
Subject: [PATCH] WIP: cleanup

---
 dnn/vec_avx.h | 72 +++++++++++++++++++++++++++++++++++++++++----------
 1 file changed, 59 insertions(+), 13 deletions(-)

diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h
index 9dda9f7a2..e6e7c2e15 100644
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -35,7 +35,7 @@
 #include <immintrin.h>
 
 #ifdef __AVX2__
-static __m256 exp8_approx(__m256 X)
+static inline __m256 exp8_approx(__m256 X)
 {
    const __m256 K0 = _mm256_set1_ps(0.99992522f);
    const __m256 K1 = _mm256_set1_ps(0.69583354f);
@@ -60,7 +60,7 @@ static __m256 exp8_approx(__m256 X)
 #else
 #define _mm256_fmadd_ps(a,b,c) _mm256_add_ps(_mm256_mul_ps(a, b), c)
 #define _mm_fmadd_ps(a,b,c) _mm_add_ps(_mm_mul_ps(a, b), c)
-static __m128 exp4_approx(__m128 X)
+static inline __m128 exp4_approx(__m128 X)
 {
    const __m128 K0 = _mm_set1_ps(0.99992522f);
    const __m128 K1 = _mm_set1_ps(0.69583354f);
@@ -82,7 +82,7 @@ static __m128 exp4_approx(__m128 X)
    Y = _mm_castsi128_ps(_mm_and_si128(mask, _mm_add_epi32(I, _mm_castps_si128(Y))));
    return Y;
 }
-static __m256 exp8_approx(__m256 X)
+static inline __m256 exp8_approx(__m256 X)
 {
    __m256 Y;
    __m128 Xhi, Xlo, Yhi, Ylo;
@@ -96,7 +96,7 @@ static __m256 exp8_approx(__m256 X)
 }
 #endif
 
-static float celt_exp(float x)
+static inline float celt_exp(float x)
 {
    float out[8];
    __m256 X, Y;
@@ -106,7 +106,7 @@ static float celt_exp(float x)
    return out[0];
 }
 
-static void softmax(float *y, const float *x, int N)
+static inline void softmax(float *y, const float *x, int N)
 {
     int i;
     for (i=0;i<N-7;i+=8)
@@ -120,7 +120,7 @@ static void softmax(float *y, const float *x, int N)
         y[i] = celt_exp(x[i]);
 }
 
-static void vec_tanh(float *y, const float *x, int N)
+static inline void vec_tanh(float *y, const float *x, int N)
 {
     int i;
     for (i=0;i<N-7;i+=8)
@@ -142,7 +142,7 @@ static void vec_tanh(float *y, const float *x, int N)
     }
 }
 
-static void vec_sigmoid(float *y, const float *x, int N)
+static inline void vec_sigmoid(float *y, const float *x, int N)
 {
     int i;
     for (i=0;i<N-7;i+=8)
@@ -163,7 +163,7 @@ static void vec_sigmoid(float *y, const float *x, int N)
     }
 }
 
-static void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
 {
    int i, j;
    for (i=0;i<rows;i+=16)
@@ -189,7 +189,7 @@ static void sgemv_accum16(float *out, const float *weights, int rows, int cols,
       _mm256_storeu_ps (&y[8], vy8);
    }
 }
-static void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x)
+static inline void sparse_sgemv_accum16(float *out, const float *weights, int rows, const int *idx, const float *x)
 {
    int i, j;
    for (i=0;i<rows;i+=16)
@@ -222,7 +222,17 @@ static void sparse_sgemv_accum16(float *out, const float *weights, int rows, con
 }
 
 #ifdef DOT_PROD
-static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
+
+#define USE_SU_BIAS
+
+#define MAX_INPUTS (2048)
+
+
+#define SCALE (128.f*127.f)
+#define SCALE_1 (1.f/128.f/127.f)
+
+#if 0
+static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
 {
    int i, j;
    for (i=0;i<rows;i+=8)
@@ -247,9 +257,45 @@ static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows,
       _mm256_storeu_ps (&y[0], vy0);
    }
 }
-
 #else
-static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, int ignore, const int *idx, const float *x)
+static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x)
+{
+   int i, j;
+   unsigned x[MAX_INPUTS];
+   for (i=0;i<rows;i++) out[i] *= SCALE;
+   for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]);
+   for (i=0;i<rows;i+=8)
+   {
+      int colblocks;
+      colblocks = *idx++;
+      for (j=0;j<colblocks;j++)
+      {
+         int pos;
+         float * restrict y;
+         int xj0, xj1, xj2, xj3;
+         pos = 4 * (*idx++);
+         xj0 = x[pos+0];
+         xj1 = x[pos+1];
+         xj2 = x[pos+2];
+         xj3 = x[pos+3];
+         y = &out[i];
+         y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+         y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+         y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+         y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+         y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+         y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+         y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+         y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
+         w += 32;
+      }
+   }
+   for (i=0;i<rows;i++) out[i] *= SCALE_1;
+}
+#endif
+
+#else /*DOT_PROD*/
+static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, int ignore, const int *idx, const float *x)
 {
    int i, j;
    (void)ignore;
@@ -288,6 +334,6 @@ static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows,
       _mm256_storeu_ps (&y[0], vy0);
    }
 }
-#endif
+#endif /*DOT_PROD*/
 
 #endif /*VEC_AVX_H*/
-- 
GitLab