Commit dab2ca9d authored by Yue Chen's avatar Yue Chen

new_multisymbol: use cdf-based costs of palette flags

The modification is only applicable to palette_y_mode and
palette_uv_mode. Welcome to make changes to other palette syntax.

Change-Id: I7bf0a49c06a3986475076fe291e26f4b783b8ab9
parent 568bf107
......@@ -280,6 +280,8 @@ struct macroblock {
[PALETTE_COLORS];
int palette_uv_color_cost[PALETTE_SIZES][PALETTE_COLOR_INDEX_CONTEXTS]
[PALETTE_COLORS];
int palette_y_mode_cost[PALETTE_BLOCK_SIZES][PALETTE_Y_MODE_CONTEXTS][2];
int palette_uv_mode_cost[PALETTE_UV_MODE_CONTEXTS][2];
#if CONFIG_CFL
// The rate associated with each alpha codeword
int cfl_cost[CFL_JOINT_SIGNS][CFL_PRED_PLANES][CFL_ALPHABET_SIZE];
......
......@@ -4790,6 +4790,43 @@ static void sum_intra_stats(FRAME_COUNTS *counts, MACROBLOCKD *xd,
update_cdf(fc->uv_mode_cdf[y_mode], uv_mode, UV_INTRA_MODES);
}
#if CONFIG_NEW_MULTISYMBOL
// TODO(anybody) We can add stats accumulation here to train entropy models for
// palette modes
static void update_palette_cdf(MACROBLOCKD *xd, const MODE_INFO *mi) {
FRAME_CONTEXT *fc = xd->tile_ctx;
const MB_MODE_INFO *const mbmi = &mi->mbmi;
const MODE_INFO *const above_mi = xd->above_mi;
const MODE_INFO *const left_mi = xd->left_mi;
const BLOCK_SIZE bsize = mbmi->sb_type;
const PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
assert(bsize >= BLOCK_8X8 && bsize <= BLOCK_LARGEST);
const int block_palette_idx = bsize - BLOCK_8X8;
if (mbmi->mode == DC_PRED) {
const int n = pmi->palette_size[0];
int palette_y_mode_ctx = 0;
if (above_mi) {
palette_y_mode_ctx +=
(above_mi->mbmi.palette_mode_info.palette_size[0] > 0);
}
if (left_mi) {
palette_y_mode_ctx +=
(left_mi->mbmi.palette_mode_info.palette_size[0] > 0);
}
update_cdf(fc->palette_y_mode_cdf[block_palette_idx][palette_y_mode_ctx],
n > 0, 2);
}
if (mbmi->uv_mode == UV_DC_PRED) {
const int n = pmi->palette_size[1];
const int palette_uv_mode_ctx = (pmi->palette_size[0] > 0);
update_cdf(fc->palette_uv_mode_cdf[palette_uv_mode_ctx], n > 0, 2);
}
}
#endif
#if CONFIG_VAR_TX
static void update_txfm_count(MACROBLOCK *x, MACROBLOCKD *xd,
FRAME_COUNTS *counts, TX_SIZE tx_size, int depth,
......@@ -5105,6 +5142,10 @@ static void encode_superblock(const AV1_COMP *const cpi, ThreadData *td,
if (!dry_run) {
sum_intra_stats(td->counts, xd, mi, xd->above_mi, xd->left_mi,
frame_is_intra_only(cm), mi_row, mi_col);
#if CONFIG_NEW_MULTISYMBOL
if (av1_allow_palette(cm->allow_screen_content_tools, bsize))
update_palette_cdf(xd, mi);
#endif
}
// TODO(anybody) : remove this flag when PVQ supports pallete coding tool
......
......@@ -152,6 +152,29 @@ void av1_fill_mode_rates(AV1_COMMON *const cm, MACROBLOCK *x,
fc->palette_y_size_cdf[i], NULL);
av1_cost_tokens_from_cdf(x->palette_uv_size_cost[i],
fc->palette_uv_size_cdf[i], NULL);
for (j = 0; j < PALETTE_Y_MODE_CONTEXTS; ++j) {
#if CONFIG_NEW_MULTISYMBOL
av1_cost_tokens_from_cdf(x->palette_y_mode_cost[i][j],
fc->palette_y_mode_cdf[i][j], NULL);
#else
x->palette_y_mode_cost[i][j][0] =
av1_cost_bit(av1_default_palette_y_mode_prob[i][j], 0);
x->palette_y_mode_cost[i][j][1] =
av1_cost_bit(av1_default_palette_y_mode_prob[i][j], 1);
#endif
}
}
for (i = 0; i < PALETTE_UV_MODE_CONTEXTS; ++i) {
#if CONFIG_NEW_MULTISYMBOL
av1_cost_tokens_from_cdf(x->palette_uv_mode_cost[i],
fc->palette_uv_mode_cdf[i], NULL);
#else
x->palette_uv_mode_cost[i][0] =
av1_cost_bit(av1_default_palette_uv_mode_prob[i], 0);
x->palette_uv_mode_cost[i][1] =
av1_cost_bit(av1_default_palette_uv_mode_prob[i], 1);
#endif
}
for (i = 0; i < PALETTE_SIZES; ++i) {
......
......@@ -3366,9 +3366,7 @@ static int rd_pick_palette_intra_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
dc_mode_cost +
x->palette_y_size_cost[bsize - BLOCK_8X8][k - PALETTE_MIN_SIZE] +
write_uniform_cost(k, color_map[0]) +
av1_cost_bit(
av1_default_palette_y_mode_prob[bsize - BLOCK_8X8][palette_ctx],
1);
x->palette_y_mode_cost[bsize - BLOCK_8X8][palette_ctx][1];
palette_mode_cost += av1_palette_color_cost_y(pmi,
#if CONFIG_PALETTE_DELTA_ENCODING
color_cache, n_cache,
......@@ -4503,9 +4501,7 @@ static int64_t rd_pick_intra_sby_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
}
if (try_palette && mbmi->mode == DC_PRED) {
this_rate +=
av1_cost_bit(av1_default_palette_y_mode_prob[bsize - BLOCK_8X8]
[palette_y_mode_ctx],
0);
x->palette_y_mode_cost[bsize - BLOCK_8X8][palette_y_mode_ctx][0];
}
#if CONFIG_FILTER_INTRA
if (mbmi->mode == DC_PRED)
......@@ -6030,8 +6026,7 @@ static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x,
tokenonly_rd_stats.rate + dc_mode_cost +
x->palette_uv_size_cost[bsize - BLOCK_8X8][n - PALETTE_MIN_SIZE] +
write_uniform_cost(n, color_map[0]) +
av1_cost_bit(
av1_default_palette_uv_mode_prob[pmi->palette_size[0] > 0], 1);
x->palette_uv_mode_cost[pmi->palette_size[0] > 0][1];
this_rate += av1_palette_color_cost_uv(pmi,
#if CONFIG_PALETTE_DELTA_ENCODING
color_cache, n_cache,
......@@ -6469,8 +6464,7 @@ static int64_t rd_pick_intra_sbuv_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
this_rate += av1_cost_bit(cpi->common.fc->filter_intra_probs[1], 0);
#endif // CONFIG_FILTER_INTRA
if (try_palette && mode == UV_DC_PRED)
this_rate += av1_cost_bit(
av1_default_palette_uv_mode_prob[pmi->palette_size[0] > 0], 0);
this_rate += x->palette_uv_mode_cost[pmi->palette_size[0] > 0][0];
#if CONFIG_PVQ
od_encode_rollback(&x->daala_enc, &buf);
......@@ -10364,8 +10358,7 @@ static void pick_filter_intra_interframe(
rate2 = rate_y + intra_mode_cost[mbmi->mode] + rate_uv +
x->intra_uv_mode_cost[mbmi->mode][mbmi->uv_mode];
if (try_palette && mbmi->mode == DC_PRED)
rate2 += av1_cost_bit(
av1_default_palette_y_mode_prob[bsize - BLOCK_8X8][palette_ctx], 0);
rate2 += x->palette_y_mode_cost[bsize - BLOCK_8X8][palette_ctx][0];
if (!xd->lossless[mbmi->segment_id]) {
// super_block_yrd above includes the cost of the tx_size in the
......@@ -11147,8 +11140,7 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data,
#endif // CONFIG_CB4X4
if (try_palette && mbmi->mode == DC_PRED) {
rate2 += av1_cost_bit(
av1_default_palette_y_mode_prob[bsize - BLOCK_8X8][palette_ctx], 0);
rate2 += x->palette_y_mode_cost[bsize - BLOCK_8X8][palette_ctx][0];
}
if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(bsize)) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment