Commit 2cdf0d85 authored by Alex Converse's avatar Alex Converse
Browse files

Specify ANS window size at initialization

Change-Id: Ia1757d580dd230d9e743b1f8c3e87df164008684
parent 251cf364
...@@ -34,6 +34,7 @@ struct AnsDecoder { ...@@ -34,6 +34,7 @@ struct AnsDecoder {
uint32_t state; uint32_t state;
#if ANS_MAX_SYMBOLS #if ANS_MAX_SYMBOLS
int symbols_left; int symbols_left;
int window_size;
#endif #endif
#if CONFIG_ACCOUNTING #if CONFIG_ACCOUNTING
Accounting *accounting; Accounting *accounting;
...@@ -134,6 +135,9 @@ static INLINE int rans_read(struct AnsDecoder *ans, const aom_cdf_prob *tab) { ...@@ -134,6 +135,9 @@ static INLINE int rans_read(struct AnsDecoder *ans, const aom_cdf_prob *tab) {
} }
static INLINE int ans_read_init(struct AnsDecoder *const ans, static INLINE int ans_read_init(struct AnsDecoder *const ans,
#if ANS_MAX_SYMBOLS
int window_size,
#endif
const uint8_t *const buf, int offset) { const uint8_t *const buf, int offset) {
unsigned x; unsigned x;
if (offset < 1) return 1; if (offset < 1) return 1;
...@@ -176,14 +180,19 @@ static INLINE int ans_read_init(struct AnsDecoder *const ans, ...@@ -176,14 +180,19 @@ static INLINE int ans_read_init(struct AnsDecoder *const ans,
ans->state += L_BASE; ans->state += L_BASE;
if (ans->state >= L_BASE * IO_BASE) return 1; if (ans->state >= L_BASE * IO_BASE) return 1;
#if ANS_MAX_SYMBOLS #if ANS_MAX_SYMBOLS
ans->symbols_left = ANS_MAX_SYMBOLS; ans->window_size = window_size;
ans->symbols_left = window_size;
#endif #endif
return 0; return 0;
} }
#if ANS_REVERSE #if ANS_REVERSE
static INLINE int ans_read_reinit(struct AnsDecoder *const ans) { static INLINE int ans_read_reinit(struct AnsDecoder *const ans) {
return ans_read_init(ans, ans->buf + ans->buf_offset, -ans->buf_offset); return ans_read_init(ans,
#if ANS_MAX_SYMBOLS
ans->window_size,
#endif
ans->buf + ans->buf_offset, -ans->buf_offset);
} }
#endif #endif
......
...@@ -68,13 +68,21 @@ typedef struct aom_dk_reader aom_reader; ...@@ -68,13 +68,21 @@ typedef struct aom_dk_reader aom_reader;
#endif #endif
static INLINE int aom_reader_init(aom_reader *r, const uint8_t *buffer, static INLINE int aom_reader_init(aom_reader *r, const uint8_t *buffer,
size_t size, aom_decrypt_cb decrypt_cb, size_t size,
#if CONFIG_ANS && ANS_MAX_SYMBOLS
size_t window_size,
#endif
aom_decrypt_cb decrypt_cb,
void *decrypt_state) { void *decrypt_state) {
#if CONFIG_ANS #if CONFIG_ANS
(void)decrypt_cb; (void)decrypt_cb;
(void)decrypt_state; (void)decrypt_state;
if (size > INT_MAX) return 1; if (size > INT_MAX) return 1;
return ans_read_init(r, buffer, (int)size); return ans_read_init(r,
#if ANS_MAX_SYMBOLS
(int)window_size,
#endif
buffer, (int)size);
#elif CONFIG_DAALA_EC #elif CONFIG_DAALA_EC
(void)decrypt_cb; (void)decrypt_cb;
(void)decrypt_state; (void)decrypt_state;
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#include "aom/internal/aom_codec_internal.h" #include "aom/internal/aom_codec_internal.h"
void aom_buf_ans_alloc(struct BufAnsCoder *c, void aom_buf_ans_alloc(struct BufAnsCoder *c,
struct aom_internal_error_info *error, int size_hint) { struct aom_internal_error_info *error, int size) {
c->error = error; c->error = error;
c->size = size_hint; c->size = size;
AOM_CHECK_MEM_ERROR(error, c->buf, aom_malloc(c->size * sizeof(*c->buf))); AOM_CHECK_MEM_ERROR(error, c->buf, aom_malloc(c->size * sizeof(*c->buf)));
// Initialize to overfull to trigger the assert in write. // Initialize to overfull to trigger the assert in write.
c->offset = c->size + 1; c->offset = c->size + 1;
...@@ -30,6 +30,7 @@ void aom_buf_ans_free(struct BufAnsCoder *c) { ...@@ -30,6 +30,7 @@ void aom_buf_ans_free(struct BufAnsCoder *c) {
c->size = 0; c->size = 0;
} }
#if !ANS_MAX_SYMBOLS
void aom_buf_ans_grow(struct BufAnsCoder *c) { void aom_buf_ans_grow(struct BufAnsCoder *c) {
struct buffered_ans_symbol *new_buf = NULL; struct buffered_ans_symbol *new_buf = NULL;
int new_size = c->size * 2; int new_size = c->size * 2;
...@@ -40,6 +41,7 @@ void aom_buf_ans_grow(struct BufAnsCoder *c) { ...@@ -40,6 +41,7 @@ void aom_buf_ans_grow(struct BufAnsCoder *c) {
c->buf = new_buf; c->buf = new_buf;
c->size = new_size; c->size = new_size;
} }
#endif
void aom_buf_ans_flush(struct BufAnsCoder *const c) { void aom_buf_ans_flush(struct BufAnsCoder *const c) {
int offset; int offset;
......
...@@ -43,14 +43,24 @@ struct BufAnsCoder { ...@@ -43,14 +43,24 @@ struct BufAnsCoder {
int size; int size;
int offset; int offset;
int output_bytes; int output_bytes;
#if ANS_MAX_SYMBOLS
int window_size;
#endif
}; };
// Allocate a buffered ANS coder to store size symbols.
// When ANS_MAX_SYMBOLS is turned on, the size is the fixed size of each ANS
// partition.
// When ANS_MAX_SYMBOLS is turned off, size is merely an initial hint and the
// buffer will grow on demand
void aom_buf_ans_alloc(struct BufAnsCoder *c, void aom_buf_ans_alloc(struct BufAnsCoder *c,
struct aom_internal_error_info *error, int size_hint); struct aom_internal_error_info *error, int hint);
void aom_buf_ans_free(struct BufAnsCoder *c); void aom_buf_ans_free(struct BufAnsCoder *c);
#if !ANS_MAX_SYMBOLS
void aom_buf_ans_grow(struct BufAnsCoder *c); void aom_buf_ans_grow(struct BufAnsCoder *c);
#endif
void aom_buf_ans_flush(struct BufAnsCoder *const c); void aom_buf_ans_flush(struct BufAnsCoder *const c);
...@@ -64,30 +74,34 @@ static INLINE void buf_ans_write_init(struct BufAnsCoder *const c, ...@@ -64,30 +74,34 @@ static INLINE void buf_ans_write_init(struct BufAnsCoder *const c,
static INLINE void buf_uabs_write(struct BufAnsCoder *const c, uint8_t val, static INLINE void buf_uabs_write(struct BufAnsCoder *const c, uint8_t val,
AnsP8 prob) { AnsP8 prob) {
assert(c->offset <= c->size); assert(c->offset <= c->size);
#if !ANS_MAX_SYMBOLS
if (c->offset == c->size) { if (c->offset == c->size) {
aom_buf_ans_grow(c); aom_buf_ans_grow(c);
} }
#endif
c->buf[c->offset].method = ANS_METHOD_UABS; c->buf[c->offset].method = ANS_METHOD_UABS;
c->buf[c->offset].val_start = val; c->buf[c->offset].val_start = val;
c->buf[c->offset].prob = prob; c->buf[c->offset].prob = prob;
++c->offset; ++c->offset;
#if ANS_MAX_SYMBOLS #if ANS_MAX_SYMBOLS
if (c->offset == ANS_MAX_SYMBOLS) aom_buf_ans_flush(c); if (c->offset == c->size) aom_buf_ans_flush(c);
#endif #endif
} }
static INLINE void buf_rans_write(struct BufAnsCoder *const c, static INLINE void buf_rans_write(struct BufAnsCoder *const c,
const struct rans_sym *const sym) { const struct rans_sym *const sym) {
assert(c->offset <= c->size); assert(c->offset <= c->size);
#if !ANS_MAX_SYMBOLS
if (c->offset == c->size) { if (c->offset == c->size) {
aom_buf_ans_grow(c); aom_buf_ans_grow(c);
} }
#endif
c->buf[c->offset].method = ANS_METHOD_RANS; c->buf[c->offset].method = ANS_METHOD_RANS;
c->buf[c->offset].val_start = sym->cum_prob; c->buf[c->offset].val_start = sym->cum_prob;
c->buf[c->offset].prob = sym->prob; c->buf[c->offset].prob = sym->prob;
++c->offset; ++c->offset;
#if ANS_MAX_SYMBOLS #if ANS_MAX_SYMBOLS
if (c->offset == ANS_MAX_SYMBOLS) aom_buf_ans_flush(c); if (c->offset == c->size) aom_buf_ans_flush(c);
#endif #endif
} }
......
...@@ -2183,7 +2183,11 @@ static void setup_bool_decoder(const uint8_t *data, const uint8_t *data_end, ...@@ -2183,7 +2183,11 @@ static void setup_bool_decoder(const uint8_t *data, const uint8_t *data_end,
aom_internal_error(error_info, AOM_CODEC_CORRUPT_FRAME, aom_internal_error(error_info, AOM_CODEC_CORRUPT_FRAME,
"Truncated packet or corrupt tile length"); "Truncated packet or corrupt tile length");
if (aom_reader_init(r, data, read_size, decrypt_cb, decrypt_state)) if (aom_reader_init(r, data, read_size,
#if CONFIG_ANS && ANS_MAX_SYMBOLS
ANS_MAX_SYMBOLS,
#endif
decrypt_cb, decrypt_state))
aom_internal_error(error_info, AOM_CODEC_MEM_ERROR, aom_internal_error(error_info, AOM_CODEC_MEM_ERROR,
"Failed to allocate bool decoder %d", 1); "Failed to allocate bool decoder %d", 1);
} }
...@@ -4149,8 +4153,11 @@ static int read_compressed_header(AV1Decoder *pbi, const uint8_t *data, ...@@ -4149,8 +4153,11 @@ static int read_compressed_header(AV1Decoder *pbi, const uint8_t *data,
int j; int j;
#endif #endif
if (aom_reader_init(&r, data, partition_size, pbi->decrypt_cb, if (aom_reader_init(&r, data, partition_size,
pbi->decrypt_state)) #if CONFIG_ANS && ANS_MAX_SYMBOLS
ANS_MAX_SYMBOLS,
#endif
pbi->decrypt_cb, pbi->decrypt_state))
aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR, aom_internal_error(&cm->error, AOM_CODEC_MEM_ERROR,
"Failed to allocate bool decoder 0"); "Failed to allocate bool decoder 0");
......
...@@ -791,7 +791,8 @@ void av1_alloc_compressor_data(AV1_COMP *cpi) { ...@@ -791,7 +791,8 @@ void av1_alloc_compressor_data(AV1_COMP *cpi) {
CHECK_MEM_ERROR(cm, cpi->tile_tok[0][0], CHECK_MEM_ERROR(cm, cpi->tile_tok[0][0],
aom_calloc(tokens, sizeof(*cpi->tile_tok[0][0]))); aom_calloc(tokens, sizeof(*cpi->tile_tok[0][0])));
#if CONFIG_ANS #if CONFIG_ANS
aom_buf_ans_alloc(&cpi->buf_ans, &cm->error, tokens); aom_buf_ans_alloc(&cpi->buf_ans, &cm->error,
ANS_MAX_SYMBOLS ? ANS_MAX_SYMBOLS : tokens);
#endif // CONFIG_ANS #endif // CONFIG_ANS
} }
......
...@@ -35,7 +35,11 @@ TEST(AV1, TestAccounting) { ...@@ -35,7 +35,11 @@ TEST(AV1, TestAccounting) {
} }
aom_stop_encode(&bw); aom_stop_encode(&bw);
aom_reader br; aom_reader br;
aom_reader_init(&br, bw_buffer, bw.pos, NULL, NULL); aom_reader_init(&br, bw_buffer, bw.pos,
#if CONFIG_ANS && ANS_MAX_SYMBOLS
1 << 16,
#endif
NULL, NULL);
Accounting accounting; Accounting accounting;
aom_accounting_init(&accounting); aom_accounting_init(&accounting);
......
...@@ -26,6 +26,9 @@ namespace { ...@@ -26,6 +26,9 @@ namespace {
typedef std::vector<std::pair<uint8_t, bool> > PvVec; typedef std::vector<std::pair<uint8_t, bool> > PvVec;
const int kPrintStats = 0; const int kPrintStats = 0;
// When ANS is windowed use the window size, otherwise use a small value to
// exercise the buffer growth code
const int kBufAnsSize = ANS_MAX_SYMBOLS ? ANS_MAX_SYMBOLS : 100;
PvVec abs_encode_build_vals(int iters) { PvVec abs_encode_build_vals(int iters) {
PvVec ret; PvVec ret;
...@@ -49,7 +52,7 @@ PvVec abs_encode_build_vals(int iters) { ...@@ -49,7 +52,7 @@ PvVec abs_encode_build_vals(int iters) {
bool check_uabs(const PvVec &pv_vec, uint8_t *buf) { bool check_uabs(const PvVec &pv_vec, uint8_t *buf) {
BufAnsCoder a; BufAnsCoder a;
aom_buf_ans_alloc(&a, NULL, 100); aom_buf_ans_alloc(&a, NULL, kBufAnsSize);
buf_ans_write_init(&a, buf); buf_ans_write_init(&a, buf);
std::clock_t start = std::clock(); std::clock_t start = std::clock();
...@@ -62,7 +65,12 @@ bool check_uabs(const PvVec &pv_vec, uint8_t *buf) { ...@@ -62,7 +65,12 @@ bool check_uabs(const PvVec &pv_vec, uint8_t *buf) {
aom_buf_ans_free(&a); aom_buf_ans_free(&a);
bool okay = true; bool okay = true;
AnsDecoder d; AnsDecoder d;
if (ans_read_init(&d, buf, offset)) return false; if (ans_read_init(&d,
#if ANS_MAX_SYMBOLS
kBufAnsSize,
#endif
buf, offset))
return false;
start = std::clock(); start = std::clock();
for (PvVec::const_iterator it = pv_vec.begin(); it != pv_vec.end(); ++it) { for (PvVec::const_iterator it = pv_vec.begin(); it != pv_vec.end(); ++it) {
okay = okay && (uabs_read(&d, 256 - it->first) != 0) == it->second; okay = okay && (uabs_read(&d, 256 - it->first) != 0) == it->second;
...@@ -115,7 +123,7 @@ void rans_build_dec_tab(const struct rans_sym sym_tab[], ...@@ -115,7 +123,7 @@ void rans_build_dec_tab(const struct rans_sym sym_tab[],
bool check_rans(const std::vector<int> &sym_vec, const rans_sym *const tab, bool check_rans(const std::vector<int> &sym_vec, const rans_sym *const tab,
uint8_t *buf) { uint8_t *buf) {
BufAnsCoder a; BufAnsCoder a;
aom_buf_ans_alloc(&a, NULL, 100); aom_buf_ans_alloc(&a, NULL, kBufAnsSize);
buf_ans_write_init(&a, buf); buf_ans_write_init(&a, buf);
aom_cdf_prob dec_tab[kRansSymbols]; aom_cdf_prob dec_tab[kRansSymbols];
rans_build_dec_tab(tab, dec_tab); rans_build_dec_tab(tab, dec_tab);
...@@ -131,7 +139,12 @@ bool check_rans(const std::vector<int> &sym_vec, const rans_sym *const tab, ...@@ -131,7 +139,12 @@ bool check_rans(const std::vector<int> &sym_vec, const rans_sym *const tab,
aom_buf_ans_free(&a); aom_buf_ans_free(&a);
bool okay = true; bool okay = true;
AnsDecoder d; AnsDecoder d;
if (ans_read_init(&d, buf, offset)) return false; if (ans_read_init(&d,
#if ANS_MAX_SYMBOLS
kBufAnsSize,
#endif
buf, offset))
return false;
start = std::clock(); start = std::clock();
for (std::vector<int>::const_iterator it = sym_vec.begin(); for (std::vector<int>::const_iterator it = sym_vec.begin();
it != sym_vec.end(); ++it) { it != sym_vec.end(); ++it) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment