diff --git a/dnn/nnet.c b/dnn/nnet.c index cf7d4e10178e5a88ffb2319e718fe94c199a66cc..28ebf26ef8b478b1efb1306de7e72abe28dbf542 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -39,6 +39,8 @@ #include "nnet.h" #include "nnet_data.h" +#define SOFTMAX_HACK + #ifdef __AVX2__ #include <immintrin.h> static __m256 exp8_approx(__m256 X) @@ -340,6 +342,10 @@ void compute_activation(float *output, float *input, int N, int activation) for (i=0;i<N;i++) output[i] = relu(input[i]); } else if (activation == ACTIVATION_SOFTMAX) { +#ifdef SOFTMAX_HACK + for (i=0;i<N;i++) + output[i] = input[i]; +#else float sum = 0; softmax(output, input, N); for (i=0;i<N;i++) { @@ -348,6 +354,7 @@ void compute_activation(float *output, float *input, int N, int activation) sum = 1.f/(sum+1e-30); for (i=0;i<N;i++) output[i] = sum*output[i]; +#endif } else { celt_assert(activation == ACTIVATION_LINEAR); for (i=0;i<N;i++) @@ -619,12 +626,24 @@ int sample_from_pdf(const float *pdf, int N, float exp_boost, float pdf_floor) float tmp[DUAL_FC_OUT_SIZE]; celt_assert(N <= DUAL_FC_OUT_SIZE); sum = 0; +#ifdef SOFTMAX_HACK + for (i=0;i<N;i++) + { + tmp[i] = pdf[i] * (1.f+exp_boost); + } + softmax(tmp, tmp, N); + for (i=0;i<N;i++) + { + sum += tmp[i]; + } +#else /* Decrease the temperature of the sampling. */ for (i=0;i<N;i++) { tmp[i] = pow(pdf[i], 1.f+exp_boost); sum += tmp[i]; } +#endif norm = 1.f/sum; /* Convert tmp to a CDF while subtracting the floor */ tmp[0] = MAX16(0, norm*tmp[0] - pdf_floor);