Commit 5a88172c authored by Cheng Chen's avatar Cheng Chen

Change comp_group index context and save sending comp_group

Extend context model for comp_group_idx.
Save sending comp_group_idx when masked_compound is not allowed.

Change-Id: Ia7ae53958c9e1c8fe07be4b14a425d9b8648082d
parent 2ef24ea2
......@@ -1504,16 +1504,20 @@ static const aom_cdf_prob
static const aom_cdf_prob
default_comp_group_idx_cdfs[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)] = {
{ AOM_ICDF(29491), AOM_ICDF(32768), 0 },
{ AOM_ICDF(24576), AOM_ICDF(32768), 0 },
{ AOM_ICDF(16384), AOM_ICDF(32768), 0 },
{ AOM_ICDF(8192), AOM_ICDF(32768), 0 },
{ AOM_ICDF(24576), AOM_ICDF(32768), 0 },
{ AOM_ICDF(16384), AOM_ICDF(32768), 0 },
{ AOM_ICDF(13107), AOM_ICDF(32768), 0 },
{ AOM_ICDF(13107), AOM_ICDF(32768), 0 },
};
static const aom_prob default_compound_idx_probs[COMP_INDEX_CONTEXTS] = {
192, 128, 64, 192, 128, 64
};
static const aom_prob default_comp_group_idx_probs[COMP_GROUP_IDX_CONTEXTS] = {
192, 128, 64
192, 128, 64, 192, 128, 64, 128
};
#endif // CONFIG_JNT_COMP
......
......@@ -592,7 +592,7 @@ typedef enum ATTRIBUTE_PACKED {
#if CONFIG_JNT_COMP
#define COMP_INDEX_CONTEXTS 6
#define COMP_GROUP_IDX_CONTEXTS 3
#define COMP_GROUP_IDX_CONTEXTS 7
#endif // CONFIG_JNT_COMP
#define NMV_CONTEXTS 3
......
......@@ -133,14 +133,14 @@ static INLINE int get_comp_group_idx_context(const MACROBLOCKD *xd) {
if (has_second_ref(above_mbmi))
above_ctx = above_mbmi->comp_group_idx;
else if (above_mbmi->ref_frame[0] == ALTREF_FRAME)
above_ctx = 1;
above_ctx = 3;
}
if (left_mi) {
const MB_MODE_INFO *left_mbmi = &left_mi->mbmi;
if (has_second_ref(left_mbmi))
left_ctx = left_mbmi->comp_group_idx;
else if (left_mbmi->ref_frame[0] == ALTREF_FRAME)
left_ctx = 1;
left_ctx = 3;
}
return above_ctx + left_ctx;
......
......@@ -2189,12 +2189,19 @@ static void read_inter_block_mode_info(AV1Decoder *const pbi,
mbmi->interinter_compound_type = COMPOUND_AVERAGE;
// read idx to indicate current compound inter prediction mode group
int masked_compound_used = is_any_masked_compound_used(bsize);
masked_compound_used = masked_compound_used && cm->allow_masked_compound;
if (has_second_ref(mbmi)) {
const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
mbmi->comp_group_idx = aom_read_symbol(
r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
if (xd->counts)
++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
if (masked_compound_used) {
const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
mbmi->comp_group_idx = aom_read_symbol(
r, ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
if (xd->counts)
++xd->counts->comp_group_idx[ctx_comp_group_idx][mbmi->comp_group_idx];
} else {
mbmi->comp_group_idx = 0;
}
if (mbmi->comp_group_idx == 0) {
const int comp_index_ctx = get_comp_index_context(cm, xd);
......
......@@ -1454,10 +1454,17 @@ static void pack_inter_mode_mvs(AV1_COMP *cpi, const int mi_row,
// First write idx to indicate current compound inter prediction mode group
// Group A (0): jnt_comp, compound_average
// Group B (1): interintra, compound_segment, wedge
int masked_compound_used = is_any_masked_compound_used(bsize);
masked_compound_used = masked_compound_used && cm->allow_masked_compound;
if (has_second_ref(mbmi)) {
const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
aom_write_symbol(w, mbmi->comp_group_idx,
ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
if (masked_compound_used) {
assert(mbmi->comp_group_idx == 0);
const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
aom_write_symbol(w, mbmi->comp_group_idx,
ec_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2);
}
if (mbmi->comp_group_idx == 0) {
if (mbmi->compound_idx)
......
......@@ -5661,7 +5661,6 @@ static int cost_mv_ref(const MACROBLOCK *const x, PREDICTION_MODE mode,
}
}
#if !CONFIG_JNT_COMP
static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
COMPOUND_TYPE comp_type) {
(void)bsize;
......@@ -5672,7 +5671,6 @@ static int get_interinter_compound_type_bits(BLOCK_SIZE bsize,
default: assert(0); return 0;
}
}
#endif
typedef struct {
int eobs;
......@@ -8142,12 +8140,16 @@ static int64_t handle_inter_mode(
#if CONFIG_JNT_COMP
if (is_comp_pred) {
if (mbmi->compound_idx == 0) {
mbmi->comp_group_idx = 0;
const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
int masked_compound_used = is_any_masked_compound_used(bsize);
masked_compound_used = masked_compound_used && cm->allow_masked_compound;
if (masked_compound_used) {
const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
rd_stats->rate += x->comp_group_idx_cost[comp_group_idx_ctx][0];
}
const int comp_index_ctx = get_comp_index_context(cm, xd);
rd_stats->rate += x->comp_idx_cost[comp_index_ctx][0];
rd_stats->rate += x->comp_idx_cost[comp_index_ctx][mbmi->compound_idx];
}
}
#endif // CONFIG_JNT_COMP
......@@ -8316,26 +8318,27 @@ static int64_t handle_inter_mode(
best_rd_cur = INT64_MAX;
mbmi->interinter_compound_type = cur_type;
#if CONFIG_JNT_COMP
const int ctx_comp_group_idx = get_comp_group_idx_context(xd);
if (cur_type == COMPOUND_AVERAGE) {
mbmi->comp_group_idx = 0;
rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][0];
const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
int masked_type_cost = 0;
if (masked_compound_used) {
if (cur_type == COMPOUND_AVERAGE) {
masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
const int comp_index_ctx = get_comp_index_context(cm, xd);
rs2 += x->comp_idx_cost[comp_index_ctx][1];
} else {
mbmi->comp_group_idx = 1;
rs2 = x->comp_group_idx_cost[ctx_comp_group_idx][1];
int masked_type_cost = 0;
if (masked_compound_used) {
if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
masked_type_cost +=
x->compound_type_cost[bsize]
[mbmi->interinter_compound_type - 1];
const int comp_index_ctx = get_comp_index_context(cm, xd);
masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
} else {
masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
masked_type_cost +=
x->compound_type_cost[bsize][mbmi->interinter_compound_type - 1];
}
rs2 += masked_type_cost;
} else {
const int comp_index_ctx = get_comp_index_context(cm, xd);
masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
}
rs2 = av1_cost_literal(get_interinter_compound_type_bits(
bsize, mbmi->interinter_compound_type)) +
masked_type_cost;
#else
int masked_type_cost = 0;
if (masked_compound_used) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment