Commit 44c63350 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

optimisations: Another bunch of simplifications to alg_quant(), mainly to

remove unnecessary copying and some conditional branches.
parent 05974935
......@@ -54,6 +54,23 @@ static inline int find_max16(celt_word16_t *x, int len)
}
#endif
#ifndef OVERRIDE_FIND_MAX32
static inline int find_max32(celt_word32_t *x, int len)
{
celt_word32_t max_corr=-VERY_LARGE16;
int i, id = 0;
for (i=0;i<len;i++)
{
if (x[i] > max_corr)
{
id = i;
max_corr = x[i];
}
}
return id;
}
#endif
#ifndef FIXED_POINT
......
......@@ -93,30 +93,19 @@ static void mix_pitch_and_residual(int *iy, celt_norm_t *X, int N, int K, const
RESTORE_STACK;
}
/** All the info necessary to keep track of a hypothesis during the search */
struct NBest {
celt_word32_t score;
int sign;
int pos;
celt_word32_t xy;
celt_word32_t yy;
celt_word32_t yp;
};
void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *P, ec_enc *enc)
{
VARDECL(celt_norm_t, _y);
VARDECL(celt_norm_t, _ny);
VARDECL(int, _iy);
VARDECL(int, _iny);
VARDECL(celt_norm_t, y);
VARDECL(int, iy);
VARDECL(int, signx);
celt_norm_t *y, *ny;
int *iy, *iny;
int i, j;
VARDECL(celt_word32_t, scores);
int i, j, is;
celt_word16_t s;
int pulsesLeft;
celt_word32_t sum;
celt_word32_t xy, yy, yp;
struct NBest nbest;
celt_word32_t Rpp=0, Rxp=0;
celt_word16_t Rpp;
#ifdef FIXED_POINT
int yshift;
#endif
......@@ -126,17 +115,11 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
yshift = 14-EC_ILOG(K);
#endif
ALLOC(_y, N, celt_norm_t);
ALLOC(_ny, N, celt_norm_t);
ALLOC(_iy, N, int);
ALLOC(_iny, N, int);
ALLOC(y, N, celt_norm_t);
ALLOC(iy, N, int);
ALLOC(signx, N, int);
ALLOC(scores, N, celt_word32_t);
y = _y;
ny = _ny;
iy = _iy;
iny = _iny;
for (j=0;j<N;j++)
{
if (X[j]>0)
......@@ -145,13 +128,12 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
signx[j]=-1;
}
sum = 0;
for (j=0;j<N;j++)
{
Rpp = MAC16_16(Rpp, P[j],P[j]);
Rxp = MAC16_16(Rxp, X[j],P[j]);
sum = MAC16_16(sum, P[j],P[j]);
}
Rpp = ROUND16(Rpp, NORM_SHIFT);
Rxp = ROUND16(Rxp, NORM_SHIFT);
Rpp = ROUND16(sum, NORM_SHIFT);
celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
......@@ -165,6 +147,9 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
while (pulsesLeft > 0)
{
int pulsesAtOnce=1;
int sign;
celt_word32_t Rxy, Ryy, Ryp;
celt_word32_t g;
/* Decide on how many pulses to find at once */
pulsesAtOnce = pulsesLeft/N;
......@@ -172,31 +157,31 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
pulsesAtOnce = 1;
/*printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);*/
nbest.score = -VERY_LARGE32;
for (j=0;j<N;j++)
/* Choose between fast and accurate strategy depending on where we are in the search */
if (pulsesLeft>1)
{
int sign;
/*fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);*/
celt_word32_t Rxy, Ryy, Ryp;
celt_word32_t score;
celt_word32_t g;
celt_word16_t s;
/* Select sign based on X[j] alone */
sign = signx[j];
s = SHL16(sign*pulsesAtOnce, yshift);
/* Updating the sums of the new pulse(s) */
Rxy = xy + MULT16_16(s,X[j]);
Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
Ryp = yp + MULT16_16(s, P[j]);
if (pulsesLeft>1)
for (j=0;j<N;j++)
{
score = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
} else
/* Select sign based on X[j] alone */
sign = signx[j];
s = SHL16(sign*pulsesAtOnce, yshift);
/* Temporary sums of the new pulse(s) */
Rxy = xy + MULT16_16(s,X[j]);
Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
Ryp = yp + MULT16_16(s, P[j]);
scores[j] = MULT32_32_Q31(MULT16_16(ROUND16(Rxy,14),ABS16(ROUND16(Rxy,14))), celt_rcp(SHR32(Ryy,12)));
}
} else {
for (j=0;j<N;j++)
{
/* Select sign based on X[j] alone */
sign = signx[j];
s = SHL16(sign*pulsesAtOnce, yshift);
/* Temporary sums of the new pulse(s) */
Rxy = xy + MULT16_16(s,X[j]);
Ryy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
Ryp = yp + MULT16_16(s, P[j]);
/* Compute the gain such that ||p + g*y|| = 1 */
g = MULT16_32_Q15(
celt_sqrt(MULT16_16(ROUND16(Ryp,14),ROUND16(Ryp,14)) + Ryy -
......@@ -206,54 +191,23 @@ void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, const celt_norm_t *
/* Knowing that gain, what's the error: (x-g*y)^2
(result is negated and we discard x^2 because it's constant) */
/* score = 2.f*g*Rxy - 1.f*g*g*Ryy*NORM_SCALING_1;*/
score = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
scores[j] = 2*MULT16_32_Q14(ROUND16(Rxy,14),g)
- MULT16_32_Q14(EXTRACT16(MULT16_32_Q14(ROUND16(Ryy,14),g)),g);
}
if (score>nbest.score)
{
nbest.score = score;
nbest.pos = j;
nbest.sign = sign;
nbest.xy = Rxy;
nbest.yy = Ryy;
nbest.yp = Ryp;
}
}
celt_assert2(nbest.score > -VERY_LARGE32, "Could not find any match in VQ codebook. Something got corrupted somewhere.");
/* Only now that we've made the final choice, update ny/iny and others */
{
int n;
int is;
celt_norm_t s;
is = nbest.sign*pulsesAtOnce;
s = SHL16(is, yshift);
for (n=0;n<N;n++)
ny[n] = y[n];
ny[nbest.pos] += s;
for (n=0;n<N;n++)
iny[n] = iy[n];
iny[nbest.pos] += is;
xy = nbest.xy;
yy = nbest.yy;
yp = nbest.yp;
}
/* Swap ny/iny with y/iy */
{
celt_norm_t *tmp_ny;
int *tmp_iny;
tmp_ny = ny;
ny = y;
y = tmp_ny;
tmp_iny = iny;
iny = iy;
iy = tmp_iny;
}
j = find_max32(scores, N);
is = signx[j]*pulsesAtOnce;
s = SHL16(is, yshift);
/* Updating the sums of the new pulse(s) */
xy = xy + MULT16_16(s,X[j]);
yy = yy + 2*MULT16_16(s,y[j]) + MULT16_16(s,s);
yp = yp + MULT16_16(s, P[j]);
/* Only now that we've made the final choice, update y/iy */
y[j] += s;
iy[j] += is;
pulsesLeft -= pulsesAtOnce;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment