Commit 8263f80c authored by Cheng Chen's avatar Cheng Chen
Browse files

JNT_COMP: refactor if statements

Refactor if statement that use frame_offset == -1 to indicate
jnt_comp is not chosen, as distance now can not be negative.
Instead, add a variable use_jnt_comp_avg for the same functionality.

Change-Id: Ie6b9c6ab36131b48bc9e066babada17046729cd8
parent d3af66c7
......@@ -591,6 +591,7 @@ typedef struct cfl_ctx {
#if CONFIG_JNT_COMP
typedef struct jnt_comp_params {
int use_jnt_comp_avg;
int fwd_offset;
int bck_offset;
} JNT_COMP_PARAMS;
......
......@@ -465,12 +465,7 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
sum += y_filter[k] * src_vert[(y - fo_vert + k) * im_stride + x];
}
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
if (conv_params->bck_offset == -1) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average == 0) {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
} else {
......@@ -479,6 +474,11 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
}
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
}
}
}
......@@ -536,12 +536,7 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
}
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
#if CONFIG_JNT_COMP
if (conv_params->bck_offset == -1) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average == 0) {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
} else {
......@@ -550,6 +545,11 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
}
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
}
#else
if (conv_params->do_average)
......@@ -669,12 +669,7 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) -
((1 << (offset_bits - conv_params->round_1)) +
(1 << (offset_bits - conv_params->round_1 - 1)));
if (conv_params->fwd_offset == -1) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
......@@ -683,6 +678,11 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
}
}
}
......@@ -746,12 +746,7 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
((1 << (offset_bits - conv_params->round_1)) +
(1 << (offset_bits - conv_params->round_1 - 1)));
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset == -1) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
......@@ -760,6 +755,11 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
}
#else
if (conv_params->do_average)
......
......@@ -36,6 +36,7 @@ typedef struct ConvolveParams {
int plane;
int do_post_rounding;
#if CONFIG_JNT_COMP
int use_jnt_comp_avg;
int fwd_offset;
int bck_offset;
#endif
......
......@@ -924,14 +924,14 @@ typedef struct SubpelParams {
#if CONFIG_JNT_COMP
void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi,
int order_idx, int *fwd_offset, int *bck_offset,
int is_compound) {
int *use_jnt_comp_avg, int is_compound) {
assert(fwd_offset != NULL && bck_offset != NULL);
if (!is_compound || mbmi->compound_idx) {
*fwd_offset = -1;
*bck_offset = -1;
*use_jnt_comp_avg = 0;
return;
}
*use_jnt_comp_avg = 1;
const int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
const int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
const int cur_frame_index = cm->cur_frame->cur_frame_offset;
......@@ -1256,7 +1256,8 @@ static INLINE void build_inter_predictors(const AV1_COMMON *cm, MACROBLOCKD *xd,
get_conv_params_no_round(ref, ref, plane, tmp_dst, MAX_SB_SIZE);
#if CONFIG_JNT_COMP
av1_jnt_comp_weight_assign(cm, &mi->mbmi, 0, &conv_params.fwd_offset,
&conv_params.bck_offset, is_compound);
&conv_params.bck_offset,
&conv_params.use_jnt_comp_avg, is_compound);
#endif // CONFIG_JNT_COMP
#else
......
......@@ -563,7 +563,7 @@ void av1_build_wedge_inter_predictor_from_buf(MACROBLOCKD *xd, BLOCK_SIZE bsize,
#if CONFIG_JNT_COMP
void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi,
int order_idx, int *fwd_offset, int *bck_offset,
int is_compound);
int *use_jnt_comp_avg, int is_compound);
#endif // CONFIG_JNT_COMP
#ifdef __cplusplus
......
......@@ -535,8 +535,7 @@ void av1_highbd_warp_affine_c(const int32_t *mat, const uint16_t *ref,
conv_params->round_0 - conv_params->round_1)) -
(1 << (offset_bits_vert - conv_params->round_1));
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 &&
conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
*p += sum * conv_params->bck_offset;
*p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
......@@ -628,8 +627,7 @@ static int64_t highbd_warp_error(
ConvolveParams conv_params = get_conv_params(0, 0, 0);
#if CONFIG_JNT_COMP
conv_params.fwd_offset = -1;
conv_params.bck_offset = -1;
conv_params.use_jnt_comp_avg = 0;
#endif
for (int i = p_row; i < p_row + p_height; i += WARP_ERROR_BLOCK) {
for (int j = p_col; j < p_col + p_width; j += WARP_ERROR_BLOCK) {
......@@ -864,8 +862,7 @@ void av1_warp_affine_c(const int32_t *mat, const uint8_t *ref, int width,
conv_params->round_0 - conv_params->round_1)) -
(1 << (offset_bits_vert - conv_params->round_1));
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 &&
conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
*p += sum * conv_params->bck_offset;
*p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
......@@ -952,8 +949,7 @@ static int64_t warp_error(WarpedMotionParams *wm, const uint8_t *const ref,
uint8_t tmp[WARP_ERROR_BLOCK * WARP_ERROR_BLOCK];
ConvolveParams conv_params = get_conv_params(0, 0, 0);
#if CONFIG_JNT_COMP
conv_params.fwd_offset = -1;
conv_params.bck_offset = -1;
conv_params.use_jnt_comp_avg = 0;
#endif
for (int i = p_row; i < p_row + p_height; i += WARP_ERROR_BLOCK) {
......
......@@ -313,7 +313,7 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst,
int32_t *dst_x = dst + y * dst_stride + x;
#if CONFIG_JNT_COMP
__m128i result;
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
result = _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
......@@ -343,7 +343,7 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst,
for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) - sub32;
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
......@@ -432,7 +432,7 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst,
int32_t *dst_x = dst + y * dst_stride + x;
#if CONFIG_JNT_COMP
__m128i result;
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
result = _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
......@@ -462,7 +462,7 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst,
for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) - sub32;
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
......
......@@ -190,7 +190,7 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride,
const __m128i res_hi_round =
_mm_sra_epi32(_mm_add_epi32(res_hi, round_const), round_shift);
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
// NOTE(chengchen):
// only this part is different from av1_convolve_2d_sse2
// original c function at: av1/common/convolve.c:
......@@ -409,7 +409,7 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride,
const __m128i res_hi_round =
_mm_sra_epi32(_mm_add_epi32(res_hi, round_const), round_shift);
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
// FIXME(chengchen): validate this implementation
// original c function at: av1/common/convolve.c: av1_convolve_2d_c
__m128i *const p = (__m128i *)&dst[i * dst_stride + j];
......
......@@ -322,7 +322,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref,
res_lo =
_mm_srl_epi32(res_lo, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1));
......@@ -347,8 +347,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref,
_mm_srl_epi32(res_hi, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 &&
conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1));
......
......@@ -319,7 +319,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width,
res_lo =
_mm_srl_epi16(res_lo, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1));
......@@ -342,8 +342,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width,
res_hi =
_mm_srl_epi16(res_hi, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 &&
conv_params->bck_offset != -1) {
if (conv_params->use_jnt_comp_avg) {
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1));
......
......@@ -189,7 +189,7 @@ static INLINE const uint8_t *pre(const uint8_t *buf, int stride, int r, int c) {
src_address, src_stride, second_pred, mask, \
mask_stride, invert_mask, &sse); \
} else { \
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) \
if (xd->jcp_param.use_jnt_comp_avg) \
thismse = vfp->jsvaf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r), \
src_address, src_stride, &sse, second_pred, \
&xd->jcp_param); \
......@@ -384,7 +384,7 @@ static unsigned int setup_center_error(
mask, mask_stride, invert_mask);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
y_stride, &xd->jcp_param);
else
......@@ -407,7 +407,7 @@ static unsigned int setup_center_error(
mask, mask_stride, invert_mask);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
y_stride, &xd->jcp_param);
else
......@@ -712,7 +712,7 @@ static int upsampled_pref_error(const MACROBLOCKD *xd,
mask_stride, invert_mask);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
aom_jnt_comp_avg_upsampled_pred(pred, second_pred, w, h, subpel_x_q3,
subpel_y_q3, y, y_stride,
&xd->jcp_param);
......@@ -823,8 +823,7 @@ int av1_find_best_sub_pixel_tree(
mask_stride, invert_mask, &sse);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 &&
xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
thismse =
vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address,
src_stride, &sse, second_pred, &xd->jcp_param);
......@@ -875,7 +874,7 @@ int av1_find_best_sub_pixel_tree(
mask_stride, invert_mask, &sse);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
thismse =
vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address,
src_stride, &sse, second_pred, &xd->jcp_param);
......@@ -1458,7 +1457,7 @@ int av1_get_mvpred_av_var(const MACROBLOCK *x, const MV *best_mv,
unsigned int unused;
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
return vfp->jsvaf(get_buf_from_mv(in_what, best_mv), in_what->stride, 0, 0,
what->buf, what->stride, &unused, second_pred,
&xd->jcp_param) +
......@@ -2482,7 +2481,7 @@ int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range,
mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
best_sad = fn_ptr->jsdaf(what->buf, what->stride,
get_buf_from_mv(in_what, best_mv),
in_what->stride, second_pred, &xd->jcp_param) +
......@@ -2510,7 +2509,7 @@ int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range,
second_pred, mask, mask_stride, invert_mask);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1)
if (xd->jcp_param.use_jnt_comp_avg)
sad = fn_ptr->jsdaf(what->buf, what->stride,
get_buf_from_mv(in_what, &mv), in_what->stride,
second_pred, &xd->jcp_param);
......
......@@ -5894,7 +5894,8 @@ static void joint_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
#if CONFIG_JNT_COMP
const int order_idx = id != 0;
av1_jnt_comp_weight_assign(cm, mbmi, order_idx, &xd->jcp_param.fwd_offset,
&xd->jcp_param.bck_offset, 1);
&xd->jcp_param.bck_offset,
&xd->jcp_param.use_jnt_comp_avg, 1);
#endif // CONFIG_JNT_COMP
// Do compound motion search on the current reference frame.
......@@ -6537,7 +6538,8 @@ static void build_second_inter_pred(const AV1_COMP *cpi, MACROBLOCK *x,
#if CONFIG_JNT_COMP
av1_jnt_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset,
&xd->jcp_param.bck_offset, 1);
&xd->jcp_param.bck_offset,
&xd->jcp_param.use_jnt_comp_avg, 1);
#endif // CONFIG_JNT_COMP
if (scaled_ref_frame) {
......
......@@ -140,10 +140,8 @@ void AV1Convolve2DTest::RunCheckOutput2(convolve_2d_func test_impl) {
get_conv_params_no_round(0, do_average, 0, output2, MAX_SB_SIZE);
// Test special case where fwd and bck offsets are -1
conv_params1.fwd_offset = -1;
conv_params1.bck_offset = -1;
conv_params2.fwd_offset = -1;
conv_params2.bck_offset = -1;
conv_params1.use_jnt_comp_avg = 0;
conv_params2.use_jnt_comp_avg = 0;
for (subx = 0; subx < 16; ++subx)
for (suby = 0; suby < 16; ++suby) {
......@@ -179,6 +177,8 @@ void AV1Convolve2DTest::RunCheckOutput2(convolve_2d_func test_impl) {
// Test different combination of fwd and bck offset weights
for (l = 0; l < 2; ++l) {
for (m = 0; m < 4; ++m) {
conv_params1.use_jnt_comp_avg = 1;
conv_params2.use_jnt_comp_avg = 1;
conv_params1.fwd_offset = quant_dist_lookup_table[l][m][0];
conv_params1.bck_offset = quant_dist_lookup_table[l][m][1];
conv_params2.fwd_offset = quant_dist_lookup_table[l][m][0];
......
......@@ -264,9 +264,9 @@ class ConvolveScaleTestBase : public ::testing::Test {
#if CONFIG_JNT_COMP
void SetConvParamOffset(int i, int j) {
if (i == -1 && j == -1) {
convolve_params_.fwd_offset = -1;
convolve_params_.bck_offset = -1;
convolve_params_.use_jnt_comp_avg = 0;
} else {
convolve_params_.use_jnt_comp_avg = 1;
convolve_params_.fwd_offset = quant_dist_lookup_table[i][j][0];
convolve_params_.bck_offset = quant_dist_lookup_table[i][j][1];
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment