diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index 1de37e6d78ab47c18796b1503c092b6ddfc1bdae..c96f5c853431b15c06236b75583da4b6750a94d8 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -392,10 +392,8 @@ static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int c __m256i ones; int i, j; unsigned char x[MAX_INPUTS]; - int out[MAX_OUTPUTS]; (void)col_stride; ones = _mm256_set1_epi16(1); - for (i=0;i<rows;i++) out[i] = SCALE*_out[i]; //for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]); __m256 const127 = _mm256_set1_ps(127.f); for (i=0;i<cols;i+=8) { @@ -415,10 +413,11 @@ static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int c } for (i=0;i<rows;i+=8) { - int * restrict y; __m256i vy0; - y = &out[i]; - vy0 = _mm256_loadu_si256((const __m256i *)&y[0]); + __m256 vout; + vout = _mm256_loadu_ps(&_out[i]); + vout = _mm256_mul_ps(vout, _mm256_set1_ps(SCALE)); + vy0 = _mm256_cvtps_epi32(vout); j=0; #if 1 /* Unrolling by 4 gives some gain, comment out if it does not. */ for (;j<cols-12;j+=16) @@ -464,9 +463,10 @@ static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int c vy0 = _mm256_add_epi32(vy0, tmp); w += 32; } - _mm256_storeu_si256 ((__m256i *)&y[0], vy0); + vout = _mm256_cvtepi32_ps(vy0); + vout = _mm256_mul_ps(vout, _mm256_set1_ps(SCALE_1)); + _mm256_storeu_ps(&_out[i], vout); } - for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i]; } #else static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x)