Commit 4b5c2bb4 authored by Rupert Swarbrick's avatar Rupert Swarbrick Committed by Debargha Mukherjee
Browse files

Define missing subtract_xxx functions in highbd_subtract_sse2.c

Also, get rid of the boilerplate code using some macros. STACK_V(h,f) means
"call f twice, stacking vertically at an offset of h". STACK_H(w,f)
means "call f twice, stacking horizontally at an offset of w".

Note that functions like subtract_128x64 are now only defined when the
equivalent block sizes (e.g. BLOCK_128x64) are defined. As such, we
have to fix up subtract_test.cc so it doesn't try to call
aom_highbd_subtract_block_sse2 with unsupported sizes.

BUG=aomedia:684

Change-Id: I5b0fefe70e4083786d11d25cdd5dcf02823bae7b
parent 5c700910
...@@ -177,177 +177,86 @@ static void subtract_8x8(int16_t *diff, ptrdiff_t diff_stride, ...@@ -177,177 +177,86 @@ static void subtract_8x8(int16_t *diff, ptrdiff_t diff_stride,
_mm_storeu_si128((__m128i *)(diff + 7 * diff_stride), x7); _mm_storeu_si128((__m128i *)(diff + 7 * diff_stride), x7);
} }
static void subtract_8x16(int16_t *diff, ptrdiff_t diff_stride, #define STACK_V(h, fun) \
const uint16_t *src, ptrdiff_t src_stride, do { \
const uint16_t *pred, ptrdiff_t pred_stride) { fun(diff, diff_stride, src, src_stride, pred, pred_stride); \
subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride); fun(diff + diff_stride * h, diff_stride, src + src_stride * h, src_stride, \
diff += diff_stride << 3; pred + pred_stride * h, pred_stride); \
src += src_stride << 3; } while (0)
pred += pred_stride << 3;
subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride); #define STACK_H(w, fun) \
} do { \
fun(diff, diff_stride, src, src_stride, pred, pred_stride); \
static void subtract_16x8(int16_t *diff, ptrdiff_t diff_stride, fun(diff + w, diff_stride, src + w, src_stride, pred + w, pred_stride); \
const uint16_t *src, ptrdiff_t src_stride, } while (0)
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride); #define SUBTRACT_FUN(size) \
diff += 8; static void subtract_##size(int16_t *diff, ptrdiff_t diff_stride, \
src += 8; const uint16_t *src, ptrdiff_t src_stride, \
pred += 8; const uint16_t *pred, ptrdiff_t pred_stride)
subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride);
} SUBTRACT_FUN(8x16) { STACK_V(8, subtract_8x8); }
SUBTRACT_FUN(16x8) { STACK_H(8, subtract_8x8); }
static void subtract_16x16(int16_t *diff, ptrdiff_t diff_stride, SUBTRACT_FUN(16x16) { STACK_V(8, subtract_16x8); }
const uint16_t *src, ptrdiff_t src_stride, SUBTRACT_FUN(16x32) { STACK_V(16, subtract_16x16); }
const uint16_t *pred, ptrdiff_t pred_stride) { SUBTRACT_FUN(32x16) { STACK_H(16, subtract_16x16); }
subtract_16x8(diff, diff_stride, src, src_stride, pred, pred_stride); SUBTRACT_FUN(32x32) { STACK_V(16, subtract_32x16); }
diff += diff_stride << 3; SUBTRACT_FUN(32x64) { STACK_V(32, subtract_32x32); }
src += src_stride << 3; SUBTRACT_FUN(64x32) { STACK_H(32, subtract_32x32); }
pred += pred_stride << 3; SUBTRACT_FUN(64x64) { STACK_V(32, subtract_64x32); }
subtract_16x8(diff, diff_stride, src, src_stride, pred, pred_stride); #if CONFIG_EXT_PARTITION
} SUBTRACT_FUN(64x128) { STACK_V(64, subtract_64x64); }
SUBTRACT_FUN(128x64) { STACK_H(64, subtract_64x64); }
static void subtract_16x32(int16_t *diff, ptrdiff_t diff_stride, SUBTRACT_FUN(128x128) { STACK_V(64, subtract_128x64); }
const uint16_t *src, ptrdiff_t src_stride, #endif
const uint16_t *pred, ptrdiff_t pred_stride) { SUBTRACT_FUN(4x16) { STACK_V(8, subtract_4x8); }
subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride); SUBTRACT_FUN(16x4) { STACK_H(8, subtract_8x4); }
diff += diff_stride << 4; SUBTRACT_FUN(8x32) { STACK_V(16, subtract_8x16); }
src += src_stride << 4; SUBTRACT_FUN(32x8) { STACK_H(16, subtract_16x8); }
pred += pred_stride << 4; SUBTRACT_FUN(16x64) { STACK_V(32, subtract_16x32); }
subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride); SUBTRACT_FUN(64x16) { STACK_H(32, subtract_32x16); }
}
static void subtract_32x16(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += 16;
src += 16;
pred += 16;
subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static void subtract_32x32(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_32x16(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += diff_stride << 4;
src += src_stride << 4;
pred += pred_stride << 4;
subtract_32x16(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static void subtract_32x64(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += diff_stride << 5;
src += src_stride << 5;
pred += pred_stride << 5;
subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static void subtract_64x32(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += 32;
src += 32;
pred += 32;
subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static void subtract_64x64(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_64x32(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += diff_stride << 5;
src += src_stride << 5;
pred += pred_stride << 5;
subtract_64x32(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static void subtract_64x128(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += diff_stride << 6;
src += src_stride << 6;
pred += pred_stride << 6;
subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static void subtract_128x64(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += 64;
src += 64;
pred += 64;
subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static void subtract_128x128(int16_t *diff, ptrdiff_t diff_stride,
const uint16_t *src, ptrdiff_t src_stride,
const uint16_t *pred, ptrdiff_t pred_stride) {
subtract_128x64(diff, diff_stride, src, src_stride, pred, pred_stride);
diff += diff_stride << 6;
src += src_stride << 6;
pred += pred_stride << 6;
subtract_128x64(diff, diff_stride, src, src_stride, pred, pred_stride);
}
static SubtractWxHFuncType getSubtractFunc(int rows, int cols) { static SubtractWxHFuncType getSubtractFunc(int rows, int cols) {
SubtractWxHFuncType ret_func_ptr = NULL;
if (rows == 4) { if (rows == 4) {
if (cols == 4) { if (cols == 4) return subtract_4x4;
ret_func_ptr = subtract_4x4; if (cols == 8) return subtract_8x4;
} else if (cols == 8) { if (cols == 16) return subtract_16x4;
ret_func_ptr = subtract_8x4; }
} if (rows == 8) {
} else if (rows == 8) { if (cols == 4) return subtract_4x8;
if (cols == 4) { if (cols == 8) return subtract_8x8;
ret_func_ptr = subtract_4x8; if (cols == 16) return subtract_16x8;
} else if (cols == 8) { if (cols == 32) return subtract_32x8;
ret_func_ptr = subtract_8x8; }
} else if (cols == 16) { if (rows == 16) {
ret_func_ptr = subtract_16x8; if (cols == 4) return subtract_4x16;
} if (cols == 8) return subtract_8x16;
} else if (rows == 16) { if (cols == 16) return subtract_16x16;
if (cols == 8) { if (cols == 32) return subtract_32x16;
ret_func_ptr = subtract_8x16; if (cols == 64) return subtract_64x16;
} else if (cols == 16) { }
ret_func_ptr = subtract_16x16; if (rows == 32) {
} else if (cols == 32) { if (cols == 8) return subtract_8x32;
ret_func_ptr = subtract_32x16; if (cols == 16) return subtract_16x32;
} if (cols == 32) return subtract_32x32;
} else if (rows == 32) { if (cols == 64) return subtract_64x32;
if (cols == 16) { }
ret_func_ptr = subtract_16x32; if (rows == 64) {
} else if (cols == 32) { if (cols == 16) return subtract_16x64;
ret_func_ptr = subtract_32x32; if (cols == 32) return subtract_32x64;
} else if (cols == 64) { if (cols == 64) return subtract_64x64;
ret_func_ptr = subtract_64x32; #if CONFIG_EXT_PARTITION
} if (cols == 128) return subtract_128x64;
} else if (rows == 64) { #endif // CONFIG_EXT_PARTITION
if (cols == 32) {
ret_func_ptr = subtract_32x64;
} else if (cols == 64) {
ret_func_ptr = subtract_64x64;
} else if (cols == 128) {
ret_func_ptr = subtract_128x64;
}
} else if (rows == 128) {
if (cols == 64) {
ret_func_ptr = subtract_64x128;
} else if (cols == 128) {
ret_func_ptr = subtract_128x128;
}
} }
if (!ret_func_ptr) { #if CONFIG_EXT_PARTITION
assert(0); if (rows == 128) {
if (cols == 64) return subtract_64x128;
if (cols == 128) return subtract_128x128;
} }
return ret_func_ptr; #endif // CONFIG_EXT_PARTITION
assert(0);
return NULL;
} }
void aom_highbd_subtract_block_sse2(int rows, int cols, int16_t *diff, void aom_highbd_subtract_block_sse2(int rows, int cols, int16_t *diff,
......
...@@ -130,7 +130,11 @@ class AV1HBDSubtractBlockTest : public ::testing::TestWithParam<Params> { ...@@ -130,7 +130,11 @@ class AV1HBDSubtractBlockTest : public ::testing::TestWithParam<Params> {
rnd_.Reset(ACMRandom::DeterministicSeed()); rnd_.Reset(ACMRandom::DeterministicSeed());
#if CONFIG_EXT_PARTITION
const size_t max_width = 128; const size_t max_width = 128;
#else
const size_t max_width = 64;
#endif
const size_t max_block_size = max_width * max_width; const size_t max_block_size = max_width * max_width;
src_ = CONVERT_TO_BYTEPTR(reinterpret_cast<uint16_t *>( src_ = CONVERT_TO_BYTEPTR(reinterpret_cast<uint16_t *>(
aom_memalign(16, max_block_size * sizeof(uint16_t)))); aom_memalign(16, max_block_size * sizeof(uint16_t))));
...@@ -163,7 +167,11 @@ class AV1HBDSubtractBlockTest : public ::testing::TestWithParam<Params> { ...@@ -163,7 +167,11 @@ class AV1HBDSubtractBlockTest : public ::testing::TestWithParam<Params> {
void AV1HBDSubtractBlockTest::CheckResult() { void AV1HBDSubtractBlockTest::CheckResult() {
const int test_num = 100; const int test_num = 100;
const int max_width = 128; #if CONFIG_EXT_PARTITION
const size_t max_width = 128;
#else
const size_t max_width = 64;
#endif
const int max_block_size = max_width * max_width; const int max_block_size = max_width * max_width;
const int mask = (1 << bit_depth_) - 1; const int mask = (1 << bit_depth_) - 1;
int i, j; int i, j;
...@@ -192,7 +200,11 @@ TEST_P(AV1HBDSubtractBlockTest, CheckResult) { CheckResult(); } ...@@ -192,7 +200,11 @@ TEST_P(AV1HBDSubtractBlockTest, CheckResult) { CheckResult(); }
void AV1HBDSubtractBlockTest::RunForSpeed() { void AV1HBDSubtractBlockTest::RunForSpeed() {
const int test_num = 200000; const int test_num = 200000;
const int max_width = 128; #if CONFIG_EXT_PARTITION
const size_t max_width = 128;
#else
const size_t max_width = 64;
#endif
const int max_block_size = max_width * max_width; const int max_block_size = max_width * max_width;
const int mask = (1 << bit_depth_) - 1; const int mask = (1 << bit_depth_) - 1;
int i, j; int i, j;
...@@ -239,12 +251,14 @@ const Params kAV1HBDSubtractBlock_sse2[] = { ...@@ -239,12 +251,14 @@ const Params kAV1HBDSubtractBlock_sse2[] = {
make_tuple(64, 32, 12, &aom_highbd_subtract_block_c), make_tuple(64, 32, 12, &aom_highbd_subtract_block_c),
make_tuple(64, 64, 12, &aom_highbd_subtract_block_sse2), make_tuple(64, 64, 12, &aom_highbd_subtract_block_sse2),
make_tuple(64, 64, 12, &aom_highbd_subtract_block_c), make_tuple(64, 64, 12, &aom_highbd_subtract_block_c),
#if CONFIG_EXT_PARTITION
make_tuple(64, 128, 12, &aom_highbd_subtract_block_sse2), make_tuple(64, 128, 12, &aom_highbd_subtract_block_sse2),
make_tuple(64, 128, 12, &aom_highbd_subtract_block_c), make_tuple(64, 128, 12, &aom_highbd_subtract_block_c),
make_tuple(128, 64, 12, &aom_highbd_subtract_block_sse2), make_tuple(128, 64, 12, &aom_highbd_subtract_block_sse2),
make_tuple(128, 64, 12, &aom_highbd_subtract_block_c), make_tuple(128, 64, 12, &aom_highbd_subtract_block_c),
make_tuple(128, 128, 12, &aom_highbd_subtract_block_sse2), make_tuple(128, 128, 12, &aom_highbd_subtract_block_sse2),
make_tuple(128, 128, 12, &aom_highbd_subtract_block_c) make_tuple(128, 128, 12, &aom_highbd_subtract_block_c)
#endif // CONFIG_EXT_PARTITION
}; };
INSTANTIATE_TEST_CASE_P(SSE2, AV1HBDSubtractBlockTest, INSTANTIATE_TEST_CASE_P(SSE2, AV1HBDSubtractBlockTest,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment