diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index 80b0229096af507496a2fb594637213cd906f52e..e11fae04ed7df3d6e11922458fd847ec60261fe0 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -218,3 +218,45 @@ static void sparse_sgemv_accum16(float *out, const float *weights, int rows, con } } +#ifdef DOT_PROD +#else +static void sparse_sgemv_accum8x4(float *out, const float *weights, int rows, const int *idx, const float *x) +{ + int i, j; + for (i=0;i<rows;i+=8) + { + float * restrict y; + int cols; + __m256 vy0; + y = &out[i]; + vy0 = _mm256_loadu_ps(&y[0]); + cols = *idx++; + for (j=0;j<cols;j++) + { + int id; + __m256 vxj; + __m256 vw; + id = *idx++; + vxj = _mm256_broadcast_ss(&x[4*id]); + vw = _mm256_loadu_ps(&weights[0]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vxj = _mm256_broadcast_ss(&x[4*id+1]); + vw = _mm256_loadu_ps(&weights[8]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vxj = _mm256_broadcast_ss(&x[4*id+2]); + vw = _mm256_loadu_ps(&weights[16]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + vxj = _mm256_broadcast_ss(&x[4*id+3]); + vw = _mm256_loadu_ps(&weights[24]); + vy0 = _mm256_fmadd_ps(vw, vxj, vy0); + + weights += 32; + } + _mm256_storeu_ps (&y[0], vy0); + } +} +#endif +