Commit ca6958c6 authored by Cheng Chen's avatar Cheng Chen

JNT_COMP: 3. rd select the best weight

Select the best compound_idx in rd.
The rate/cost for compound_idx and their ctx will be in patch 4.

But there's a bug for now if we don't encode one more time using the
selected compound_idx. It remains a issue to be solved in the future.

Change-Id: I5e1ba51da2b6ab5bacd8aba752dda43bd2257014
parent bda536a4
......@@ -422,6 +422,10 @@ typedef struct MB_MODE_INFO {
int sign;
int delta;
#endif
#if CONFIG_JNT_COMP
int compound_idx;
#endif
} MB_MODE_INFO;
typedef struct MODE_INFO {
......
......@@ -1155,6 +1155,11 @@ static void jnt_comp_weight_assign(const AV1_COMMON *cm,
conv_params->fwd_offset = (DIST_PRECISION >> 1);
conv_params->bck_offset = (DIST_PRECISION >> 1);
}
if (mbmi->compound_idx) {
conv_params->fwd_offset = -1;
conv_params->bck_offset = -1;
}
} else {
conv_params->bck_offset = -1;
conv_params->fwd_offset = -1;
......
......@@ -5795,43 +5795,48 @@ static int check_best_zero_mv(
static void jnt_comp_weight_assign(const AV1_COMMON *cm,
const MB_MODE_INFO *mbmi, int order_idx,
uint8_t *second_pred) {
int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
int bck_frame_index = 0, fwd_frame_index = 0;
int cur_frame_index = cm->cur_frame->cur_frame_offset;
if (mbmi->compound_idx) {
second_pred[4096] = -1;
second_pred[4097] = -1;
} else {
int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
int bck_frame_index = 0, fwd_frame_index = 0;
int cur_frame_index = cm->cur_frame->cur_frame_offset;
if (bck_idx >= 0) {
bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
}
if (bck_idx >= 0) {
bck_frame_index = cm->buffer_pool->frame_bufs[bck_idx].cur_frame_offset;
}
if (fwd_idx >= 0) {
fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
}
if (fwd_idx >= 0) {
fwd_frame_index = cm->buffer_pool->frame_bufs[fwd_idx].cur_frame_offset;
}
const double fwd = abs(fwd_frame_index - cur_frame_index);
const double bck = abs(cur_frame_index - bck_frame_index);
int order;
double ratio;
const double fwd = abs(fwd_frame_index - cur_frame_index);
const double bck = abs(cur_frame_index - bck_frame_index);
int order;
double ratio;
if (COMPOUND_WEIGHT_MODE == DIST) {
if (fwd > bck) {
ratio = (bck != 0) ? fwd / bck : 5.0;
order = 0;
if (COMPOUND_WEIGHT_MODE == DIST) {
if (fwd > bck) {
ratio = (bck != 0) ? fwd / bck : 5.0;
order = 0;
} else {
ratio = (fwd != 0) ? bck / fwd : 5.0;
order = 1;
}
int quant_dist_idx;
for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
if (ratio < quant_dist_category[quant_dist_idx]) break;
}
second_pred[4096] =
quant_dist_lookup_table[order_idx][quant_dist_idx][order];
second_pred[4097] =
quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
} else {
ratio = (fwd != 0) ? bck / fwd : 5.0;
order = 1;
}
int quant_dist_idx;
for (quant_dist_idx = 0; quant_dist_idx < 4; ++quant_dist_idx) {
if (ratio < quant_dist_category[quant_dist_idx]) break;
second_pred[4096] = (DIST_PRECISION >> 1);
second_pred[4097] = (DIST_PRECISION >> 1);
}
second_pred[4096] =
quant_dist_lookup_table[order_idx][quant_dist_idx][order];
second_pred[4097] =
quant_dist_lookup_table[order_idx][quant_dist_idx][1 - order];
} else {
second_pred[4096] = (DIST_PRECISION >> 1);
second_pred[4097] = (DIST_PRECISION >> 1);
}
}
#endif // CONFIG_JNT_COMP
......@@ -10217,6 +10222,130 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data,
}
}
}
#if CONFIG_JNT_COMP
{
int cum_rate = rate2;
MB_MODE_INFO backup_mbmi = *mbmi;
int_mv backup_frame_mv[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
int_mv backup_single_newmv[TOTAL_REFS_PER_FRAME];
int backup_single_newmv_rate[TOTAL_REFS_PER_FRAME];
int64_t backup_modelled_rd[MB_MODE_COUNT][TOTAL_REFS_PER_FRAME];
memcpy(backup_frame_mv, frame_mv, sizeof(frame_mv));
memcpy(backup_single_newmv, single_newmv, sizeof(single_newmv));
memcpy(backup_single_newmv_rate, single_newmv_rate,
sizeof(single_newmv_rate));
memcpy(backup_modelled_rd, modelled_rd, sizeof(modelled_rd));
InterpFilters backup_interp_filters = mbmi->interp_filters;
for (int comp_idx = 0; comp_idx < 1 + has_second_ref(mbmi);
++comp_idx) {
RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
av1_init_rd_stats(&rd_stats);
av1_init_rd_stats(&rd_stats_y);
av1_init_rd_stats(&rd_stats_uv);
rd_stats.rate = cum_rate;
memcpy(frame_mv, backup_frame_mv, sizeof(frame_mv));
memcpy(single_newmv, backup_single_newmv, sizeof(single_newmv));
memcpy(single_newmv_rate, backup_single_newmv_rate,
sizeof(single_newmv_rate));
memcpy(modelled_rd, backup_modelled_rd, sizeof(modelled_rd));
mbmi->interp_filters = backup_interp_filters;
int dummy_disable_skip = 0;
// Point to variables that are maintained between loop iterations
args.single_newmv = single_newmv;
args.single_newmv_rate = single_newmv_rate;
args.modelled_rd = modelled_rd;
mbmi->compound_idx = comp_idx;
int64_t tmp_rd = handle_inter_mode(
cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
&dummy_disable_skip, frame_mv, mi_row, mi_col, &args, best_rd);
if (tmp_rd < INT64_MAX) {
if (RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist) <
RDCOST(x->rdmult, 0, rd_stats.sse))
tmp_rd =
RDCOST(x->rdmult, rd_stats.rate + x->skip_cost[skip_ctx][0],
rd_stats.dist);
else
tmp_rd = RDCOST(x->rdmult,
rd_stats.rate + x->skip_cost[skip_ctx][1] -
rd_stats_y.rate - rd_stats_uv.rate,
rd_stats.sse);
}
if (tmp_rd < this_rd) {
this_rd = tmp_rd;
rate2 = rd_stats.rate;
skippable = rd_stats.skip;
distortion2 = rd_stats.dist;
total_sse = rd_stats.sse;
rate_y = rd_stats_y.rate;
rate_uv = rd_stats_uv.rate;
disable_skip = dummy_disable_skip;
backup_mbmi = *mbmi;
}
}
*mbmi = backup_mbmi;
// TODO(chengchen): Redo encoding use the selected compound_idx
// But ideally, this is unnecessary
{
RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
av1_init_rd_stats(&rd_stats);
av1_init_rd_stats(&rd_stats_y);
av1_init_rd_stats(&rd_stats_uv);
rd_stats.rate = cum_rate;
memcpy(frame_mv, backup_frame_mv, sizeof(frame_mv));
memcpy(single_newmv, backup_single_newmv, sizeof(single_newmv));
memcpy(single_newmv_rate, backup_single_newmv_rate,
sizeof(single_newmv_rate));
memcpy(modelled_rd, backup_modelled_rd, sizeof(modelled_rd));
mbmi->interp_filters = backup_interp_filters;
int dummy_disable_skip = 0;
args.single_newmv = single_newmv;
args.single_newmv_rate = single_newmv_rate;
args.modelled_rd = modelled_rd;
int64_t tmp_rd = handle_inter_mode(
cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv,
&dummy_disable_skip, frame_mv, mi_row, mi_col, &args, best_rd);
if (tmp_rd < INT64_MAX) {
if (RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist) <
RDCOST(x->rdmult, 0, rd_stats.sse))
tmp_rd =
RDCOST(x->rdmult, rd_stats.rate + x->skip_cost[skip_ctx][0],
rd_stats.dist);
else
tmp_rd = RDCOST(x->rdmult,
rd_stats.rate + x->skip_cost[skip_ctx][1] -
rd_stats_y.rate - rd_stats_uv.rate,
rd_stats.sse);
}
this_rd = tmp_rd;
rate2 = rd_stats.rate;
skippable = rd_stats.skip;
distortion2 = rd_stats.dist;
total_sse = rd_stats.sse;
rate_y = rd_stats_y.rate;
rate_uv = rd_stats_uv.rate;
disable_skip = dummy_disable_skip;
}
}
#else // CONFIG_JNT_COMP
{
RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
av1_init_rd_stats(&rd_stats);
......@@ -10240,6 +10369,7 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data,
rate_y = rd_stats_y.rate;
rate_uv = rd_stats_uv.rate;
}
#endif // CONFIG_JNT_COMP
// TODO(jingning): This needs some refactoring to improve code quality
// and reduce redundant steps.
......@@ -10293,12 +10423,22 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data,
memcpy(x->blk_skip_drl[i], x->blk_skip[i],
sizeof(uint8_t) * ctx->num_4x4_blk);
for (ref_idx = 0; ref_idx < ref_set; ++ref_idx) {
#if CONFIG_JNT_COMP
for (int sidx = 0; sidx < ref_set * (1 + has_second_ref(mbmi)); ++sidx)
#else
for (ref_idx = 0; ref_idx < ref_set; ++ref_idx)
#endif // CONFIG_JNT_COMP
{
int64_t tmp_alt_rd = INT64_MAX;
int dummy_disable_skip = 0;
int ref;
int_mv cur_mv;
RD_STATS tmp_rd_stats, tmp_rd_stats_y, tmp_rd_stats_uv;
#if CONFIG_JNT_COMP
ref_idx = sidx;
if (has_second_ref(mbmi)) ref_idx /= 2;
mbmi->compound_idx = sidx % 2;
#endif // CONFIG_JNT_COMP
av1_invalid_rd_stats(&tmp_rd_stats);
......@@ -10480,6 +10620,9 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data,
for (i = 0; i < MAX_MB_PLANE; ++i)
memcpy(x->blk_skip[i], x->blk_skip_drl[i],
sizeof(uint8_t) * ctx->num_4x4_blk);
#if CONFIG_JNT_COMP
*mbmi = backup_mbmi;
#endif // CONFIG_JNT_COMP
}
mbmi_ext->ref_mvs[ref_frame][0] = backup_ref_mv[0];
if (comp_pred) mbmi_ext->ref_mvs[second_ref_frame][0] = backup_ref_mv[1];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment