From a1079c2ce33e198d582fabcfab79085726ff83cd Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Thu, 8 Jul 2021 13:20:15 -0400 Subject: [PATCH] Again, same conversion as 3206cec, for NEON --- dnn/vec_neon.h | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h index 1a4a4ce5f..b01d0eb21 100644 --- a/dnn/vec_neon.h +++ b/dnn/vec_neon.h @@ -291,15 +291,15 @@ static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int c { int i, j; signed char x[MAX_INPUTS]; - int out[MAX_OUTPUTS]; + const float32x4_t scale = vdupq_n_f32(SCALE); + const float32x4_t scale_1 = vdupq_n_f32(SCALE_1); (void)col_stride; - for (i=0;i<rows;i++) out[i] = (int)floor(.5+SCALE*_out[i]); for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]); for (i=0;i<rows;i+=8) { int32x4_t acc0, acc1; - acc0 = vld1q_s32(&out[i]); - acc1 = vld1q_s32(&out[i+4]); + acc0 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i]))); + acc1 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i+4]))); for (j=0;j<cols;j+=4) { int8x16_t vw0, vw1, vx; @@ -310,25 +310,24 @@ static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int c acc1 = vdotprod(acc1, vw1, vx); w += 32; } - vst1q_s32(&out[i], acc0); - vst1q_s32(&out[i+4], acc1); + vst1q_f32(&_out[i], vmulq_f32(scale_1, vcvtq_f32_s32(acc0))); + vst1q_f32(&_out[i+4], vmulq_f32(scale_1, vcvtq_f32_s32(acc1))); } - for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i]; } static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x) { int i, j; signed char x[MAX_INPUTS]; - int out[MAX_OUTPUTS]; - for (i=0;i<rows;i++) out[i] = (int)floor(.5+SCALE*_out[i]); + const float32x4_t scale = vdupq_n_f32(SCALE); + const float32x4_t scale_1 = vdupq_n_f32(SCALE_1); for (i=0;i<cols;i++) x[i] = floor(.5+127*_x[i]); for (i=0;i<rows;i+=8) { int colblocks; int32x4_t acc0, acc1; - acc0 = vld1q_s32(&out[i]); - acc1 = vld1q_s32(&out[i+4]); + acc0 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i]))); + acc1 = vcvtnq_s32_f32(vmulq_f32(scale, vld1q_f32(&_out[i+4]))); colblocks = *idx++; for (j=0;j<colblocks;j++) { @@ -342,8 +341,7 @@ static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows acc1 = vdotprod(acc1, vw1, vx); w += 32; } - vst1q_s32(&out[i], acc0); - vst1q_s32(&out[i+4], acc1); + vst1q_f32(&_out[i], vmulq_f32(scale_1, vcvtq_f32_s32(acc0))); + vst1q_f32(&_out[i+4], vmulq_f32(scale_1, vcvtq_f32_s32(acc1))); } - for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i]; } -- GitLab