ans.h 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*
 *  Copyright (c) 2015 The WebM project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

Alex Converse's avatar
Alex Converse committed
11
12
#ifndef AOM_DSP_ANS_H_
#define AOM_DSP_ANS_H_
13
14
15
// An implementation of Asymmetric Numeral Systems
// http://arxiv.org/abs/1311.2540v2

Alex Converse's avatar
Alex Converse committed
16
#include <assert.h>
Yaowu Xu's avatar
Yaowu Xu committed
17
18
#include "./aom_config.h"
#include "aom/aom_integer.h"
19
20
#include "aom_dsp/prob.h"
#include "aom_ports/mem_ops.h"
21
22
23

#define ANS_DIVIDE_BY_MULTIPLY 1
#if ANS_DIVIDE_BY_MULTIPLY
Alex Converse's avatar
Alex Converse committed
24
#include "aom_dsp/divide.h"
25
#define ANS_DIVREM(quotient, remainder, dividend, divisor) \
clang-format's avatar
clang-format committed
26
27
28
  do {                                                     \
    quotient = fastdiv(dividend, divisor);                 \
    remainder = dividend - quotient * divisor;             \
29
  } while (0)
clang-format's avatar
clang-format committed
30
#define ANS_DIV(dividend, divisor) fastdiv(dividend, divisor)
31
32
#else
#define ANS_DIVREM(quotient, remainder, dividend, divisor) \
clang-format's avatar
clang-format committed
33
34
35
  do {                                                     \
    quotient = dividend / divisor;                         \
    remainder = dividend % divisor;                        \
36
  } while (0)
clang-format's avatar
clang-format committed
37
#define ANS_DIV(dividend, divisor) ((dividend) / (divisor))
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#endif

#ifdef __cplusplus
extern "C" {
#endif  // __cplusplus

struct AnsCoder {
  uint8_t *buf;
  int buf_offset;
  uint32_t state;
};

struct AnsDecoder {
  const uint8_t *buf;
  int buf_offset;
  uint32_t state;
};

typedef uint8_t AnsP8;
#define ans_p8_precision 256u
#define ans_p8_shift 8
59
60
61
62
63
64
typedef uint16_t AnsP10;
#define ans_p10_precision 1024u

#define rans_precision ans_p10_precision

#define l_base (ans_p10_precision * 4)  // l_base % precision must be 0
65
66
67
68
69
70
71
72
73
74
75
#define io_base 256
// Range I = { l_base, l_base + 1, ..., l_base * io_base - 1 }

static INLINE void ans_write_init(struct AnsCoder *const ans,
                                  uint8_t *const buf) {
  ans->buf = buf;
  ans->buf_offset = 0;
  ans->state = l_base;
}

static INLINE int ans_write_end(struct AnsCoder *const ans) {
Alex Converse's avatar
Alex Converse committed
76
77
78
79
80
  uint32_t state;
  assert(ans->state >= l_base);
  assert(ans->state < l_base * io_base);
  state = ans->state - l_base;
  if (state < (1 << 6)) {
81
    ans->buf[ans->buf_offset] = (0x00 << 6) + state;
Alex Converse's avatar
Alex Converse committed
82
83
    return ans->buf_offset + 1;
  } else if (state < (1 << 14)) {
84
    mem_put_le16(ans->buf + ans->buf_offset, (0x01 << 14) + state);
Alex Converse's avatar
Alex Converse committed
85
    return ans->buf_offset + 2;
86
87
  } else if (state < (1 << 22)) {
    mem_put_le24(ans->buf + ans->buf_offset, (0x02 << 22) + state);
Alex Converse's avatar
Alex Converse committed
88
    return ans->buf_offset + 3;
89
90
91
  } else {
    assert(0 && "State is too large to be serialized");
    return ans->buf_offset;
Alex Converse's avatar
Alex Converse committed
92
  }
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
}

// rABS with descending spread
// p or p0 takes the place of l_s from the paper
// ans_p8_precision is m
static INLINE void rabs_desc_write(struct AnsCoder *ans, int val, AnsP8 p0) {
  const AnsP8 p = ans_p8_precision - p0;
  const unsigned l_s = val ? p : p0;
  unsigned quot, rem;
  if (ans->state >= l_base / ans_p8_precision * io_base * l_s) {
    ans->buf[ans->buf_offset++] = ans->state % io_base;
    ans->state /= io_base;
  }
  ANS_DIVREM(quot, rem, ans->state, l_s);
  ans->state = quot * ans_p8_precision + rem + (val ? 0 : p);
}

#define ANS_IMPL1 0
#define UNPREDICTABLE(x) x
static INLINE int rabs_desc_read(struct AnsDecoder *ans, AnsP8 p0) {
  int val;
#if ANS_IMPL1
  unsigned l_s;
#else
  unsigned quot, rem, x, xn;
#endif
  const AnsP8 p = ans_p8_precision - p0;
  if (ans->state < l_base) {
    ans->state = ans->state * io_base + ans->buf[--ans->buf_offset];
  }
#if ANS_IMPL1
  val = ans->state % ans_p8_precision < p;
  l_s = val ? p : p0;
  ans->state = (ans->state / ans_p8_precision) * l_s +
               ans->state % ans_p8_precision - (!val * p);
#else
  x = ans->state;
  quot = x / ans_p8_precision;
  rem = x % ans_p8_precision;
  xn = quot * p;
  val = rem < p;
  if (UNPREDICTABLE(val)) {
    ans->state = xn + rem;
  } else {
    // ans->state = quot * p0 + rem - p;
    ans->state = x - xn - p;
  }
#endif
  return val;
}

// rABS with ascending spread
// p or p0 takes the place of l_s from the paper
// ans_p8_precision is m
static INLINE void rabs_asc_write(struct AnsCoder *ans, int val, AnsP8 p0) {
  const AnsP8 p = ans_p8_precision - p0;
  const unsigned l_s = val ? p : p0;
  unsigned quot, rem;
  if (ans->state >= l_base / ans_p8_precision * io_base * l_s) {
    ans->buf[ans->buf_offset++] = ans->state % io_base;
    ans->state /= io_base;
  }
  ANS_DIVREM(quot, rem, ans->state, l_s);
  ans->state = quot * ans_p8_precision + rem + (val ? p0 : 0);
}

static INLINE int rabs_asc_read(struct AnsDecoder *ans, AnsP8 p0) {
  int val;
#if ANS_IMPL1
  unsigned l_s;
#else
  unsigned quot, rem, x, xn;
#endif
  const AnsP8 p = ans_p8_precision - p0;
  if (ans->state < l_base) {
    ans->state = ans->state * io_base + ans->buf[--ans->buf_offset];
  }
#if ANS_IMPL1
  val = ans->state % ans_p8_precision < p;
  l_s = val ? p : p0;
  ans->state = (ans->state / ans_p8_precision) * l_s +
               ans->state % ans_p8_precision - (!val * p);
#else
  x = ans->state;
  quot = x / ans_p8_precision;
  rem = x % ans_p8_precision;
  xn = quot * p;
  val = rem >= p0;
  if (UNPREDICTABLE(val)) {
    ans->state = xn + rem - p0;
  } else {
    // ans->state = quot * p0 + rem - p0;
    ans->state = x - xn;
  }
#endif
  return val;
}

#define rabs_read rabs_desc_read
#define rabs_write rabs_desc_write

// uABS with normalization
static INLINE void uabs_write(struct AnsCoder *ans, int val, AnsP8 p0) {
  AnsP8 p = ans_p8_precision - p0;
  const unsigned l_s = val ? p : p0;
198
  while (ans->state >= l_base / ans_p8_precision * io_base * l_s) {
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    ans->buf[ans->buf_offset++] = ans->state % io_base;
    ans->state /= io_base;
  }
  if (!val)
    ans->state = ANS_DIV(ans->state * ans_p8_precision, p0);
  else
    ans->state = ANS_DIV((ans->state + 1) * ans_p8_precision + p - 1, p) - 1;
}

static INLINE int uabs_read(struct AnsDecoder *ans, AnsP8 p0) {
  AnsP8 p = ans_p8_precision - p0;
  int s;
  // unsigned int xp1;
  unsigned xp, sp;
  unsigned state = ans->state;
214
  while (state < l_base && ans->buf_offset > 0) {
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    state = state * io_base + ans->buf[--ans->buf_offset];
  }
  sp = state * p;
  // xp1 = (sp + p) / ans_p8_precision;
  xp = sp / ans_p8_precision;
  // s = xp1 - xp;
  s = (sp & 0xFF) >= p0;
  if (UNPREDICTABLE(s))
    ans->state = xp;
  else
    ans->state = state - xp;
  return s;
}

static INLINE int uabs_read_bit(struct AnsDecoder *ans) {
  int s;
  unsigned state = ans->state;
232
  while (state < l_base && ans->buf_offset > 0) {
233
234
235
236
237
238
239
    state = state * io_base + ans->buf[--ans->buf_offset];
  }
  s = (int)(state & 1);
  ans->state = state >> 1;
  return s;
}

240
241
242
243
244
245
static INLINE int uabs_read_literal(struct AnsDecoder *ans, int bits) {
  int literal = 0, bit;
  assert(bits < 31);

  // TODO(aconverse): Investigate ways to read/write literals faster,
  // e.g. 8-bit chunks.
clang-format's avatar
clang-format committed
246
  for (bit = bits - 1; bit >= 0; bit--) literal |= uabs_read_bit(ans) << bit;
247
248
249
250

  return literal;
}

Alex Converse's avatar
Alex Converse committed
251
252
// TODO(aconverse): Replace trees with tokensets.
static INLINE int uabs_read_tree(struct AnsDecoder *ans,
Yaowu Xu's avatar
Yaowu Xu committed
253
                                 const aom_tree_index *tree,
Alex Converse's avatar
Alex Converse committed
254
                                 const AnsP8 *probs) {
Yaowu Xu's avatar
Yaowu Xu committed
255
  aom_tree_index i = 0;
Alex Converse's avatar
Alex Converse committed
256

clang-format's avatar
clang-format committed
257
  while ((i = tree[i + uabs_read(ans, probs[i >> 1])]) > 0) continue;
Alex Converse's avatar
Alex Converse committed
258
259
260
261

  return -i;
}

262
struct rans_sym {
263
264
  AnsP10 prob;
  AnsP10 cum_prob;  // not-inclusive
265
266
267
268
};

struct rans_dec_sym {
  uint8_t val;
269
270
  AnsP10 prob;
  AnsP10 cum_prob;  // not-inclusive
271
272
};

273
// This is now just a boring cdf. It starts with an explicit zero.
Alex Converse's avatar
Alex Converse committed
274
275
// TODO(aconverse): Remove starting zero.
typedef uint16_t rans_dec_lut[16];
276

277
static INLINE void rans_build_cdf_from_pdf(const AnsP10 token_probs[],
Alex Converse's avatar
Alex Converse committed
278
                                           rans_dec_lut cdf_tab) {
279
  int i;
Alex Converse's avatar
Alex Converse committed
280
  cdf_tab[0] = 0;
281
  for (i = 1; cdf_tab[i - 1] < rans_precision; ++i) {
Alex Converse's avatar
Alex Converse committed
282
    cdf_tab[i] = cdf_tab[i - 1] + token_probs[i - 1];
283
  }
284
  assert(cdf_tab[i - 1] == rans_precision);
285
286
}

clang-format's avatar
clang-format committed
287
static INLINE int ans_find_largest(const AnsP10 *const pdf_tab, int num_syms) {
288
289
290
291
292
293
294
295
296
297
298
299
300
  int largest_idx = -1;
  int largest_p = -1;
  int i;
  for (i = 0; i < num_syms; ++i) {
    int p = pdf_tab[i];
    if (p > largest_p) {
      largest_p = p;
      largest_idx = i;
    }
  }
  return largest_idx;
}

301
302
303
304
static INLINE void rans_merge_prob8_pdf(AnsP10 *const out_pdf,
                                        const AnsP8 node_prob,
                                        const AnsP10 *const src_pdf,
                                        int in_syms) {
305
  int i;
306
  int adjustment = rans_precision;
307
308
309
310
311
  const int round_fact = ans_p8_precision >> 1;
  const AnsP8 p1 = ans_p8_precision - node_prob;
  const int out_syms = in_syms + 1;
  assert(src_pdf != out_pdf);

312
313
  out_pdf[0] = node_prob << (10 - 8);
  adjustment -= out_pdf[0];
314
315
  for (i = 0; i < in_syms; ++i) {
    int p = (p1 * src_pdf[i] + round_fact) >> ans_p8_shift;
Yaowu Xu's avatar
Yaowu Xu committed
316
317
    p = AOMMIN(p, (int)rans_precision - in_syms);
    p = AOMMAX(p, 1);
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    out_pdf[i + 1] = p;
    adjustment -= p;
  }

  // Adjust probabilities so they sum to the total probability
  if (adjustment > 0) {
    i = ans_find_largest(out_pdf, out_syms);
    out_pdf[i] += adjustment;
  } else {
    while (adjustment < 0) {
      i = ans_find_largest(out_pdf, out_syms);
      --out_pdf[i];
      assert(out_pdf[i] > 0);
      adjustment++;
    }
  }
334
335
336
337
}

// rANS with normalization
// sym->prob takes the place of l_s from the paper
338
// ans_p10_precision is m
339
340
static INLINE void rans_write(struct AnsCoder *ans,
                              const struct rans_sym *const sym) {
341
342
  const AnsP10 p = sym->prob;
  while (ans->state >= l_base / rans_precision * io_base * p) {
343
344
345
346
    ans->buf[ans->buf_offset++] = ans->state % io_base;
    ans->state /= io_base;
  }
  ans->state =
347
      (ans->state / p) * rans_precision + ans->state % p + sym->cum_prob;
348
349
}

Alex Converse's avatar
Alex Converse committed
350
static INLINE void fetch_sym(struct rans_dec_sym *out, const rans_dec_lut cdf,
351
                             AnsP10 rem) {
Alex Converse's avatar
Alex Converse committed
352
353
354
355
356
357
358
  int i = 0;
  // TODO(skal): if critical, could be a binary search.
  // Or, better, an O(1) alias-table.
  while (rem >= cdf[i]) {
    ++i;
  }
  out->val = i - 1;
359
360
  out->prob = (AnsP10)(cdf[i] - cdf[i - 1]);
  out->cum_prob = (AnsP10)cdf[i - 1];
Alex Converse's avatar
Alex Converse committed
361
362
}

clang-format's avatar
clang-format committed
363
static INLINE int rans_read(struct AnsDecoder *ans, const rans_dec_lut tab) {
364
365
  unsigned rem;
  unsigned quo;
Alex Converse's avatar
Alex Converse committed
366
  struct rans_dec_sym sym;
367
  while (ans->state < l_base && ans->buf_offset > 0) {
368
369
    ans->state = ans->state * io_base + ans->buf[--ans->buf_offset];
  }
370
371
  quo = ans->state / rans_precision;
  rem = ans->state % rans_precision;
Alex Converse's avatar
Alex Converse committed
372
373
374
  fetch_sym(&sym, tab, rem);
  ans->state = quo * sym.prob + rem - sym.cum_prob;
  return sym.val;
375
376
377
}

static INLINE int ans_read_init(struct AnsDecoder *const ans,
clang-format's avatar
clang-format committed
378
                                const uint8_t *const buf, int offset) {
Alex Converse's avatar
Alex Converse committed
379
380
  unsigned x;
  if (offset < 1) return 1;
381
  ans->buf = buf;
Alex Converse's avatar
Alex Converse committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
  x = buf[offset - 1] >> 6;
  if (x == 0) {
    ans->buf_offset = offset - 1;
    ans->state = buf[offset - 1] & 0x3F;
  } else if (x == 1) {
    if (offset < 2) return 1;
    ans->buf_offset = offset - 2;
    ans->state = mem_get_le16(buf + offset - 2) & 0x3FFF;
  } else if (x == 2) {
    if (offset < 3) return 1;
    ans->buf_offset = offset - 3;
    ans->state = mem_get_le24(buf + offset - 3) & 0x3FFFFF;
  } else {
    // x == 3 implies this byte is a superframe marker
    return 1;
  }
  ans->state += l_base;
clang-format's avatar
clang-format committed
399
  if (ans->state >= l_base * io_base) return 1;
400
401
402
403
404
405
  return 0;
}

static INLINE int ans_read_end(struct AnsDecoder *const ans) {
  return ans->state == l_base;
}
406
407
408
409

static INLINE int ans_reader_has_error(const struct AnsDecoder *const ans) {
  return ans->state < l_base && ans->buf_offset == 0;
}
410
411
412
413
#undef ANS_DIVREM
#ifdef __cplusplus
}  // extern "C"
#endif  // __cplusplus
Alex Converse's avatar
Alex Converse committed
414
#endif  // AOM_DSP_ANS_H_