Commit abf5c8ed authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Merge branch 'cwrs_speedup' (derf's cwrs changes)

Conflicts:
	libcelt/cwrs.c
parents 1dab60cc d910274f
......@@ -92,6 +92,75 @@ int fits_in32(int _n, int _m)
}
}
#define MASK32 (0xFFFFFFFF)
/*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/
static const unsigned INV_TABLE[128]={
0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7,
0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF,
0xF0F0F0F1,0x286BCA1B,0x3CF3CF3D,0xE9BD37A7,
0xC28F5C29,0x684BDA13,0x4F72C235,0xBDEF7BDF,
0x3E0F83E1,0x8AF8AF8B,0x914C1BAD,0x96F96F97,
0xC18F9C19,0x2FA0BE83,0xA4FA4FA5,0x677D46CF,
0x1A1F58D1,0xFAFAFAFB,0x8C13521D,0x586FB587,
0xB823EE09,0xA08AD8F3,0xC10C9715,0xBEFBEFBF,
0xC0FC0FC1,0x07A44C6B,0xA33F128D,0xE327A977,
0xC7E3F1F9,0x962FC963,0x3F2B3885,0x613716AF,
0x781948B1,0x2B2E43DB,0xFCFCFCFD,0x6FD0EB67,
0xFA3F47E9,0xD2FD2FD3,0x3F4FD3F5,0xD4E25B9F,
0x5F02A3A1,0xBF5A814B,0x7C32B16D,0xD3431B57,
0xD8FD8FD9,0x8D28AC43,0xDA6C0965,0xDB195E8F,
0x0FDBC091,0x61F2A4BB,0xDCFDCFDD,0x46FDD947,
0x56BE69C9,0xEB2FDEB3,0x26E978D5,0xEFDFBF7F,
0x0FE03F81,0xC9484E2B,0xE133F84D,0xE1A8C537,
0x077975B9,0x70586723,0xCD29C245,0xFAA11E6F,
0x0FE3C071,0x08B51D9B,0x8CE2CABD,0xBF937F27,
0xA8FE53A9,0x592FE593,0x2C0685B5,0x2EB11B5F,
0xFCD1E361,0x451AB30B,0x72CFE72D,0xDB35A717,
0xFB74A399,0xE80BFA03,0x0D516325,0x1BCB564F,
0xE02E4851,0xD962AE7B,0x10F8ED9D,0x95AEDD07,
0xE9DC0589,0xA18A4473,0xEA53FA95,0xEE936F3F,
0x90948F41,0xEAFEAFEB,0x3D137E0D,0xEF46C0F7,
0x028C1979,0x791064E3,0xC04FEC05,0xE115062F,
0x32385831,0x6E68575B,0xA10D387D,0x6FECF2E7,
0x3FB47F69,0xED4BFB53,0x74FED775,0xDB43BB1F,
0x87654321,0x9BA144CB,0x478BBCED,0xBFB912D7,
0x1FDCD759,0x14B2A7C3,0xCB125CE5,0x437B2E0F,
0x10FEF011,0xD2B3183B,0x386CAB5D,0xEF6AC0C7,
0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF
};
/*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact.
_a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result
fits in 32 bits, but currently the table for multiplicative inverses is only
valid for _d<128.*/
static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b,
celt_uint32_t _c,celt_uint32_t _d){
return (_a*_b-_c)*INV_TABLE[_d]&MASK32;
}
/*Computes (_a*_b-_c)/_d when the quotient is known to be exact.
_d does not actually have to be even, but imusdiv32odd will be faster when
it's odd, so you should use that instead.
_a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the
table for multiplicative inverses is only valid for _d<256).
_b and _c may be arbitrary so long as the arbitrary precision reuslt fits in
32 bits.*/
static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b,
celt_uint32_t _c,celt_uint32_t _d){
unsigned inv;
int mask;
int shift;
int one;
shift=EC_ILOG(_d^_d-1);
inv=INV_TABLE[_d-1>>shift];
shift--;
one=1<<shift;
mask=one-1;
return (_a*(_b>>shift)-(_c>>shift)+
(_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32;
}
/*Computes the next row/column of any recurrence that obeys the relation
u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1].
_ui0 is the base case for the new row/column.*/
......@@ -145,119 +214,127 @@ celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){
/*Returns the number of ways of choosing _m elements from a set of size _n with
replacement when a sign bit is needed for each unique element.
On exit, _u will be initialized to column _m of U(n,m).*/
_u: On exit, _u[i] contains U(i+1,_m).*/
celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){
int k;
CELT_MEMSET(_u,0,_n);
if(_m<=0)return 1;
if(_n<=0)return 0;
for(k=1;k<_m;k++)unext32(_u,_n,2);
return ncwrs_unext32(_n,_u);
celt_uint32_t ret;
celt_uint32_t um2;
int k;
/*If _m==0, _u[] should be set to zero and the return should be 1.*/
celt_assert(_m>0);
/*We'll overflow our buffer unless _n>=2.*/
celt_assert(_n>=2);
um2=_u[0]=1;
if(_m<=6){
if(_m<2){
k=1;
do _u[k]=1;
while(++k<_n);
}
else{
k=1;
do _u[k]=(k<<1)+1;
while(++k<_n);
for(k=2;k<_m;k++)unext32(_u,_n,1);
}
}
else{
celt_uint32_t um1;
celt_uint32_t n2m1;
_u[1]=n2m1=um1=(_m<<1)-1;
for(k=2;k<_n;k++){
/*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/
_u[k]=um2=imusdiv32even(n2m1,um1,um2,k)+um2;
if(++k>=_n)break;
_u[k]=um1=imusdiv32odd(n2m1,um2,um1,k>>1)+um1;
}
}
ret=1;
k=1;
do ret+=_u[k];
while(++k<_n);
return ret<<1;
}
/*Returns the _i'th combination of _m elements chosen from a set of size _n
with associated sign bits.
_x: Returns the combination with elements sorted in ascending order.
_s: Returns the associated sign bits.
_u: Temporary storage already initialized to column _m of U(n,m).
Its contents will be overwritten.*/
void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,celt_uint32_t *_u){
int j;
int k;
for(k=j=0;k<_m;k++){
celt_uint32_t p;
celt_uint32_t t;
p=_u[_n-j-1];
if(k>0){
t=p>>1;
if(t<=_i||_s[k-1])_i+=t;
}
while(p<=_i){
_i-=p;
j++;
p=_u[_n-j-1];
_y: Returns the vector of pulses.
_u: Must contain entries [1..._n] of column _m of U() on input.
Its contents will be destructively modified.*/
void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
celt_uint32_t *_u){
celt_uint32_t p;
celt_uint32_t q;
int j;
int k;
celt_assert(_n>0);
p=_nc;
q=0;
j=0;
k=_m;
do{
int s;
int yj;
p-=q;
q=_u[_n-j-1];
p-=q;
s=_i>=p;
if(s)_i-=p;
yj=k;
while(q>_i){
uprev32(_u,_n-j,--k>0);
p=q;
q=_u[_n-j-1];
}
t=p>>1;
_s[k]=_i>=t;
_x[k]=j;
if(_s[k])_i-=t;
uprev32(_u,_n-j,2);
_i-=q;
yj-=k;
_y[j]=yj-(yj<<1&-s);
}
while(++j<_n);
}
/*Returns the index of the given combination of _m elements chosen from a set
of size _n with associated sign bits.
_x: The combination with elements sorted in ascending order.
_s: The associated sign bits.
_u: Temporary storage already initialized to column _m of U(n,m).
Its contents will be overwritten.*/
celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
_y: The vector of pulses, whose sum of absolute values must be _m.
_nc: Returns V(_n,_m).*/
celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
celt_uint32_t *_u){
celt_uint32_t nc;
celt_uint32_t i;
int j;
int k;
i=0;
for(k=j=0;k<_m;k++){
celt_uint32_t p;
p=_u[_n-j-1];
if(k>0)p>>=1;
while(j<_x[k]){
i+=p;
j++;
p=_u[_n-j-1];
}
if((k==0||_x[k]!=_x[k-1])&&_s[k])i+=p>>1;
uprev32(_u,_n-j,2);
/*We can't unroll the first two iterations of the loop unless _n>=2.*/
celt_assert(_n>=2);
nc=1;
i=_y[_n-1]<0;
_u[0]=0;
for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1;
k=abs(_y[_n-1]);
j=_n-2;
nc+=_u[_m];
i+=_u[k];
k+=abs(_y[j]);
if(_y[j]<0)i+=_u[k+1];
while(j-->0){
unext32(_u,_m+2,0);
nc+=_u[_m];
i+=_u[k];
k+=abs(_y[j]);
if(_y[j]<0)i+=_u[k+1];
}
/*If _m==0, nc should not be doubled.*/
celt_assert(_m>0);
*_nc=nc<<1;
return i;
}
/*Converts a combination _x of _m unit pulses with associated sign bits _s into
a pulse vector _y of length _n.
_y: Returns the vector of pulses.
_x: The combination with elements sorted in ascending order. _x[_m] = -1
_s: The associated sign bits.*/
void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s){
int k;
const int signs[2]={1,-1};
CELT_MEMSET(_y, 0, _n);
k=0; do {
_y[_x[k]]+=signs[_s[k]];
} while (++k<_m);
}
/*Converts a pulse vector vector _y of length _n into a combination of _m unit
pulses with associated sign bits _s.
_x: Returns the combination with elements sorted in ascending order.
_s: Returns the associated sign bits.
_y: The vector of pulses, whose sum of absolute values must be _m.*/
void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y){
int j;
int k;
for(k=j=0;j<_n;j++){
if(_y[j]){
int n;
int s;
n=abs(_y[j]);
s=_y[j]<0;
do {
_x[k]=j;
_s[k]=s;
k++;
} while (--n>0);
}
}
}
static inline void encode_comb32(int _n,int _m,const int *_x,const int *_s,
ec_enc *_enc){
static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){
VARDECL(celt_uint32_t,u);
celt_uint32_t nc;
celt_uint32_t i;
SAVE_STACK;
ALLOC(u,_n,celt_uint32_t);
nc=ncwrs_u32(_n,_m,u);
i=icwrs32(_n,_m,_x,_s,u);
ALLOC(u,_m+2,celt_uint32_t);
i=icwrs32(_n,_m,&nc,_y,u);
ec_enc_uint(_enc,i,nc);
RESTORE_STACK;
}
......@@ -283,21 +360,13 @@ int get_required_bits(int N, int K, int frac)
void encode_pulses(int *_y, int N, int K, ec_enc *enc)
{
VARDECL(int, comb);
VARDECL(int, signs);
SAVE_STACK;
ALLOC(comb, K, int);
ALLOC(signs, K, int);
pulse2comb(N, K, comb, signs, _y);
if (K==0) {
} else if (N==1)
{
ec_enc_bits(enc, _y[0]<0, 1);
} else if(fits_in32(N,K))
{
encode_comb32(N, K, comb, signs, enc);
encode_pulse32(N, K, _y, enc);
} else {
int i;
int count=0;
......@@ -309,25 +378,20 @@ void encode_pulses(int *_y, int N, int K, ec_enc *enc)
encode_pulses(_y, split, count, enc);
encode_pulses(_y+split, N-split, K-count, enc);
}
RESTORE_STACK;
}
static inline void decode_comb32(int _n,int _m,int *_x,int *_s,ec_dec *_dec){
static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){
VARDECL(celt_uint32_t,u);
celt_uint32_t nc;
SAVE_STACK;
ALLOC(u,_n,celt_uint32_t);
cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_x,_s,u);
nc=ncwrs_u32(_n,_m,u);
cwrsi32(_n,_m,ec_dec_uint(_dec,nc),nc,_y,u);
RESTORE_STACK;
}
void decode_pulses(int *_y, int N, int K, ec_dec *dec)
{
VARDECL(int, comb);
VARDECL(int, signs);
SAVE_STACK;
ALLOC(comb, K, int);
ALLOC(signs, K, int);
if (K==0) {
int i;
for (i=0;i<N;i++)
......@@ -341,8 +405,7 @@ void decode_pulses(int *_y, int N, int K, ec_dec *dec)
_y[0] = -K;
} else if(fits_in32(N,K))
{
decode_comb32(N, K, comb, signs, dec);
comb2pulse(N, K, _y, comb, signs);
decode_pulse32(N, K, _y, dec);
} else {
int split;
int count = ec_dec_uint(dec,K+1);
......@@ -350,5 +413,4 @@ void decode_pulses(int *_y, int N, int K, ec_dec *dec)
decode_pulses(_y, split, count, dec);
decode_pulses(_y+split, N-split, K-count, dec);
}
RESTORE_STACK;
}
......@@ -46,28 +46,22 @@ int fits_in64(int _n, int _m);
/* 32-bit versions */
celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u);
void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_x,int *_s,
void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y,
celt_uint32_t *_u);
celt_uint32_t icwrs32(int _n,int _m,const int *_x,const int *_s,
celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y,
celt_uint32_t *_u);
/* 64-bit versions */
celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u);
celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_u);
void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_x,int *_s,
void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y,
celt_uint64_t *_u);
celt_uint64_t icwrs64(int _n,int _m,const int *_x,const int *_s,
celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y,
celt_uint64_t *_u);
void comb2pulse(int _n,int _m,int * restrict _y,const int *_x,const int *_s);
void pulse2comb(int _n,int _m,int *_x,int *_s,const int *_y);
int get_required_bits(int N, int K, int frac);
void encode_pulses(int *_y, int N, int K, ec_enc *enc);
......
......@@ -21,33 +21,24 @@ int main(int _argc,char **_argv){
inc=nc/10000;
if(inc<1)inc=1;
for(i=0;i<nc;i+=inc){
celt_uint32_t u[NMAX];
int x[MMAX];
int s[MMAX];
int x2[MMAX];
int s2[MMAX];
celt_uint32_t u[NMAX>MMAX+2?NMAX:MMAX+2];
int y[NMAX];
celt_uint32_t v;
int k;
memcpy(u,uu,n*sizeof(*u));
cwrsi32(n,m,i,x,s,u);
/*printf("%6u of %u:",i,nc);*/
/*for(k=0;k<m;k++){
printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
}
cwrsi32(n,m,i,nc,y,u);
/*printf("%6u of %u:",i,nc);
for(k=0;k<n;k++)printf(" %+3i",y[k]);
printf(" ->");*/
memcpy(u,uu,n*sizeof(*u));
if(icwrs32(n,m,x,s,u)!=i){
if(icwrs32(n,m,&v,y,u)!=i){
fprintf(stderr,"Combination-index mismatch.\n");
return 1;
}
comb2pulse(n,m,y,x,s);
/*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
printf("\n");*/
pulse2comb(n,m,x2,s2,y);
for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
fprintf(stderr,"Pulse-combination mismatch.\n");
return 1;
if(v!=nc){
fprintf(stderr,"Combination count mismatch.\n");
return 2;
}
/*printf(" %6u\n",i);*/
}
/*printf("\n");*/
}
......
......@@ -24,33 +24,24 @@ int main(int _argc,char **_argv){
if(inc<1)inc=1;
/*printf("%d/%d: %llu",n,m, nc);*/
for(i=0;i<nc;i+=inc){
celt_uint64_t u[NMAX];
int x[MMAX];
int s[MMAX];
int x2[MMAX];
int s2[MMAX];
celt_uint64_t u[NMAX>MMAX+2?NMAX:MMAX+2];
int y[NMAX];
celt_uint64_t v;
int k;
memcpy(u,uu,n*sizeof(*u));
cwrsi64(n,m,i,x,s,u);
cwrsi64(n,m,i,nc,y,u);
/*printf("%llu of %llu:",i,nc);
for(k=0;k<m;k++){
printf(" %c%i",k>0&&x[k]==x[k-1]?' ':s[k]?'-':'+',x[k]);
}
for(k=0;k<n;k++)printf(" %+3i",y[k]);
printf(" ->");*/
memcpy(u,uu,n*sizeof(*u));
if(icwrs64(n,m,x,s,u)!=i){
if(icwrs64(n,m,&v,y,u)!=i){
fprintf(stderr,"Combination-index mismatch.\n");
return 1;
}
comb2pulse(n,m,y,x,s);
/*for(j=0;j<n;j++)printf(" %c%i",y[j]?y[j]<0?'-':'+':' ',abs(y[j]));
printf("\n");*/
pulse2comb(n,m,x2,s2,y);
for(k=0;k<m;k++)if(x[k]!=x2[k]||s[k]!=s2[k]){
fprintf(stderr,"Pulse-combination mismatch.\n");
return 1;
if(v!=nc){
fprintf(stderr,"Combination count mismatch.\n");
return 2;
}
/*printf(" %6llu\n",i);*/
}
/*printf("\n");*/
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment