diff --git a/aom_dsp/aom_dsp.mk b/aom_dsp/aom_dsp.mk index 07fbe02277c368e0bde810b4b2fb0da97b15d5d9..a06326da746691fc005d7198babc058f3bcbc871 100644 --- a/aom_dsp/aom_dsp.mk +++ b/aom_dsp/aom_dsp.mk @@ -349,6 +349,9 @@ ifeq ($(CONFIG_MOTION_VAR),yes) DSP_SRCS-$(HAVE_SSE4_1) += x86/obmc_sad_sse4.c DSP_SRCS-$(HAVE_SSE4_1) += x86/obmc_variance_sse4.c endif #CONFIG_MOTION_VAR +ifeq ($(CONFIG_EXT_PARTITION),yes) +DSP_SRCS-$(HAVE_AVX2) += x86/sad_impl_avx2.c +endif endif #CONFIG_AV1_ENCODER DSP_SRCS-$(HAVE_SSE) += x86/sad4d_sse2.asm diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl index 7b2b5fa85ffc3738aa50f674600ebcb40d072d9b..6d230d20210abb2bb6b51720a3c37be519a397d0 100644 --- a/aom_dsp/aom_dsp_rtcd_defs.pl +++ b/aom_dsp/aom_dsp_rtcd_defs.pl @@ -1167,9 +1167,9 @@ foreach (@block_sizes) { add_proto qw/unsigned int/, "aom_sad${w}x${h}_avg", "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred"; } -specialize qw/aom_sad128x128 sse2/; -specialize qw/aom_sad128x64 sse2/; -specialize qw/aom_sad64x128 sse2/; +specialize qw/aom_sad128x128 avx2 sse2/; +specialize qw/aom_sad128x64 avx2 sse2/; +specialize qw/aom_sad64x128 avx2 sse2/; specialize qw/aom_sad64x64 avx2 neon msa sse2/; specialize qw/aom_sad64x32 avx2 msa sse2/; specialize qw/aom_sad32x64 avx2 msa sse2/; diff --git a/aom_dsp/x86/sad_impl_avx2.c b/aom_dsp/x86/sad_impl_avx2.c new file mode 100644 index 0000000000000000000000000000000000000000..9183fa666a6b5251a7a19d313d4c78d7aaf15fa1 --- /dev/null +++ b/aom_dsp/x86/sad_impl_avx2.c @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include +#include "./aom_dsp_rtcd.h" + +static unsigned int sad32x32(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride) { + __m256i s1, s2, r1, r2; + __m256i sum = _mm256_setzero_si256(); + __m128i sum_i128; + int i; + + for (i = 0; i < 16; ++i) { + r1 = _mm256_loadu_si256((__m256i const *)ref_ptr); + r2 = _mm256_loadu_si256((__m256i const *)(ref_ptr + ref_stride)); + s1 = _mm256_sad_epu8(r1, _mm256_loadu_si256((__m256i const *)src_ptr)); + s2 = _mm256_sad_epu8( + r2, _mm256_loadu_si256((__m256i const *)(src_ptr + src_stride))); + sum = _mm256_add_epi32(sum, _mm256_add_epi32(s1, s2)); + ref_ptr += ref_stride << 1; + src_ptr += src_stride << 1; + } + + sum = _mm256_add_epi32(sum, _mm256_srli_si256(sum, 8)); + sum_i128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 1), + _mm256_castsi256_si128(sum)); + return _mm_cvtsi128_si32(sum_i128); +} + +static unsigned int sad64x32(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride) { + unsigned int half_width = 32; + uint32_t sum = sad32x32(src_ptr, src_stride, ref_ptr, ref_stride); + src_ptr += half_width; + ref_ptr += half_width; + sum += sad32x32(src_ptr, src_stride, ref_ptr, ref_stride); + return sum; +} + +static unsigned int sad64x64(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride) { + uint32_t sum = sad64x32(src_ptr, src_stride, ref_ptr, ref_stride); + src_ptr += src_stride << 5; + ref_ptr += ref_stride << 5; + sum += sad64x32(src_ptr, src_stride, ref_ptr, ref_stride); + return sum; +} + +unsigned int aom_sad128x64_avx2(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride) { + unsigned int half_width = 64; + uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride); + src_ptr += half_width; + ref_ptr += half_width; + sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride); + return sum; +} + +unsigned int aom_sad64x128_avx2(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride) { + uint32_t sum = sad64x64(src_ptr, src_stride, ref_ptr, ref_stride); + src_ptr += src_stride << 6; + ref_ptr += ref_stride << 6; + sum += sad64x64(src_ptr, src_stride, ref_ptr, ref_stride); + return sum; +} + +unsigned int aom_sad128x128_avx2(const uint8_t *src_ptr, int src_stride, + const uint8_t *ref_ptr, int ref_stride) { + uint32_t sum = aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride); + src_ptr += src_stride << 6; + ref_ptr += ref_stride << 6; + sum += aom_sad128x64_avx2(src_ptr, src_stride, ref_ptr, ref_stride); + return sum; +} diff --git a/test/sad_test.cc b/test/sad_test.cc index b7776583d6c52d46a7103d6f79d457c8adea1b78..4ccf0f4cccddafdbb90767945712eee7ad5c20f3 100644 --- a/test/sad_test.cc +++ b/test/sad_test.cc @@ -930,6 +930,11 @@ INSTANTIATE_TEST_CASE_P(SSE2, SADx4Test, ::testing::ValuesIn(x4d_sse2_tests)); #if HAVE_AVX2 const SadMxNParam avx2_tests[] = { +#if CONFIG_EXT_PARTITION + make_tuple(64, 128, &aom_sad64x128_avx2, -1), + make_tuple(128, 64, &aom_sad128x64_avx2, -1), + make_tuple(128, 128, &aom_sad128x128_avx2, -1), +#endif make_tuple(64, 64, &aom_sad64x64_avx2, -1), make_tuple(64, 32, &aom_sad64x32_avx2, -1), make_tuple(32, 64, &aom_sad32x64_avx2, -1),