Commit b9f757c7 authored by Sarah Parker's avatar Sarah Parker

Refactor compound_segment to try different segmentation masks

Change-Id: I7c992c9aae895aebcfb5c147cb179cf665c0ac10
parent ab44fd14
......@@ -41,6 +41,18 @@ extern "C" {
#if CONFIG_EXT_INTER
// Should we try rectangular interintra predictions?
#define USE_RECT_INTERINTRA 1
#if CONFIG_COMPOUND_SEGMENT
#define MAX_SEG_MASK_BITS 3
// SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS
typedef enum {
UNIFORM_45 = 0,
UNIFORM_45_INV,
UNIFORM_55,
UNIFORM_55_INV,
SEG_MASK_TYPES,
} SEG_MASK_TYPE;
#endif // CONFIG_COMPOUND_SEGMENT
#endif
typedef enum {
......@@ -256,8 +268,8 @@ typedef struct {
int wedge_index;
int wedge_sign;
#if CONFIG_COMPOUND_SEGMENT
int which;
DECLARE_ALIGNED(16, uint8_t, seg_mask[2][2 * MAX_SB_SQUARE]);
SEG_MASK_TYPE mask_type;
DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
#endif // CONFIG_COMPOUND_SEGMENT
} INTERINTER_COMPOUND_DATA;
#endif // CONFIG_EXT_INTER
......
......@@ -251,45 +251,80 @@ const uint8_t *av1_get_soft_mask(int wedge_index, int wedge_sign,
return mask;
}
// get a mask according to the compound type
// TODO(sarahparker) this needs to be extended for other experiments and
// is currently only intended for ext_inter alone
const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
BLOCK_SIZE sb_type, int invert) {
#if CONFIG_COMPOUND_SEGMENT
static uint8_t *invert_mask(uint8_t *mask_inv_buffer, const uint8_t *const mask,
int h, int w, int stride) {
int i, j;
for (i = 0; i < h; ++i)
for (j = 0; j < w; ++j) {
mask_inv_buffer[i * stride + j] =
AOM_BLEND_A64_MAX_ALPHA - mask[i * stride + j];
}
return mask_inv_buffer;
}
#endif // CONFIG_COMPOUND_SEGMENT
const uint8_t *av1_get_compound_type_mask_inverse(
const INTERINTER_COMPOUND_DATA *const comp_data,
#if CONFIG_COMPOUND_SEGMENT
uint8_t *mask_buffer, int h, int w, int stride,
#endif
BLOCK_SIZE sb_type) {
assert(is_masked_compound_type(comp_data->type));
switch (comp_data->type) {
case COMPOUND_WEDGE:
return av1_get_contiguous_soft_mask(
comp_data->wedge_index,
invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type);
return av1_get_contiguous_soft_mask(comp_data->wedge_index,
!comp_data->wedge_sign, sb_type);
#if CONFIG_COMPOUND_SEGMENT
case COMPOUND_SEG:
if (invert) return comp_data->seg_mask[!comp_data->which];
return comp_data->seg_mask[comp_data->which];
return invert_mask(mask_buffer, comp_data->seg_mask, h, w, stride);
#endif // CONFIG_COMPOUND_SEGMENT
default: assert(0); return NULL;
}
}
const uint8_t *av1_get_compound_type_mask(
const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type) {
assert(is_masked_compound_type(comp_data->type));
switch (comp_data->type) {
case COMPOUND_WEDGE:
return av1_get_contiguous_soft_mask(comp_data->wedge_index,
comp_data->wedge_sign, sb_type);
#if CONFIG_COMPOUND_SEGMENT
case COMPOUND_SEG: return comp_data->seg_mask;
#endif // CONFIG_COMPOUND_SEGMENT
default: assert(0); return NULL;
}
}
#if CONFIG_COMPOUND_SEGMENT
// temporary placeholder mask, this will be generated using segmentation later
void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type, int h,
int w, int mask_val) {
int i, j;
int block_stride = block_size_wide[sb_type];
for (i = 0; i < h; ++i)
for (j = 0; j < w; ++j) {
mask[i * block_stride + j] =
which_inverse ? AOM_BLEND_A64_MAX_ALPHA - mask_val : mask_val;
}
}
void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
const uint8_t *src0, int src0_stride,
const uint8_t *src1, int src1_stride,
BLOCK_SIZE sb_type, int h, int w) {
int block_stride = block_size_wide[sb_type];
int i, j;
(void)src0;
(void)src1;
(void)src0_stride;
(void)src1_stride;
for (i = 0; i < h; ++i)
for (j = 0; j < w; ++j) {
// if which == 0, put more weight on the first predictor
comp_data->seg_mask[0][i * block_stride + j] = 45;
comp_data->seg_mask[1][i * block_stride + j] =
AOM_BLEND_A64_MAX_ALPHA - 45;
}
switch (mask_type) {
case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break;
case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break;
case UNIFORM_55: uniform_mask(mask, 0, sb_type, h, w, 55); break;
case UNIFORM_55_INV: uniform_mask(mask, 1, sb_type, h, w, 55); break;
default: assert(0);
}
}
#endif // CONFIG_COMPOUND_SEGMENT
......@@ -420,16 +455,16 @@ static void build_masked_compound_wedge_extend_highbd(
#endif // CONFIG_AOM_HIGHBITDEPTH
#endif // CONFIG_SUPERTX
static void build_masked_compound(uint8_t *dst, int dst_stride,
const uint8_t *src0, int src0_stride,
const uint8_t *src1, int src1_stride,
INTERINTER_COMPOUND_DATA *comp_data,
BLOCK_SIZE sb_type, int h, int w) {
static void build_masked_compound(
uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride,
const uint8_t *src1, int src1_stride,
const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h,
int w) {
// Derive subsampling from h and w passed in. May be refactored to
// pass in subsampling factors directly.
const int subh = (2 << b_height_log2_lookup[sb_type]) == h;
const int subw = (2 << b_width_log2_lookup[sb_type]) == w;
const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type);
aom_blend_a64_mask(dst, dst_stride, src0, src0_stride, src1, src1_stride,
mask, block_size_wide[sb_type], h, w, subh, subw);
}
......@@ -520,8 +555,9 @@ void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride,
#else
#if CONFIG_COMPOUND_SEGMENT
if (!plane && comp_data->type == COMPOUND_SEG)
build_compound_seg_mask(comp_data, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
mi->mbmi.sb_type, h, w);
build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, dst,
dst_stride, tmp_dst, MAX_SB_SIZE, mi->mbmi.sb_type,
h, w);
#endif // CONFIG_COMPOUND_SEGMENT
build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE,
comp_data, mi->mbmi.sb_type, h, w);
......@@ -2216,7 +2252,8 @@ static void build_wedge_inter_predictor_from_buf(
is_masked_compound_type(mbmi->interinter_compound_data.type)) {
#if CONFIG_COMPOUND_SEGMENT
if (!plane && comp_data->type == COMPOUND_SEG)
build_compound_seg_mask(comp_data, ext_dst0, ext_dst_stride0, ext_dst1,
build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type,
ext_dst0, ext_dst_stride0, ext_dst1,
ext_dst_stride1, mbmi->sb_type, h, w);
#endif // CONFIG_COMPOUND_SEGMENT
#if CONFIG_AOM_HIGHBITDEPTH
......
......@@ -189,7 +189,7 @@ static INLINE int get_interintra_wedge_bits(BLOCK_SIZE sb_type) {
}
#if CONFIG_COMPOUND_SEGMENT
void build_compound_seg_mask(INTERINTER_COMPOUND_DATA *comp_data,
void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type,
const uint8_t *src0, int src0_stride,
const uint8_t *src1, int src1_stride,
BLOCK_SIZE sb_type, int h, int w);
......@@ -514,8 +514,15 @@ const uint8_t *av1_get_soft_mask(int wedge_index, int wedge_sign,
BLOCK_SIZE sb_type, int wedge_offset_x,
int wedge_offset_y);
const uint8_t *av1_get_compound_type_mask(INTERINTER_COMPOUND_DATA *comp_data,
BLOCK_SIZE sb_type, int invert);
const uint8_t *av1_get_compound_type_mask_inverse(
const INTERINTER_COMPOUND_DATA *const comp_data,
#if CONFIG_COMPOUND_SEGMENT
uint8_t *mask_buffer, int h, int w, int stride,
#endif
BLOCK_SIZE sb_type);
const uint8_t *av1_get_compound_type_mask(
const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type);
void av1_build_interintra_predictors(MACROBLOCKD *xd, uint8_t *ypred,
uint8_t *upred, uint8_t *vpred,
......
......@@ -1888,7 +1888,8 @@ static void read_inter_block_mode_info(AV1Decoder *const pbi,
}
#if CONFIG_COMPOUND_SEGMENT
else if (mbmi->interinter_compound_data.type == COMPOUND_SEG) {
mbmi->interinter_compound_data.which = aom_read_bit(r, ACCT_STR);
mbmi->interinter_compound_data.mask_type =
aom_read_literal(r, MAX_SEG_MASK_BITS, ACCT_STR);
}
#endif // CONFIG_COMPOUND_SEGMENT
}
......
......@@ -1626,7 +1626,8 @@ static void pack_inter_mode_mvs(AV1_COMP *cpi, const MODE_INFO *mi,
}
#if CONFIG_COMPOUND_SEGMENT
else if (mbmi->interinter_compound_data.type == COMPOUND_SEG) {
aom_write_bit(w, mbmi->interinter_compound_data.which);
aom_write_literal(w, mbmi->interinter_compound_data.mask_type,
MAX_SEG_MASK_BITS);
}
#endif // CONFIG_COMPOUND_SEGMENT
}
......
......@@ -6468,8 +6468,9 @@ static void do_masked_motion_search(const AV1_COMP *const cpi, MACROBLOCK *x,
static void do_masked_motion_search_indexed(
const AV1_COMP *const cpi, MACROBLOCK *x,
INTERINTER_COMPOUND_DATA *comp_data, BLOCK_SIZE bsize, int mi_row,
int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2], int which) {
const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2],
int which) {
// NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
MACROBLOCKD *xd = &x->e_mbd;
MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
......@@ -6477,15 +6478,22 @@ static void do_masked_motion_search_indexed(
const uint8_t *mask;
const int mask_stride = block_size_wide[bsize];
mask = av1_get_compound_type_mask(comp_data, sb_type, 0);
mask = av1_get_compound_type_mask(comp_data, sb_type);
if (which == 0 || which == 2)
do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
&tmp_mv[0], &rate_mv[0], 0, mv_idx[0]);
if (which == 1 || which == 2) {
// get the negative mask
mask = av1_get_compound_type_mask(comp_data, sb_type, 1);
// get the negative mask
#if CONFIG_COMPOUND_SEGMENT
uint8_t inv_mask_buf[2 * MAX_SB_SQUARE];
const int h = block_size_high[bsize];
mask = av1_get_compound_type_mask_inverse(
comp_data, inv_mask_buf, h, mask_stride, mask_stride, sb_type);
#else
mask = av1_get_compound_type_mask_inverse(comp_data, sb_type);
#endif
do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col,
&tmp_mv[1], &rate_mv[1], 1, mv_idx[1]);
}
......@@ -6846,7 +6854,10 @@ static int64_t pick_interinter_seg_mask(const AV1_COMP *const cpi,
int rate;
uint64_t sse;
int64_t dist;
int rd0, rd1;
int rd0;
SEG_MASK_TYPE cur_mask_type;
int64_t best_rd = INT64_MAX;
SEG_MASK_TYPE best_mask_type = 0;
#if CONFIG_AOM_HIGHBITDEPTH
const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH;
const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
......@@ -6874,26 +6885,31 @@ static int64_t pick_interinter_seg_mask(const AV1_COMP *const cpi,
aom_subtract_block(bh, bw, d10, bw, p1, bw, p0, bw);
}
// build mask and inverse
build_compound_seg_mask(comp_data, p0, bw, p1, bw, bsize, bh, bw);
// try each mask type and its inverse
for (cur_mask_type = 0; cur_mask_type < SEG_MASK_TYPES; cur_mask_type++) {
// build mask and inverse
build_compound_seg_mask(comp_data->seg_mask, cur_mask_type, p0, bw, p1, bw,
bsize, bh, bw);
// compute rd for mask0
sse = av1_wedge_sse_from_residuals(r1, d10, comp_data->seg_mask[0], N);
sse = ROUND_POWER_OF_TWO(sse, bd_round);
// compute rd for mask
sse = av1_wedge_sse_from_residuals(r1, d10, comp_data->seg_mask, N);
sse = ROUND_POWER_OF_TWO(sse, bd_round);
model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
rd0 = RDCOST(x->rdmult, x->rddiv, rate, dist);
model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
rd0 = RDCOST(x->rdmult, x->rddiv, rate, dist);
// compute rd for mask1
sse = av1_wedge_sse_from_residuals(r1, d10, comp_data->seg_mask[1], N);
sse = ROUND_POWER_OF_TWO(sse, bd_round);
if (rd0 < best_rd) {
best_mask_type = cur_mask_type;
best_rd = rd0;
}
}
model_rd_from_sse(cpi, xd, bsize, 0, sse, &rate, &dist);
rd1 = RDCOST(x->rdmult, x->rddiv, rate, dist);
// make final mask
comp_data->mask_type = best_mask_type;
build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, p0, bw, p1,
bw, bsize, bh, bw);
// pick the better of the two
mbmi->interinter_compound_data.which = rd1 < rd0;
return mbmi->interinter_compound_data.which ? rd1 : rd0;
return best_rd;
}
#endif // CONFIG_COMPOUND_SEGMENT
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment