Commit 8263f80c authored by Cheng Chen's avatar Cheng Chen
Browse files

JNT_COMP: refactor if statements

Refactor if statement that use frame_offset == -1 to indicate
jnt_comp is not chosen, as distance now can not be negative.
Instead, add a variable use_jnt_comp_avg for the same functionality.

Change-Id: Ie6b9c6ab36131b48bc9e066babada17046729cd8
parent d3af66c7
...@@ -591,6 +591,7 @@ typedef struct cfl_ctx { ...@@ -591,6 +591,7 @@ typedef struct cfl_ctx {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
typedef struct jnt_comp_params { typedef struct jnt_comp_params {
int use_jnt_comp_avg;
int fwd_offset; int fwd_offset;
int bck_offset; int bck_offset;
} JNT_COMP_PARAMS; } JNT_COMP_PARAMS;
......
...@@ -465,12 +465,7 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride, ...@@ -465,12 +465,7 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
sum += y_filter[k] * src_vert[(y - fo_vert + k) * im_stride + x]; sum += y_filter[k] * src_vert[(y - fo_vert + k) * im_stride + x];
} }
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1); CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
if (conv_params->bck_offset == -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->do_average == 0) { if (conv_params->do_average == 0) {
dst[y * dst_stride + x] = res * conv_params->fwd_offset; dst[y * dst_stride + x] = res * conv_params->fwd_offset;
} else { } else {
...@@ -479,6 +474,11 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride, ...@@ -479,6 +474,11 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x], dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1); DIST_PRECISION_BITS - 1);
} }
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} }
} }
} }
...@@ -536,12 +536,7 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride, ...@@ -536,12 +536,7 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
} }
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1); CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1);
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->bck_offset == -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->do_average == 0) { if (conv_params->do_average == 0) {
dst[y * dst_stride + x] = res * conv_params->fwd_offset; dst[y * dst_stride + x] = res * conv_params->fwd_offset;
} else { } else {
...@@ -550,6 +545,11 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride, ...@@ -550,6 +545,11 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x], dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1); DIST_PRECISION_BITS - 1);
} }
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} }
#else #else
if (conv_params->do_average) if (conv_params->do_average)
...@@ -669,12 +669,7 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride, ...@@ -669,12 +669,7 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) - CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) -
((1 << (offset_bits - conv_params->round_1)) + ((1 << (offset_bits - conv_params->round_1)) +
(1 << (offset_bits - conv_params->round_1 - 1))); (1 << (offset_bits - conv_params->round_1 - 1)));
if (conv_params->fwd_offset == -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->do_average) { if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset; dst[y * dst_stride + x] += res * conv_params->bck_offset;
...@@ -683,6 +678,11 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride, ...@@ -683,6 +678,11 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride,
} else { } else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset; dst[y * dst_stride + x] = res * conv_params->fwd_offset;
} }
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} }
} }
} }
...@@ -746,12 +746,7 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride, ...@@ -746,12 +746,7 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
((1 << (offset_bits - conv_params->round_1)) + ((1 << (offset_bits - conv_params->round_1)) +
(1 << (offset_bits - conv_params->round_1 - 1))); (1 << (offset_bits - conv_params->round_1 - 1)));
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset == -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} else {
if (conv_params->do_average) { if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset; dst[y * dst_stride + x] += res * conv_params->bck_offset;
...@@ -760,6 +755,11 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride, ...@@ -760,6 +755,11 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
} else { } else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset; dst[y * dst_stride + x] = res * conv_params->fwd_offset;
} }
} else {
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
dst[y * dst_stride + x] = res;
} }
#else #else
if (conv_params->do_average) if (conv_params->do_average)
......
...@@ -36,6 +36,7 @@ typedef struct ConvolveParams { ...@@ -36,6 +36,7 @@ typedef struct ConvolveParams {
int plane; int plane;
int do_post_rounding; int do_post_rounding;
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
int use_jnt_comp_avg;
int fwd_offset; int fwd_offset;
int bck_offset; int bck_offset;
#endif #endif
......
...@@ -924,14 +924,14 @@ typedef struct SubpelParams { ...@@ -924,14 +924,14 @@ typedef struct SubpelParams {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi, void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi,
int order_idx, int *fwd_offset, int *bck_offset, int order_idx, int *fwd_offset, int *bck_offset,
int is_compound) { int *use_jnt_comp_avg, int is_compound) {
assert(fwd_offset != NULL && bck_offset != NULL); assert(fwd_offset != NULL && bck_offset != NULL);
if (!is_compound || mbmi->compound_idx) { if (!is_compound || mbmi->compound_idx) {
*fwd_offset = -1; *use_jnt_comp_avg = 0;
*bck_offset = -1;
return; return;
} }
*use_jnt_comp_avg = 1;
const int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx; const int bck_idx = cm->frame_refs[mbmi->ref_frame[0] - LAST_FRAME].idx;
const int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx; const int fwd_idx = cm->frame_refs[mbmi->ref_frame[1] - LAST_FRAME].idx;
const int cur_frame_index = cm->cur_frame->cur_frame_offset; const int cur_frame_index = cm->cur_frame->cur_frame_offset;
...@@ -1256,7 +1256,8 @@ static INLINE void build_inter_predictors(const AV1_COMMON *cm, MACROBLOCKD *xd, ...@@ -1256,7 +1256,8 @@ static INLINE void build_inter_predictors(const AV1_COMMON *cm, MACROBLOCKD *xd,
get_conv_params_no_round(ref, ref, plane, tmp_dst, MAX_SB_SIZE); get_conv_params_no_round(ref, ref, plane, tmp_dst, MAX_SB_SIZE);
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
av1_jnt_comp_weight_assign(cm, &mi->mbmi, 0, &conv_params.fwd_offset, av1_jnt_comp_weight_assign(cm, &mi->mbmi, 0, &conv_params.fwd_offset,
&conv_params.bck_offset, is_compound); &conv_params.bck_offset,
&conv_params.use_jnt_comp_avg, is_compound);
#endif // CONFIG_JNT_COMP #endif // CONFIG_JNT_COMP
#else #else
......
...@@ -563,7 +563,7 @@ void av1_build_wedge_inter_predictor_from_buf(MACROBLOCKD *xd, BLOCK_SIZE bsize, ...@@ -563,7 +563,7 @@ void av1_build_wedge_inter_predictor_from_buf(MACROBLOCKD *xd, BLOCK_SIZE bsize,
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi, void av1_jnt_comp_weight_assign(const AV1_COMMON *cm, const MB_MODE_INFO *mbmi,
int order_idx, int *fwd_offset, int *bck_offset, int order_idx, int *fwd_offset, int *bck_offset,
int is_compound); int *use_jnt_comp_avg, int is_compound);
#endif // CONFIG_JNT_COMP #endif // CONFIG_JNT_COMP
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -535,8 +535,7 @@ void av1_highbd_warp_affine_c(const int32_t *mat, const uint16_t *ref, ...@@ -535,8 +535,7 @@ void av1_highbd_warp_affine_c(const int32_t *mat, const uint16_t *ref,
conv_params->round_0 - conv_params->round_1)) - conv_params->round_0 - conv_params->round_1)) -
(1 << (offset_bits_vert - conv_params->round_1)); (1 << (offset_bits_vert - conv_params->round_1));
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && if (conv_params->use_jnt_comp_avg) {
conv_params->bck_offset != -1) {
if (conv_params->do_average) { if (conv_params->do_average) {
*p += sum * conv_params->bck_offset; *p += sum * conv_params->bck_offset;
*p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1); *p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
...@@ -628,8 +627,7 @@ static int64_t highbd_warp_error( ...@@ -628,8 +627,7 @@ static int64_t highbd_warp_error(
ConvolveParams conv_params = get_conv_params(0, 0, 0); ConvolveParams conv_params = get_conv_params(0, 0, 0);
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
conv_params.fwd_offset = -1; conv_params.use_jnt_comp_avg = 0;
conv_params.bck_offset = -1;
#endif #endif
for (int i = p_row; i < p_row + p_height; i += WARP_ERROR_BLOCK) { for (int i = p_row; i < p_row + p_height; i += WARP_ERROR_BLOCK) {
for (int j = p_col; j < p_col + p_width; j += WARP_ERROR_BLOCK) { for (int j = p_col; j < p_col + p_width; j += WARP_ERROR_BLOCK) {
...@@ -864,8 +862,7 @@ void av1_warp_affine_c(const int32_t *mat, const uint8_t *ref, int width, ...@@ -864,8 +862,7 @@ void av1_warp_affine_c(const int32_t *mat, const uint8_t *ref, int width,
conv_params->round_0 - conv_params->round_1)) - conv_params->round_0 - conv_params->round_1)) -
(1 << (offset_bits_vert - conv_params->round_1)); (1 << (offset_bits_vert - conv_params->round_1));
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && if (conv_params->use_jnt_comp_avg) {
conv_params->bck_offset != -1) {
if (conv_params->do_average) { if (conv_params->do_average) {
*p += sum * conv_params->bck_offset; *p += sum * conv_params->bck_offset;
*p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1); *p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
...@@ -952,8 +949,7 @@ static int64_t warp_error(WarpedMotionParams *wm, const uint8_t *const ref, ...@@ -952,8 +949,7 @@ static int64_t warp_error(WarpedMotionParams *wm, const uint8_t *const ref,
uint8_t tmp[WARP_ERROR_BLOCK * WARP_ERROR_BLOCK]; uint8_t tmp[WARP_ERROR_BLOCK * WARP_ERROR_BLOCK];
ConvolveParams conv_params = get_conv_params(0, 0, 0); ConvolveParams conv_params = get_conv_params(0, 0, 0);
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
conv_params.fwd_offset = -1; conv_params.use_jnt_comp_avg = 0;
conv_params.bck_offset = -1;
#endif #endif
for (int i = p_row; i < p_row + p_height; i += WARP_ERROR_BLOCK) { for (int i = p_row; i < p_row + p_height; i += WARP_ERROR_BLOCK) {
......
...@@ -313,7 +313,7 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst, ...@@ -313,7 +313,7 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst,
int32_t *dst_x = dst + y * dst_stride + x; int32_t *dst_x = dst + y * dst_stride + x;
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
__m128i result; __m128i result;
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) { if (conv_params->do_average) {
result = _mm_srai_epi32( result = _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x), _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
...@@ -343,7 +343,7 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst, ...@@ -343,7 +343,7 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst,
for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k]; for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) - sub32; CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) - sub32;
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) { if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset; dst[y * dst_stride + x] += res * conv_params->bck_offset;
...@@ -432,7 +432,7 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst, ...@@ -432,7 +432,7 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst,
int32_t *dst_x = dst + y * dst_stride + x; int32_t *dst_x = dst + y * dst_stride + x;
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
__m128i result; __m128i result;
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) { if (conv_params->do_average) {
result = _mm_srai_epi32( result = _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x), _mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
...@@ -462,7 +462,7 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst, ...@@ -462,7 +462,7 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst,
for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k]; for (int k = 0; k < ntaps; ++k) sum += filter[k] * src_x[k];
CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) - sub32; CONV_BUF_TYPE res = ROUND_POWER_OF_TWO(sum, conv_params->round_1) - sub32;
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) { if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset; dst[y * dst_stride + x] += res * conv_params->bck_offset;
......
...@@ -190,7 +190,7 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride, ...@@ -190,7 +190,7 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride,
const __m128i res_hi_round = const __m128i res_hi_round =
_mm_sra_epi32(_mm_add_epi32(res_hi, round_const), round_shift); _mm_sra_epi32(_mm_add_epi32(res_hi, round_const), round_shift);
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
// NOTE(chengchen): // NOTE(chengchen):
// only this part is different from av1_convolve_2d_sse2 // only this part is different from av1_convolve_2d_sse2
// original c function at: av1/common/convolve.c: // original c function at: av1/common/convolve.c:
...@@ -409,7 +409,7 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride, ...@@ -409,7 +409,7 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride,
const __m128i res_hi_round = const __m128i res_hi_round =
_mm_sra_epi32(_mm_add_epi32(res_hi, round_const), round_shift); _mm_sra_epi32(_mm_add_epi32(res_hi, round_const), round_shift);
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
// FIXME(chengchen): validate this implementation // FIXME(chengchen): validate this implementation
// original c function at: av1/common/convolve.c: av1_convolve_2d_c // original c function at: av1/common/convolve.c: av1_convolve_2d_c
__m128i *const p = (__m128i *)&dst[i * dst_stride + j]; __m128i *const p = (__m128i *)&dst[i * dst_stride + j];
......
...@@ -322,7 +322,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref, ...@@ -322,7 +322,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref,
res_lo = res_lo =
_mm_srl_epi32(res_lo, _mm_cvtsi32_si128(conv_params->round_1)); _mm_srl_epi32(res_lo, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
if (comp_avg) { if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p), const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1)); _mm_mullo_epi32(res_lo, wt1));
...@@ -347,8 +347,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref, ...@@ -347,8 +347,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref,
_mm_srl_epi32(res_hi, _mm_cvtsi32_si128(conv_params->round_1)); _mm_srl_epi32(res_hi, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && if (conv_params->use_jnt_comp_avg) {
conv_params->bck_offset != -1) {
if (comp_avg) { if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1), const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1)); _mm_mullo_epi32(res_hi, wt1));
......
...@@ -319,7 +319,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width, ...@@ -319,7 +319,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width,
res_lo = res_lo =
_mm_srl_epi16(res_lo, _mm_cvtsi32_si128(conv_params->round_1)); _mm_srl_epi16(res_lo, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && conv_params->bck_offset != -1) { if (conv_params->use_jnt_comp_avg) {
if (comp_avg) { if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p), const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1)); _mm_mullo_epi32(res_lo, wt1));
...@@ -342,8 +342,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width, ...@@ -342,8 +342,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width,
res_hi = res_hi =
_mm_srl_epi16(res_hi, _mm_cvtsi32_si128(conv_params->round_1)); _mm_srl_epi16(res_hi, _mm_cvtsi32_si128(conv_params->round_1));
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (conv_params->fwd_offset != -1 && if (conv_params->use_jnt_comp_avg) {
conv_params->bck_offset != -1) {
if (comp_avg) { if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1), const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1)); _mm_mullo_epi32(res_hi, wt1));
......
...@@ -189,7 +189,7 @@ static INLINE const uint8_t *pre(const uint8_t *buf, int stride, int r, int c) { ...@@ -189,7 +189,7 @@ static INLINE const uint8_t *pre(const uint8_t *buf, int stride, int r, int c) {
src_address, src_stride, second_pred, mask, \ src_address, src_stride, second_pred, mask, \
mask_stride, invert_mask, &sse); \ mask_stride, invert_mask, &sse); \
} else { \ } else { \
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) \ if (xd->jcp_param.use_jnt_comp_avg) \
thismse = vfp->jsvaf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r), \ thismse = vfp->jsvaf(pre(y, y_stride, r, c), y_stride, sp(c), sp(r), \
src_address, src_stride, &sse, second_pred, \ src_address, src_stride, &sse, second_pred, \
&xd->jcp_param); \ &xd->jcp_param); \
...@@ -384,7 +384,7 @@ static unsigned int setup_center_error( ...@@ -384,7 +384,7 @@ static unsigned int setup_center_error(
mask, mask_stride, invert_mask); mask, mask_stride, invert_mask);
} else { } else {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) if (xd->jcp_param.use_jnt_comp_avg)
aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
y_stride, &xd->jcp_param); y_stride, &xd->jcp_param);
else else
...@@ -407,7 +407,7 @@ static unsigned int setup_center_error( ...@@ -407,7 +407,7 @@ static unsigned int setup_center_error(
mask, mask_stride, invert_mask); mask, mask_stride, invert_mask);
} else { } else {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) if (xd->jcp_param.use_jnt_comp_avg)
aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset, aom_jnt_comp_avg_pred(comp_pred, second_pred, w, h, y + offset,
y_stride, &xd->jcp_param); y_stride, &xd->jcp_param);
else else
...@@ -712,7 +712,7 @@ static int upsampled_pref_error(const MACROBLOCKD *xd, ...@@ -712,7 +712,7 @@ static int upsampled_pref_error(const MACROBLOCKD *xd,
mask_stride, invert_mask); mask_stride, invert_mask);
} else { } else {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) if (xd->jcp_param.use_jnt_comp_avg)
aom_jnt_comp_avg_upsampled_pred(pred, second_pred, w, h, subpel_x_q3, aom_jnt_comp_avg_upsampled_pred(pred, second_pred, w, h, subpel_x_q3,
subpel_y_q3, y, y_stride, subpel_y_q3, y, y_stride,
&xd->jcp_param); &xd->jcp_param);
...@@ -823,8 +823,7 @@ int av1_find_best_sub_pixel_tree( ...@@ -823,8 +823,7 @@ int av1_find_best_sub_pixel_tree(
mask_stride, invert_mask, &sse); mask_stride, invert_mask, &sse);
} else { } else {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && if (xd->jcp_param.use_jnt_comp_avg)
xd->jcp_param.bck_offset != -1)
thismse = thismse =
vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address, vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address,
src_stride, &sse, second_pred, &xd->jcp_param); src_stride, &sse, second_pred, &xd->jcp_param);
...@@ -875,7 +874,7 @@ int av1_find_best_sub_pixel_tree( ...@@ -875,7 +874,7 @@ int av1_find_best_sub_pixel_tree(
mask_stride, invert_mask, &sse); mask_stride, invert_mask, &sse);
} else { } else {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) if (xd->jcp_param.use_jnt_comp_avg)
thismse = thismse =
vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address, vfp->jsvaf(pre_address, y_stride, sp(tc), sp(tr), src_address,
src_stride, &sse, second_pred, &xd->jcp_param); src_stride, &sse, second_pred, &xd->jcp_param);
...@@ -1458,7 +1457,7 @@ int av1_get_mvpred_av_var(const MACROBLOCK *x, const MV *best_mv, ...@@ -1458,7 +1457,7 @@ int av1_get_mvpred_av_var(const MACROBLOCK *x, const MV *best_mv,
unsigned int unused; unsigned int unused;
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) if (xd->jcp_param.use_jnt_comp_avg)
return vfp->jsvaf(get_buf_from_mv(in_what, best_mv), in_what->stride, 0, 0, return vfp->jsvaf(get_buf_from_mv(in_what, best_mv), in_what->stride, 0, 0,
what->buf, what->stride, &unused, second_pred, what->buf, what->stride, &unused, second_pred,
&xd->jcp_param) + &xd->jcp_param) +
...@@ -2482,7 +2481,7 @@ int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range, ...@@ -2482,7 +2481,7 @@ int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range,
mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit); mvsad_err_cost(x, best_mv, &fcenter_mv, error_per_bit);
} else { } else {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) if (xd->jcp_param.use_jnt_comp_avg)
best_sad = fn_ptr->jsdaf(what->buf, what->stride, best_sad = fn_ptr->jsdaf(what->buf, what->stride,
get_buf_from_mv(in_what, best_mv), get_buf_from_mv(in_what, best_mv),
in_what->stride, second_pred, &xd->jcp_param) + in_what->stride, second_pred, &xd->jcp_param) +
...@@ -2510,7 +2509,7 @@ int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range, ...@@ -2510,7 +2509,7 @@ int av1_refining_search_8p_c(MACROBLOCK *x, int error_per_bit, int search_range,
second_pred, mask, mask_stride, invert_mask); second_pred, mask, mask_stride, invert_mask);
} else { } else {
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
if (xd->jcp_param.fwd_offset != -1 && xd->jcp_param.bck_offset != -1) if (xd->jcp_param.use_jnt_comp_avg)
sad = fn_ptr->jsdaf(what->buf, what->stride, sad = fn_ptr->jsdaf(what->buf, what->stride,
get_buf_from_mv(in_what, &mv), in_what->stride, get_buf_from_mv(in_what, &mv), in_what->stride,
second_pred, &xd->jcp_param); second_pred, &xd->jcp_param);
......
...@@ -5894,7 +5894,8 @@ static void joint_motion_search(const AV1_COMP *cpi, MACROBLOCK *x, ...@@ -5894,7 +5894,8 @@ static void joint_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
const int order_idx = id != 0; const int order_idx = id != 0;
av1_jnt_comp_weight_assign(cm, mbmi, order_idx, &xd->jcp_param.fwd_offset, av1_jnt_comp_weight_assign(cm, mbmi, order_idx, &xd->jcp_param.fwd_offset,
&xd->jcp_param.bck_offset, 1); &xd->jcp_param.bck_offset,
&xd->jcp_param.use_jnt_comp_avg, 1);
#endif // CONFIG_JNT_COMP #endif // CONFIG_JNT_COMP
// Do compound motion search on the current reference frame. // Do compound motion search on the current reference frame.
...@@ -6537,7 +6538,8 @@ static void build_second_inter_pred(const AV1_COMP *cpi, MACROBLOCK *x, ...@@ -6537,7 +6538,8 @@ static void build_second_inter_pred(const AV1_COMP *cpi, MACROBLOCK *x,
#if CONFIG_JNT_COMP #if CONFIG_JNT_COMP
av1_jnt_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset, av1_jnt_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset,
&xd->jcp_param.bck_offset, 1); &xd->jcp_param.bck_offset,
&xd->jcp_param.use_jnt_comp_avg, 1);
#endif // CONFIG_JNT_COMP #endif // CONFIG_JNT_COMP
if (scaled_ref_frame) { if (scaled_ref_frame) {
......
...@@ -140,10 +140,8 @@ void AV1Convolve2DTest::RunCheckOutput2(convolve_2d_func test_impl) { ...@@ -140,10 +140,8 @@ void AV1Convolve2DTest::RunCheckOutput2(convolve_2d_func test_impl) {
get_conv_params_no_round(0, do_average, 0, output2, MAX_SB_SIZE); get_conv_params_no_round(0, do_average, 0, output2, MAX_SB_SIZE);
// Test special case where fwd and bck offsets are -1 // Test special case where fwd and bck offsets are -1
conv_params1.fwd_offset = -1; conv_params1.use_jnt_comp_avg = 0;
conv_params1.bck_offset = -1; conv_params2.use_jnt_comp_avg = 0;
conv_params2.fwd_offset = -1;
conv_params2.bck_offset = -1;