Commit b80020d4 authored by Deb Mukherjee's avatar Deb Mukherjee
Browse files

Refactoring motion search libs

The core motion estimation fucntions all return sad now consistently.
The only exception is vp9_full_pixel_diamond(), however the core diamond
and refining search routines called from vp9_full_pixel_diamond() also
return SAD. If variance of pred error + mv cost is desired it must be
calculated explicitly outside these functions. For very fast encoding,
hopefully this will eliminate some redundant computations.

Also suggests reimplementing FAST_HEX with the vp9_pattern_search
framework. It is not exactly the same as the existing FAST_HEX, but
performance is slightly better and speed is very similar. Enables
removing a lot of duplicate code.

Change-Id: I152736393438c25bdf7e96b37cbb8ce330f4f94a
parent be647f7b
......@@ -415,6 +415,8 @@ static void first_pass_motion_search(VP9_COMP *cpi, MACROBLOCK *x,
x->sadperbit16, &num00, &v_fn_ptr,
x->nmvjointcost,
x->mvcost, ref_mv);
if (tmp_err < INT_MAX)
tmp_err = vp9_get_mvpred_var(x, &tmp_mv, ref_mv, &v_fn_ptr, 1);
if (tmp_err < INT_MAX - new_mv_mode_penalty)
tmp_err += new_mv_mode_penalty;
......@@ -439,6 +441,8 @@ static void first_pass_motion_search(VP9_COMP *cpi, MACROBLOCK *x,
&num00, &v_fn_ptr,
x->nmvjointcost,
x->mvcost, ref_mv);
if (tmp_err < INT_MAX)
tmp_err = vp9_get_mvpred_var(x, &tmp_mv, ref_mv, &v_fn_ptr, 1);
if (tmp_err < INT_MAX - new_mv_mode_penalty)
tmp_err += new_mv_mode_penalty;
......
......@@ -728,16 +728,52 @@ static int vp9_pattern_search(const MACROBLOCK *x,
best_mv->col;
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
if (bestsad == INT_MAX)
return INT_MAX;
return bestsad;
}
return vfp->vf(what, what_stride, this_offset, in_what_stride,
(unsigned int *)&bestsad) +
use_mvcost ? mv_err_cost(&this_mv, center_mv,
x->nmvjointcost, x->mvcost, x->errorperbit)
: 0;
int vp9_get_mvpred_var(const MACROBLOCK *x,
MV *best_mv,
const MV *center_mv,
const vp9_variance_fn_ptr_t *vfp,
int use_mvcost) {
unsigned int bestsad;
MV this_mv;
const MACROBLOCKD *const xd = &x->e_mbd;
const uint8_t *what = x->plane[0].src.buf;
const int what_stride = x->plane[0].src.stride;
const int in_what_stride = xd->plane[0].pre[0].stride;
const uint8_t *base_offset = xd->plane[0].pre[0].buf;
const uint8_t *this_offset = base_offset + (best_mv->row * in_what_stride) +
best_mv->col;
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
return vfp->vf(what, what_stride, this_offset, in_what_stride, &bestsad) +
(use_mvcost ? mv_err_cost(&this_mv, center_mv, x->nmvjointcost,
x->mvcost, x->errorperbit) : 0);
}
int vp9_get_mvpred_av_var(const MACROBLOCK *x,
MV *best_mv,
const MV *center_mv,
const uint8_t *second_pred,
const vp9_variance_fn_ptr_t *vfp,
int use_mvcost) {
unsigned int bestsad;
MV this_mv;
const MACROBLOCKD *const xd = &x->e_mbd;
const uint8_t *what = x->plane[0].src.buf;
const int what_stride = x->plane[0].src.stride;
const int in_what_stride = xd->plane[0].pre[0].stride;
const uint8_t *base_offset = xd->plane[0].pre[0].buf;
const uint8_t *this_offset = base_offset + (best_mv->row * in_what_stride) +
best_mv->col;
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
return vfp->svaf(this_offset, in_what_stride, 0, 0, what, what_stride,
&bestsad, second_pred) +
(use_mvcost ? mv_err_cost(&this_mv, center_mv, x->nmvjointcost,
x->mvcost, x->errorperbit) : 0);
}
int vp9_hex_search(const MACROBLOCK *x,
MV *ref_mv,
......@@ -855,182 +891,18 @@ int vp9_square_search(const MACROBLOCK *x,
square_num_candidates, square_candidates);
};
// Number of candidates in first hex search
#define FIRST_HEX_CANDIDATES 6
// Index of previous hex search's best match
#define PRE_BEST_CANDIDATE 6
// Number of candidates in following hex search
#define NEXT_HEX_CANDIDATES 3
// Number of candidates in refining search
#define REFINE_CANDIDATES 4
int vp9_fast_hex_search(const MACROBLOCK *x,
MV *ref_mv,
int search_param,
int sad_per_bit,
int do_init_search, // must be zero for fast_hex
const vp9_variance_fn_ptr_t *vfp,
int use_mvcost,
const MV *center_mv,
MV *best_mv) {
const MACROBLOCKD* const xd = &x->e_mbd;
static const MV hex[FIRST_HEX_CANDIDATES] = {
{ -1, -2}, {1, -2}, {2, 0}, {1, 2}, { -1, 2}, { -2, 0}
};
static const MV next_chkpts[PRE_BEST_CANDIDATE][NEXT_HEX_CANDIDATES] = {
{{ -2, 0}, { -1, -2}, {1, -2}},
{{ -1, -2}, {1, -2}, {2, 0}},
{{1, -2}, {2, 0}, {1, 2}},
{{2, 0}, {1, 2}, { -1, 2}},
{{1, 2}, { -1, 2}, { -2, 0}},
{{ -1, 2}, { -2, 0}, { -1, -2}}
};
static const MV neighbors[REFINE_CANDIDATES] = {
{0, -1}, { -1, 0}, {1, 0}, {0, 1}
};
int i, j;
const uint8_t *what = x->plane[0].src.buf;
const int what_stride = x->plane[0].src.stride;
const int in_what_stride = xd->plane[0].pre[0].stride;
int br, bc;
MV this_mv;
unsigned int bestsad = 0x7fffffff;
unsigned int thissad;
const uint8_t *base_offset;
const uint8_t *this_offset;
int k = -1;
int best_site = -1;
const int max_hex_search = 512;
const int max_dia_search = 32;
const int *mvjsadcost = x->nmvjointsadcost;
int *mvsadcost[2] = {x->nmvsadcost[0], x->nmvsadcost[1]};
const MV fcenter_mv = {center_mv->row >> 3, center_mv->col >> 3};
// Adjust ref_mv to make sure it is within MV range
clamp_mv(ref_mv, x->mv_col_min, x->mv_col_max, x->mv_row_min, x->mv_row_max);
br = ref_mv->row;
bc = ref_mv->col;
// Check the start point
base_offset = xd->plane[0].pre[0].buf;
this_offset = base_offset + (br * in_what_stride) + bc;
this_mv.row = br;
this_mv.col = bc;
bestsad = vfp->sdf(what, what_stride, this_offset, in_what_stride, 0x7fffffff)
+ mvsad_err_cost(&this_mv, &fcenter_mv, mvjsadcost, mvsadcost,
sad_per_bit);
// Initial 6-point hex search
if (check_bounds(x, br, bc, 2)) {
for (i = 0; i < FIRST_HEX_CANDIDATES; i++) {
this_mv.row = br + hex[i].row;
this_mv.col = bc + hex[i].col;
this_offset = base_offset + (this_mv.row * in_what_stride) + this_mv.col;
thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
bestsad);
CHECK_BETTER
}
} else {
for (i = 0; i < FIRST_HEX_CANDIDATES; i++) {
this_mv.row = br + hex[i].row;
this_mv.col = bc + hex[i].col;
if (!is_mv_in(x, &this_mv))
continue;
this_offset = base_offset + (this_mv.row * in_what_stride) + this_mv.col;
thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
bestsad);
CHECK_BETTER
}
}
// Continue hex search if we find a better match in first round
if (best_site != -1) {
br += hex[best_site].row;
bc += hex[best_site].col;
k = best_site;
// Allow search covering maximum MV range
for (j = 1; j < max_hex_search; j++) {
best_site = -1;
if (check_bounds(x, br, bc, 2)) {
for (i = 0; i < 3; i++) {
this_mv.row = br + next_chkpts[k][i].row;
this_mv.col = bc + next_chkpts[k][i].col;
this_offset = base_offset + (this_mv.row * in_what_stride) +
this_mv.col;
thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
bestsad);
CHECK_BETTER
}
} else {
for (i = 0; i < 3; i++) {
this_mv.row = br + next_chkpts[k][i].row;
this_mv.col = bc + next_chkpts[k][i].col;
if (!is_mv_in(x, &this_mv))
continue;
this_offset = base_offset + (this_mv.row * in_what_stride) +
this_mv.col;
thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
bestsad);
CHECK_BETTER
}
}
if (best_site == -1) {
break;
} else {
br += next_chkpts[k][best_site].row;
bc += next_chkpts[k][best_site].col;
k += 5 + best_site;
if (k >= 12) k -= 12;
else if (k >= 6) k -= 6;
}
}
}
// Check 4 1-away neighbors
for (j = 0; j < max_dia_search; j++) {
best_site = -1;
if (check_bounds(x, br, bc, 1)) {
for (i = 0; i < REFINE_CANDIDATES; i++) {
this_mv.row = br + neighbors[i].row;
this_mv.col = bc + neighbors[i].col;
this_offset = base_offset + (this_mv.row * in_what_stride) +
this_mv.col;
thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
bestsad);
CHECK_BETTER
}
} else {
for (i = 0; i < REFINE_CANDIDATES; i++) {
this_mv.row = br + neighbors[i].row;
this_mv.col = bc + neighbors[i].col;
if (!is_mv_in(x, &this_mv))
continue;
this_offset = base_offset + (this_mv.row * in_what_stride) +
this_mv.col;
thissad = vfp->sdf(what, what_stride, this_offset, in_what_stride,
bestsad);
CHECK_BETTER
}
}
if (best_site == -1) {
break;
} else {
br += neighbors[best_site].row;
bc += neighbors[best_site].col;
}
}
best_mv->row = br;
best_mv->col = bc;
return bestsad;
return vp9_hex_search(x, ref_mv, MAX(MAX_MVSEARCH_STEPS - 2, search_param),
sad_per_bit, do_init_search, vfp, use_mvcost,
center_mv, best_mv);
}
#undef CHECK_BETTER
......@@ -1045,8 +917,6 @@ int vp9_full_range_search_c(const MACROBLOCK *x, MV *ref_mv, MV *best_mv,
const int what_stride = x->plane[0].src.stride;
const uint8_t *in_what;
const int in_what_stride = xd->plane[0].pre[0].stride;
const uint8_t *best_address;
MV this_mv;
unsigned int bestsad = INT_MAX;
......@@ -1076,7 +946,6 @@ int vp9_full_range_search_c(const MACROBLOCK *x, MV *ref_mv, MV *best_mv,
// Work out the start point for the search
in_what = xd->plane[0].pre[0].buf + ref_row * in_what_stride + ref_col;
best_address = in_what;
// Check the starting position
bestsad = fn_ptr->sdf(what, what_stride, in_what, in_what_stride, 0x7fffffff)
......@@ -1134,20 +1003,9 @@ int vp9_full_range_search_c(const MACROBLOCK *x, MV *ref_mv, MV *best_mv,
}
}
}
best_mv->row += best_tr;
best_mv->col += best_tc;
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
if (bestsad == INT_MAX)
return INT_MAX;
return fn_ptr->vf(what, what_stride, best_address, in_what_stride,
(unsigned int *)(&thissad)) +
mv_err_cost(&this_mv, center_mv,
mvjcost, mvcost, x->errorperbit);
return bestsad;
}
int vp9_diamond_search_sad_c(const MACROBLOCK *x,
......@@ -1272,17 +1130,7 @@ int vp9_diamond_search_sad_c(const MACROBLOCK *x,
(*num00)++;
}
}
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
if (bestsad == INT_MAX)
return INT_MAX;
return fn_ptr->vf(what, what_stride, best_address, in_what_stride,
(unsigned int *)(&thissad)) +
mv_err_cost(&this_mv, center_mv,
mvjcost, mvcost, x->errorperbit);
return bestsad;
}
int vp9_diamond_search_sadx4(const MACROBLOCK *x,
......@@ -1448,24 +1296,14 @@ int vp9_diamond_search_sadx4(const MACROBLOCK *x,
(*num00)++;
}
}
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
if (bestsad == INT_MAX)
return INT_MAX;
return fn_ptr->vf(what, what_stride, best_address, in_what_stride,
(unsigned int *)(&thissad)) +
mv_err_cost(&this_mv, center_mv,
mvjcost, mvcost, x->errorperbit);
return bestsad;
}
/* do_refine: If last step (1-away) of n-step search doesn't pick the center
point as the best match, we will do a final 1-away diamond
refining search */
int vp9_full_pixel_diamond(VP9_COMP *cpi, MACROBLOCK *x,
int vp9_full_pixel_diamond(const VP9_COMP *cpi, MACROBLOCK *x,
MV *mvp_full, int step_param,
int sadpb, int further_steps, int do_refine,
const vp9_variance_fn_ptr_t *fn_ptr,
......@@ -1476,6 +1314,8 @@ int vp9_full_pixel_diamond(VP9_COMP *cpi, MACROBLOCK *x,
step_param, sadpb, &n,
fn_ptr, x->nmvjointcost,
x->mvcost, ref_mv);
if (bestsme < INT_MAX)
bestsme = vp9_get_mvpred_var(x, &temp_mv, ref_mv, fn_ptr, 1);
*dst_mv = temp_mv;
// If there won't be more n-step search, check to see if refining search is
......@@ -1493,6 +1333,8 @@ int vp9_full_pixel_diamond(VP9_COMP *cpi, MACROBLOCK *x,
step_param + n, sadpb, &num00,
fn_ptr, x->nmvjointcost, x->mvcost,
ref_mv);
if (thissme < INT_MAX)
thissme = vp9_get_mvpred_var(x, &temp_mv, ref_mv, fn_ptr, 1);
// check to see if refining search is needed.
if (num00 > further_steps - n)
......@@ -1512,12 +1354,13 @@ int vp9_full_pixel_diamond(VP9_COMP *cpi, MACROBLOCK *x,
thissme = cpi->refining_search_sad(x, &best_mv, sadpb, search_range,
fn_ptr, x->nmvjointcost, x->mvcost,
ref_mv);
if (thissme < INT_MAX)
thissme = vp9_get_mvpred_var(x, &best_mv, ref_mv, fn_ptr, 1);
if (thissme < bestsme) {
bestsme = thissme;
*dst_mv = best_mv;
}
}
return bestsme;
}
......@@ -1562,15 +1405,7 @@ int vp9_full_search_sad_c(const MACROBLOCK *x, const MV *ref_mv,
}
}
}
if (best_sad < INT_MAX) {
unsigned int unused;
const MV mv = {best_mv->row * 8, best_mv->col * 8};
return fn_ptr->vf(what, what_stride, best_address, in_what_stride, &unused)
+ mv_err_cost(&mv, center_mv, mvjcost, mvcost, x->errorperbit);
} else {
return INT_MAX;
}
return best_sad;
}
int vp9_full_search_sadx3(const MACROBLOCK *x, const MV *ref_mv,
......@@ -1665,17 +1500,7 @@ int vp9_full_search_sadx3(const MACROBLOCK *x, const MV *ref_mv,
c++;
}
}
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
if (bestsad < INT_MAX)
return fn_ptr->vf(what, what_stride, bestaddress, in_what_stride,
(unsigned int *)(&thissad)) +
mv_err_cost(&this_mv, center_mv,
mvjcost, mvcost, x->errorperbit);
else
return INT_MAX;
return bestsad;
}
int vp9_full_search_sadx8(const MACROBLOCK *x, const MV *ref_mv,
......@@ -1798,17 +1623,7 @@ int vp9_full_search_sadx8(const MACROBLOCK *x, const MV *ref_mv,
c++;
}
}
this_mv.row = best_mv->row * 8;
this_mv.col = best_mv->col * 8;
if (bestsad < INT_MAX)
return fn_ptr->vf(what, what_stride, bestaddress, in_what_stride,
(unsigned int *)(&thissad)) +
mv_err_cost(&this_mv, center_mv,
mvjcost, mvcost, x->errorperbit);
else
return INT_MAX;
return bestsad;
}
int vp9_refining_search_sad_c(const MACROBLOCK *x,
......@@ -1866,16 +1681,7 @@ int vp9_refining_search_sad_c(const MACROBLOCK *x,
best_address = &in_what[ref_mv->row * in_what_stride + ref_mv->col];
}
}
if (bestsad < INT_MAX) {
unsigned int unused;
const MV mv = {ref_mv->row * 8, ref_mv->col * 8};
return fn_ptr->vf(what, what_stride, best_address, in_what_stride,
&unused) +
mv_err_cost(&mv, center_mv, mvjcost, mvcost, x->errorperbit);
} else {
return INT_MAX;
}
return bestsad;
}
int vp9_refining_search_sadx4(const MACROBLOCK *x,
......@@ -1977,17 +1783,7 @@ int vp9_refining_search_sadx4(const MACROBLOCK *x,
neighbors[best_site].col;
}
}
this_mv.row = ref_mv->row * 8;
this_mv.col = ref_mv->col * 8;
if (bestsad < INT_MAX)
return fn_ptr->vf(what, what_stride, best_address, in_what_stride,
(unsigned int *)(&thissad)) +
mv_err_cost(&this_mv, center_mv,
mvjcost, mvcost, x->errorperbit);
else
return INT_MAX;
return bestsad;
}
// This function is called when we do joint motion search in comp_inter_inter
......@@ -2055,18 +1851,5 @@ int vp9_refining_search_8p_c(const MACROBLOCK *x,
best_address = &in_what[ref_mv->row * in_what_stride + ref_mv->col];
}
}
this_mv.row = ref_mv->row * 8;
this_mv.col = ref_mv->col * 8;
if (bestsad < INT_MAX) {
// FIXME(rbultje, yunqing): add full-pixel averaging variance functions
// so we don't have to use the subpixel with xoff=0,yoff=0 here.
return fn_ptr->svaf(best_address, in_what_stride, 0, 0, what, what_stride,
(unsigned int *)(&thissad), second_pred) +
mv_err_cost(&this_mv, center_mv,
mvjcost, mvcost, x->errorperbit);
} else {
return INT_MAX;
}
return bestsad;
}
......@@ -35,6 +35,19 @@ extern "C" {
void vp9_set_mv_search_range(MACROBLOCK *x, const MV *mv);
int vp9_mv_bit_cost(const MV *mv, const MV *ref,
const int *mvjcost, int *mvcost[2], int weight);
// Utility to compute variance + MV rate cost for a given MV
int vp9_get_mvpred_var(const MACROBLOCK *x,
MV *best_mv,
const MV *center_mv,
const vp9_variance_fn_ptr_t *vfp,
int use_mvcost);
int vp9_get_mvpred_av_var(const MACROBLOCK *x,
MV *best_mv,
const MV *center_mv,
const uint8_t *second_pred,
const vp9_variance_fn_ptr_t *vfp,
int use_mvcost);
void vp9_init_dsmotion_compensation(MACROBLOCK *x, int stride);
void vp9_init3smotion_compensation(MACROBLOCK *x, int stride);
......@@ -42,47 +55,27 @@ struct VP9_COMP;
int vp9_init_search_range(struct VP9_COMP *cpi, int size);
// Runs sequence of diamond searches in smaller steps for RD
int vp9_full_pixel_diamond(struct VP9_COMP *cpi, MACROBLOCK *x,
int vp9_full_pixel_diamond(const struct VP9_COMP *cpi, MACROBLOCK *x,
MV *mvp_full, int step_param,
int sadpb, int further_steps, int do_refine,
const vp9_variance_fn_ptr_t *fn_ptr,
const MV *ref_mv, MV *dst_mv);
int vp9_hex_search(const MACROBLOCK *x,
MV *ref_mv,
int search_param,
int error_per_bit,
int do_init_search,
const vp9_variance_fn_ptr_t *vf,
int use_mvcost,
const MV *center_mv,
MV *best_mv);
int vp9_bigdia_search(const MACROBLOCK *x,
MV *ref_mv,
int search_param,
int error_per_bit,
int do_init_search,
const vp9_variance_fn_ptr_t *vf,
int use_mvcost,
const MV *center_mv,
MV *best_mv);
int vp9_square_search(const MACROBLOCK *x,
MV *ref_mv,
int search_param,
int error_per_bit,
int do_init_search,
const vp9_variance_fn_ptr_t *vf,
int use_mvcost,
const MV *center_mv,
MV *best_mv);
int vp9_fast_hex_search(const MACROBLOCK *x,
MV *ref_mv,
int search_param,
int sad_per_bit,
const vp9_variance_fn_ptr_t *vfp,
int use_mvcost,
const MV *center_mv,
MV *best_mv);
typedef int (integer_mv_pattern_search_fn) (
const MACROBLOCK *x,
MV *ref_mv,
int search_param,
int error_per_bit,
int do_init_search,
const vp9_variance_fn_ptr_t *vf,
int use_mvcost,
const MV *center_mv,
MV *best_mv);
integer_mv_pattern_search_fn vp9_hex_search;
integer_mv_pattern_search_fn vp9_bigdia_search;
integer_mv_pattern_search_fn vp9_square_search;
integer_mv_pattern_search_fn vp9_fast_hex_search;
typedef int (fractional_mv_step_fp) (
const MACROBLOCK *x,
......
......@@ -34,7 +34,7 @@ static int full_pixel_motion_search(VP9_COMP *cpi, MACROBLOCK *x,
MB_MODE_INFO *mbmi = &xd->mi_8x8[0]->mbmi;
struct buf_2d backup_yv12[MAX_MB_PLANE] = {{0}};
int bestsme = INT_MAX;
int further_steps, step_param;
int step_param;
int sadpb = x->sadperbit16;
MV mvp_full;
int ref = mbmi->ref_frame[0];
......@@ -67,7 +67,6 @@ static int full_pixel_motion_search(VP9_COMP *cpi, MACROBLOCK *x,
// TODO(jingning) exploiting adaptive motion search control in non-RD
// mode decision too.
step_param = 6;
further_steps = (cpi->sf.max_step_search_steps - 1) - step_param;
for (i = LAST_FRAME; i <= LAST_FRAME && cpi->common.show_frame; ++i) {
if ((x->pred_mv_sad[ref] >> 3) > x->pred_mv_sad[i]) {
......@@ -88,22 +87,28 @@ static int full_pixel_motion_search(VP9_COMP *cpi, MACROBLOCK *x,
mvp_full.row >>= 3;
if (cpi->sf.search_method == FAST_HEX) {
bestsme = vp9_fast_hex_search(x, &mvp_full, step_param, sadpb,
// NOTE: this returns SAD
bestsme = vp9_fast_hex_search(x, &mvp_full, step_param, sadpb, 0,
&cpi->fn_ptr[bsize], 1,
&ref_mv.as_mv, &tmp_mv->as_mv);
} else if (cpi->sf.search_method == HEX) {
// NOTE: this returns SAD
bestsme = vp9_hex_search(x, &mvp_full, step_param, sadpb, 1,
&cpi->fn_ptr[bsize], 1,
&ref_mv.as_mv, &tmp_mv->as_mv);
} else if (cpi->sf.search_method == SQUARE) {