Commit a8c1d85e authored by Jingning Han's avatar Jingning Han

Vectorize motion vector probability models

This commit converts the scalar motion vector probability model
into vector format for entropy coding contexted on the predicted
motion vector.

Change-Id: I09a17ed4d01efa49640c2882efbf78913b32556e
parent efc115f4
......@@ -73,7 +73,11 @@ typedef struct frame_contexts {
aom_prob comp_ref_prob[REF_CONTEXTS];
struct tx_probs tx_probs;
aom_prob skip_probs[SKIP_CONTEXTS];
#if CONFIG_REF_MV
nmv_context nmvc[NMV_CONTEXTS];
#else
nmv_context nmvc;
#endif
#if CONFIG_MISC_FIXES
struct segmentation_probs seg;
#endif
......@@ -108,7 +112,11 @@ typedef struct FRAME_COUNTS {
unsigned int comp_ref[REF_CONTEXTS][2];
struct tx_counts tx;
unsigned int skip[SKIP_CONTEXTS][2];
#if CONFIG_REF_MV
nmv_context_counts mv[NMV_CONTEXTS];
#else
nmv_context_counts mv;
#endif
#if CONFIG_MISC_FIXES
struct seg_counts seg;
#endif
......
......@@ -179,7 +179,45 @@ void av1_inc_mv(const MV *mv, nmv_context_counts *counts, const int usehp) {
void av1_adapt_mv_probs(AV1_COMMON *cm, int allow_hp) {
int i, j;
#if CONFIG_REF_MV
int idx;
for (idx = 0; idx < NMV_CONTEXTS; ++idx) {
nmv_context *fc = &cm->fc->nmvc[idx];
const nmv_context *pre_fc =
&cm->frame_contexts[cm->frame_context_idx].nmvc[idx];
const nmv_context_counts *counts = &cm->counts.mv[idx];
aom_tree_merge_probs(av1_mv_joint_tree, pre_fc->joints, counts->joints,
fc->joints);
for (i = 0; i < 2; ++i) {
nmv_component *comp = &fc->comps[i];
const nmv_component *pre_comp = &pre_fc->comps[i];
const nmv_component_counts *c = &counts->comps[i];
comp->sign = mode_mv_merge_probs(pre_comp->sign, c->sign);
aom_tree_merge_probs(av1_mv_class_tree, pre_comp->classes, c->classes,
comp->classes);
aom_tree_merge_probs(av1_mv_class0_tree, pre_comp->class0, c->class0,
comp->class0);
for (j = 0; j < MV_OFFSET_BITS; ++j)
comp->bits[j] = mode_mv_merge_probs(pre_comp->bits[j], c->bits[j]);
for (j = 0; j < CLASS0_SIZE; ++j)
aom_tree_merge_probs(av1_mv_fp_tree, pre_comp->class0_fp[j],
c->class0_fp[j], comp->class0_fp[j]);
aom_tree_merge_probs(av1_mv_fp_tree, pre_comp->fp, c->fp, comp->fp);
if (allow_hp) {
comp->class0_hp = mode_mv_merge_probs(pre_comp->class0_hp,
c->class0_hp);
comp->hp = mode_mv_merge_probs(pre_comp->hp, c->hp);
}
}
}
#else
nmv_context *fc = &cm->fc->nmvc;
const nmv_context *pre_fc = &cm->frame_contexts[cm->frame_context_idx].nmvc;
const nmv_context_counts *counts = &cm->counts.mv;
......@@ -212,6 +250,15 @@ void av1_adapt_mv_probs(AV1_COMMON *cm, int allow_hp) {
comp->hp = mode_mv_merge_probs(pre_comp->hp, c->hp);
}
}
#endif
}
void av1_init_mv_probs(AV1_COMMON *cm) { cm->fc->nmvc = default_nmv_context; }
void av1_init_mv_probs(AV1_COMMON *cm) {
#if CONFIG_REF_MV
int i;
for (i = 0; i < NMV_CONTEXTS; ++i)
cm->fc->nmvc[i] = default_nmv_context;
#else
cm->fc->nmvc = default_nmv_context;
#endif
}
......@@ -132,6 +132,8 @@ typedef uint8_t PREDICTION_MODE;
#define SKIP_CONTEXTS 3
#if CONFIG_REF_MV
#define NMV_CONTEXTS 2
#define NEWMV_MODE_CONTEXTS 7
#define ZEROMV_MODE_CONTEXTS 2
#define REFMV_MODE_CONTEXTS 9
......
......@@ -324,6 +324,19 @@ static INLINE void lower_mv_precision(MV *mv, int allow_hp) {
}
#if CONFIG_REF_MV
static INLINE int av1_nmv_ctx(const uint8_t ref_mv_count,
const CANDIDATE_MV *ref_mv_stack) {
if (ref_mv_stack[0].weight >= REF_CAT_LEVEL &&
ref_mv_count > 0) {
if (abs(ref_mv_stack[0].this_mv.as_mv.row -
ref_mv_stack[0].pred_mv.as_mv.row) < 8 &&
abs(ref_mv_stack[0].this_mv.as_mv.col -
ref_mv_stack[0].pred_mv.as_mv.col) < 8)
return 1;
}
return 0;
}
static int8_t av1_ref_frame_type(const MV_REFERENCE_FRAME *const rf) {
if (rf[1] > INTRA_FRAME)
return rf[0] + ALTREF_FRAME;
......
......@@ -418,6 +418,39 @@ void av1_accumulate_frame_counts(AV1_COMMON *cm, FRAME_COUNTS *counts,
for (i = 0; i < SKIP_CONTEXTS; i++)
for (j = 0; j < 2; j++) cm->counts.skip[i][j] += counts->skip[i][j];
#if CONFIG_REF_MV
for (m = 0; m < NMV_CONTEXTS; ++m) {
for (i = 0; i < MV_JOINTS; i++)
cm->counts.mv[m].joints[i] += counts->mv[m].joints[i];
for (k = 0; k < 2; k++) {
nmv_component_counts *comps = &cm->counts.mv[m].comps[k];
nmv_component_counts *comps_t = &counts->mv[m].comps[k];
for (i = 0; i < 2; i++) {
comps->sign[i] += comps_t->sign[i];
comps->class0_hp[i] += comps_t->class0_hp[i];
comps->hp[i] += comps_t->hp[i];
}
for (i = 0; i < MV_CLASSES; i++)
comps->classes[i] += comps_t->classes[i];
for (i = 0; i < CLASS0_SIZE; i++) {
comps->class0[i] += comps_t->class0[i];
for (j = 0; j < MV_FP_SIZE; j++)
comps->class0_fp[i][j] += comps_t->class0_fp[i][j];
}
for (i = 0; i < MV_OFFSET_BITS; i++)
for (j = 0; j < 2; j++)
comps->bits[i][j] += comps_t->bits[i][j];
for (i = 0; i < MV_FP_SIZE; i++)
comps->fp[i] += comps_t->fp[i];
}
}
#else
for (i = 0; i < MV_JOINTS; i++)
cm->counts.mv.joints[i] += counts->mv.joints[i];
......@@ -444,6 +477,7 @@ void av1_accumulate_frame_counts(AV1_COMMON *cm, FRAME_COUNTS *counts,
for (i = 0; i < MV_FP_SIZE; i++) comps->fp[i] += comps_t->fp[i];
}
#endif
for (i = 0; i < EXT_TX_SIZES; i++) {
int j;
......
......@@ -2237,8 +2237,9 @@ static int read_compressed_header(AV1Decoder *pbi, const uint8_t *data,
av1_diff_update_prob(&r, &cm->kf_y_prob[k][j][i]);
#endif
} else {
#if !CONFIG_REF_MV
nmv_context *const nmvc = &fc->nmvc;
#endif
read_inter_mode_probs(fc, &r);
if (cm->interp_filter == SWITCHABLE) read_switchable_interp_probs(fc, &r);
......@@ -2263,7 +2264,12 @@ static int read_compressed_header(AV1Decoder *pbi, const uint8_t *data,
av1_diff_update_prob(&r, &fc->partition_prob[j][i]);
#endif
#if CONFIG_REF_MV
for (i = 0; i < NMV_CONTEXTS; ++i)
read_mv_probs(&fc->nmvc[i], cm->allow_high_precision_mv, &r);
#else
read_mv_probs(nmvc, cm->allow_high_precision_mv, &r);
#endif
read_ext_tx_probs(fc, &r);
}
......@@ -2303,7 +2309,14 @@ static void debug_check_frame_counts(const AV1_COMMON *const cm) {
sizeof(cm->counts.comp_ref)));
assert(!memcmp(&cm->counts.tx, &zero_counts.tx, sizeof(cm->counts.tx)));
assert(!memcmp(cm->counts.skip, zero_counts.skip, sizeof(cm->counts.skip)));
#if CONFIG_REF_MV
assert(!memcmp(&cm->counts.mv[0], &zero_counts.mv[0],
sizeof(cm->counts.mv[0])));
assert(!memcmp(&cm->counts.mv[1], &zero_counts.mv[1],
sizeof(cm->counts.mv[0])));
#else
assert(!memcmp(&cm->counts.mv, &zero_counts.mv, sizeof(cm->counts.mv)));
#endif
assert(!memcmp(cm->counts.intra_ext_tx, zero_counts.intra_ext_tx,
sizeof(cm->counts.intra_ext_tx)));
assert(!memcmp(cm->counts.inter_ext_tx, zero_counts.inter_ext_tx,
......
......@@ -556,10 +556,21 @@ static INLINE int assign_mv(AV1_COMMON *cm, MACROBLOCKD *xd,
switch (mode) {
case NEWMV: {
FRAME_COUNTS *counts = xd->counts;
#if !CONFIG_REF_MV
nmv_context_counts *const mv_counts = counts ? &counts->mv : NULL;
#endif
for (i = 0; i < 1 + is_compound; ++i) {
#if CONFIG_REF_MV
int nmv_ctx = av1_nmv_ctx(xd->ref_mv_count[mbmi->ref_frame[i]],
xd->ref_mv_stack[mbmi->ref_frame[i]]);
nmv_context_counts *const mv_counts =
counts ? &counts->mv[nmv_ctx] : NULL;
read_mv(r, &mv[i].as_mv, &ref_mv[i].as_mv, &cm->fc->nmvc[nmv_ctx],
mv_counts, allow_hp);
#else
read_mv(r, &mv[i].as_mv, &ref_mv[i].as_mv, &cm->fc->nmvc, mv_counts,
allow_hp);
#endif
ret = ret && is_mv_valid(&mv[i].as_mv);
#if CONFIG_REF_MV
pred_mv[i].as_int = ref_mv[i].as_int;
......@@ -772,7 +783,10 @@ static void read_inter_block_mode_info(AV1Decoder *const pbi,
}
mi->mbmi.mode = b_mode;
#if CONFIG_REF_MV
mbmi->pred_mv[0].as_int = mi->bmi[3].pred_mv[0].as_int;
mbmi->pred_mv[1].as_int = mi->bmi[3].pred_mv[1].as_int;
#endif
mbmi->mv[0].as_int = mi->bmi[3].as_mv[0].as_int;
mbmi->mv[1].as_int = mi->bmi[3].as_mv[1].as_int;
} else {
......
......@@ -403,7 +403,9 @@ static void write_ref_frames(const AV1_COMMON *cm, const MACROBLOCKD *xd,
static void pack_inter_mode_mvs(AV1_COMP *cpi, const MODE_INFO *mi,
aom_writer *w) {
AV1_COMMON *const cm = &cpi->common;
#if !CONFIG_REF_MV
const nmv_context *nmvc = &cm->fc->nmvc;
#endif
const MACROBLOCK *const x = &cpi->td.mb;
const MACROBLOCKD *const xd = &x->e_mbd;
const struct segmentation *const seg = &cm->seg;
......@@ -502,19 +504,33 @@ static void pack_inter_mode_mvs(AV1_COMP *cpi, const MODE_INFO *mi,
#endif
write_inter_mode(cm, w, b_mode, mode_ctx);
if (b_mode == NEWMV) {
for (ref = 0; ref < 1 + is_compound; ++ref)
for (ref = 0; ref < 1 + is_compound; ++ref) {
#if CONFIG_REF_MV
int nmv_ctx =
av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[ref]],
mbmi_ext->ref_mv_stack[mbmi->ref_frame[ref]]);
const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
#endif
av1_encode_mv(cpi, w, &mi->bmi[j].as_mv[ref].as_mv,
&mbmi_ext->ref_mvs[mbmi->ref_frame[ref]][0].as_mv,
nmvc, allow_hp);
}
}
}
}
} else {
if (mode == NEWMV) {
for (ref = 0; ref < 1 + is_compound; ++ref)
av1_encode_mv(cpi, w, &mbmi->mv[ref].as_mv,
for (ref = 0; ref < 1 + is_compound; ++ref) {
#if CONFIG_REF_MV
int nmv_ctx =
av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[ref]],
mbmi_ext->ref_mv_stack[mbmi->ref_frame[ref]]);
const nmv_context *nmvc = &cm->fc->nmvc[nmv_ctx];
#endif
av1_encode_mv(cpi, w, &mbmi->mv[ref].as_mv,
&mbmi_ext->ref_mvs[mbmi->ref_frame[ref]][0].as_mv,
nmvc, allow_hp);
}
}
}
}
......@@ -1555,7 +1571,11 @@ static size_t write_compressed_header(AV1_COMP *cpi, uint8_t *data) {
#endif
av1_write_nmv_probs(cm, cm->allow_high_precision_mv, &header_bc,
&counts->mv);
#if CONFIG_REF_MV
counts->mv);
#else
&counts->mv);
#endif
update_ext_tx_probs(cm, &header_bc);
}
......
......@@ -165,9 +165,49 @@ static void write_mv_update(const aom_tree_index *tree,
}
void av1_write_nmv_probs(AV1_COMMON *cm, int usehp, aom_writer *w,
nmv_context_counts *const counts) {
nmv_context_counts *const nmv_counts) {
int i, j;
#if CONFIG_REF_MV
int nmv_ctx = 0;
for (nmv_ctx = 0; nmv_ctx < NMV_CONTEXTS; ++nmv_ctx) {
nmv_context *const mvc = &cm->fc->nmvc[nmv_ctx];
nmv_context_counts *const counts = &nmv_counts[nmv_ctx];
write_mv_update(av1_mv_joint_tree, mvc->joints, counts->joints,
MV_JOINTS, w);
for (i = 0; i < 2; ++i) {
nmv_component *comp = &mvc->comps[i];
nmv_component_counts *comp_counts = &counts->comps[i];
update_mv(w, comp_counts->sign, &comp->sign, MV_UPDATE_PROB);
write_mv_update(av1_mv_class_tree, comp->classes, comp_counts->classes,
MV_CLASSES, w);
write_mv_update(av1_mv_class0_tree, comp->class0, comp_counts->class0,
CLASS0_SIZE, w);
for (j = 0; j < MV_OFFSET_BITS; ++j)
update_mv(w, comp_counts->bits[j], &comp->bits[j], MV_UPDATE_PROB);
}
for (i = 0; i < 2; ++i) {
for (j = 0; j < CLASS0_SIZE; ++j)
write_mv_update(av1_mv_fp_tree, mvc->comps[i].class0_fp[j],
counts->comps[i].class0_fp[j], MV_FP_SIZE, w);
write_mv_update(av1_mv_fp_tree, mvc->comps[i].fp, counts->comps[i].fp,
MV_FP_SIZE, w);
}
if (usehp) {
for (i = 0; i < 2; ++i) {
update_mv(w, counts->comps[i].class0_hp, &mvc->comps[i].class0_hp,
MV_UPDATE_PROB);
update_mv(w, counts->comps[i].hp, &mvc->comps[i].hp, MV_UPDATE_PROB);
}
}
}
#else
nmv_context *const mvc = &cm->fc->nmvc;
nmv_context_counts *const counts = nmv_counts;
write_mv_update(av1_mv_joint_tree, mvc->joints, counts->joints, MV_JOINTS,
w);
......@@ -201,8 +241,10 @@ void av1_write_nmv_probs(AV1_COMMON *cm, int usehp, aom_writer *w,
update_mv(w, counts->comps[i].hp, &mvc->comps[i].hp, MV_UPDATE_PROB);
}
}
#endif
}
void av1_encode_mv(AV1_COMP *cpi, aom_writer *w, const MV *mv, const MV *ref,
const nmv_context *mvctx, int usehp) {
const MV diff = { mv->row - ref->row, mv->col - ref->col };
......@@ -233,13 +275,20 @@ void av1_build_nmv_cost_table(int *mvjoint, int *mvcost[2],
}
static void inc_mvs(const MB_MODE_INFO *mbmi, const MB_MODE_INFO_EXT *mbmi_ext,
const int_mv mvs[2], nmv_context_counts *counts) {
const int_mv mvs[2], nmv_context_counts *nmv_counts) {
int i;
for (i = 0; i < 1 + has_second_ref(mbmi); ++i) {
const MV *ref = &mbmi_ext->ref_mvs[mbmi->ref_frame[i]][0].as_mv;
const MV diff = { mvs[i].as_mv.row - ref->row,
mvs[i].as_mv.col - ref->col };
#if CONFIG_REF_MV
int nmv_ctx = av1_nmv_ctx(mbmi_ext->ref_mv_count[mbmi->ref_frame[i]],
mbmi_ext->ref_mv_stack[mbmi->ref_frame[i]]);
nmv_context_counts *counts = &nmv_counts[nmv_ctx];
#else
nmv_context_counts *counts = nmv_counts;
#endif
av1_inc_mv(&diff, counts, av1_use_mv_hp(ref));
}
}
......@@ -259,10 +308,21 @@ void av1_update_mv_count(ThreadData *td) {
for (idx = 0; idx < 2; idx += num_4x4_w) {
const int i = idy * 2 + idx;
if (mi->bmi[i].as_mode == NEWMV)
inc_mvs(mbmi, mbmi_ext, mi->bmi[i].as_mv, &td->counts->mv);
inc_mvs(mbmi, mbmi_ext, mi->bmi[i].as_mv,
#if CONFIG_REF_MV
td->counts->mv);
#else
&td->counts->mv);
#endif
}
}
} else {
if (mbmi->mode == NEWMV) inc_mvs(mbmi, mbmi_ext, mbmi->mv, &td->counts->mv);
if (mbmi->mode == NEWMV)
inc_mvs(mbmi, mbmi_ext, mbmi->mv,
#if CONFIG_REF_MV
td->counts->mv);
#else
&td->counts->mv);
#endif
}
}
......@@ -290,10 +290,19 @@ void av1_initialize_rd_consts(AV1_COMP *cpi) {
fill_mode_costs(cpi);
if (!frame_is_intra_only(cm)) {
av1_build_nmv_cost_table(
x->nmvjointcost,
cm->allow_high_precision_mv ? x->nmvcost_hp : x->nmvcost, &cm->fc->nmvc,
cm->allow_high_precision_mv);
#if CONFIG_REF_MV
int nmv_ctx = 0;
av1_build_nmv_cost_table(x->nmvjointcost,
cm->allow_high_precision_mv ? x->nmvcost_hp
: x->nmvcost,
&cm->fc->nmvc[nmv_ctx],
cm->allow_high_precision_mv);
#else
av1_build_nmv_cost_table(x->nmvjointcost,
cm->allow_high_precision_mv ? x->nmvcost_hp
: x->nmvcost,
&cm->fc->nmvc, cm->allow_high_precision_mv);
#endif
#if CONFIG_REF_MV
for (i = 0; i < NEWMV_MODE_CONTEXTS; ++i) {
......
......@@ -2471,13 +2471,6 @@ static int64_t handle_inter_mode(
if (mv_check_bounds(x, &cur_mv[i].as_mv)) return INT64_MAX;
mbmi->mv[i].as_int = cur_mv[i].as_int;
#if CONFIG_REF_MV
if (this_mode != NEWMV)
mbmi->pred_mv[i].as_int = mbmi->mv[i].as_int;
else
mbmi->pred_mv[i].as_int = mbmi_ext->ref_mvs[refs[i]][0].as_int;
#endif
}
#if CONFIG_REF_MV
......@@ -3680,6 +3673,15 @@ void av1_rd_pick_inter_mode_sb(AV1_COMP *cpi, TileDataEnc *tile_data,
*mbmi = best_mbmode;
x->skip |= best_skip2;
#if CONFIG_REF_MV
for (i = 0; i < 1 + has_second_ref(mbmi); ++i) {
if (mbmi->mode != NEWMV)
mbmi->pred_mv[i].as_int = mbmi->mv[i].as_int;
else
mbmi->pred_mv[i].as_int = mbmi_ext->ref_mvs[mbmi->ref_frame[i]][0].as_int;
}
#endif
for (i = 0; i < REFERENCE_MODES; ++i) {
if (best_pred_rd[i] == INT64_MAX)
best_pred_diff[i] = INT_MIN;
......@@ -4368,6 +4370,10 @@ void av1_rd_pick_inter_mode_sub8x8(AV1_COMP *cpi, TileDataEnc *tile_data,
for (i = 0; i < 4; ++i)
memcpy(&xd->mi[0]->bmi[i], &best_bmodes[i], sizeof(b_mode_info));
#if CONFIG_REF_MV
mbmi->pred_mv[0].as_int = xd->mi[0]->bmi[3].pred_mv[0].as_int;
mbmi->pred_mv[1].as_int = xd->mi[0]->bmi[3].pred_mv[1].as_int;
#endif
mbmi->mv[0].as_int = xd->mi[0]->bmi[3].as_mv[0].as_int;
mbmi->mv[1].as_int = xd->mi[0]->bmi[3].as_mv[1].as_int;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment