diff --git a/libcelt/cwrs.c b/libcelt/cwrs.c index c4b5fe7c4afa3d12ebae58b935013dd628b8de6c..247d5fd0625fd2ce36a030e3a8a6190e211069c6 100644 --- a/libcelt/cwrs.c +++ b/libcelt/cwrs.c @@ -264,128 +264,83 @@ static inline void uprev64(celt_uint64_t *_ui,int _n,celt_uint64_t _ui0){ /*Returns the number of ways of choosing _m elements from a set of size _n with replacement when a sign bit is needed for each unique element. - On input, _u should be initialized to column (_m-1) of U(n,m). - On exit, _u will be initialized to column _m of U(n,m).*/ -celt_uint32_t ncwrs_unext32(int _n,celt_uint32_t *_ui){ - celt_uint32_t ret; - celt_uint32_t ui0; - celt_uint32_t ui1; - int j; - ret=ui0=2; - celt_assert(_n>=2); - j=1; do { - ui1=_ui[j]+_ui[j-1]+ui0; - _ui[j-1]=ui0; - ui0=ui1; - ret+=ui0; - } while (++j<_n); - _ui[j-1]=ui0; - return ret; -} - -celt_uint64_t ncwrs_unext64(int _n,celt_uint64_t *_ui){ - celt_uint64_t ret; - celt_uint64_t ui0; - celt_uint64_t ui1; - int j; - ret=ui0=1; - celt_assert(_n>=2); - j=1; do { - ui1=_ui[j]+_ui[j-1]+ui0; - _ui[j-1]=ui0; - ui0=ui1; - ret+=ui0; - } while (++j<_n); - _ui[j-1]=ui0; - return ret<<=1; -} - -/*Returns the number of ways of choosing _m elements from a set of size _n with - replacement when a sign bit is needed for each unique element. - _u: On exit, _u[i] contains U(i+1,_m).*/ + _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/ celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){ - celt_uint32_t ret; celt_uint32_t um2; int k; - /*If _m==0, _u[] should be set to zero and the return should be 1.*/ - celt_assert(_m>0); - /*We'll overflow our buffer unless _n>=2.*/ - celt_assert(_n>=2); - um2=_u[0]=1; - if(_m<=6){ - if(_m<2){ - k=1; - do _u[k]=1; - while(++k<_n); - } - else{ - k=1; - do _u[k]=(k<<1)+1; - while(++k<_n); - for(k=2;k<_m;k++)unext32(_u,_n,1); - } + int len; + len=_m+2; + _u[0]=0; + _u[1]=um2=1; + if(_n<=6){ + /*If _n==0, _u[0] should be 1 and the rest should be 0.*/ + /*If _n==1, _u[i] should be 1 for i>1.*/ + celt_assert(_n>=2); + /*If _m==0, the following do-while loop will overflow the buffer.*/ + celt_assert(_m>0); + k=2; + do _u[k]=(k<<1)-1; + while(++k<len); + for(k=2;k<_n;k++)unext32(_u+2,_m,(k<<1)+1); } else{ celt_uint32_t um1; celt_uint32_t n2m1; - _u[1]=n2m1=um1=(_m<<1)-1; - for(k=2;k<_n;k++){ + _u[2]=n2m1=um1=(_n<<1)-1; + for(k=3;k<len;k++){ /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/ - _u[k]=um2=imusdiv32even(n2m1,um1,um2,k)+um2; - if(++k>=_n)break; - _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k>>1)+um1; + _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2; + if(++k>=len)break; + _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1; } } - ret=1; - k=1; - do ret+=_u[k]; - while(++k<_n); - return ret<<1; + return _u[_m]+_u[_m+1]; } celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u){ int k; - CELT_MEMSET(_u,0,_n); - if(_m<=0)return 1; - if(_n<=0)return 0; - for(k=1;k<_m;k++)unext64(_u,_n,1); - return ncwrs_unext64(_n,_u); + int len; + len=_m+2; + _u[0]=0; + /*If _n==0, _u[0] should be 1 and the rest should be 0.*/ + /*If _n==1, _u[i] should be 1 for i>1.*/ + celt_assert(_n>=2); + k=1; + do _u[k]=(k<<1)-1; + while(++k<len); + for(k=2;k<_n;k++)unext64(_u+2,_m,(k<<1)+1); + /*TODO: For large _n, an imusdiv64 could make this O(_m) instead of + O(_n*_m), but would require an INV_TABLE twice as large, as well as lots + of 64x64->64 bit multiplies.*/ + return _u[_m]+_u[_m+1]; } /*Returns the _i'th combination of _m elements chosen from a set of size _n with associated sign bits. _y: Returns the vector of pulses. - _u: Must contain entries [1..._n] of column _m of U() on input. + _u: Must contain entries [0..._m+1] of row _n of U() on input. Its contents will be destructively modified.*/ -void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y, - celt_uint32_t *_u){ - celt_uint32_t p; - celt_uint32_t q; - int j; - int k; +void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){ + int j; + int k; celt_assert(_n>0); - p=_nc; - q=0; j=0; k=_m; do{ - int s; - int yj; - p-=q; - q=_u[_n-j-1]; - p-=q; + celt_uint32_t p; + int s; + int yj; + p=_u[k+1]; s=_i>=p; if(s)_i-=p; yj=k; - while(q>_i){ - uprev32(_u,_n-j,--k>0); - p=q; - q=_u[_n-j-1]; - } - _i-=q; + p=_u[k]; + while(p>_i)p=_u[--k]; + _i-=p; yj-=k; _y[j]=yj-(yj<<1&-s); + uprev32(_u,k+2,0); } while(++j<_n); } @@ -393,36 +348,28 @@ void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y, /*Returns the _i'th combination of _m elements chosen from a set of size _n with associated sign bits. _y: Returns the vector of pulses. - _u: Must contain entries [1..._n] of column _m of U() on input. + _u: Must contain entries [0..._m+1] of row _n of U() on input. Its contents will be destructively modified.*/ -void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y, - celt_uint64_t *_u){ - celt_uint64_t p; - celt_uint64_t q; - int j; - int k; +void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_y,celt_uint64_t *_u){ + int j; + int k; celt_assert(_n>0); - p=_nc; - q=0; j=0; k=_m; do{ - int s; - int yj; - p-=q; - q=_u[_n-j-1]; - p-=q; + celt_uint64_t p; + int s; + int yj; + p=_u[k+1]; s=_i>=p; if(s)_i-=p; yj=k; - while(q>_i){ - uprev64(_u,_n-j,--k>0); - p=q; - q=_u[_n-j-1]; - } - _i-=q; + p=_u[k]; + while(p>_i)p=_u[--k]; + _i-=p; yj-=k; _y[j]=yj-(yj<<1&-s); + uprev64(_u,k+2,0); } while(++j<_n); } @@ -433,32 +380,26 @@ void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y, _nc: Returns V(_n,_m).*/ celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, celt_uint32_t *_u){ - celt_uint32_t nc; celt_uint32_t i; int j; int k; /*We can't unroll the first two iterations of the loop unless _n>=2.*/ celt_assert(_n>=2); - nc=1; i=_y[_n-1]<0; _u[0]=0; for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1; k=abs(_y[_n-1]); j=_n-2; - nc+=_u[_m]; i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; while(j-->0){ unext32(_u,_m+2,0); - nc+=_u[_m]; i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; } - /*If _m==0, nc should not be doubled.*/ - celt_assert(_m>0); - *_nc=nc<<1; + *_nc=_u[_m]+_u[_m+1]; return i; } @@ -468,32 +409,26 @@ celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, _nc: Returns V(_n,_m).*/ celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y, celt_uint64_t *_u){ - celt_uint64_t nc; celt_uint64_t i; int j; int k; /*We can't unroll the first two iterations of the loop unless _n>=2.*/ celt_assert(_n>=2); - nc=1; i=_y[_n-1]<0; _u[0]=0; for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1; k=abs(_y[_n-1]); j=_n-2; - nc+=_u[_m]; i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; while(j-->0){ unext64(_u,_m+2,0); - nc+=_u[_m]; i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; } - /*If _m==0, nc should not be doubled.*/ - celt_assert(_m>0); - *_nc=nc<<1; + *_nc=_u[_m]+_u[_m+1]; return i; } @@ -526,7 +461,7 @@ int get_required_bits(int N, int K, int frac) { VARDECL(celt_uint64_t,u); SAVE_STACK; - ALLOC(u,N,celt_uint64_t); + ALLOC(u,K+2,celt_uint64_t); nbits = log2_frac64(ncwrs_u64(N,K,u), frac); RESTORE_STACK; } else { @@ -564,21 +499,17 @@ void encode_pulses(int *_y, int N, int K, ec_enc *enc) static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){ VARDECL(celt_uint32_t,u); - celt_uint32_t nc; SAVE_STACK; - ALLOC(u,_n,celt_uint32_t); - nc=ncwrs_u32(_n,_m,u); - cwrsi32(_n,_m,ec_dec_uint(_dec,nc),nc,_y,u); + ALLOC(u,_m+2,celt_uint32_t); + cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u); RESTORE_STACK; } static inline void decode_pulse64(int _n,int _m,int *_y,ec_dec *_dec){ VARDECL(celt_uint64_t,u); - celt_uint64_t nc; SAVE_STACK; - ALLOC(u,_n,celt_uint64_t); - nc=ncwrs_u64(_n,_m,u); - cwrsi64(_n,_m,ec_dec_uint64(_dec,nc),nc,_y,u); + ALLOC(u,_m+2,celt_uint64_t); + cwrsi64(_n,_m,ec_dec_uint64(_dec,ncwrs_u64(_n,_m,u)),_y,u); RESTORE_STACK; } diff --git a/libcelt/cwrs.h b/libcelt/cwrs.h index 0a81c6116e479ecbec858fae6c21737b7686ae09..5c0d50a0f128b9e9a8615410bbc7d14d51f5ce7f 100644 --- a/libcelt/cwrs.h +++ b/libcelt/cwrs.h @@ -46,8 +46,7 @@ int fits_in64(int _n, int _m); /* 32-bit versions */ celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u); -void cwrsi32(int _n,int _m,celt_uint32_t _i,celt_uint32_t _nc,int *_y, - celt_uint32_t *_u); +void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u); celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, celt_uint32_t *_u); @@ -55,8 +54,7 @@ celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, /* 64-bit versions */ celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u); -void cwrsi64(int _n,int _m,celt_uint64_t _i,celt_uint64_t _nc,int *_y, - celt_uint64_t *_u); +void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_y,celt_uint64_t *_u); celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y, celt_uint64_t *_u); diff --git a/tests/cwrs32-test.c b/tests/cwrs32-test.c index 6b193e39dd072eab62748063b9692cc392d521c1..e97ba64437142456cfb71bc3f626881e8e768c72 100644 --- a/tests/cwrs32-test.c +++ b/tests/cwrs32-test.c @@ -13,7 +13,7 @@ int main(int _argc,char **_argv){ for(n=2;n<=NMAX;n++){ int m; for(m=1;m<=MMAX;m++){ - celt_uint32_t uu[NMAX]; + celt_uint32_t uu[MMAX+2]; celt_uint32_t inc; celt_uint32_t nc; celt_uint32_t i; @@ -21,12 +21,12 @@ int main(int _argc,char **_argv){ inc=nc/10000; if(inc<1)inc=1; for(i=0;i<nc;i+=inc){ - celt_uint32_t u[NMAX>MMAX+2?NMAX:MMAX+2]; + celt_uint32_t u[MMAX+2]; int y[NMAX]; celt_uint32_t v; int k; - memcpy(u,uu,n*sizeof(*u)); - cwrsi32(n,m,i,nc,y,u); + memcpy(u,uu,(m+2)*sizeof(*u)); + cwrsi32(n,m,i,y,u); /*printf("%6u of %u:",i,nc); for(k=0;k<n;k++)printf(" %+3i",y[k]); printf(" ->");*/ diff --git a/tests/cwrs64-test.c b/tests/cwrs64-test.c index e699debdf064defa18798b3536505aa9f000fcb1..548fa223048f4262887ab092e7245f533e68b35c 100644 --- a/tests/cwrs64-test.c +++ b/tests/cwrs64-test.c @@ -14,7 +14,7 @@ int main(int _argc,char **_argv){ for(n=2;n<=NMAX;n+=3){ int m; for(m=1;m<=MMAX;m++){ - celt_uint64_t uu[NMAX]; + celt_uint64_t uu[MMAX+2]; celt_uint64_t inc; celt_uint64_t nc; celt_uint64_t i; @@ -24,12 +24,12 @@ int main(int _argc,char **_argv){ if(inc<1)inc=1; /*printf("%d/%d: %llu",n,m, nc);*/ for(i=0;i<nc;i+=inc){ - celt_uint64_t u[NMAX>MMAX+2?NMAX:MMAX+2]; + celt_uint64_t u[MMAX+2]; int y[NMAX]; celt_uint64_t v; int k; - memcpy(u,uu,n*sizeof(*u)); - cwrsi64(n,m,i,nc,y,u); + memcpy(u,uu,(m+2)*sizeof(*u)); + cwrsi64(n,m,i,y,u); /*printf("%llu of %llu:",i,nc); for(k=0;k<n;k++)printf(" %+3i",y[k]); printf(" ->");*/