Commit 3c74dd45 authored by Peng Bin's avatar Peng Bin Committed by Zoe Liu

Add aom_comp_mask_pred_avx2

1. Add AVX2 implementation of aom_comp_mask_pred.
2. For width 8 still use ssse3 version.
3. For other widths(16,32), AVX2 version is 1.2x-2.0x faster
than ssse3 version

Change-Id: I80acc1be54ab21a52f7847e91b1299853add757c
parent b80466f6
......@@ -392,6 +392,7 @@ if (CONFIG_AV1_ENCODER)
set(AOM_DSP_ENCODER_INTRIN_SSSE3
${AOM_DSP_ENCODER_INTRIN_SSSE3}
"${AOM_ROOT}/aom_dsp/x86/masked_sad_intrin_ssse3.c"
"${AOM_ROOT}/aom_dsp/x86/masked_variance_intrin_ssse3.h"
"${AOM_ROOT}/aom_dsp/x86/masked_variance_intrin_ssse3.c")
set(AOM_DSP_ENCODER_INTRIN_SSE2
......
......@@ -1784,7 +1784,7 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
add_proto qw/void aom_comp_mask_pred/, "uint8_t *comp_pred, const uint8_t *pred, int width, int height, const uint8_t *ref, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
specialize qw/aom_comp_mask_pred ssse3/;
specialize qw/aom_comp_mask_pred ssse3 avx2/;
add_proto qw/void aom_highbd_comp_mask_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask";
add_proto qw/void aom_highbd_comp_mask_upsampled_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, int subsample_x_q3, int subsample_y_q3, const uint8_t *ref8, int ref_stride, const uint8_t *mask, int mask_stride, int invert_mask, int bd";
......
......@@ -15,11 +15,12 @@
#include "./aom_config.h"
#include "./aom_dsp_rtcd.h"
#include "aom_dsp/blend.h"
#include "aom/aom_integer.h"
#include "aom_ports/mem.h"
#include "aom_dsp/aom_filter.h"
#include "aom_dsp/blend.h"
#include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
#include "aom_dsp/x86/synonyms.h"
#include "aom_ports/mem.h"
// For width a multiple of 16
static void bilinear_filter(const uint8_t *src, int src_stride, int xoffset,
......@@ -1040,32 +1041,6 @@ static void highbd_masked_variance4xh(const uint16_t *src_ptr, int src_stride,
*sse = _mm_cvtsi128_si32(_mm_srli_si128(sum, 4));
}
static INLINE void comp_mask_pred_16_ssse3(const uint8_t *src0,
const uint8_t *src1,
const uint8_t *mask, uint8_t *dst) {
const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
const __m128i round_offset =
_mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
const __m128i sA0 = _mm_lddqu_si128((const __m128i *)(src0));
const __m128i sA1 = _mm_lddqu_si128((const __m128i *)(src1));
const __m128i aA = _mm_load_si128((const __m128i *)(mask));
const __m128i maA = _mm_sub_epi8(alpha_max, aA);
const __m128i ssAL = _mm_unpacklo_epi8(sA0, sA1);
const __m128i aaAL = _mm_unpacklo_epi8(aA, maA);
const __m128i ssAH = _mm_unpackhi_epi8(sA0, sA1);
const __m128i aaAH = _mm_unpackhi_epi8(aA, maA);
const __m128i blendAL = _mm_maddubs_epi16(ssAL, aaAL);
const __m128i blendAH = _mm_maddubs_epi16(ssAH, aaAH);
const __m128i roundAL = _mm_mulhrs_epi16(blendAL, round_offset);
const __m128i roundAH = _mm_mulhrs_epi16(blendAH, round_offset);
_mm_store_si128((__m128i *)dst, _mm_packus_epi16(roundAL, roundAH));
}
void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
int width, int height, const uint8_t *ref,
int ref_stride, const uint8_t *mask,
......@@ -1074,46 +1049,11 @@ void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
const uint8_t *src1 = invert_mask ? ref : pred;
const int stride0 = invert_mask ? width : ref_stride;
const int stride1 = invert_mask ? ref_stride : width;
const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
const __m128i round_offset =
_mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
assert(height % 2 == 0);
assert(width % 8 == 0);
int i = 0;
if (width == 8) {
do {
// odd line A
const __m128i sA0 = _mm_loadl_epi64((const __m128i *)(src0));
const __m128i sA1 = _mm_loadl_epi64((const __m128i *)(src1));
const __m128i aA = _mm_loadl_epi64((const __m128i *)(mask));
// even line B
const __m128i sB0 = _mm_loadl_epi64((const __m128i *)(src0 + stride0));
const __m128i sB1 = _mm_loadl_epi64((const __m128i *)(src1 + stride1));
const __m128i a = _mm_castps_si128(_mm_loadh_pi(
_mm_castsi128_ps(aA), (const __m64 *)(mask + mask_stride)));
const __m128i ssA = _mm_unpacklo_epi8(sA0, sA1);
const __m128i ssB = _mm_unpacklo_epi8(sB0, sB1);
const __m128i ma = _mm_sub_epi8(alpha_max, a);
const __m128i aaA = _mm_unpacklo_epi8(a, ma);
const __m128i aaB = _mm_unpackhi_epi8(a, ma);
const __m128i blendA = _mm_maddubs_epi16(ssA, aaA);
const __m128i blendB = _mm_maddubs_epi16(ssB, aaB);
const __m128i roundA = _mm_mulhrs_epi16(blendA, round_offset);
const __m128i roundB = _mm_mulhrs_epi16(blendB, round_offset);
const __m128i round = _mm_packus_epi16(roundA, roundB);
// comp_pred's stride == width == 8
_mm_store_si128((__m128i *)(comp_pred), round);
comp_pred += (width << 1);
src0 += (stride0 << 1);
src1 += (stride1 << 1);
mask += (mask_stride << 1);
i += 2;
} while (i < height);
comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
mask, mask_stride);
} else if (width == 16) {
do {
comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
......@@ -1126,6 +1066,7 @@ void aom_comp_mask_pred_ssse3(uint8_t *comp_pred, const uint8_t *pred,
i += 2;
} while (i < height);
} else { // width == 32
assert(width == 32);
do {
comp_mask_pred_16_ssse3(src0, src1, mask, comp_pred);
comp_mask_pred_16_ssse3(src0 + 16, src1 + 16, mask + 16, comp_pred + 16);
......
/*
* Copyright (c) 2018, Alliance for Open Media. All rights reserved
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#ifndef _AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H
#define _AOM_DSP_X86_MASKED_VARIANCE_INTRIN_SSSE3_H
#include <stdlib.h>
#include <string.h>
#include <tmmintrin.h>
#include "./aom_config.h"
#include "./aom_dsp_rtcd.h"
#include "aom_dsp/blend.h"
static INLINE void comp_mask_pred_16_ssse3(const uint8_t *src0,
const uint8_t *src1,
const uint8_t *mask, uint8_t *dst) {
const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
const __m128i round_offset =
_mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
const __m128i sA0 = _mm_lddqu_si128((const __m128i *)(src0));
const __m128i sA1 = _mm_lddqu_si128((const __m128i *)(src1));
const __m128i aA = _mm_load_si128((const __m128i *)(mask));
const __m128i maA = _mm_sub_epi8(alpha_max, aA);
const __m128i ssAL = _mm_unpacklo_epi8(sA0, sA1);
const __m128i aaAL = _mm_unpacklo_epi8(aA, maA);
const __m128i ssAH = _mm_unpackhi_epi8(sA0, sA1);
const __m128i aaAH = _mm_unpackhi_epi8(aA, maA);
const __m128i blendAL = _mm_maddubs_epi16(ssAL, aaAL);
const __m128i blendAH = _mm_maddubs_epi16(ssAH, aaAH);
const __m128i roundAL = _mm_mulhrs_epi16(blendAL, round_offset);
const __m128i roundAH = _mm_mulhrs_epi16(blendAH, round_offset);
_mm_store_si128((__m128i *)dst, _mm_packus_epi16(roundAL, roundAH));
}
static INLINE void comp_mask_pred_8_ssse3(uint8_t *comp_pred, int height,
const uint8_t *src0, int stride0,
const uint8_t *src1, int stride1,
const uint8_t *mask,
int mask_stride) {
int i = 0;
const __m128i alpha_max = _mm_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
const __m128i round_offset =
_mm_set1_epi16(1 << (15 - AOM_BLEND_A64_ROUND_BITS));
do {
// odd line A
const __m128i sA0 = _mm_loadl_epi64((const __m128i *)(src0));
const __m128i sA1 = _mm_loadl_epi64((const __m128i *)(src1));
const __m128i aA = _mm_loadl_epi64((const __m128i *)(mask));
// even line B
const __m128i sB0 = _mm_loadl_epi64((const __m128i *)(src0 + stride0));
const __m128i sB1 = _mm_loadl_epi64((const __m128i *)(src1 + stride1));
const __m128i a = _mm_castps_si128(_mm_loadh_pi(
_mm_castsi128_ps(aA), (const __m64 *)(mask + mask_stride)));
const __m128i ssA = _mm_unpacklo_epi8(sA0, sA1);
const __m128i ssB = _mm_unpacklo_epi8(sB0, sB1);
const __m128i ma = _mm_sub_epi8(alpha_max, a);
const __m128i aaA = _mm_unpacklo_epi8(a, ma);
const __m128i aaB = _mm_unpackhi_epi8(a, ma);
const __m128i blendA = _mm_maddubs_epi16(ssA, aaA);
const __m128i blendB = _mm_maddubs_epi16(ssB, aaB);
const __m128i roundA = _mm_mulhrs_epi16(blendA, round_offset);
const __m128i roundB = _mm_mulhrs_epi16(blendB, round_offset);
const __m128i round = _mm_packus_epi16(roundA, roundB);
// comp_pred's stride == width == 8
_mm_store_si128((__m128i *)(comp_pred), round);
comp_pred += (8 << 1);
src0 += (stride0 << 1);
src1 += (stride1 << 1);
mask += (mask_stride << 1);
i += 2;
} while (i < height);
}
#endif
......@@ -11,6 +11,7 @@
#include <immintrin.h>
#include "./aom_dsp_rtcd.h"
#include "aom_dsp/x86/masked_variance_intrin_ssse3.h"
typedef void (*get_var_avx2)(const uint8_t *src, int src_stride,
const uint8_t *ref, int ref_stride,
......@@ -190,3 +191,87 @@ unsigned int aom_sub_pixel_avg_variance32x32_avx2(
_mm256_zeroupper();
return variance;
}
static INLINE __m256i mm256_loadu2(const uint8_t *p0, const uint8_t *p1) {
const __m256i d =
_mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)p1));
return _mm256_insertf128_si256(d, _mm_loadu_si128((const __m128i *)p0), 1);
}
static INLINE void comp_mask_pred_line_avx2(const __m256i s0, const __m256i s1,
const __m256i a,
uint8_t *comp_pred) {
const __m256i alpha_max = _mm256_set1_epi8(AOM_BLEND_A64_MAX_ALPHA);
const int16_t round_bits = 15 - AOM_BLEND_A64_ROUND_BITS;
const __m256i round_offset = _mm256_set1_epi16(1 << (round_bits));
const __m256i ma = _mm256_sub_epi8(alpha_max, a);
const __m256i ssAL = _mm256_unpacklo_epi8(s0, s1);
const __m256i aaAL = _mm256_unpacklo_epi8(a, ma);
const __m256i ssAH = _mm256_unpackhi_epi8(s0, s1);
const __m256i aaAH = _mm256_unpackhi_epi8(a, ma);
const __m256i blendAL = _mm256_maddubs_epi16(ssAL, aaAL);
const __m256i blendAH = _mm256_maddubs_epi16(ssAH, aaAH);
const __m256i roundAL = _mm256_mulhrs_epi16(blendAL, round_offset);
const __m256i roundAH = _mm256_mulhrs_epi16(blendAH, round_offset);
const __m256i roundA = _mm256_packus_epi16(roundAL, roundAH);
_mm256_storeu_si256((__m256i *)(comp_pred), roundA);
}
void aom_comp_mask_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
int height, const uint8_t *ref, int ref_stride,
const uint8_t *mask, int mask_stride,
int invert_mask) {
int i = 0;
const uint8_t *src0 = invert_mask ? pred : ref;
const uint8_t *src1 = invert_mask ? ref : pred;
const int stride0 = invert_mask ? width : ref_stride;
const int stride1 = invert_mask ? ref_stride : width;
if (width == 8) {
comp_mask_pred_8_ssse3(comp_pred, height, src0, stride0, src1, stride1,
mask, mask_stride);
} else if (width == 16) {
do {
const __m256i sA0 = mm256_loadu2(src0 + stride0, src0);
const __m256i sA1 = mm256_loadu2(src1 + stride1, src1);
const __m256i aA = mm256_loadu2(mask + mask_stride, mask);
src0 += (stride0 << 1);
src1 += (stride1 << 1);
mask += (mask_stride << 1);
const __m256i sB0 = mm256_loadu2(src0 + stride0, src0);
const __m256i sB1 = mm256_loadu2(src1 + stride1, src1);
const __m256i aB = mm256_loadu2(mask + mask_stride, mask);
src0 += (stride0 << 1);
src1 += (stride1 << 1);
mask += (mask_stride << 1);
// comp_pred's stride == width == 16
comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
comp_pred += (16 << 2);
i += 4;
} while (i < height);
} else { // for width == 32
do {
const __m256i sA0 = _mm256_lddqu_si256((const __m256i *)(src0));
const __m256i sA1 = _mm256_lddqu_si256((const __m256i *)(src1));
const __m256i aA = _mm256_lddqu_si256((const __m256i *)(mask));
const __m256i sB0 = _mm256_lddqu_si256((const __m256i *)(src0 + stride0));
const __m256i sB1 = _mm256_lddqu_si256((const __m256i *)(src1 + stride1));
const __m256i aB =
_mm256_lddqu_si256((const __m256i *)(mask + mask_stride));
comp_mask_pred_line_avx2(sA0, sA1, aA, comp_pred);
comp_mask_pred_line_avx2(sB0, sB1, aB, comp_pred + 32);
comp_pred += (32 << 1);
src0 += (stride0 << 1);
src1 += (stride1 << 1);
mask += (mask_stride << 1);
i += 2;
} while (i < height);
}
}
......@@ -11,7 +11,6 @@
#include <cstdlib>
#include <new>
#include <vector>
#include "./aom_config.h"
#include "./aom_dsp_rtcd.h"
......@@ -28,15 +27,13 @@
#include "test/util.h"
#include "third_party/googletest/src/googletest/include/gtest/gtest.h"
using std::vector;
namespace AV1CompMaskVariance {
typedef void (*comp_mask_pred_func)(uint8_t *comp_pred, const uint8_t *pred,
int width, int height, const uint8_t *ref,
int ref_stride, const uint8_t *mask,
int mask_stride, int invert_mask);
const BLOCK_SIZE valid_bsize[] = {
const BLOCK_SIZE kValidBlockSize[] = {
BLOCK_8X8, BLOCK_8X16, BLOCK_8X32, BLOCK_16X8, BLOCK_16X16,
BLOCK_16X32, BLOCK_32X8, BLOCK_32X16, BLOCK_32X32,
};
......@@ -53,12 +50,13 @@ class AV1CompMaskVarianceTest
protected:
void RunCheckOutput(comp_mask_pred_func test_impl, BLOCK_SIZE bsize, int inv);
void RunSpeedTest(comp_mask_pred_func test_impl, BLOCK_SIZE bsize);
bool CheckResult(int w, int h) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int idx = i * w + j;
bool CheckResult(int width, int height) {
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
const int idx = y * width + x;
if (comp_pred1_[idx] != comp_pred2_[idx]) {
printf("%dx%d mismatch @%d(%d,%d) ", w, h, idx, i, j);
printf("%dx%d mismatch @%d(%d,%d) ", width, height, idx, y, x);
printf("%d != %d ", comp_pred1_[idx], comp_pred2_[idx]);
return false;
}
}
......@@ -160,7 +158,14 @@ TEST_P(AV1CompMaskVarianceTest, DISABLED_Speed) {
INSTANTIATE_TEST_CASE_P(
SSSE3, AV1CompMaskVarianceTest,
::testing::Combine(::testing::Values(&aom_comp_mask_pred_ssse3),
::testing::ValuesIn(valid_bsize)));
::testing::ValuesIn(kValidBlockSize)));
#endif
#if HAVE_AVX2
INSTANTIATE_TEST_CASE_P(
AVX2, AV1CompMaskVarianceTest,
::testing::Combine(::testing::Values(&aom_comp_mask_pred_avx2),
::testing::ValuesIn(kValidBlockSize)));
#endif
#ifndef aom_comp_mask_pred
......@@ -249,7 +254,15 @@ TEST_P(AV1CompMaskUpVarianceTest, DISABLED_Speed) {
INSTANTIATE_TEST_CASE_P(
SSSE3, AV1CompMaskUpVarianceTest,
::testing::Combine(::testing::Values(&aom_comp_mask_pred_ssse3),
::testing::ValuesIn(valid_bsize)));
::testing::ValuesIn(kValidBlockSize)));
#endif
#if HAVE_AVX2
INSTANTIATE_TEST_CASE_P(
AVX2, AV1CompMaskUpVarianceTest,
::testing::Combine(::testing::Values(&aom_comp_mask_pred_avx2),
::testing::ValuesIn(kValidBlockSize)));
#endif
#endif // ifndef aom_comp_mask_pred
} // namespace AV1CompMaskVariance
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment