diff --git a/dnn/vec.h b/dnn/vec.h index fd9b5bde13a345ed83ff490eaab83f7441f50b13..dbf505879cdc37b3cbd4ffc9cc9d6b20db79978c 100644 --- a/dnn/vec.h +++ b/dnn/vec.h @@ -37,7 +37,7 @@ #ifdef __AVX__ #include "vec_avx.h" -#elif __ARM_NEON__ +#elif defined(__ARM_NEON__) || defined(__ARM_NEON) #include "vec_neon.h" #else diff --git a/dnn/vec_neon.h b/dnn/vec_neon.h index 46cba1272efcd132a63eb38fab83b05b5fa12f2c..3e1632ba5e1827d799b782b1bf94207a96732848 100644 --- a/dnn/vec_neon.h +++ b/dnn/vec_neon.h @@ -29,8 +29,13 @@ /* NEON support for ARM machines */ #include <arm_neon.h> + +#define DOT_PROD +typedef signed char qweight; + + #ifndef LPCNET_TEST -static OPUS_INLINE float32x4_t exp4_approx(float32x4_t x) { +static inline OPUS_INLINE float32x4_t exp4_approx(float32x4_t x) { int32x4_t i; float32x4_t xf; @@ -57,7 +62,7 @@ static OPUS_INLINE float32x4_t exp4_approx(float32x4_t x) { return Y; } -static OPUS_INLINE float celt_exp(float x) +static inline float celt_exp(float x) { float out[4]; float32x4_t X, Y; @@ -67,7 +72,7 @@ static OPUS_INLINE float celt_exp(float x) return out[0]; } -static void softmax(float *y, const float *x, int N) +static inline void softmax(float *y, const float *x, int N) { int i; for (i=0;i<N-3;i+=4) @@ -81,7 +86,7 @@ static void softmax(float *y, const float *x, int N) y[i] = celt_exp(x[i]); } -static void vec_tanh(float *y, const float *x, int N) +static inline void vec_tanh(float *y, const float *x, int N) { int i; for (i=0;i<N-3;i+=4) @@ -103,7 +108,7 @@ static void vec_tanh(float *y, const float *x, int N) } } -static void vec_sigmoid(float *y, const float *x, int N) +static inline void vec_sigmoid(float *y, const float *x, int N) { int i; for (i=0;i<N-3;i+=4) @@ -124,7 +129,7 @@ static void vec_sigmoid(float *y, const float *x, int N) } #endif -static void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) +static inline void sgemv_accum16(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) { int i, j; for (i=0;i<rows;i+=16) @@ -168,7 +173,7 @@ static void sgemv_accum16(float *out, const float *weights, int rows, int cols, } } -static void sparse_sgemv_accum16(float *out, const float *w, int rows, const int *idx, const float *x) +static inline void sparse_sgemv_accum16(float *out, const float *w, int rows, const int *idx, const float *x) { int i, j; for (i=0;i<rows;i+=16) @@ -207,3 +212,75 @@ static void sparse_sgemv_accum16(float *out, const float *w, int rows, const int } } + +#define SCALE (128.f*127.f) +#define SCALE_1 (1.f/128.f/127.f) + +#define MAX_INPUTS 2048 + +static inline void sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, int col_stride, const float *_x) +{ + int i, j; + signed char x[MAX_INPUTS]; + (void)col_stride; + for (i=0;i<rows;i++) out[i] *= SCALE; + for (i=0;i<cols;i++) x[i] = (int)floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + for (j=0;j<cols;j+=4) + { + float * restrict y; + float xj0, xj1, xj2, xj3; + xj0 = x[j+0]; + xj1 = x[j+1]; + xj2 = x[j+2]; + xj3 = x[j+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= SCALE_1; +} + +static inline void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, int cols, const int *idx, const float *_x) +{ + int i, j; + signed char x[MAX_INPUTS]; + for (i=0;i<rows;i++) out[i] *= SCALE; + for (i=0;i<cols;i++) x[i] = floor(.5+127*_x[i]); + for (i=0;i<rows;i+=8) + { + int colblocks; + colblocks = *idx++; + for (j=0;j<colblocks;j++) + { + int pos; + float * restrict y; + int xj0, xj1, xj2, xj3; + pos = 4 * (*idx++); + xj0 = x[pos+0]; + xj1 = x[pos+1]; + xj2 = x[pos+2]; + xj3 = x[pos+3]; + y = &out[i]; + y[0] += (w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3); + y[1] += (w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3); + y[2] += (w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3); + y[3] += (w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3); + y[4] += (w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3); + y[5] += (w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3); + y[6] += (w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3); + y[7] += (w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3); + w += 32; + } + } + for (i=0;i<rows;i++) out[i] *= SCALE_1; +}