Commit 4c5b0204 authored by Alex Converse's avatar Alex Converse
Browse files

Make aom_sum_squares_2d_i16 take width and height parameters.

SSE2 may be needed for nx4 and 4xn.

Change-Id: I3c10112447fdb5fe51a68bc2c6e2f2641b102723
parent 2d5c2016
......@@ -1298,7 +1298,7 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
#
# Sum of Squares
#
add_proto qw/uint64_t aom_sum_squares_2d_i16/, "const int16_t *src, int stride, int size";
add_proto qw/uint64_t aom_sum_squares_2d_i16/, "const int16_t *src, int stride, int width, int height";
specialize qw/aom_sum_squares_2d_i16 sse2/;
add_proto qw/uint64_t aom_sum_squares_i16/, "const int16_t *src, uint32_t N";
......
......@@ -13,13 +13,13 @@
#include "./aom_dsp_rtcd.h"
uint64_t aom_sum_squares_2d_i16_c(const int16_t *src, int src_stride,
int size) {
uint64_t aom_sum_squares_2d_i16_c(const int16_t *src, int src_stride, int width,
int height) {
int r, c;
uint64_t ss = 0;
for (r = 0; r < size; r++) {
for (c = 0; c < size; c++) {
for (r = 0; r < height; r++) {
for (c = 0; c < width; c++) {
const int16_t v = src[c];
ss += v * v;
}
......
......@@ -50,16 +50,17 @@ static uint64_t aom_sum_squares_2d_i16_4x4_sse2(const int16_t *src,
__attribute__((noinline))
#endif
static uint64_t
aom_sum_squares_2d_i16_nxn_sse2(const int16_t *src, int stride, int size) {
aom_sum_squares_2d_i16_nxn_sse2(const int16_t *src, int stride, int width,
int height) {
int r, c;
const __m128i v_zext_mask_q = _mm_set_epi32(0, 0xffffffff, 0, 0xffffffff);
__m128i v_acc_q = _mm_setzero_si128();
for (r = 0; r < size; r += 8) {
for (r = 0; r < height; r += 8) {
__m128i v_acc_d = _mm_setzero_si128();
for (c = 0; c < size; c += 8) {
for (c = 0; c < width; c += 8) {
const int16_t *b = src + c;
const __m128i v_val_0_w =
......@@ -119,15 +120,18 @@ aom_sum_squares_2d_i16_nxn_sse2(const int16_t *src, int stride, int size) {
#endif
}
uint64_t aom_sum_squares_2d_i16_sse2(const int16_t *src, int stride, int size) {
uint64_t aom_sum_squares_2d_i16_sse2(const int16_t *src, int stride, int width,
int height) {
// 4 elements per row only requires half an XMM register, so this
// must be a special case, but also note that over 75% of all calls
// are with size == 4, so it is also the common case.
if (LIKELY(size == 4)) {
if (LIKELY(width == 4 && height == 4)) {
return aom_sum_squares_2d_i16_4x4_sse2(src, stride);
} else {
} else if (LIKELY(width % 8 == 0 && height % 8 == 0)) {
// Generic case
return aom_sum_squares_2d_i16_nxn_sse2(src, stride, size);
return aom_sum_squares_2d_i16_nxn_sse2(src, stride, width, height);
} else {
return aom_sum_squares_2d_i16_c(src, stride, width, height);
}
}
......
......@@ -1448,43 +1448,8 @@ static int rate_block(int plane, int block, const ENTROPY_CONTEXT *a,
static uint64_t sum_squares_2d(const int16_t *diff, int diff_stride,
TX_SIZE tx_size) {
uint64_t sse;
switch (tx_size) {
#if CONFIG_CB4X4
case TX_2X2:
sse = aom_sum_squares_2d_i16_c(diff, diff_stride, tx_size_wide[tx_size]);
break;
#endif // CONFIG_CB4X4
case TX_4X8:
sse = aom_sum_squares_2d_i16(diff, diff_stride, 4) +
aom_sum_squares_2d_i16(diff + 4 * diff_stride, diff_stride, 4);
break;
case TX_8X4:
sse = aom_sum_squares_2d_i16(diff, diff_stride, 4) +
aom_sum_squares_2d_i16(diff + 4, diff_stride, 4);
break;
case TX_8X16:
sse = aom_sum_squares_2d_i16(diff, diff_stride, 8) +
aom_sum_squares_2d_i16(diff + 8 * diff_stride, diff_stride, 8);
break;
case TX_16X8:
sse = aom_sum_squares_2d_i16(diff, diff_stride, 8) +
aom_sum_squares_2d_i16(diff + 8, diff_stride, 8);
break;
case TX_16X32:
sse = aom_sum_squares_2d_i16(diff, diff_stride, 16) +
aom_sum_squares_2d_i16(diff + 16 * diff_stride, diff_stride, 16);
break;
case TX_32X16:
sse = aom_sum_squares_2d_i16(diff, diff_stride, 16) +
aom_sum_squares_2d_i16(diff + 16, diff_stride, 16);
break;
default:
assert(tx_size < TX_SIZES);
sse = aom_sum_squares_2d_i16(diff, diff_stride, tx_size_wide[tx_size]);
break;
}
return sse;
return aom_sum_squares_2d_i16(diff, diff_stride, tx_size_wide[tx_size],
tx_size_high[tx_size]);
}
static void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
......
......@@ -32,7 +32,8 @@ const int kNumIterations = 10000;
static const int16_t kInt13Max = (1 << 12) - 1;
typedef uint64_t (*SSI16Func)(const int16_t *src, int stride, int size);
typedef uint64_t (*SSI16Func)(const int16_t *src, int stride, int width,
int height);
typedef libaom_test::FuncParam<SSI16Func> TestFuncs;
class SumSquaresTest : public ::testing::TestWithParam<TestFuncs> {
......@@ -56,21 +57,23 @@ TEST_P(SumSquaresTest, OperationCheck) {
const int limit = 1 << (msb + 1);
for (int k = 0; k < kNumIterations; k++) {
int size = 4 << rnd(6); // Up to 128x128
int width = 4 * rnd(32); // Up to 128x128
int height = 4 * rnd(32); // Up to 128x128
int stride = 4 << rnd(7); // Up to 256 stride
while (stride < size) { // Make sure it's valid
while (stride < width) { // Make sure it's valid
stride = 4 << rnd(7);
}
for (int ii = 0; ii < size; ii++) {
for (int jj = 0; jj < size; jj++) {
for (int ii = 0; ii < height; ii++) {
for (int jj = 0; jj < width; jj++) {
src[ii * stride + jj] = rnd(2) ? rnd(limit) : -rnd(limit);
}
}
const uint64_t res_ref = params_.ref_func(src, stride, size);
const uint64_t res_ref = params_.ref_func(src, stride, width, height);
uint64_t res_tst;
ASM_REGISTER_STATE_CHECK(res_tst = params_.tst_func(src, stride, size));
ASM_REGISTER_STATE_CHECK(res_tst =
params_.tst_func(src, stride, width, height));
if (!failed) {
failed = res_ref != res_tst;
......@@ -91,22 +94,24 @@ TEST_P(SumSquaresTest, ExtremeValues) {
const int limit = 1 << (msb + 1);
for (int k = 0; k < kNumIterations; k++) {
int size = 4 << rnd(6); // Up to 128x128
int width = 4 * rnd(32); // Up to 128x128
int height = 4 * rnd(32); // Up to 128x128
int stride = 4 << rnd(7); // Up to 256 stride
while (stride < size) { // Make sure it's valid
while (stride < width) { // Make sure it's valid
stride = 4 << rnd(7);
}
int val = rnd(2) ? limit - 1 : -(limit - 1);
for (int ii = 0; ii < size; ii++) {
for (int jj = 0; jj < size; jj++) {
for (int ii = 0; ii < height; ii++) {
for (int jj = 0; jj < width; jj++) {
src[ii * stride + jj] = val;
}
}
const uint64_t res_ref = params_.ref_func(src, stride, size);
const uint64_t res_ref = params_.ref_func(src, stride, width, height);
uint64_t res_tst;
ASM_REGISTER_STATE_CHECK(res_tst = params_.tst_func(src, stride, size));
ASM_REGISTER_STATE_CHECK(res_tst =
params_.tst_func(src, stride, width, height));
if (!failed) {
failed = res_ref != res_tst;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment