diff --git a/dnn/nnet.c b/dnn/nnet.c index 012fc9bfa2fa8d7aa3b2ae2e1a80681defba35b6..5730c3f8f4e9c5dac1fb24611cf9239ca6ea2e86 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -224,13 +224,14 @@ void compute_gru2(const GRULayer *gru, float *state, const float *input) celt_assert(gru->reset_after); stride = 3*N; /* Compute update gate. */ +#ifdef USE_SU_BIAS for (i=0;i<3*N;i++) - zrh[i] = gru->bias[i]; -#if 1 - sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, stride, input); + zrh[i] = gru->subias[i]; #else - sgemv_accum(zrh, gru->input_weights, 3*N, M, stride, input); + for (i=0;i<3*N;i++) + zrh[i] = gru->bias[i]; #endif + sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, stride, input); for (i=0;i<3*N;i++) recur[i] = gru->bias[3*N + i]; sgemv_accum(recur, gru->recurrent_weights, 3*N, N, stride, state); diff --git a/dnn/vec.h b/dnn/vec.h index 93504b621ed92a72cee63ed0d208902388f0f86a..8a873e687e512ef7a2de140d0e08674e2fa5f58d 100644 --- a/dnn/vec.h +++ b/dnn/vec.h @@ -198,13 +198,16 @@ static inline void sparse_sgemv_accum16(float *out, const float *w, int rows, co #define SCALE (128.f*127.f) #define SCALE_1 (1.f/128.f/127.f) + +#ifdef USE_SU_BIAS + static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x) { int i, j; - signed char x[MAX_INPUTS]; + unsigned char x[MAX_INPUTS]; (void)col_stride; for (i=0;i<rows;i++) out[i] *= SCALE; - for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]); + for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]); for (i=0;i<rows;i+=8) { for (j=0;j<cols;j+=4) @@ -230,8 +233,6 @@ static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int co for (i=0;i<rows;i++) out[i] *= SCALE_1; } - -#ifdef USE_SU_BIAS static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x) { int i, j; @@ -267,6 +268,39 @@ static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, for (i=0;i<rows;i++) out[i] *= SCALE_1; } #else /*USE_SU_BIAS*/ + +static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x) +{ + int i, j; + signed char x[MAX_INPUTS]; + (void)col_stride; + for (i=0;i<rows;i++) out[i] *= SCALE; + for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + for (j=0;j<cols;j+=4) + { + float * restrict y; + float xj0, xj1, xj2, xj3; + xj0 = x[j+0]; + xj1 = x[j+1]; + xj2 = x[j+2]; + xj3 = x[j+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= SCALE_1; +} + static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x) { int i, j; diff --git a/dnn/vec_avx.h b/dnn/vec_avx.h index aaa67363a66e3f22e551e169124a641121d2ab5f..d75b83b2f922048cd53c05b217ac7c8d08738af9 100644 --- a/dnn/vec_avx.h +++ b/dnn/vec_avx.h @@ -241,6 +241,88 @@ typedef float qweight; #define SCALE (128.f*127.f) #define SCALE_1 (1.f/128.f/127.f) +#if 1 +static inline void sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, int col_stride, const float *_x) +{ + __m256i ones; + int i, j; + unsigned char x[MAX_INPUTS]; + int out[MAX_OUTPUTS]; + (void)col_stride; + ones = _mm256_set1_epi16(1); + for (i=0;i<rows;i++) out[i] = SCALE*_out[i]; + //for (i=0;i<cols;i++) x[i] = 127+floor(.5+127*_x[i]); + __m256 const127 = _mm256_set1_ps(127.f); + for (i=0;i<cols;i+=8) { + __m256 xf; + __m256i xi; + xf = _mm256_loadu_ps(&_x[i]); + //xf = _mm256_mul_ps(xf, const127); + //xf = _mm256_add_ps(xf, const127); + xf = _mm256_fmadd_ps(xf, const127, const127); + xi = _mm256_cvtps_epi32(xf); + xi = _mm256_packus_epi32(xi, _mm256_setzero_si256()); + xi = _mm256_permute4x64_epi64(xi, 0xD8); + xi = _mm256_packus_epi16(xi, _mm256_setzero_si256()); + xi = _mm256_permutevar8x32_epi32(xi, _mm256_setr_epi32(0,1, 0,0, 0,0, 0,0)); + //xi = _mm256_permute4x64_epi64(xi, 0x); + _mm256_storeu_si256 ((__m256i *)&x[i], xi); + } + for (i=0;i<rows;i+=8) + { + int * restrict y; + __m256i vy0; + y = &out[i]; + vy0 = _mm256_loadu_si256((const __m256i *)&y[0]); + for (j=0;j<cols;j+=4) + { + __m256i tmp; + __m256i vxj; + __m256i vw; + vxj = _mm256_set1_epi32(*(int*)&x[j]); + vw = _mm256_loadu_si256((const __m256i *)w); //_mm256_lddqu_si256? + tmp = _mm256_maddubs_epi16(vxj, vw); //swap? + tmp = _mm256_madd_epi16(tmp, ones); + vy0 = _mm256_add_epi32(vy0, tmp); + w += 32; + } + _mm256_storeu_si256 ((__m256i *)&y[0], vy0); + } + for (i=0;i<rows;i++) _out[i] = SCALE_1*out[i]; +} +#else +static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x) +{ + int i, j; + unsigned char x[MAX_INPUTS]; + (void)col_stride; + for (i=0;i<rows;i++) out[i] *= SCALE; + for (i=0;i<cols;i++) x[i] = 127+(int)floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + for (j=0;j<cols;j+=4) + { + float * restrict y; + float xj0, xj1, xj2, xj3; + xj0 = x[j+0]; + xj1 = x[j+1]; + xj2 = x[j+2]; + xj3 = x[j+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= SCALE_1; +} +#endif static inline void sparse_sgemv_accum8x4(float *_out, const qweight *w, int rows, int cols, const int *idx, const float *_x) {