Commit 3f2b57d8 authored by Cheng Chen's avatar Cheng Chen

Optimize av1_jnt_convolve_2d_copy function

With shift, convolve copy no longer needs 32-bit multiplication of
two 8-bit numbers. Thus we can implement it with sse2 instead of
sse4.

Change-Id: I63e8ba414383a24f820bad4a6c607f222ec40ec2
parent 9ad440f5
......@@ -602,7 +602,7 @@ if (aom_config("CONFIG_CONVOLVE_ROUND") eq "yes") {
if (aom_config("CONFIG_COMPOUND_ROUND") ne "yes") {
add_proto qw/void av1_jnt_convolve_2d_copy/, "const uint8_t *src, int src_stride, CONV_BUF_TYPE *dst, int dst_stride, int w, int h, InterpFilterParams *filter_params_x, InterpFilterParams *filter_params_y, const int subpel_x_q4, const int subpel_y_q4, ConvolveParams *conv_params";
specialize qw/av1_jnt_convolve_2d_copy sse4_1/;
specialize qw/av1_jnt_convolve_2d_copy sse2/;
}
}
......
......@@ -684,7 +684,7 @@ void av1_convolve_2d_copy_c(const uint8_t *src, int src_stride,
for (int y = 0; y < h; ++y) {
for (int x = 0; x < w; ++x) {
CONV_BUF_TYPE res = (1 << bits) * src[y * src_stride + x];
CONV_BUF_TYPE res = src[y * src_stride + x] << bits;
if (conv_params->do_average)
dst[y * dst_stride + x] += res;
else
......@@ -776,12 +776,14 @@ void av1_jnt_convolve_2d_copy_c(const uint8_t *src, int src_stride,
CONV_BUF_TYPE res = (1 << bits) * src[y * src_stride + x];
if (conv_params->use_jnt_comp_avg) {
if (conv_params->do_average) {
dst[y * dst_stride + x] += res * conv_params->bck_offset;
dst[y * dst_stride + x] +=
(src[y * src_stride + x] * conv_params->bck_offset) << bits;
dst[y * dst_stride + x] = ROUND_POWER_OF_TWO(dst[y * dst_stride + x],
DIST_PRECISION_BITS - 1);
} else {
dst[y * dst_stride + x] = res * conv_params->fwd_offset;
dst[y * dst_stride + x] =
(src[y * src_stride + x] * conv_params->fwd_offset) << bits;
}
} else {
if (conv_params->do_average)
......
......@@ -385,16 +385,17 @@ void av1_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
InterpFilterParams *filter_params_y,
const int subpel_x_q4, const int subpel_y_q4,
ConvolveParams *conv_params) {
(void)filter_params_x;
(void)filter_params_y;
(void)subpel_x_q4;
(void)subpel_y_q4;
const int bits =
FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
const int do_average = conv_params->do_average;
const __m128i zero = _mm_setzero_si128();
const __m128i left_shift = _mm_cvtsi32_si128(bits);
int i, j;
(void)filter_params_x;
(void)filter_params_y;
(void)subpel_x_q4;
(void)subpel_y_q4;
if (!(w % 16)) {
for (i = 0; i < h; ++i) {
......@@ -489,4 +490,212 @@ void av1_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
}
}
}
#endif
#if CONFIG_JNT_COMP
void av1_jnt_convolve_2d_copy_sse2(const uint8_t *src, int src_stride,
CONV_BUF_TYPE *dst, int dst_stride, int w,
int h, InterpFilterParams *filter_params_x,
InterpFilterParams *filter_params_y,
const int subpel_x_q4, const int subpel_y_q4,
ConvolveParams *conv_params) {
(void)filter_params_x;
(void)filter_params_y;
(void)subpel_x_q4;
(void)subpel_y_q4;
const int bits =
FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
const int do_average = conv_params->do_average;
const __m128i zero = _mm_setzero_si128();
const __m128i left_shift = _mm_cvtsi32_si128(bits);
int i, j;
const int w0 = conv_params->fwd_offset;
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
if (!(w % 16)) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 16) {
const __m128i d8 = _mm_loadu_si128((__m128i *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
const __m128i d16_1 = _mm_unpackhi_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
__m128i d32_1 = _mm_unpackhi_epi16(d16_0, zero);
__m128i d32_2 = _mm_unpacklo_epi16(d16_1, zero);
__m128i d32_3 = _mm_unpackhi_epi16(d16_1, zero);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_2, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
d32_2 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_3, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
d32_3 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
d32_2 = _mm_sll_epi32(_mm_mullo_epi16(d32_2, wt0), left_shift);
d32_3 = _mm_sll_epi32(_mm_mullo_epi16(d32_3, wt0), left_shift);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
_mm_sll_epi32(d32_0, left_shift));
d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_sll_epi32(d32_1, left_shift));
d32_2 = _mm_add_epi32(_mm_loadu_si128(p + 2),
_mm_sll_epi32(d32_2, left_shift));
d32_3 = _mm_add_epi32(_mm_loadu_si128(p + 3),
_mm_sll_epi32(d32_3, left_shift));
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
d32_1 = _mm_sll_epi32(d32_1, left_shift);
d32_2 = _mm_sll_epi32(d32_2, left_shift);
d32_3 = _mm_sll_epi32(d32_3, left_shift);
}
}
_mm_storeu_si128(p + 0, d32_0);
_mm_storeu_si128(p + 1, d32_1);
_mm_storeu_si128(p + 2, d32_2);
_mm_storeu_si128(p + 3, d32_3);
}
src += src_stride;
dst += dst_stride;
}
} else if (!(w % 8)) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 8) {
const __m128i d8 = _mm_loadl_epi64((__m128i *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
__m128i d32_1 = _mm_unpackhi_epi16(d16_0, zero);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
mul = _mm_mullo_epi16(d32_1, wt1);
weighted_res = _mm_sll_epi32(mul, left_shift);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
d32_1 = _mm_sll_epi32(_mm_mullo_epi16(d32_1, wt0), left_shift);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
_mm_sll_epi32(d32_0, left_shift));
d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1),
_mm_sll_epi32(d32_1, left_shift));
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
d32_1 = _mm_sll_epi32(d32_1, left_shift);
}
}
_mm_storeu_si128(p + 0, d32_0);
_mm_storeu_si128(p + 1, d32_1);
}
src += src_stride;
dst += dst_stride;
}
} else if (!(w % 4)) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 4) {
const __m128i d8 = _mm_loadl_epi64((__m128i *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0),
_mm_sll_epi32(d32_0, left_shift));
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
}
}
_mm_storeu_si128(p, d32_0);
}
src += src_stride;
dst += dst_stride;
}
} else {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 2) {
const __m128i d8 = _mm_cvtsi32_si128(*(const int *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i mul = _mm_mullo_epi16(d32_0, wt1);
__m128i weighted_res = _mm_sll_epi32(mul, left_shift);
__m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_sll_epi32(_mm_mullo_epi16(d32_0, wt0), left_shift);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadl_epi64(p),
_mm_sll_epi32(d32_0, left_shift));
} else {
d32_0 = _mm_sll_epi32(d32_0, left_shift);
}
}
_mm_storel_epi64(p, d32_0);
}
src += src_stride;
dst += dst_stride;
}
}
}
#endif // CONFIG_JNT_COMP
#endif // CONFIG_COMPOUND_ROUND
......@@ -450,194 +450,5 @@ void av1_jnt_convolve_2d_sse4_1(const uint8_t *src, int src_stride,
}
}
}
void av1_jnt_convolve_2d_copy_sse4_1(const uint8_t *src, int src_stride,
CONV_BUF_TYPE *dst, int dst_stride, int w,
int h, InterpFilterParams *filter_params_x,
InterpFilterParams *filter_params_y,
const int subpel_x_q4,
const int subpel_y_q4,
ConvolveParams *conv_params) {
const int bits =
FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
const int do_average = conv_params->do_average;
const __m128i zero = _mm_setzero_si128();
const __m128i left_shift = _mm_cvtsi32_si128(bits);
int i, j;
(void)filter_params_x;
(void)filter_params_y;
(void)subpel_x_q4;
(void)subpel_y_q4;
const int w0 = conv_params->fwd_offset;
const int w1 = conv_params->bck_offset;
const __m128i wt0 = _mm_set1_epi32(w0);
const __m128i wt1 = _mm_set1_epi32(w1);
const int jnt_round_const = 1 << (DIST_PRECISION_BITS - 2);
const __m128i jnt_r = _mm_set1_epi32(jnt_round_const);
if (!(w % 16)) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 16) {
const __m128i d8 = _mm_loadu_si128((__m128i *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
const __m128i d16_1 = _mm_unpackhi_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
__m128i d32_1 = _mm_unpackhi_epi16(d16_0, zero);
__m128i d32_2 = _mm_unpacklo_epi16(d16_1, zero);
__m128i d32_3 = _mm_unpackhi_epi16(d16_1, zero);
d32_0 = _mm_sll_epi32(d32_0, left_shift);
d32_1 = _mm_sll_epi32(d32_1, left_shift);
d32_2 = _mm_sll_epi32(d32_2, left_shift);
d32_3 = _mm_sll_epi32(d32_3, left_shift);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i weighted_res = _mm_mullo_epi32(d32_0, wt1);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
weighted_res = _mm_mullo_epi32(d32_1, wt1);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
weighted_res = _mm_mullo_epi32(d32_2, wt1);
sum = _mm_add_epi32(_mm_loadu_si128(p + 2), weighted_res);
d32_2 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
weighted_res = _mm_mullo_epi32(d32_3, wt1);
sum = _mm_add_epi32(_mm_loadu_si128(p + 3), weighted_res);
d32_3 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_mullo_epi32(d32_0, wt0);
d32_1 = _mm_mullo_epi32(d32_1, wt0);
d32_2 = _mm_mullo_epi32(d32_2, wt0);
d32_3 = _mm_mullo_epi32(d32_3, wt0);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0), d32_0);
d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1), d32_1);
d32_2 = _mm_add_epi32(_mm_loadu_si128(p + 2), d32_2);
d32_3 = _mm_add_epi32(_mm_loadu_si128(p + 3), d32_3);
}
}
_mm_storeu_si128(p + 0, d32_0);
_mm_storeu_si128(p + 1, d32_1);
_mm_storeu_si128(p + 2, d32_2);
_mm_storeu_si128(p + 3, d32_3);
}
src += src_stride;
dst += dst_stride;
}
} else if (!(w % 8)) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 8) {
const __m128i d8 = _mm_loadl_epi64((__m128i *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
__m128i d32_1 = _mm_unpackhi_epi16(d16_0, zero);
d32_0 = _mm_sll_epi32(d32_0, left_shift);
d32_1 = _mm_sll_epi32(d32_1, left_shift);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i weighted_res = _mm_mullo_epi32(d32_0, wt1);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
weighted_res = _mm_mullo_epi32(d32_1, wt1);
sum = _mm_add_epi32(_mm_loadu_si128(p + 1), weighted_res);
d32_1 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_mullo_epi32(d32_0, wt0);
d32_1 = _mm_mullo_epi32(d32_1, wt0);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0), d32_0);
d32_1 = _mm_add_epi32(_mm_loadu_si128(p + 1), d32_1);
}
}
_mm_storeu_si128(p + 0, d32_0);
_mm_storeu_si128(p + 1, d32_1);
}
src += src_stride;
dst += dst_stride;
}
} else if (!(w % 4)) {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 4) {
const __m128i d8 = _mm_loadl_epi64((__m128i *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
d32_0 = _mm_sll_epi32(d32_0, left_shift);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i weighted_res = _mm_mullo_epi32(d32_0, wt1);
__m128i sum = _mm_add_epi32(_mm_loadu_si128(p + 0), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_mullo_epi32(d32_0, wt0);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadu_si128(p + 0), d32_0);
}
}
_mm_storeu_si128(p, d32_0);
}
src += src_stride;
dst += dst_stride;
}
} else {
for (i = 0; i < h; ++i) {
for (j = 0; j < w; j += 2) {
const __m128i d8 = _mm_cvtsi32_si128(*(const int *)&src[j]);
const __m128i d16_0 = _mm_unpacklo_epi8(d8, zero);
__m128i d32_0 = _mm_unpacklo_epi16(d16_0, zero);
d32_0 = _mm_sll_epi32(d32_0, left_shift);
__m128i *const p = (__m128i *)&dst[j];
if (conv_params->use_jnt_comp_avg) {
if (do_average) {
__m128i weighted_res = _mm_mullo_epi32(d32_0, wt1);
__m128i sum = _mm_add_epi32(_mm_loadl_epi64(p), weighted_res);
d32_0 = _mm_srai_epi32(_mm_add_epi32(sum, jnt_r),
DIST_PRECISION_BITS - 1);
} else {
d32_0 = _mm_mullo_epi32(d32_0, wt0);
}
} else {
if (do_average) {
d32_0 = _mm_add_epi32(_mm_loadl_epi64(p), d32_0);
}
}
_mm_storel_epi64(p, d32_0);
}
src += src_stride;
dst += dst_stride;
}
}
}
#endif // CONFIG_COMPOUND_ROUND
#endif // CONFIG_JNT_COMP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment