From 54abdb6f5d671a1fc9756e475012c365aae71f59 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Tue, 6 Jul 2021 17:05:07 -0400
Subject: [PATCH] Sparse matrix indexing optimization

The 4* is now stored in the table to avoid computing it in the loop
---
 dnn/training_tf2/dump_lpcnet.py |  2 +-
 dnn/vec.h                       |  6 +++---
 dnn/vec_avx.h                   | 10 +++++-----
 dnn/vec_neon.h                  |  2 +-
 4 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/dnn/training_tf2/dump_lpcnet.py b/dnn/training_tf2/dump_lpcnet.py
index 730ecb750..9dcdba470 100755
--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -80,7 +80,7 @@ def printSparseVector(f, A, name):
             qblock = AQ[j*4:(j+1)*4, i*8:(i+1)*8]
             if np.sum(np.abs(block)) > 1e-10:
                 nb_nonzero = nb_nonzero + 1
-                idx = np.append(idx, j)
+                idx = np.append(idx, j*4)
                 vblock = qblock.transpose((1,0)).reshape((-1,))
                 W0 = np.concatenate([W0, block.reshape((-1,))])
                 W = np.concatenate([W, vblock])
diff --git a/dnn/vec.h b/dnn/vec.h
index f93200723..ae6049fab 100644
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -250,7 +250,7 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows,
          int pos;
          float * restrict y;
          int xj0, xj1, xj2, xj3;
-         pos = 4 * (*idx++);
+         pos = (*idx++);
          xj0 = x[pos+0];
          xj1 = x[pos+1];
          xj2 = x[pos+2];
@@ -318,7 +318,7 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows,
          int pos;
          float * restrict y;
          int xj0, xj1, xj2, xj3;
-         pos = 4 * (*idx++);
+         pos = (*idx++);
          xj0 = x[pos+0];
          xj1 = x[pos+1];
          xj2 = x[pos+2];
@@ -357,7 +357,7 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows,
          int pos;
          float * restrict y;
          float xj0, xj1, xj2, xj3;
-         pos = 4 * (*idx++);
+         pos = (*idx++);
          xj0 = x[pos+0];
          xj1 = x[pos+1];
          xj2 = x[pos+2];
diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h
index df02dca36..f18c771ae 100644
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -508,7 +508,7 @@ static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows
          __m256i vxj;
          __m256i vw;
          int pos;
-         pos = 4 * (*idx++);
+         pos = (*idx++);
          vxj = _mm256_set1_epi32(*(int*)&x[pos]);
          vw = _mm256_loadu_si256((const __m256i *)w); //_mm256_lddqu_si256?
          tmp = _mm256_maddubs_epi16(vxj, vw); //swap?
@@ -544,19 +544,19 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *weights, int
          __m256 vxj;
          __m256 vw;
          id = *idx++;
-         vxj = _mm256_broadcast_ss(&x[4*id]);
+         vxj = _mm256_broadcast_ss(&x[id]);
          vw = _mm256_loadu_ps(&weights[0]);
          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
 
-         vxj = _mm256_broadcast_ss(&x[4*id+1]);
+         vxj = _mm256_broadcast_ss(&x[id+1]);
          vw = _mm256_loadu_ps(&weights[8]);
          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
 
-         vxj = _mm256_broadcast_ss(&x[4*id+2]);
+         vxj = _mm256_broadcast_ss(&x[id+2]);
          vw = _mm256_loadu_ps(&weights[16]);
          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
 
-         vxj = _mm256_broadcast_ss(&x[4*id+3]);
+         vxj = _mm256_broadcast_ss(&x[id+3]);
          vw = _mm256_loadu_ps(&weights[24]);
          vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
 
diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h
index a964f7516..1a4a4ce5f 100644
--- a/dnn/vec_neon.h
+++ b/dnn/vec_neon.h
@@ -333,7 +333,7 @@ static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows
       for (j=0;j<colblocks;j++)
       {
          int pos;
-         pos = 4 * (*idx++);
+         pos = (*idx++);
          int8x16_t vw0, vw1, vx;
          vx = (int8x16_t)vld1q_dup_s32((int*)&x[pos]);
          vw0 = vld1q_s8(w);
-- 
GitLab