Commit 2affb3b0 authored by Angie Chiang's avatar Angie Chiang

Add gen_txb_cache() and it's related functions

This function pre-generate counts/magnitudes of each level map
such that we don't have to re-calculate the counts/magnitudes
while doing the optimization.

Change-Id: Ifdfc89522cf2f2b9f3734d451324081f42b47cb0
parent 488f921c
......@@ -32,19 +32,52 @@ static int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
/* clang-format on*/
};
static INLINE int get_level_count(const tran_low_t *tcoeffs, int stride,
int row, int col, int level,
int (*nb_offset)[2], int nb_num) {
int count = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
count += abs_coeff > level;
}
return count;
}
static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int stride,
int row, int col, int (*nb_offset)[2], int nb_num) {
mag[0] = 0;
mag[1] = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
if (nb_offset[idx][0] >= 0 && nb_offset[idx][1] >= 0) {
if (abs_coeff > mag[0]) {
mag[0] = abs_coeff;
mag[1] = 1;
} else if (abs_coeff == mag[0]) {
++mag[1];
}
}
}
}
static INLINE int get_level_count_mag(int *mag, const tran_low_t *tcoeffs,
int c, // raster order
int bwl, int level, int (*nb_offset)[2],
int nb_num) {
const int row = c >> bwl;
const int col = c - (row << bwl);
const int stride = 1 << bwl;
int stride, int row, int col, int level,
int (*nb_offset)[2], int nb_num) {
int count = 0;
*mag = 0;
for (int idx = 0; idx < nb_num; ++idx) {
int ref_row = row + nb_offset[idx][0];
int ref_col = col + nb_offset[idx][1];
int pos = (ref_row << bwl) + ref_col;
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
......@@ -55,34 +88,42 @@ static INLINE int get_level_count_mag(int *mag, const tran_low_t *tcoeffs,
return count;
}
static INLINE int get_base_ctx(const tran_low_t *tcoeffs,
int c, // raster order
const int bwl, const int level) {
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = level - 1;
int mag;
int count = get_level_count_mag(&mag, tcoeffs, c, bwl, level_minus_1,
base_ref_offset, BASE_CONTEXT_POSITION_NUM);
int ctx = (count + 1) >> 1;
mag = mag > level;
static INLINE int get_base_ctx_from_count_mag(int row, int col, int count,
int mag, int level) {
const int ctx = (count + 1) >> 1;
const int sig_mag = mag > level;
int ctx_idx = -1;
if (row == 0 && col == 0) {
ctx_idx = (ctx << 1) + mag;
ctx_idx = (ctx << 1) + sig_mag;
assert(ctx_idx < 8);
} else if (row == 0) {
ctx_idx = 8 + (ctx << 1) + mag;
ctx_idx = 8 + (ctx << 1) + sig_mag;
assert(ctx_idx < 18);
} else if (col == 0) {
ctx_idx = 8 + 10 + (ctx << 1) + mag;
ctx_idx = 8 + 10 + (ctx << 1) + sig_mag;
assert(ctx_idx < 28);
} else {
ctx_idx = 8 + 10 + 10 + (ctx << 1) + mag;
ctx_idx = 8 + 10 + 10 + (ctx << 1) + sig_mag;
assert(ctx_idx < COEFF_BASE_CONTEXTS);
}
return ctx_idx;
}
static INLINE int get_base_ctx(const tran_low_t *tcoeffs,
int c, // raster order
const int bwl, const int level) {
const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = level - 1;
int mag;
int count =
get_level_count_mag(&mag, tcoeffs, stride, row, col, level_minus_1,
base_ref_offset, BASE_CONTEXT_POSITION_NUM);
int ctx_idx = get_base_ctx_from_count_mag(row, col, count, mag, level);
return ctx_idx;
}
#define BR_CONTEXT_POSITION_NUM 8 // Base range coefficient context
static int br_ref_offset[BR_CONTEXT_POSITION_NUM][2] = {
/* clang-format off*/
......@@ -95,17 +136,13 @@ static int br_level_map[9] = {
0, 0, 1, 1, 2, 2, 3, 3, 3,
};
static INLINE int get_level_ctx(const tran_low_t *tcoeffs,
const int c, // raster order
const int bwl) {
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = NUM_BASE_LEVELS;
int mag;
int count = get_level_count_mag(&mag, tcoeffs, c, bwl, level_minus_1,
br_ref_offset, BR_CONTEXT_POSITION_NUM);
#define BR_MAG_OFFSET 1
// TODO(angiebird): optimize this function by using a table to map from
// count/mag to ctx
static INLINE int get_br_ctx_from_count_mag(int row, int col, int count,
int mag) {
int offset = 0;
if (mag <= 1)
if (mag <= BR_MAG_OFFSET)
offset = 0;
else if (mag <= 3)
offset = 1;
......@@ -130,6 +167,21 @@ static INLINE int get_level_ctx(const tran_low_t *tcoeffs,
return 8 + ctx;
}
static INLINE int get_br_ctx(const tran_low_t *tcoeffs,
const int c, // raster order
const int bwl) {
const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = NUM_BASE_LEVELS;
int mag;
const int count =
get_level_count_mag(&mag, tcoeffs, stride, row, col, level_minus_1,
br_ref_offset, BR_CONTEXT_POSITION_NUM);
const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
return ctx;
}
#define SIG_REF_OFFSET_NUM 11
static int sig_ref_offset[SIG_REF_OFFSET_NUM][2] = {
{ -2, -1 }, { -2, 0 }, { -2, 1 }, { -1, -2 }, { -1, -1 }, { -1, 0 },
......
......@@ -169,7 +169,7 @@ uint8_t av1_read_coeffs_txb(const AV1_COMMON *const cm, MACROBLOCKD *xd,
sign = aom_read_bit(r, ACCT_STR);
}
ctx = get_level_ctx(tcoeffs, scan[c], bwl);
ctx = get_br_ctx(tcoeffs, scan[c], bwl);
if (cm->fc->coeff_lps[tx_size][plane_type][ctx] == 0) exit(0);
......
......@@ -159,7 +159,7 @@ void av1_write_coeffs_txb(const AV1_COMMON *const cm, MACROBLOCKD *xd,
}
// level is above 1.
ctx = get_level_ctx(tcoeff, scan[c], bwl);
ctx = get_br_ctx(tcoeff, scan[c], bwl);
for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
if (level == (idx + 1 + NUM_BASE_LEVELS)) {
aom_write(w, 1, cm->fc->coeff_lps[tx_size][plane_type][ctx]);
......@@ -357,7 +357,7 @@ int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane,
int idx;
int ctx;
ctx = get_level_ctx(qcoeff, scan[c], bwl);
ctx = get_br_ctx(qcoeff, scan[c], bwl);
for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
if (level == (idx + 1 + NUM_BASE_LEVELS)) {
......@@ -399,6 +399,115 @@ int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane,
return cost;
}
static INLINE int has_base(tran_low_t qc, int base_idx) {
const int level = base_idx + 1;
return abs(qc) >= level;
}
static void gen_base_count_mag_arr(int (*base_count_arr)[MAX_TX_SQUARE],
int (*base_mag_arr)[2],
const tran_low_t *qcoeff, int stride,
int eob, const int16_t *scan) {
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
if (!has_base(qcoeff[coeff_idx], 0)) continue;
const int row = coeff_idx / stride;
const int col = coeff_idx % stride;
int *mag = base_mag_arr[coeff_idx];
get_mag(mag, qcoeff, stride, row, col, base_ref_offset,
BASE_CONTEXT_POSITION_NUM);
for (int i = 0; i < NUM_BASE_LEVELS; ++i) {
if (!has_base(qcoeff[coeff_idx], i)) continue;
int *count = base_count_arr[i] + coeff_idx;
*count = get_level_count(qcoeff, stride, row, col, i, base_ref_offset,
BASE_CONTEXT_POSITION_NUM);
}
}
}
static void gen_nz_count_arr(int(*nz_count_arr), const tran_low_t *qcoeff,
int stride, int eob,
const SCAN_ORDER *scan_order) {
const int16_t *scan = scan_order->scan;
const int16_t *iscan = scan_order->iscan;
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
const int row = coeff_idx / stride;
const int col = coeff_idx % stride;
nz_count_arr[coeff_idx] = get_nz_count(qcoeff, stride, row, col, iscan);
}
}
static void gen_nz_ctx_arr(int (*nz_ctx_arr)[2], int(*nz_count_arr),
const tran_low_t *qcoeff, int bwl, int eob,
const SCAN_ORDER *scan_order) {
const int16_t *scan = scan_order->scan;
const int16_t *iscan = scan_order->iscan;
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
const int count = nz_count_arr[coeff_idx];
nz_ctx_arr[coeff_idx][0] =
get_nz_map_ctx_from_count(count, qcoeff, coeff_idx, bwl, iscan);
}
}
static void gen_base_ctx_arr(int (*base_ctx_arr)[MAX_TX_SQUARE][2],
int (*base_count_arr)[MAX_TX_SQUARE],
int (*base_mag_arr)[2], const tran_low_t *qcoeff,
int stride, int eob, const int16_t *scan) {
(void)qcoeff;
for (int i = 0; i < NUM_BASE_LEVELS; ++i) {
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
if (!has_base(qcoeff[coeff_idx], i)) continue;
const int row = coeff_idx / stride;
const int col = coeff_idx % stride;
const int count = base_count_arr[i][coeff_idx];
const int *mag = base_mag_arr[coeff_idx];
const int level = i + 1;
base_ctx_arr[i][coeff_idx][0] =
get_base_ctx_from_count_mag(row, col, count, mag[0], level);
}
}
}
static INLINE int has_br(tran_low_t qc) {
return abs(qc) >= 1 + NUM_BASE_LEVELS;
}
static void gen_br_count_mag_arr(int *br_count_arr, int (*br_mag_arr)[2],
const tran_low_t *qcoeff, int stride, int eob,
const int16_t *scan) {
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
if (!has_br(qcoeff[coeff_idx])) continue;
const int row = coeff_idx / stride;
const int col = coeff_idx % stride;
int *count = br_count_arr + coeff_idx;
int *mag = br_mag_arr[coeff_idx];
*count = get_level_count(qcoeff, stride, row, col, NUM_BASE_LEVELS,
br_ref_offset, BR_CONTEXT_POSITION_NUM);
get_mag(mag, qcoeff, stride, row, col, br_ref_offset,
BR_CONTEXT_POSITION_NUM);
}
}
static void gen_br_ctx_arr(int (*br_ctx_arr)[2], const int *br_count_arr,
int (*br_mag_arr)[2], const tran_low_t *qcoeff,
int stride, int eob, const int16_t *scan) {
(void)qcoeff;
for (int c = 0; c < eob; ++c) {
const int coeff_idx = scan[c]; // raster order
if (!has_br(qcoeff[coeff_idx])) continue;
const int row = coeff_idx / stride;
const int col = coeff_idx % stride;
const int count = br_count_arr[coeff_idx];
const int *mag = br_mag_arr[coeff_idx];
br_ctx_arr[coeff_idx][0] =
get_br_ctx_from_count_mag(row, col, count, mag[0]);
}
}
static INLINE int get_sign_bit_cost(tran_low_t qc, int coeff_idx,
const aom_prob *dc_sign_prob,
int dc_sign_ctx) {
......@@ -428,6 +537,27 @@ static INLINE int get_golomb_cost(int abs_qc) {
}
}
// TODO(angiebird): add static once this function is called
void gen_txb_cache(TxbCache *txb_cache, TxbInfo *txb_info) {
const int16_t *scan = txb_info->scan_order->scan;
gen_nz_count_arr(txb_cache->nz_count_arr, txb_info->qcoeff, txb_info->stride,
txb_info->eob, txb_info->scan_order);
gen_nz_ctx_arr(txb_cache->nz_ctx_arr, txb_cache->nz_count_arr,
txb_info->qcoeff, txb_info->bwl, txb_info->eob,
txb_info->scan_order);
gen_base_count_mag_arr(txb_cache->base_count_arr, txb_cache->base_mag_arr,
txb_info->qcoeff, txb_info->stride, txb_info->eob,
scan);
gen_base_ctx_arr(txb_cache->base_ctx_arr, txb_cache->base_count_arr,
txb_cache->base_mag_arr, txb_info->qcoeff, txb_info->stride,
txb_info->eob, scan);
gen_br_count_mag_arr(txb_cache->br_count_arr, txb_cache->br_mag_arr,
txb_info->qcoeff, txb_info->stride, txb_info->eob, scan);
gen_br_ctx_arr(txb_cache->br_ctx_arr, txb_cache->br_count_arr,
txb_cache->br_mag_arr, txb_info->qcoeff, txb_info->stride,
txb_info->eob, scan);
}
static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info,
TxbProbs *txb_probs) {
const TXB_CTX *txb_ctx = txb_info->txb_ctx;
......@@ -456,7 +586,7 @@ static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info,
}
if (abs_qc > NUM_BASE_LEVELS) {
int ctx = get_level_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl);
int ctx = get_br_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl);
cost += get_br_cost(abs_qc, ctx, txb_probs->coeff_lps);
cost += get_golomb_cost(abs_qc);
}
......@@ -639,7 +769,7 @@ void av1_update_and_record_txb_context(int plane, int block, int blk_row,
}
// level is above 1.
ctx = get_level_ctx(tcoeff, scan[c], bwl);
ctx = get_br_ctx(tcoeff, scan[c], bwl);
for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
if (level == (idx + 1 + NUM_BASE_LEVELS)) {
++td->counts->coeff_lps[tx_size][plane_type][ctx][1];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment