From d98d8ae087844c616c1f187265ae9d20dea54207 Mon Sep 17 00:00:00 2001 From: "Timothy B. Terriberry" <tterribe@xiph.org> Date: Tue, 26 May 2009 09:09:27 -0400 Subject: [PATCH] CWRS clean-ups and optimizations. Adds specialized O(N*log(K)) versions of cwrsi() and O(N) versions of icwrs() for N={3,4,5}, which allows them to operate all the way up to the theoretical pulse limit without serious performance degredation. Also substantially reduces the computation time and stack usage of get_required_bits(). On x86-64, this gives a 2% speed-up for 256 sample frames, and almost a 16% speed-up for 64 sample frames. --- libcelt/cwrs.c | 790 +++++++++++++++++++++++++++++++++++--------- libcelt/cwrs.h | 21 -- tests/cwrs32-test.c | 147 ++++++++- tests/cwrs64-test.c | 50 --- 4 files changed, 767 insertions(+), 241 deletions(-) delete mode 100644 tests/cwrs64-test.c diff --git a/libcelt/cwrs.c b/libcelt/cwrs.c index 2733e0fc9..4b0f36a52 100644 --- a/libcelt/cwrs.c +++ b/libcelt/cwrs.c @@ -29,15 +29,6 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -/* Functions for encoding and decoding pulse vectors. - These are based on the function - U(n,m) = U(n-1,m) + U(n,m-1) + U(n-1,m-1), - U(n,1) = U(1,m) = 2, - which counts the number of ways of placing m pulses in n dimensions, where - at least one pulse lies in dimension 0. - For more details, see: http://people.xiph.org/~tterribe/notes/cwrs.html -*/ - #ifdef HAVE_CONFIG_H #include "config.h" #endif @@ -80,28 +71,9 @@ int log2_frac(ec_uint32 val, int frac) else return l-1<<frac; } -int fits_in32(int _n, int _m) -{ - static const celt_int16_t maxN[15] = { - 32767, 32767, 32767, 1476, 283, 109, 60, 40, - 29, 24, 20, 18, 16, 14, 13}; - static const celt_int16_t maxM[15] = { - 32767, 32767, 32767, 32767, 1172, 238, 95, 53, - 36, 27, 22, 18, 16, 15, 13}; - if (_n>=14) - { - if (_m>=14) - return 0; - else - return _n <= maxN[_m]; - } else { - return _m <= maxM[_n]; - } -} - #define MASK32 (0xFFFFFFFF) -/*INV_TABLE[i] holds the multiplicative inverse of (2*i-1) mod 2**32.*/ +/*INV_TABLE[i] holds the multiplicative inverse of (2*i+1) mod 2**32.*/ static const celt_uint32_t INV_TABLE[128]={ 0x00000001,0xAAAAAAAB,0xCCCCCCCD,0xB6DB6DB7, 0x38E38E39,0xBA2E8BA3,0xC4EC4EC5,0xEEEEEEEF, @@ -137,12 +109,12 @@ static const celt_uint32_t INV_TABLE[128]={ 0x0E64C149,0x9A020A33,0xE6B41C55,0xFEFEFEFF }; -/*Computes (_a*_b-_c)/(2*_d-1) when the quotient is known to be exact. +/*Computes (_a*_b-_c)/(2*_d+1) when the quotient is known to be exact. _a, _b, _c, and _d may be arbitrary so long as the arbitrary precision result fits in 32 bits, but currently the table for multiplicative inverses is only valid for _d<128.*/ static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b, - celt_uint32_t _c,celt_uint32_t _d){ + celt_uint32_t _c,int _d){ return (_a*_b-_c)*INV_TABLE[_d]&MASK32; } @@ -150,16 +122,18 @@ static inline celt_uint32_t imusdiv32odd(celt_uint32_t _a,celt_uint32_t _b, _d does not actually have to be even, but imusdiv32odd will be faster when it's odd, so you should use that instead. _a and _d are assumed to be small (e.g., _a*_d fits in 32 bits; currently the - table for multiplicative inverses is only valid for _d<256). + table for multiplicative inverses is only valid for _d<=256). _b and _c may be arbitrary so long as the arbitrary precision reuslt fits in 32 bits.*/ static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b, - celt_uint32_t _c,celt_uint32_t _d){ + celt_uint32_t _c,int _d){ celt_uint32_t inv; - int mask; - int shift; - int one; + int mask; + int shift; + int one; + celt_assert(_d>0); shift=EC_ILOG(_d^_d-1); + celt_assert(_d<=256); inv=INV_TABLE[_d-1>>shift]; shift--; one=1<<shift; @@ -168,13 +142,262 @@ static inline celt_uint32_t imusdiv32even(celt_uint32_t _a,celt_uint32_t _b, (_a*(_b&mask)+one-(_c&mask)>>shift)-1)*inv&MASK32; } +/*Compute floor(sqrt(_val)) with exact arithmetic. + This has been tested on all possible 32-bit inputs.*/ +static unsigned isqrt32(celt_uint32_t _val){ + unsigned b; + unsigned g; + int bshift; + /*Uses the second method from + http://www.azillionmonkeys.com/qed/sqroot.html + The main idea is to search for the largest binary digit b such that + (g+b)*(g+b) <= _val, and add it to the solution g.*/ + g=0; + bshift=EC_ILOG(_val)-1>>1; + b=1U<<bshift; + for(;bshift>=0;bshift--){ + celt_uint32_t t; + t=((celt_uint32_t)g<<1)+b<<bshift; + if(t<=_val){ + g+=b; + _val-=t; + } + b>>=1; + } + return g; +} + +/*Compute floor(sqrt(_val)) with exact arithmetic. + This has been tested on all possible 36-bit inputs.*/ +static celt_uint32_t isqrt36(celt_uint64_t _val){ + celt_uint32_t val32; + celt_uint32_t b; + celt_uint32_t g; + int bshift; + g=0; + b=0x20000; + for(bshift=18;bshift-->13;){ + celt_uint64_t t; + t=((celt_uint64_t)g<<1)+b<<bshift; + if(t<=_val){ + g+=b; + _val-=t; + } + b>>=1; + } + val32=(celt_uint32_t)_val; + for(;bshift>=0;bshift--){ + celt_uint32_t t; + t=(g<<1)+b<<bshift; + if(t<=val32){ + g+=b; + val32-=t; + } + b>>=1; + } + return g; +} + +/*Although derived separately, the pulse vector coding scheme is equivalent to + a Pyramid Vector Quantizer \cite{Fis86}. + Some additional notes about an early version appear at + http://people.xiph.org/~tterribe/notes/cwrs.html, but the codebook ordering + and the definitions of some terms have evolved since that was written. + + The conversion from a pulse vector to an integer index (encoding) and back + (decoding) is governed by two related functions, V(N,K) and U(N,K). + + V(N,K) = the number of combinations, with replacement, of N items, taken K + at a time, when a sign bit is added to each item taken at least once (i.e., + the number of N-dimensional unit pulse vectors with K pulses). + One way to compute this is via + V(N,K) = K>0 ? sum(k=1...K,2**k*choose(N,k)*choose(K-1,k-1)) : 1, + where choose() is the binomial function. + A table of values for N<10 and K<10 looks like: + V[10][10] = { + {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {1, 2, 2, 2, 2, 2, 2, 2, 2, 2}, + {1, 4, 8, 12, 16, 20, 24, 28, 32, 36}, + {1, 6, 18, 38, 66, 102, 146, 198, 258, 326}, + {1, 8, 32, 88, 192, 360, 608, 952, 1408, 1992}, + {1, 10, 50, 170, 450, 1002, 1970, 3530, 5890, 9290}, + {1, 12, 72, 292, 912, 2364, 5336, 10836, 20256, 35436}, + {1, 14, 98, 462, 1666, 4942, 12642, 28814, 59906, 115598}, + {1, 16, 128, 688, 2816, 9424, 27008, 68464, 157184, 332688}, + {1, 18, 162, 978, 4482, 16722, 53154, 148626, 374274, 864146} + }; + + U(N,K) = the number of such combinations wherein N-1 objects are taken at + most K-1 at a time. + This is given by + U(N,K) = sum(k=0...K-1,V(N-1,k)) + = K>0 ? (V(N-1,K-1) + V(N,K-1))/2 : 0. + The latter expression also makes clear that U(N,K) is half the number of such + combinations wherein the first object is taken at least once. + Although it may not be clear from either of these definitions, U(N,K) is the + natural function to work with when enumerating the pulse vector codebooks, + not V(N,K). + U(N,K) is not well-defined for N=0, but with the extension + U(0,K) = K>0 ? 0 : 1, + the function becomes symmetric: U(N,K) = U(K,N), with a similar table: + U[10][10] = { + {1, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 1, 3, 5, 7, 9, 11, 13, 15, 17}, + {0, 1, 5, 13, 25, 41, 61, 85, 113, 145}, + {0, 1, 7, 25, 63, 129, 231, 377, 575, 833}, + {0, 1, 9, 41, 129, 321, 681, 1289, 2241, 3649}, + {0, 1, 11, 61, 231, 681, 1683, 3653, 7183, 13073}, + {0, 1, 13, 85, 377, 1289, 3653, 8989, 19825, 40081}, + {0, 1, 15, 113, 575, 2241, 7183, 19825, 48639, 108545}, + {0, 1, 17, 145, 833, 3649, 13073, 40081, 108545, 265729} + }; + + With this extension, V(N,K) may be written in terms of U(N,K): + V(N,K) = U(N,K) + U(N,K+1) + for all N>=0, K>=0. + Thus U(N,K+1) represents the number of combinations where the first element + is positive or zero, and U(N,K) represents the number of combinations where + it is negative. + With a large enough table of U(N,K) values, we could write O(N) encoding + and O(min(N*log(K),N+K)) decoding routines, but such a table would be + prohibitively large for small embedded devices (K may be as large as 32767 + for small N, and N may be as large as 200). + + Both functions obey the same recurrence relation: + V(N,K) = V(N-1,K) + V(N,K-1) + V(N-1,K-1), + U(N,K) = U(N-1,K) + U(N,K-1) + U(N-1,K-1), + for all N>0, K>0, with different initial conditions at N=0 or K=0. + This allows us to construct a row of one of the tables above given the + previous row or the next row. + Thus we can derive O(NK) encoding and decoding routines with O(K) memory + using only addition and subtraction. + + When encoding, we build up from the U(2,K) row and work our way forwards. + When decoding, we need to start at the U(N,K) row and work our way backwards, + which requires a means of computing U(N,K). + U(N,K) may be computed from two previous values with the same N: + U(N,K) = ((2*N-1)*U(N,K-1) - U(N,K-2))/(K-1) + U(N,K-2) + for all N>1, and since U(N,K) is symmetric, a similar relation holds for two + previous values with the same K: + U(N,K>1) = ((2*K-1)*U(N-1,K) - U(N-2,K))/(N-1) + U(N-2,K) + for all K>1. + This allows us to construct an arbitrary row of the U(N,K) table by starting + with the first two values, which are constants. + This saves roughly 2/3 the work in our O(NK) decoding routine, but costs O(K) + multiplications. + Similar relations can be derived for V(N,K), but are not used here. + + For N>0 and K>0, U(N,K) and V(N,K) take on the form of an (N-1)-degree + polynomial for fixed N. + The first few are + U(1,K) = 1, + U(2,K) = 2*K-1, + U(3,K) = (2*K-2)*K+1, + U(4,K) = (((4*K-6)*K+8)*K-3)/3, + U(5,K) = ((((2*K-4)*K+10)*K-8)*K+3)/3, + and + V(1,K) = 2, + V(2,K) = 4*K, + V(3,K) = 4*K*K+2, + V(4,K) = 8*(K*K+2)*K/3, + V(5,K) = ((4*K*K+20)*K*K+6)/3, + for all K>0. + This allows us to derive O(N) encoding and O(N*log(K)) decoding routines for + small N (and indeed decoding is also O(N) for N<3). + + @ARTICLE{Fis86, + author="Thomas R. Fischer", + title="A Pyramid Vector Quantizer", + journal="IEEE Transactions on Information Theory", + volume="IT-32", + number=4, + pages="568--583", + month=Jul, + year=1986 + }*/ + +/*Determines if V(N,K) fits in a 32-bit unsigned integer. + N and K are themselves limited to 15 bits.*/ +int fits_in32(int _n, int _k) +{ + static const celt_int16_t maxN[15] = { + 32767, 32767, 32767, 1476, 283, 109, 60, 40, + 29, 24, 20, 18, 16, 14, 13}; + static const celt_int16_t maxK[15] = { + 32767, 32767, 32767, 32767, 1172, 238, 95, 53, + 36, 27, 22, 18, 16, 15, 13}; + if (_n>=14) + { + if (_k>=14) + return 0; + else + return _n <= maxN[_k]; + } else { + return _k <= maxK[_n]; + } +} + +/*Compute U(1,_k).*/ +static inline unsigned ucwrs1(int _k){ + return _k?1:0; +} + +/*Compute V(1,_k).*/ +static inline unsigned ncwrs1(int _k){ + return _k?2:1; +} + +/*Compute U(2,_k). + Note that this may be called with _k=32768 (maxK[2]+1).*/ +static inline unsigned ucwrs2(unsigned _k){ + return _k?_k+(_k-1):0; +} + +/*Compute V(2,_k).*/ +static inline celt_uint32_t ncwrs2(int _k){ + return _k?4*(celt_uint32_t)_k:1; +} + +/*Compute U(3,_k). + Note that this may be called with _k=32768 (maxK[3]+1).*/ +static inline celt_uint32_t ucwrs3(unsigned _k){ + return _k?(2*(celt_uint32_t)_k-2)*_k+1:0; +} + +/*Compute V(3,_k).*/ +static inline celt_uint32_t ncwrs3(int _k){ + return _k?2*(2*(unsigned)_k*(celt_uint32_t)_k+1):1; +} + +/*Compute U(4,_k).*/ +static inline celt_uint32_t ucwrs4(int _k){ + return _k?imusdiv32odd(2*_k,(2*_k-3)*(celt_uint32_t)_k+4,3,1):0; +} + +/*Compute V(4,_k).*/ +static inline celt_uint32_t ncwrs4(int _k){ + return _k?((_k*(celt_uint32_t)_k+2)*_k)/3<<3:1; +} + +/*Compute U(5,_k).*/ +static inline celt_uint32_t ucwrs5(int _k){ + return _k?(((((_k-2)*(unsigned)_k+5)*(celt_uint32_t)_k-4)*_k)/3<<1)+1:0; +} + +/*Compute V(5,_k).*/ +static inline celt_uint32_t ncwrs5(int _k){ + return _k?(((_k*(unsigned)_k+5)*(celt_uint32_t)_k*_k)/3<<2)+2:1; +} + /*Computes the next row/column of any recurrence that obeys the relation u[i][j]=u[i-1][j]+u[i][j-1]+u[i-1][j-1]. _ui0 is the base case for the new row/column.*/ -static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){ +static inline void unext(celt_uint32_t *_ui,unsigned _len,celt_uint32_t _ui0){ celt_uint32_t ui1; - int j; - /* doing a do-while would overrun the array if we had less than 2 samples */ + unsigned j; + /*This do-while will overrun the array if we don't have storage for at least + 2 values.*/ j=1; do { ui1=UADD32(UADD32(_ui[j],_ui[j-1]),_ui0); _ui[j-1]=_ui0; @@ -186,10 +409,11 @@ static inline void unext32(celt_uint32_t *_ui,int _len,celt_uint32_t _ui0){ /*Computes the previous row/column of any recurrence that obeys the relation u[i-1][j]=u[i][j]-u[i][j-1]-u[i-1][j-1]. _ui0 is the base case for the new row/column.*/ -static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){ +static inline void uprev(celt_uint32_t *_ui,unsigned _n,celt_uint32_t _ui0){ celt_uint32_t ui1; - int j; - /* doing a do-while would overrun the array if we had less than 2 samples */ + unsigned j; + /*This do-while will overrun the array if we don't have storage for at least + 2 values.*/ j=1; do { ui1=USUB32(USUB32(_ui[j],_ui[j-1]),_ui0); _ui[j-1]=_ui0; @@ -198,179 +422,430 @@ static inline void uprev32(celt_uint32_t *_ui,int _n,celt_uint32_t _ui0){ _ui[j-1]=_ui0; } -/*Returns the number of ways of choosing _m elements from a set of size _n with - replacement when a sign bit is needed for each unique element. - _u: On exit, _u[i] contains U(_n,i) for i in [0..._m+1].*/ -celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u){ +/*Compute V(_n,_k), as well as U(_n,0..._k+1). + _u: On exit, _u[i] contains U(_n,i) for i in [0..._k+1].*/ +static celt_uint32_t ncwrs_urow(unsigned _n,unsigned _k,celt_uint32_t *_u){ celt_uint32_t um2; - int k; - int len; - len=_m+2; + unsigned len; + unsigned k; + len=_k+2; + /*We require storage at least 3 values (e.g., _k>0).*/ + celt_assert(len>=3); _u[0]=0; _u[1]=um2=1; - if(_n<=6 || _m>255){ + if(_n<=6 || _k>255){ /*If _n==0, _u[0] should be 1 and the rest should be 0.*/ /*If _n==1, _u[i] should be 1 for i>1.*/ celt_assert(_n>=2); - /*If _m==0, the following do-while loop will overflow the buffer.*/ - celt_assert(_m>0); + /*If _k==0, the following do-while loop will overflow the buffer.*/ + celt_assert(_k>0); k=2; do _u[k]=(k<<1)-1; while(++k<len); - for(k=2;k<_n;k++) - unext32(_u+1,_m+1,1); + for(k=2;k<_n;k++)unext(_u+1,_k+1,1); } else{ celt_uint32_t um1; celt_uint32_t n2m1; _u[2]=n2m1=um1=(_n<<1)-1; for(k=3;k<len;k++){ - /*U(n,m) = ((2*n-1)*U(n,m-1)-U(n,m-2))/(m-1) + U(n,m-2)*/ + /*U(N,K) = ((2*N-1)*U(N,K-1)-U(N,K-2))/(K-1) + U(N,K-2)*/ _u[k]=um2=imusdiv32even(n2m1,um1,um2,k-1)+um2; if(++k>=len)break; _u[k]=um1=imusdiv32odd(n2m1,um2,um1,k-1>>1)+um1; } } - return _u[_m]+_u[_m+1]; + return _u[_k]+_u[_k+1]; } +/*Returns the _i'th combination of _k elements (at most 32767) chosen from a + set of size 1 with associated sign bits. + _y: Returns the vector of pulses.*/ +static inline void cwrsi1(int _k,celt_uint32_t _i,int *_y){ + int s; + s=-(int)_i; + _y[0]=_k+s^s; +} -/*Returns the _i'th combination of _m elements chosen from a set of size _n +/*Returns the _i'th combination of _k elements (at most 32767) chosen from a + set of size 2 with associated sign bits. + _y: Returns the vector of pulses.*/ +static inline void cwrsi2(int _k,celt_uint32_t _i,int *_y){ + celt_uint32_t p; + int s; + int yj; + p=ucwrs2(_k+1U); + s=-(_i>=p); + _i-=p&s; + yj=_k; + _k=_i+1>>1; + p=ucwrs2(_k); + _i-=p; + yj-=_k; + _y[0]=yj+s^s; + cwrsi1(_k,_i,_y+1); +} + +/*Returns the _i'th combination of _k elements (at most 32767) chosen from a + set of size 3 with associated sign bits. + _y: Returns the vector of pulses.*/ +static void cwrsi3(int _k,celt_uint32_t _i,int *_y){ + celt_uint32_t p; + int s; + int yj; + p=ucwrs3(_k+1U); + s=-(_i>=p); + _i-=p&s; + yj=_k; + /*Finds the maximum _k such that ucwrs3(_k)<=_i (tested for all + _i<2147418113=U(3,32768)).*/ + _k=_i>0?isqrt32(2*_i-1)+1>>1:0; + p=ucwrs3(_k); + _i-=p; + yj-=_k; + _y[0]=yj+s^s; + cwrsi2(_k,_i,_y+1); +} + +/*Returns the _i'th combination of _k elements (at most 1172) chosen from a set + of size 4 with associated sign bits. + _y: Returns the vector of pulses.*/ +static void cwrsi4(int _k,celt_uint32_t _i,int *_y){ + celt_uint32_t p; + int s; + int yj; + int kl; + int kr; + p=ucwrs4(_k+1); + s=-(_i>=p); + _i-=p&s; + yj=_k; + /*We could solve a cubic for k here, but the form of the direct solution does + not lend itself well to exact integer arithmetic. + Instead we do a binary search on U(4,K).*/ + kl=0; + kr=_k; + for(;;){ + _k=kl+kr>>1; + p=ucwrs4(_k); + if(p<_i){ + if(_k>=kr)break; + kl=_k+1; + } + else if(p>_i)kr=_k-1; + else break; + } + _i-=p; + yj-=_k; + _y[0]=yj+s^s; + cwrsi3(_k,_i,_y+1); +} + +/*Returns the _i'th combination of _k elements (at most 238) chosen from a set + of size 5 with associated sign bits. + _y: Returns the vector of pulses.*/ +static void cwrsi5(int _k,celt_uint32_t _i,int *_y){ + celt_uint32_t p; + int s; + int yj; + p=ucwrs5(_k+1); + s=-(_i>=p); + _i-=p&s; + yj=_k; + /*Finds the maximum _k such that ucwrs5(_k)<=_i (tested for all + _i<2157192969=U(5,239)).*/ + if(_i>=0x2AAAAAA9UL)_k=isqrt32(2*isqrt36(10+6*(celt_uint64_t)_i)-7)+1>>1; + else _k=_i>0?isqrt32(2*(celt_uint32_t)isqrt32(10+6*_i)-7)+1>>1:0; + p=ucwrs5(_k); + _i-=p; + yj-=_k; + _y[0]=yj+s^s; + cwrsi4(_k,_i,_y+1); +} + +/*Returns the _i'th combination of _k elements chosen from a set of size _n with associated sign bits. _y: Returns the vector of pulses. - _u: Must contain entries [0..._m+1] of row _n of U() on input. + _u: Must contain entries [0..._k+1] of row _n of U() on input. Its contents will be destructively modified.*/ -void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u){ +static void cwrsi(int _n,int _k,celt_uint32_t _i,int *_y,celt_uint32_t *_u){ int j; - int k; celt_assert(_n>0); j=0; - k=_m; do{ celt_uint32_t p; int s; int yj; - p=_u[k+1]; - s=_i>=p; - if(s)_i-=p; - yj=k; - p=_u[k]; - while(p>_i)p=_u[--k]; + p=_u[_k+1]; + s=-(_i>=p); + _i-=p&s; + yj=_k; + p=_u[_k]; + while(p>_i)p=_u[--_k]; _i-=p; - yj-=k; - _y[j]=yj-(yj<<1&-s); - uprev32(_u,k+2,0); + yj-=_k; + _y[j]=yj+s^s; + uprev(_u,_k+2,0); } while(++j<_n); } -/*Returns the index of the given combination of _m elements chosen from a set +/*Returns the index of the given combination of K elements chosen from a set + of size 1 with associated sign bits. + _y: The vector of pulses, whose sum of absolute values is K. + _k: Returns K.*/ +static inline celt_uint32_t icwrs1(const int *_y,int *_k){ + *_k=abs(_y[0]); + return _y[0]<0; +} + +/*Returns the index of the given combination of K elements chosen from a set + of size 2 with associated sign bits. + _y: The vector of pulses, whose sum of absolute values is K. + _k: Returns K.*/ +static inline celt_uint32_t icwrs2(const int *_y,int *_k){ + celt_uint32_t i; + int k; + i=icwrs1(_y+1,&k); + i+=ucwrs2(k); + k+=abs(_y[0]); + if(_y[0]<0)i+=ucwrs2(k+1U); + *_k=k; + return i; +} + +/*Returns the index of the given combination of K elements chosen from a set + of size 3 with associated sign bits. + _y: The vector of pulses, whose sum of absolute values is K. + _k: Returns K.*/ +static inline celt_uint32_t icwrs3(const int *_y,int *_k){ + celt_uint32_t i; + int k; + i=icwrs2(_y+1,&k); + i+=ucwrs3(k); + k+=abs(_y[0]); + if(_y[0]<0)i+=ucwrs3(k+1U); + *_k=k; + return i; +} + +/*Returns the index of the given combination of K elements chosen from a set + of size 4 with associated sign bits. + _y: The vector of pulses, whose sum of absolute values is K. + _k: Returns K.*/ +static inline celt_uint32_t icwrs4(const int *_y,int *_k){ + celt_uint32_t i; + int k; + i=icwrs3(_y+1,&k); + i+=ucwrs4(k); + k+=abs(_y[0]); + if(_y[0]<0)i+=ucwrs4(k+1); + *_k=k; + return i; +} + +/*Returns the index of the given combination of K elements chosen from a set + of size 5 with associated sign bits. + _y: The vector of pulses, whose sum of absolute values is K. + _k: Returns K.*/ +static inline celt_uint32_t icwrs5(const int *_y,int *_k){ + celt_uint32_t i; + int k; + i=icwrs4(_y+1,&k); + i+=ucwrs5(k); + k+=abs(_y[0]); + if(_y[0]<0)i+=ucwrs5(k+1); + *_k=k; + return i; +} + +/*Returns the index of the given combination of K elements chosen from a set of size _n with associated sign bits. - _y: The vector of pulses, whose sum of absolute values must be _m. - _nc: Returns V(_n,_m).*/ -celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, + _y: The vector of pulses, whose sum of absolute values must be _k. + _nc: Returns V(_n,_k).*/ +celt_uint32_t icwrs(int _n,int _k,celt_uint32_t *_nc,const int *_y, celt_uint32_t *_u){ celt_uint32_t i; int j; int k; /*We can't unroll the first two iterations of the loop unless _n>=2.*/ celt_assert(_n>=2); - i=_y[_n-1]<0; _u[0]=0; - for(k=1;k<=_m+1;k++)_u[k]=(k<<1)-1; - k=abs(_y[_n-1]); + for(k=1;k<=_k+1;k++)_u[k]=(k<<1)-1; + i=icwrs1(_y+_n-1,&k); j=_n-2; i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; while(j-->0){ - unext32(_u,_m+2,0); + unext(_u,_k+2,0); i+=_u[k]; k+=abs(_y[j]); if(_y[j]<0)i+=_u[k+1]; } - *_nc=_u[_m]+_u[_m+1]; + *_nc=_u[k]+_u[k+1]; return i; } -static inline void encode_pulse32(int _n,int _m,const int *_y,ec_enc *_enc){ - VARDECL(celt_uint32_t,u); - celt_uint32_t nc; - celt_uint32_t i; - SAVE_STACK; - ALLOC(u,_m+2,celt_uint32_t); - i=icwrs32(_n,_m,&nc,_y,u); - ec_enc_uint(_enc,i,nc); - RESTORE_STACK; + +/*Computes get_required_bits when splitting is required. + _left_bits and _right_bits must contain the required bits for the left and + right sides of the split, respectively (which themselves may require + splitting).*/ +static void get_required_split_bits(celt_int16_t *_bits, + const celt_int16_t *_left_bits,const celt_int16_t *_right_bits, + int _n,int _maxk,int _frac){ + int k; + for(k=_maxk;k-->0;){ + /*If we've reached a k where everything fits in 32 bits, evaluate the + remaining required bits directly.*/ + if(fits_in32(_n,k)){ + get_required_bits(_bits,_n,k+1,_frac); + break; + } + else{ + int worst_bits; + int i; + /*Due to potentially recursive splitting, it's difficult to derive an + analytic expression for the location of the worst-case split index. + We simply check them all.*/ + worst_bits=0; + for(i=0;i<=k;i++){ + int split_bits; + split_bits=_left_bits[i]+_right_bits[k-i]; + if(split_bits>worst_bits)worst_bits=split_bits; + } + _bits[k]=log2_frac(k+1,_frac)+worst_bits; + } + } } -int get_required_bits32(int N, int K, int frac) -{ - int nbits; - VARDECL(celt_uint32_t,u); - SAVE_STACK; - ALLOC(u,K+2,celt_uint32_t); - nbits = log2_frac(ncwrs_u32(N,K,u), frac); - RESTORE_STACK; - return nbits; +/*Computes get_required_bits for a pair of N values. + _n1 and _n2 must either be equal or two consecutive integers. + Returns the buffer used to store the required bits for _n2, which is either + _bits1 if _n1==_n2 or _bits2 if _n1+1==_n2.*/ +static celt_int16_t *get_required_bits_pair(celt_int16_t *_bits1, + celt_int16_t *_bits2,celt_int16_t *_tmp,int _n1,int _n2,int _maxk,int _frac){ + celt_int16_t *tmp2; + /*If we only need a single set of required bits...*/ + if(_n1==_n2){ + /*Stop recursing if everything fits.*/ + if(fits_in32(_n1,_maxk-1))get_required_bits(_bits1,_n1,_maxk,_frac); + else{ + _tmp=get_required_bits_pair(_bits2,_tmp,_bits1, + _n1>>1,_n1+1>>1,_maxk,_frac); + get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac); + } + return _bits1; + } + /*Otherwise we need two distinct sets...*/ + celt_assert(_n1+1==_n2); + /*Stop recursing if everything fits.*/ + if(fits_in32(_n2,_maxk-1)){ + get_required_bits(_bits1,_n1,_maxk,_frac); + get_required_bits(_bits2,_n2,_maxk,_frac); + } + /*Otherwise choose an evaluation order that doesn't require extra buffers.*/ + else if(_n1&1){ + /*This special case isn't really needed, but can save some work.*/ + if(fits_in32(_n1,_maxk-1)){ + tmp2=get_required_bits_pair(_tmp,_bits1,_bits2, + _n2>>1,_n2>>1,_maxk,_frac); + get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac); + get_required_bits(_bits1,_n1,_maxk,_frac); + } + else{ + _tmp=get_required_bits_pair(_bits2,_tmp,_bits1, + _n1>>1,_n1+1>>1,_maxk,_frac); + get_required_split_bits(_bits1,_bits2,_tmp,_n1,_maxk,_frac); + get_required_split_bits(_bits2,_tmp,_tmp,_n2,_maxk,_frac); + } + } + else{ + /*There's no need to special case _n1 fitting by itself, since _n2 requires + us to recurse for both values anyway.*/ + tmp2=get_required_bits_pair(_tmp,_bits1,_bits2, + _n2>>1,_n2+1>>1,_maxk,_frac); + get_required_split_bits(_bits2,_tmp,tmp2,_n2,_maxk,_frac); + get_required_split_bits(_bits1,_tmp,_tmp,_n1,_maxk,_frac); + } + return _bits2; } -void get_required_bits(celt_int16_t *bits,int N, int MAXK, int frac) -{ - int k; - /*We special case k==0 below, since fits_in32 could reject it for large N.*/ - celt_assert(MAXK>0); - if(fits_in32(N,MAXK-1)){ - bits[0]=0; - /*This could be sped up one heck of a lot if we didn't recompute u in - ncwrs_u32 every time.*/ - for(k=1;k<MAXK;k++)bits[k]=get_required_bits32(N,k,frac); - } - else{ - VARDECL(celt_int16_t,n1bits); - VARDECL(celt_int16_t,_n2bits); - celt_int16_t *n2bits; +void get_required_bits(celt_int16_t *_bits,int _n,int _maxk,int _frac){ + int k; + /*_maxk==0 => there's nothing to do.*/ + celt_assert(_maxk>0); + if(fits_in32(_n,_maxk-1)){ + _bits[0]=0; + if(_maxk>1){ + VARDECL(celt_uint32_t,u); SAVE_STACK; - ALLOC(n1bits,MAXK,celt_int16_t); - ALLOC(_n2bits,MAXK,celt_int16_t); - get_required_bits(n1bits,(N+1)/2,MAXK,frac); - if(N&1){ - n2bits=_n2bits; - get_required_bits(n2bits,N/2,MAXK,frac); - }else{ - n2bits=n1bits; - } - bits[0]=0; - for(k=1;k<MAXK;k++){ - if(fits_in32(N,k))bits[k]=get_required_bits32(N,k,frac); - else{ - int worst_bits; - int i; - worst_bits=0; - for(i=0;i<=k;i++){ - int split_bits; - split_bits=n1bits[i]+n2bits[k-i]; - if(split_bits>worst_bits)worst_bits=split_bits; - } - bits[k]=log2_frac(k+1,frac)+worst_bits; - } - } + ALLOC(u,_maxk+1U,celt_uint32_t); + ncwrs_urow(_n,_maxk-1,u); + for(k=1;k<_maxk;k++)_bits[k]=log2_frac(u[k]+u[k+1],_frac); RESTORE_STACK; - } + } + } + else{ + VARDECL(celt_int16_t,n1bits); + VARDECL(celt_int16_t,n2bits_buf); + celt_int16_t *n2bits; + SAVE_STACK; + ALLOC(n1bits,_maxk,celt_int16_t); + ALLOC(n2bits_buf,_maxk,celt_int16_t); + n2bits=get_required_bits_pair(n1bits,n2bits_buf,_bits, + _n>>1,_n+1>>1,_maxk,_frac); + get_required_split_bits(_bits,n1bits,n2bits,_n,_maxk,_frac); + RESTORE_STACK; + } } +static inline void encode_pulses32(int _n,int _k,const int *_y,ec_enc *_enc){ + celt_uint32_t i; + switch(_n){ + case 1:{ + i=icwrs1(_y,&_k); + celt_assert(ncwrs1(_k)==2); + ec_enc_bits(_enc,i,1); + }break; + case 2:{ + i=icwrs2(_y,&_k); + ec_enc_uint(_enc,i,ncwrs2(_k)); + }break; + case 3:{ + i=icwrs3(_y,&_k); + ec_enc_uint(_enc,i,ncwrs3(_k)); + }break; + case 4:{ + i=icwrs4(_y,&_k); + ec_enc_uint(_enc,i,ncwrs4(_k)); + }break; + case 5:{ + i=icwrs5(_y,&_k); + ec_enc_uint(_enc,i,ncwrs5(_k)); + }break; + default:{ + VARDECL(celt_uint32_t,u); + celt_uint32_t nc; + SAVE_STACK; + ALLOC(u,_k+2U,celt_uint32_t); + i=icwrs(_n,_k,&nc,_y,u); + ec_enc_uint(_enc,i,nc); + RESTORE_STACK; + }break; + } +} + void encode_pulses(int *_y, int N, int K, ec_enc *enc) { if (K==0) { - } else if (N==1) - { - ec_enc_bits(enc, _y[0]<0, 1); } else if(fits_in32(N,K)) { - encode_pulse32(N, K, _y, enc); + encode_pulses32(N, K, _y, enc); } else { int i; int count=0; @@ -384,12 +859,24 @@ void encode_pulses(int *_y, int N, int K, ec_enc *enc) } } -static inline void decode_pulse32(int _n,int _m,int *_y,ec_dec *_dec){ - VARDECL(celt_uint32_t,u); - SAVE_STACK; - ALLOC(u,_m+2,celt_uint32_t); - cwrsi32(_n,_m,ec_dec_uint(_dec,ncwrs_u32(_n,_m,u)),_y,u); - RESTORE_STACK; +static inline void decode_pulses32(int _n,int _k,int *_y,ec_dec *_dec){ + switch(_n){ + case 1:{ + celt_assert(ncwrs1(_k)==2); + cwrsi1(_k,ec_dec_bits(_dec,1),_y); + }break; + case 2:cwrsi2(_k,ec_dec_uint(_dec,ncwrs2(_k)),_y);break; + case 3:cwrsi3(_k,ec_dec_uint(_dec,ncwrs3(_k)),_y);break; + case 4:cwrsi4(_k,ec_dec_uint(_dec,ncwrs4(_k)),_y);break; + case 5:cwrsi5(_k,ec_dec_uint(_dec,ncwrs5(_k)),_y);break; + default:{ + VARDECL(celt_uint32_t,u); + SAVE_STACK; + ALLOC(u,_k+2U,celt_uint32_t); + cwrsi(_n,_k,ec_dec_uint(_dec,ncwrs_urow(_n,_k,u)),_y,u); + RESTORE_STACK; + } + } } void decode_pulses(int *_y, int N, int K, ec_dec *dec) @@ -398,16 +885,9 @@ void decode_pulses(int *_y, int N, int K, ec_dec *dec) int i; for (i=0;i<N;i++) _y[i] = 0; - } else if (N==1) - { - int s = ec_dec_bits(dec, 1); - if (s==0) - _y[0] = K; - else - _y[0] = -K; } else if(fits_in32(N,K)) { - decode_pulse32(N, K, _y, dec); + decode_pulses32(N, K, _y, dec); } else { int split; int count = ec_dec_uint(dec,K+1); diff --git a/libcelt/cwrs.h b/libcelt/cwrs.h index eabdf75e1..6da087536 100644 --- a/libcelt/cwrs.h +++ b/libcelt/cwrs.h @@ -38,29 +38,8 @@ int log2_frac(ec_uint32 val, int frac); -/* Returns log of an integer with fractional accuracy */ -int log2_frac64(ec_uint64 val, int frac); /* Whether the CWRS codebook will fit into 32 bits */ int fits_in32(int _n, int _m); -/* Whether the CWRS codebook will fit into 64 bits */ -int fits_in64(int _n, int _m); - -/* 32-bit versions */ -celt_uint32_t ncwrs_u32(int _n,int _m,celt_uint32_t *_u); - -void cwrsi32(int _n,int _m,celt_uint32_t _i,int *_y,celt_uint32_t *_u); - -celt_uint32_t icwrs32(int _n,int _m,celt_uint32_t *_nc,const int *_y, - celt_uint32_t *_u); - -/* 64-bit versions */ -celt_uint64_t ncwrs_u64(int _n,int _m,celt_uint64_t *_u); - -void cwrsi64(int _n,int _m,celt_uint64_t _i,int *_y,celt_uint64_t *_u); - -celt_uint64_t icwrs64(int _n,int _m,celt_uint64_t *_nc,const int *_y, - celt_uint64_t *_u); - void get_required_bits(celt_int16_t *bits, int N, int K, int frac); diff --git a/tests/cwrs32-test.c b/tests/cwrs32-test.c index 22f40eadc..7332ca5bb 100644 --- a/tests/cwrs32-test.c +++ b/tests/cwrs32-test.c @@ -3,52 +3,169 @@ #endif #include <stdio.h> -#include "cwrs.h" #include <string.h> -#include "../libcelt/cwrs.c" #include "../libcelt/rangeenc.c" #include "../libcelt/rangedec.c" #include "../libcelt/entenc.c" #include "../libcelt/entdec.c" #include "../libcelt/entcode.c" +#include "../libcelt/cwrs.c" + +#define NMAX (14) +#define KMAX (32767) + +static const int kmax[15]={ + 32767,32767,32767,32767, 1172, + 238, 95, 53, 36, 27, + 22, 18, 16, 15, 13 +}; -#define NMAX (10) -#define MMAX (9) int main(int _argc,char **_argv){ int n; for(n=2;n<=NMAX;n++){ - int m; - for(m=1;m<=MMAX;m++){ - celt_uint32_t uu[MMAX+2]; + int dk; + int k; + dk=kmax[n]>7?kmax[n]/7:1; + k=1-dk; + do{ + celt_uint32_t uu[KMAX+2U]; celt_uint32_t inc; celt_uint32_t nc; celt_uint32_t i; - nc=ncwrs_u32(n,m,uu); + k=kmax[n]-dk<k?kmax[n]:k+dk; + printf("Testing CWRS with N=%i, K=%i...\n",n,k); + nc=ncwrs_urow(n,k,uu); inc=nc/10000; if(inc<1)inc=1; for(i=0;i<nc;i+=inc){ - celt_uint32_t u[MMAX+2]; + celt_uint32_t u[KMAX+2U]; int y[NMAX]; + int yy[5]; celt_uint32_t v; - memcpy(u,uu,(m+2)*sizeof(*u)); - cwrsi32(n,m,i,y,u); + celt_uint32_t ii; + int kk; + int j; + memcpy(u,uu,(k+2U)*sizeof(*u)); + cwrsi(n,k,i,y,u); /*printf("%6u of %u:",i,nc); - for(k=0;k<n;k++)printf(" %+3i",y[k]); + for(j=0;j<n;j++)printf(" %+3i",y[j]); printf(" ->");*/ - if(icwrs32(n,m,&v,y,u)!=i){ - fprintf(stderr,"Combination-index mismatch.\n"); + ii=icwrs(n,k,&v,y,u); + if(ii!=i){ + fprintf(stderr,"Combination-index mismatch (%lu!=%lu).\n", + (long)ii,(long)i); return 1; } if(v!=nc){ - fprintf(stderr,"Combination count mismatch.\n"); + fprintf(stderr,"Combination count mismatch (%lu!=%lu).\n", + (long)v,(long)nc); return 2; } + if(n==2){ + cwrsi2(k,i,yy); + for(j=0;j<2;j++)if(yy[j]!=y[j]){ + fprintf(stderr,"N=2 pulse vector mismatch ({%i,%i}!={%i,%i}).\n", + yy[0],yy[1],y[0],y[1]); + return 3; + } + ii=icwrs2(yy,&kk); + if(ii!=i){ + fprintf(stderr,"N=2 combination-index mismatch (%lu!=%lu).\n", + (long)ii,(long)i); + return 4; + } + if(kk!=k){ + fprintf(stderr,"N=2 pulse count mismatch (%i,%i).\n",kk,k); + return 5; + } + v=ncwrs2(k); + if(v!=nc){ + fprintf(stderr,"N=2 combination count mismatch (%lu,%lu).\n", + (long)v,(long)nc); + return 6; + } + } + else if(n==3){ + cwrsi3(k,i,yy); + for(j=0;j<3;j++)if(yy[j]!=y[j]){ + fprintf(stderr,"N=3 pulse vector mismatch " + "({%i,%i,%i}!={%i,%i,%i}).\n",yy[0],yy[1],yy[2],y[0],y[1],y[2]); + return 7; + } + ii=icwrs3(yy,&kk); + if(ii!=i){ + fprintf(stderr,"N=3 combination-index mismatch (%lu!=%lu).\n", + (long)ii,(long)i); + return 8; + } + if(kk!=k){ + fprintf(stderr,"N=3 pulse count mismatch (%i!=%i).\n",kk,k); + return 9; + } + v=ncwrs3(k); + if(v!=nc){ + fprintf(stderr,"N=3 combination count mismatch (%lu!=%lu).\n", + (long)v,(long)nc); + return 10; + } + } + else if(n==4){ + cwrsi4(k,i,yy); + for(j=0;j<4;j++)if(yy[j]!=y[j]){ + fprintf(stderr,"N=4 pulse vector mismatch " + "({%i,%i,%i,%i}!={%i,%i,%i,%i}.\n", + yy[0],yy[1],yy[2],yy[3],y[0],y[1],y[2],y[3]); + return 11; + } + ii=icwrs4(yy,&kk); + if(ii!=i){ + fprintf(stderr,"N=4 combination-index mismatch (%lu!=%lu).\n", + (long)ii,(long)i); + return 12; + } + if(kk!=k){ + fprintf(stderr,"N=4 pulse count mismatch (%i!=%i).\n",kk,k); + return 13; + } + v=ncwrs4(k); + if(v!=nc){ + fprintf(stderr,"N=4 combination count mismatch (%lu!=%lu).\n", + (long)v,(long)nc); + return 14; + } + } + else if(n==5){ + cwrsi5(k,i,yy); + for(j=0;j<5;j++)if(yy[j]!=y[j]){ + fprintf(stderr,"N=5 pulse vector mismatch " + "({%i,%i,%i,%i,%i}!={%i,%i,%i,%i,%i}).\n", + yy[0],yy[1],yy[2],yy[3],yy[4],y[0],y[1],y[2],y[3],y[4]); + return 15; + } + ii=icwrs5(yy,&kk); + if(ii!=i){ + fprintf(stderr,"N=5 combination-index mismatch (%lu!=%lu).\n", + (long)ii,(long)i); + return 16; + } + if(kk!=k){ + fprintf(stderr,"N=5 pulse count mismatch (%i!=%i).\n",kk,k); + return 17; + } + v=ncwrs5(k); + if(v!=nc){ + fprintf(stderr,"N=5 combination count mismatch (%lu!=%lu).\n", + (long)v,(long)nc); + return 18; + } + } /*printf(" %6u\n",i);*/ } /*printf("\n");*/ } + while(k<kmax[n]); } return 0; } diff --git a/tests/cwrs64-test.c b/tests/cwrs64-test.c deleted file mode 100644 index 548fa2230..000000000 --- a/tests/cwrs64-test.c +++ /dev/null @@ -1,50 +0,0 @@ -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - -#include <stdio.h> -#include "cwrs.h" -#include <string.h> - -#define NMAX (32) -#define MMAX (16) - -int main(int _argc,char **_argv){ - int n; - for(n=2;n<=NMAX;n+=3){ - int m; - for(m=1;m<=MMAX;m++){ - celt_uint64_t uu[MMAX+2]; - celt_uint64_t inc; - celt_uint64_t nc; - celt_uint64_t i; - nc=ncwrs_u64(n,m,uu); - /*Testing all cases just wouldn't work!*/ - inc=nc/1000; - if(inc<1)inc=1; - /*printf("%d/%d: %llu",n,m, nc);*/ - for(i=0;i<nc;i+=inc){ - celt_uint64_t u[MMAX+2]; - int y[NMAX]; - celt_uint64_t v; - int k; - memcpy(u,uu,(m+2)*sizeof(*u)); - cwrsi64(n,m,i,y,u); - /*printf("%llu of %llu:",i,nc); - for(k=0;k<n;k++)printf(" %+3i",y[k]); - printf(" ->");*/ - if(icwrs64(n,m,&v,y,u)!=i){ - fprintf(stderr,"Combination-index mismatch.\n"); - return 1; - } - if(v!=nc){ - fprintf(stderr,"Combination count mismatch.\n"); - return 2; - } - /*printf(" %6llu\n",i);*/ - } - /*printf("\n");*/ - } - } - return 0; -} -- GitLab