Commit 39cf8061 authored by David Barker's avatar David Barker Committed by Cheng Chen

[jnt-comp, normative] Avoid double-rounding in prediction

As per the linked bug report, the distance-weighted compound
prediction has two separate round operations, first by 3
bits (inside the various convolve functions), then by 10 bits
(after the convolution functions).

We can improve on this by right shifting by 3 bits inside the
convolve functions - this is equivalent to doing a single round
by 13 bits at the end.

Note: In the encoder, when doing joint_motion_search(), we do
things a bit differently: So that we can try modifying the two
"sides" of the prediction independently, we predict each side as
if it were a single prediction (including rounding), then blend
these single predictions together.

This is already an approximation to the "real" prediction, even
in the non-jnt-comp case. So we leave that code path as-is.

BUG=aomedia:1289

Change-Id: I9ad1fbcb3e12db2b5fc3c82b407f0fd9e6b39750
parent 0cf864fd
......@@ -702,9 +702,7 @@ void av1_jnt_convolve_2d_c(const uint8_t *src, int src_stride, uint8_t *dst0,
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
......@@ -742,9 +740,7 @@ void av1_jnt_convolve_2d_copy_c(const uint8_t *src, int src_stride,
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
......@@ -818,9 +814,7 @@ void av1_convolve_2d_scale_c(const uint8_t *src, int src_stride,
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
......@@ -985,9 +979,7 @@ void av1_highbd_jnt_convolve_2d_c(const uint16_t *src, int src_stride,
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
......@@ -1060,9 +1052,7 @@ void av1_highbd_convolve_2d_scale_c(const uint16_t *src, int src_stride,
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
......
......@@ -516,7 +516,7 @@ void av1_highbd_warp_affine_c(const int32_t *mat, const uint16_t *ref,
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
*p += sum * conv_params->bck_offset;
*p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
*p >>= (DIST_PRECISION_BITS - 1);
} else {
*p = sum * conv_params->fwd_offset;
}
......@@ -820,7 +820,7 @@ void av1_warp_affine_c(const int32_t *mat, const uint8_t *ref, int width,
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
*p += sum * conv_params->bck_offset;
*p = ROUND_POWER_OF_TWO(*p, DIST_PRECISION_BITS - 1);
*p >>= (DIST_PRECISION_BITS - 1);
} else {
*p = sum * conv_params->fwd_offset;
}
......
......@@ -263,7 +263,6 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst,
#if CONFIG_JNT_COMP
const __m128i fwd_offset = _mm_set1_epi32(conv_params->fwd_offset);
const __m128i bck_offset = _mm_set1_epi32(conv_params->bck_offset);
const __m128i jnt_round = _mm_set1_epi32(1 << (DIST_PRECISION_BITS - 2));
#endif // CONFIG_JNT_COMP
int y_qn = subpel_y_qn;
......@@ -315,11 +314,10 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst,
__m128i result;
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
result = _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
_mm_mullo_epi32(subbed, bck_offset)),
jnt_round),
DIST_PRECISION_BITS - 1);
result =
_mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
_mm_mullo_epi32(subbed, bck_offset)),
DIST_PRECISION_BITS - 1);
} else {
result = _mm_mullo_epi32(subbed, fwd_offset);
}
......@@ -347,8 +345,7 @@ static void vfilter(const int32_t *src, int src_stride, int32_t *dst,
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
......@@ -385,7 +382,6 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst,
#if CONFIG_JNT_COMP
const __m128i fwd_offset = _mm_set1_epi32(conv_params->fwd_offset);
const __m128i bck_offset = _mm_set1_epi32(conv_params->bck_offset);
const __m128i jnt_round = _mm_set1_epi32(1 << (DIST_PRECISION_BITS - 2));
#endif // CONFIG_JNT_COMP
int y_qn = subpel_y_qn;
......@@ -434,11 +430,10 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst,
__m128i result;
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
result = _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
_mm_mullo_epi32(subbed, bck_offset)),
jnt_round),
DIST_PRECISION_BITS - 1);
result =
_mm_srai_epi32(_mm_add_epi32(_mm_loadu_si128((__m128i *)dst_x),
_mm_mullo_epi32(subbed, bck_offset)),
DIST_PRECISION_BITS - 1);
} else {
result = _mm_mullo_epi32(subbed, fwd_offset);
}
......@@ -466,8 +461,7 @@ static void vfilter8(const int32_t *src, int src_stride, int32_t *dst,
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
dst[y * dst_stride + x] >>= (DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
}
......
......@@ -676,8 +676,6 @@ void av1_jnt_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
if (!(w % 16)) {
for (i = 0; i < h; ++i) {
......@@ -697,26 +695,22 @@ void av1_jnt_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_2, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
d32_2 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_2 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_3, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
d32_3 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_3 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
......@@ -763,14 +757,12 @@ void av1_jnt_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_1 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
......@@ -806,8 +798,7 @@ void av1_jnt_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
}
......@@ -838,8 +829,7 @@ void av1_jnt_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
d32_0 = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
}
......
......@@ -47,9 +47,6 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride,
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set_epi32(w0, w0, w0, w0);
const __m128i wt1 = _mm_set_epi32(w1, w1, w1, w1);
const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
const __m128i jnt_r = _mm_set_epi32(jnt_round_const, jnt_round_const,
jnt_round_const, jnt_round_const);
/* Horizontal filter */
{
......@@ -207,18 +204,14 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride,
if (do_average) {
_mm_storeu_si128(
p + 0, _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128(p + 0),
_mm_mullo_epi32(
res_lo_round, wt1)),
jnt_r),
_mm_add_epi32(_mm_loadu_si128(p + 0),
_mm_mullo_epi32(res_lo_round, wt1)),
DIST_PRECISION_BITS - 1));
_mm_storeu_si128(
p + 1, _mm_srai_epi32(
_mm_add_epi32(_mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(
res_hi_round, wt1)),
jnt_r),
_mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi_round, wt1)),
DIST_PRECISION_BITS - 1));
} else {
_mm_storeu_si128(p + 0, _mm_mullo_epi32(res_lo_round, wt0));
......
......@@ -39,8 +39,6 @@ void av1_highbd_jnt_convolve_2d_sse4_1(
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
// Check that, even with 12-bit input, the intermediate values will fit
// into an unsigned 15-bit intermediate array.
......@@ -202,12 +200,10 @@ void av1_highbd_jnt_convolve_2d_sse4_1(
_mm_loadu_si128(p + 0), _mm_mullo_epi32(res_lo_round, wt1));
const __m128i jnt_sum_hi = _mm_add_epi32(
_mm_loadu_si128(p + 1), _mm_mullo_epi32(res_hi_round, wt1));
const __m128i jnt_round_res_lo = _mm_add_epi32(jnt_sum_lo, jnt_r);
const __m128i jnt_round_res_hi = _mm_add_epi32(jnt_sum_hi, jnt_r);
const __m128i final_lo =
_mm_srai_epi32(jnt_round_res_lo, DIST_PRECISION_BITS - 1);
_mm_srai_epi32(jnt_sum_lo, DIST_PRECISION_BITS - 1);
const __m128i final_hi =
_mm_srai_epi32(jnt_round_res_hi, DIST_PRECISION_BITS - 1);
_mm_srai_epi32(jnt_sum_hi, DIST_PRECISION_BITS - 1);
_mm_storeu_si128(p + 0, final_lo);
_mm_storeu_si128(p + 1, final_hi);
......
......@@ -42,8 +42,6 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref,
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
#endif // CONFIG_JNT_COMP
/* Note: For this code to work, the left/right frame borders need to be
......@@ -320,8 +318,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref,
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1));
const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
res_lo = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_lo = _mm_mullo_epi32(res_lo, wt0);
}
......@@ -345,8 +342,7 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref,
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1));
const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
res_hi = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_hi = _mm_mullo_epi32(res_hi, wt0);
}
......
......@@ -39,8 +39,6 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width,
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
#endif // CONFIG_JNT_COMP
/* Note: For this code to work, the left/right frame borders need to be
......@@ -317,8 +315,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width,
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p),
_mm_mullo_epi32(res_lo, wt1));
const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
res_lo = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
res_lo = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_lo = _mm_mullo_epi32(res_lo, wt0);
}
......@@ -340,8 +337,7 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width,
if (comp_avg) {
const __m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_mullo_epi32(res_hi, wt1));
const __m128i sum_round = _mm_add_epi32(sum, jnt_r);
res_hi = _mm_srai_epi32(sum_round, DIST_PRECISION_BITS - 1);
res_hi = _mm_srai_epi32(sum, DIST_PRECISION_BITS - 1);
} else {
res_hi = _mm_mullo_epi32(res_hi, wt0);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment