Commit 63bd6dc9 authored by Yi Luo's avatar Yi Luo

Fix rectangle transform computation overflow

- Add 16-bit saturation in fdct_round_shift().
- Add extreme value tests and round trip error tests.
- Fix inv 4x8 txfm calculation accuracy.
- Fix 4x8, 8x4, 8x16, 16x8, 16x32, 32x16 extreme value tests.
- BDRate: lowres: -0.034
          midres: -0.036
          hdres:  -0.013
BUG=webm:1340

Change-Id: I48365c1e50a03a7b1aa69b8856b732b483299fb5
parent 125e7293
......@@ -14,12 +14,15 @@
#include "aom_dsp/txfm_common.h"
static INLINE tran_high_t saturate_int16(tran_high_t value) {
tran_high_t result;
result = value > INT16_MAX ? INT16_MAX : value;
return result < INT16_MIN ? INT16_MIN : result;
}
static INLINE tran_high_t fdct_round_shift(tran_high_t input) {
tran_high_t rv = ROUND_POWER_OF_TWO(input, DCT_CONST_BITS);
// TODO(debargha, peter.derivaz): Find new bounds for this assert
// and make the bounds consts.
// assert(INT16_MIN <= rv && rv <= INT16_MAX);
return rv;
return saturate_int16(rv);
}
void aom_fdct32(const tran_high_t *input, tran_high_t *output, int round);
......
......@@ -1190,8 +1190,6 @@ void av1_iht4x8_32_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
in[6] = load_input_data(input + 2 * 8);
in[7] = load_input_data(input + 3 * 8);
scale_sqrt2_8x4(in + 4);
// Row transform
switch (tx_type) {
case DCT_DCT:
......@@ -1230,6 +1228,8 @@ void av1_iht4x8_32_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
default: assert(0); break;
}
scale_sqrt2_8x4(in + 4);
// Repack data
in[0] = _mm_unpacklo_epi64(in[4], in[6]);
in[1] = _mm_unpackhi_epi64(in[4], in[6]);
......
......@@ -787,10 +787,10 @@ static void fadst8(const tran_low_t *input, tran_low_t *output) {
s6 = cospi_26_64 * x6 + cospi_6_64 * x7;
s7 = cospi_6_64 * x6 - cospi_26_64 * x7;
x0 = fdct_round_shift(s0 + s4);
x1 = fdct_round_shift(s1 + s5);
x2 = fdct_round_shift(s2 + s6);
x3 = fdct_round_shift(s3 + s7);
x0 = s0 + s4;
x1 = s1 + s5;
x2 = s2 + s6;
x3 = s3 + s7;
x4 = fdct_round_shift(s0 - s4);
x5 = fdct_round_shift(s1 - s5);
x6 = fdct_round_shift(s2 - s6);
......@@ -806,10 +806,10 @@ static void fadst8(const tran_low_t *input, tran_low_t *output) {
s6 = -cospi_24_64 * x6 + cospi_8_64 * x7;
s7 = cospi_8_64 * x6 + cospi_24_64 * x7;
x0 = s0 + s2;
x1 = s1 + s3;
x2 = s0 - s2;
x3 = s1 - s3;
x0 = fdct_round_shift(s0 + s2);
x1 = fdct_round_shift(s1 + s3);
x2 = fdct_round_shift(s0 - s2);
x3 = fdct_round_shift(s1 - s3);
x4 = fdct_round_shift(s4 + s6);
x5 = fdct_round_shift(s5 + s7);
x6 = fdct_round_shift(s4 - s6);
......@@ -875,14 +875,15 @@ static void fadst16(const tran_low_t *input, tran_low_t *output) {
s14 = x14 * cospi_29_64 + x15 * cospi_3_64;
s15 = x14 * cospi_3_64 - x15 * cospi_29_64;
x0 = fdct_round_shift(s0 + s8);
x1 = fdct_round_shift(s1 + s9);
x2 = fdct_round_shift(s2 + s10);
x3 = fdct_round_shift(s3 + s11);
x4 = fdct_round_shift(s4 + s12);
x5 = fdct_round_shift(s5 + s13);
x6 = fdct_round_shift(s6 + s14);
x7 = fdct_round_shift(s7 + s15);
x0 = s0 + s8;
x1 = s1 + s9;
x2 = s2 + s10;
x3 = s3 + s11;
x4 = s4 + s12;
x5 = s5 + s13;
x6 = s6 + s14;
x7 = s7 + s15;
x8 = fdct_round_shift(s0 - s8);
x9 = fdct_round_shift(s1 - s9);
x10 = fdct_round_shift(s2 - s10);
......@@ -914,14 +915,15 @@ static void fadst16(const tran_low_t *input, tran_low_t *output) {
x1 = s1 + s5;
x2 = s2 + s6;
x3 = s3 + s7;
x4 = s0 - s4;
x5 = s1 - s5;
x6 = s2 - s6;
x7 = s3 - s7;
x8 = fdct_round_shift(s8 + s12);
x9 = fdct_round_shift(s9 + s13);
x10 = fdct_round_shift(s10 + s14);
x11 = fdct_round_shift(s11 + s15);
x4 = fdct_round_shift(s0 - s4);
x5 = fdct_round_shift(s1 - s5);
x6 = fdct_round_shift(s2 - s6);
x7 = fdct_round_shift(s3 - s7);
x8 = s8 + s12;
x9 = s9 + s13;
x10 = s10 + s14;
x11 = s11 + s15;
x12 = fdct_round_shift(s8 - s12);
x13 = fdct_round_shift(s9 - s13);
x14 = fdct_round_shift(s10 - s14);
......@@ -945,18 +947,21 @@ static void fadst16(const tran_low_t *input, tran_low_t *output) {
s14 = -x14 * cospi_24_64 + x15 * cospi_8_64;
s15 = x14 * cospi_8_64 + x15 * cospi_24_64;
x0 = s0 + s2;
x1 = s1 + s3;
x2 = s0 - s2;
x3 = s1 - s3;
x0 = fdct_round_shift(s0 + s2);
x1 = fdct_round_shift(s1 + s3);
x2 = fdct_round_shift(s0 - s2);
x3 = fdct_round_shift(s1 - s3);
x4 = fdct_round_shift(s4 + s6);
x5 = fdct_round_shift(s5 + s7);
x6 = fdct_round_shift(s4 - s6);
x7 = fdct_round_shift(s5 - s7);
x8 = s8 + s10;
x9 = s9 + s11;
x10 = s8 - s10;
x11 = s9 - s11;
x8 = fdct_round_shift(s8 + s10);
x9 = fdct_round_shift(s9 + s11);
x10 = fdct_round_shift(s8 - s10);
x11 = fdct_round_shift(s9 - s11);
x12 = fdct_round_shift(s12 + s14);
x13 = fdct_round_shift(s13 + s15);
x14 = fdct_round_shift(s12 - s14);
......@@ -1230,7 +1235,7 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
for (i = 0; i < n2; ++i) {
for (j = 0; j < n; ++j) temp_in[j] = out[j + i * n];
ht.rows(temp_in, temp_out);
for (j = 0; j < n; ++j) output[j + i * n] = (temp_out[j] + 1) >> 2;
for (j = 0; j < n; ++j) output[j + i * n] = temp_out[j] >> 2;
}
// Note: overall scale factor of transform is 8 times unitary
}
......@@ -1281,7 +1286,7 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
for (i = 0; i < n; ++i) {
for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
ht.rows(temp_in, temp_out);
for (j = 0; j < n2; ++j) output[j + i * n2] = (temp_out[j] + 1) >> 2;
for (j = 0; j < n2; ++j) output[j + i * n2] = temp_out[j] >> 2;
}
// Note: overall scale factor of transform is 8 times unitary
}
......@@ -1332,8 +1337,7 @@ void av1_fht8x16_c(const int16_t *input, tran_low_t *output, int stride,
for (i = 0; i < n2; ++i) {
for (j = 0; j < n; ++j) temp_in[j] = out[j + i * n];
ht.rows(temp_in, temp_out);
for (j = 0; j < n; ++j)
output[j + i * n] = (temp_out[j] + 1 + (temp_out[j] < 0)) >> 2;
for (j = 0; j < n; ++j) output[j + i * n] = temp_out[j] >> 2;
}
// Note: overall scale factor of transform is 8 times unitary
}
......@@ -1384,8 +1388,7 @@ void av1_fht16x8_c(const int16_t *input, tran_low_t *output, int stride,
for (i = 0; i < n; ++i) {
for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
ht.rows(temp_in, temp_out);
for (j = 0; j < n2; ++j)
output[j + i * n2] = (temp_out[j] + 1 + (temp_out[j] < 0)) >> 2;
for (j = 0; j < n2; ++j) output[j + i * n2] = temp_out[j] >> 2;
}
// Note: overall scale factor of transform is 8 times unitary
}
......@@ -1435,9 +1438,7 @@ void av1_fht16x32_c(const int16_t *input, tran_low_t *output, int stride,
for (i = 0; i < n2; ++i) {
for (j = 0; j < n; ++j) temp_in[j] = out[j + i * n];
ht.rows(temp_in, temp_out);
for (j = 0; j < n; ++j)
output[j + i * n] =
(tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2);
for (j = 0; j < n; ++j) output[j + i * n] = temp_out[j] >> 2;
}
// Note: overall scale factor of transform is 4 times unitary
}
......@@ -1487,9 +1488,7 @@ void av1_fht32x16_c(const int16_t *input, tran_low_t *output, int stride,
for (i = 0; i < n; ++i) {
for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
ht.rows(temp_in, temp_out);
for (j = 0; j < n2; ++j)
output[j + i * n2] =
(tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2);
for (j = 0; j < n2; ++j) output[j + i * n2] = temp_out[j] >> 2;
}
// Note: overall scale factor of transform is 4 times unitary
}
......
......@@ -1123,14 +1123,6 @@ static void fadst8_sse2(__m128i *in) {
w15 = _mm_sub_epi32(u7, u15);
// shift and rounding
v0 = _mm_add_epi32(w0, k__DCT_CONST_ROUNDING);
v1 = _mm_add_epi32(w1, k__DCT_CONST_ROUNDING);
v2 = _mm_add_epi32(w2, k__DCT_CONST_ROUNDING);
v3 = _mm_add_epi32(w3, k__DCT_CONST_ROUNDING);
v4 = _mm_add_epi32(w4, k__DCT_CONST_ROUNDING);
v5 = _mm_add_epi32(w5, k__DCT_CONST_ROUNDING);
v6 = _mm_add_epi32(w6, k__DCT_CONST_ROUNDING);
v7 = _mm_add_epi32(w7, k__DCT_CONST_ROUNDING);
v8 = _mm_add_epi32(w8, k__DCT_CONST_ROUNDING);
v9 = _mm_add_epi32(w9, k__DCT_CONST_ROUNDING);
v10 = _mm_add_epi32(w10, k__DCT_CONST_ROUNDING);
......@@ -1140,14 +1132,6 @@ static void fadst8_sse2(__m128i *in) {
v14 = _mm_add_epi32(w14, k__DCT_CONST_ROUNDING);
v15 = _mm_add_epi32(w15, k__DCT_CONST_ROUNDING);
u0 = _mm_srai_epi32(v0, DCT_CONST_BITS);
u1 = _mm_srai_epi32(v1, DCT_CONST_BITS);
u2 = _mm_srai_epi32(v2, DCT_CONST_BITS);
u3 = _mm_srai_epi32(v3, DCT_CONST_BITS);
u4 = _mm_srai_epi32(v4, DCT_CONST_BITS);
u5 = _mm_srai_epi32(v5, DCT_CONST_BITS);
u6 = _mm_srai_epi32(v6, DCT_CONST_BITS);
u7 = _mm_srai_epi32(v7, DCT_CONST_BITS);
u8 = _mm_srai_epi32(v8, DCT_CONST_BITS);
u9 = _mm_srai_epi32(v9, DCT_CONST_BITS);
u10 = _mm_srai_epi32(v10, DCT_CONST_BITS);
......@@ -1158,20 +1142,44 @@ static void fadst8_sse2(__m128i *in) {
u15 = _mm_srai_epi32(v15, DCT_CONST_BITS);
// back to 16-bit and pack 8 integers into __m128i
in[0] = _mm_packs_epi32(u0, u1);
in[1] = _mm_packs_epi32(u2, u3);
in[2] = _mm_packs_epi32(u4, u5);
in[3] = _mm_packs_epi32(u6, u7);
v0 = _mm_add_epi32(w0, w4);
v1 = _mm_add_epi32(w1, w5);
v2 = _mm_add_epi32(w2, w6);
v3 = _mm_add_epi32(w3, w7);
v4 = _mm_sub_epi32(w0, w4);
v5 = _mm_sub_epi32(w1, w5);
v6 = _mm_sub_epi32(w2, w6);
v7 = _mm_sub_epi32(w3, w7);
w0 = _mm_add_epi32(v0, k__DCT_CONST_ROUNDING);
w1 = _mm_add_epi32(v1, k__DCT_CONST_ROUNDING);
w2 = _mm_add_epi32(v2, k__DCT_CONST_ROUNDING);
w3 = _mm_add_epi32(v3, k__DCT_CONST_ROUNDING);
w4 = _mm_add_epi32(v4, k__DCT_CONST_ROUNDING);
w5 = _mm_add_epi32(v5, k__DCT_CONST_ROUNDING);
w6 = _mm_add_epi32(v6, k__DCT_CONST_ROUNDING);
w7 = _mm_add_epi32(v7, k__DCT_CONST_ROUNDING);
v0 = _mm_srai_epi32(w0, DCT_CONST_BITS);
v1 = _mm_srai_epi32(w1, DCT_CONST_BITS);
v2 = _mm_srai_epi32(w2, DCT_CONST_BITS);
v3 = _mm_srai_epi32(w3, DCT_CONST_BITS);
v4 = _mm_srai_epi32(w4, DCT_CONST_BITS);
v5 = _mm_srai_epi32(w5, DCT_CONST_BITS);
v6 = _mm_srai_epi32(w6, DCT_CONST_BITS);
v7 = _mm_srai_epi32(w7, DCT_CONST_BITS);
in[4] = _mm_packs_epi32(u8, u9);
in[5] = _mm_packs_epi32(u10, u11);
in[6] = _mm_packs_epi32(u12, u13);
in[7] = _mm_packs_epi32(u14, u15);
// stage 2
s0 = _mm_add_epi16(in[0], in[2]);
s1 = _mm_add_epi16(in[1], in[3]);
s2 = _mm_sub_epi16(in[0], in[2]);
s3 = _mm_sub_epi16(in[1], in[3]);
s0 = _mm_packs_epi32(v0, v1);
s1 = _mm_packs_epi32(v2, v3);
s2 = _mm_packs_epi32(v4, v5);
s3 = _mm_packs_epi32(v6, v7);
u0 = _mm_unpacklo_epi16(in[4], in[5]);
u1 = _mm_unpackhi_epi16(in[4], in[5]);
u2 = _mm_unpacklo_epi16(in[6], in[7]);
......@@ -1914,22 +1922,6 @@ static void fadst16_8col(__m128i *in) {
u[30] = _mm_sub_epi32(v[14], v[30]);
u[31] = _mm_sub_epi32(v[15], v[31]);
v[0] = _mm_add_epi32(u[0], k__DCT_CONST_ROUNDING);
v[1] = _mm_add_epi32(u[1], k__DCT_CONST_ROUNDING);
v[2] = _mm_add_epi32(u[2], k__DCT_CONST_ROUNDING);
v[3] = _mm_add_epi32(u[3], k__DCT_CONST_ROUNDING);
v[4] = _mm_add_epi32(u[4], k__DCT_CONST_ROUNDING);
v[5] = _mm_add_epi32(u[5], k__DCT_CONST_ROUNDING);
v[6] = _mm_add_epi32(u[6], k__DCT_CONST_ROUNDING);
v[7] = _mm_add_epi32(u[7], k__DCT_CONST_ROUNDING);
v[8] = _mm_add_epi32(u[8], k__DCT_CONST_ROUNDING);
v[9] = _mm_add_epi32(u[9], k__DCT_CONST_ROUNDING);
v[10] = _mm_add_epi32(u[10], k__DCT_CONST_ROUNDING);
v[11] = _mm_add_epi32(u[11], k__DCT_CONST_ROUNDING);
v[12] = _mm_add_epi32(u[12], k__DCT_CONST_ROUNDING);
v[13] = _mm_add_epi32(u[13], k__DCT_CONST_ROUNDING);
v[14] = _mm_add_epi32(u[14], k__DCT_CONST_ROUNDING);
v[15] = _mm_add_epi32(u[15], k__DCT_CONST_ROUNDING);
v[16] = _mm_add_epi32(u[16], k__DCT_CONST_ROUNDING);
v[17] = _mm_add_epi32(u[17], k__DCT_CONST_ROUNDING);
v[18] = _mm_add_epi32(u[18], k__DCT_CONST_ROUNDING);
......@@ -1947,22 +1939,6 @@ static void fadst16_8col(__m128i *in) {
v[30] = _mm_add_epi32(u[30], k__DCT_CONST_ROUNDING);
v[31] = _mm_add_epi32(u[31], k__DCT_CONST_ROUNDING);
u[0] = _mm_srai_epi32(v[0], DCT_CONST_BITS);
u[1] = _mm_srai_epi32(v[1], DCT_CONST_BITS);
u[2] = _mm_srai_epi32(v[2], DCT_CONST_BITS);
u[3] = _mm_srai_epi32(v[3], DCT_CONST_BITS);
u[4] = _mm_srai_epi32(v[4], DCT_CONST_BITS);
u[5] = _mm_srai_epi32(v[5], DCT_CONST_BITS);
u[6] = _mm_srai_epi32(v[6], DCT_CONST_BITS);
u[7] = _mm_srai_epi32(v[7], DCT_CONST_BITS);
u[8] = _mm_srai_epi32(v[8], DCT_CONST_BITS);
u[9] = _mm_srai_epi32(v[9], DCT_CONST_BITS);
u[10] = _mm_srai_epi32(v[10], DCT_CONST_BITS);
u[11] = _mm_srai_epi32(v[11], DCT_CONST_BITS);
u[12] = _mm_srai_epi32(v[12], DCT_CONST_BITS);
u[13] = _mm_srai_epi32(v[13], DCT_CONST_BITS);
u[14] = _mm_srai_epi32(v[14], DCT_CONST_BITS);
u[15] = _mm_srai_epi32(v[15], DCT_CONST_BITS);
u[16] = _mm_srai_epi32(v[16], DCT_CONST_BITS);
u[17] = _mm_srai_epi32(v[17], DCT_CONST_BITS);
u[18] = _mm_srai_epi32(v[18], DCT_CONST_BITS);
......@@ -1980,14 +1956,77 @@ static void fadst16_8col(__m128i *in) {
u[30] = _mm_srai_epi32(v[30], DCT_CONST_BITS);
u[31] = _mm_srai_epi32(v[31], DCT_CONST_BITS);
s[0] = _mm_packs_epi32(u[0], u[1]);
s[1] = _mm_packs_epi32(u[2], u[3]);
s[2] = _mm_packs_epi32(u[4], u[5]);
s[3] = _mm_packs_epi32(u[6], u[7]);
s[4] = _mm_packs_epi32(u[8], u[9]);
s[5] = _mm_packs_epi32(u[10], u[11]);
s[6] = _mm_packs_epi32(u[12], u[13]);
s[7] = _mm_packs_epi32(u[14], u[15]);
v[0] = _mm_add_epi32(u[0], u[8]);
v[1] = _mm_add_epi32(u[1], u[9]);
v[2] = _mm_add_epi32(u[2], u[10]);
v[3] = _mm_add_epi32(u[3], u[11]);
v[4] = _mm_add_epi32(u[4], u[12]);
v[5] = _mm_add_epi32(u[5], u[13]);
v[6] = _mm_add_epi32(u[6], u[14]);
v[7] = _mm_add_epi32(u[7], u[15]);
v[16] = _mm_add_epi32(v[0], v[4]);
v[17] = _mm_add_epi32(v[1], v[5]);
v[18] = _mm_add_epi32(v[2], v[6]);
v[19] = _mm_add_epi32(v[3], v[7]);
v[20] = _mm_sub_epi32(v[0], v[4]);
v[21] = _mm_sub_epi32(v[1], v[5]);
v[22] = _mm_sub_epi32(v[2], v[6]);
v[23] = _mm_sub_epi32(v[3], v[7]);
v[16] = _mm_add_epi32(v[16], k__DCT_CONST_ROUNDING);
v[17] = _mm_add_epi32(v[17], k__DCT_CONST_ROUNDING);
v[18] = _mm_add_epi32(v[18], k__DCT_CONST_ROUNDING);
v[19] = _mm_add_epi32(v[19], k__DCT_CONST_ROUNDING);
v[20] = _mm_add_epi32(v[20], k__DCT_CONST_ROUNDING);
v[21] = _mm_add_epi32(v[21], k__DCT_CONST_ROUNDING);
v[22] = _mm_add_epi32(v[22], k__DCT_CONST_ROUNDING);
v[23] = _mm_add_epi32(v[23], k__DCT_CONST_ROUNDING);
v[16] = _mm_srai_epi32(v[16], DCT_CONST_BITS);
v[17] = _mm_srai_epi32(v[17], DCT_CONST_BITS);
v[18] = _mm_srai_epi32(v[18], DCT_CONST_BITS);
v[19] = _mm_srai_epi32(v[19], DCT_CONST_BITS);
v[20] = _mm_srai_epi32(v[20], DCT_CONST_BITS);
v[21] = _mm_srai_epi32(v[21], DCT_CONST_BITS);
v[22] = _mm_srai_epi32(v[22], DCT_CONST_BITS);
v[23] = _mm_srai_epi32(v[23], DCT_CONST_BITS);
s[0] = _mm_packs_epi32(v[16], v[17]);
s[1] = _mm_packs_epi32(v[18], v[19]);
s[2] = _mm_packs_epi32(v[20], v[21]);
s[3] = _mm_packs_epi32(v[22], v[23]);
v[8] = _mm_sub_epi32(u[0], u[8]);
v[9] = _mm_sub_epi32(u[1], u[9]);
v[10] = _mm_sub_epi32(u[2], u[10]);
v[11] = _mm_sub_epi32(u[3], u[11]);
v[12] = _mm_sub_epi32(u[4], u[12]);
v[13] = _mm_sub_epi32(u[5], u[13]);
v[14] = _mm_sub_epi32(u[6], u[14]);
v[15] = _mm_sub_epi32(u[7], u[15]);
v[8] = _mm_add_epi32(v[8], k__DCT_CONST_ROUNDING);
v[9] = _mm_add_epi32(v[9], k__DCT_CONST_ROUNDING);
v[10] = _mm_add_epi32(v[10], k__DCT_CONST_ROUNDING);
v[11] = _mm_add_epi32(v[11], k__DCT_CONST_ROUNDING);
v[12] = _mm_add_epi32(v[12], k__DCT_CONST_ROUNDING);
v[13] = _mm_add_epi32(v[13], k__DCT_CONST_ROUNDING);
v[14] = _mm_add_epi32(v[14], k__DCT_CONST_ROUNDING);
v[15] = _mm_add_epi32(v[15], k__DCT_CONST_ROUNDING);
v[8] = _mm_srai_epi32(v[8], DCT_CONST_BITS);
v[9] = _mm_srai_epi32(v[9], DCT_CONST_BITS);
v[10] = _mm_srai_epi32(v[10], DCT_CONST_BITS);
v[11] = _mm_srai_epi32(v[11], DCT_CONST_BITS);
v[12] = _mm_srai_epi32(v[12], DCT_CONST_BITS);
v[13] = _mm_srai_epi32(v[13], DCT_CONST_BITS);
v[14] = _mm_srai_epi32(v[14], DCT_CONST_BITS);
v[15] = _mm_srai_epi32(v[15], DCT_CONST_BITS);
s[4] = _mm_packs_epi32(v[8], v[9]);
s[5] = _mm_packs_epi32(v[10], v[11]);
s[6] = _mm_packs_epi32(v[12], v[13]);
s[7] = _mm_packs_epi32(v[14], v[15]);
//
s[8] = _mm_packs_epi32(u[16], u[17]);
s[9] = _mm_packs_epi32(u[18], u[19]);
s[10] = _mm_packs_epi32(u[20], u[21]);
......@@ -2041,14 +2080,6 @@ static void fadst16_8col(__m128i *in) {
u[14] = _mm_sub_epi32(v[6], v[14]);
u[15] = _mm_sub_epi32(v[7], v[15]);
v[0] = _mm_add_epi32(u[0], k__DCT_CONST_ROUNDING);
v[1] = _mm_add_epi32(u[1], k__DCT_CONST_ROUNDING);
v[2] = _mm_add_epi32(u[2], k__DCT_CONST_ROUNDING);
v[3] = _mm_add_epi32(u[3], k__DCT_CONST_ROUNDING);
v[4] = _mm_add_epi32(u[4], k__DCT_CONST_ROUNDING);
v[5] = _mm_add_epi32(u[5], k__DCT_CONST_ROUNDING);
v[6] = _mm_add_epi32(u[6], k__DCT_CONST_ROUNDING);
v[7] = _mm_add_epi32(u[7], k__DCT_CONST_ROUNDING);
v[8] = _mm_add_epi32(u[8], k__DCT_CONST_ROUNDING);
v[9] = _mm_add_epi32(u[9], k__DCT_CONST_ROUNDING);
v[10] = _mm_add_epi32(u[10], k__DCT_CONST_ROUNDING);
......@@ -2058,14 +2089,6 @@ static void fadst16_8col(__m128i *in) {
v[14] = _mm_add_epi32(u[14], k__DCT_CONST_ROUNDING);
v[15] = _mm_add_epi32(u[15], k__DCT_CONST_ROUNDING);
u[0] = _mm_srai_epi32(v[0], DCT_CONST_BITS);
u[1] = _mm_srai_epi32(v[1], DCT_CONST_BITS);
u[2] = _mm_srai_epi32(v[2], DCT_CONST_BITS);
u[3] = _mm_srai_epi32(v[3], DCT_CONST_BITS);
u[4] = _mm_srai_epi32(v[4], DCT_CONST_BITS);
u[5] = _mm_srai_epi32(v[5], DCT_CONST_BITS);
u[6] = _mm_srai_epi32(v[6], DCT_CONST_BITS);
u[7] = _mm_srai_epi32(v[7], DCT_CONST_BITS);
u[8] = _mm_srai_epi32(v[8], DCT_CONST_BITS);
u[9] = _mm_srai_epi32(v[9], DCT_CONST_BITS);
u[10] = _mm_srai_epi32(v[10], DCT_CONST_BITS);
......@@ -2075,28 +2098,46 @@ static void fadst16_8col(__m128i *in) {
u[14] = _mm_srai_epi32(v[14], DCT_CONST_BITS);
u[15] = _mm_srai_epi32(v[15], DCT_CONST_BITS);
x[0] = _mm_add_epi16(s[0], s[4]);
x[1] = _mm_add_epi16(s[1], s[5]);
x[2] = _mm_add_epi16(s[2], s[6]);
x[3] = _mm_add_epi16(s[3], s[7]);
x[4] = _mm_sub_epi16(s[0], s[4]);
x[5] = _mm_sub_epi16(s[1], s[5]);
x[6] = _mm_sub_epi16(s[2], s[6]);
x[7] = _mm_sub_epi16(s[3], s[7]);
x[8] = _mm_packs_epi32(u[0], u[1]);
x[9] = _mm_packs_epi32(u[2], u[3]);
x[10] = _mm_packs_epi32(u[4], u[5]);
x[11] = _mm_packs_epi32(u[6], u[7]);
v[8] = _mm_add_epi32(u[0], u[4]);
v[9] = _mm_add_epi32(u[1], u[5]);
v[10] = _mm_add_epi32(u[2], u[6]);
v[11] = _mm_add_epi32(u[3], u[7]);
v[12] = _mm_sub_epi32(u[0], u[4]);
v[13] = _mm_sub_epi32(u[1], u[5]);
v[14] = _mm_sub_epi32(u[2], u[6]);
v[15] = _mm_sub_epi32(u[3], u[7]);
v[8] = _mm_add_epi32(v[8], k__DCT_CONST_ROUNDING);
v[9] = _mm_add_epi32(v[9], k__DCT_CONST_ROUNDING);
v[10] = _mm_add_epi32(v[10], k__DCT_CONST_ROUNDING);
v[11] = _mm_add_epi32(v[11], k__DCT_CONST_ROUNDING);
v[12] = _mm_add_epi32(v[12], k__DCT_CONST_ROUNDING);
v[13] = _mm_add_epi32(v[13], k__DCT_CONST_ROUNDING);
v[14] = _mm_add_epi32(v[14], k__DCT_CONST_ROUNDING);
v[15] = _mm_add_epi32(v[15], k__DCT_CONST_ROUNDING);
v[8] = _mm_srai_epi32(v[8], DCT_CONST_BITS);
v[9] = _mm_srai_epi32(v[9], DCT_CONST_BITS);
v[10] = _mm_srai_epi32(v[10], DCT_CONST_BITS);
v[11] = _mm_srai_epi32(v[11], DCT_CONST_BITS);
v[12] = _mm_srai_epi32(v[12], DCT_CONST_BITS);
v[13] = _mm_srai_epi32(v[13], DCT_CONST_BITS);
v[14] = _mm_srai_epi32(v[14], DCT_CONST_BITS);
v[15] = _mm_srai_epi32(v[15], DCT_CONST_BITS);
s[8] = _mm_packs_epi32(v[8], v[9]);
s[9] = _mm_packs_epi32(v[10], v[11]);
s[10] = _mm_packs_epi32(v[12], v[13]);
s[11] = _mm_packs_epi32(v[14], v[15]);
x[12] = _mm_packs_epi32(u[8], u[9]);
x[13] = _mm_packs_epi32(u[10], u[11]);
x[14] = _mm_packs_epi32(u[12], u[13]);
x[15] = _mm_packs_epi32(u[14], u[15]);
// stage 3
u[0] = _mm_unpacklo_epi16(x[4], x[5]);
u[1] = _mm_unpackhi_epi16(x[4], x[5]);
u[2] = _mm_unpacklo_epi16(x[6], x[7]);
u[3] = _mm_unpackhi_epi16(x[6], x[7]);
u[0] = _mm_unpacklo_epi16(s[4], s[5]);
u[1] = _mm_unpackhi_epi16(s[4], s[5]);
u[2] = _mm_unpacklo_epi16(s[6], s[7]);
u[3] = _mm_unpackhi_epi16(s[6], s[7]);
u[4] = _mm_unpacklo_epi16(x[12], x[13]);
u[5] = _mm_unpackhi_epi16(x[12], x[13]);
u[6] = _mm_unpacklo_epi16(x[14], x[15]);
......@@ -2170,18 +2211,11 @@ static void fadst16_8col(__m128i *in) {
v[14] = _mm_srai_epi32(u[14], DCT_CONST_BITS);
v[15] = _mm_srai_epi32(u[15], DCT_CONST_BITS);
s[0] = _mm_add_epi16(x[0], x[2]);
s[1] = _mm_add_epi16(x[1], x[3]);
s[2] = _mm_sub_epi16(x[0], x[2]);
s[3] = _mm_sub_epi16(x[1], x[3]);
s[4] = _mm_packs_epi32(v[0], v[1]);
s[5] = _mm_packs_epi32(v[2], v[3]);
s[6] = _mm_packs_epi32(v[4], v[5]);
s[7] = _mm_packs_epi32(v[6], v[7]);
s[8] = _mm_add_epi16(x[8], x[10]);
s[9] = _mm_add_epi16(x[9], x[11]);
s[10] = _mm_sub_epi16(x[8], x[10]);
s[11] = _mm_sub_epi16(x[9], x[11]);
s[12] = _mm_packs_epi32(v[8], v[9]);
s[13] = _mm_packs_epi32(v[10], v[11]);
s[14] = _mm_packs_epi32(v[12], v[13]);
......@@ -2740,26 +2774,20 @@ static INLINE void load_buffer_4x8(const int16_t *input, __m128i *in,
}
static INLINE void write_buffer_4x8(tran_low_t *output, __m128i *res) {
const __m128i kOne = _mm_set1_epi16(1);
__m128i in01 = _mm_unpacklo_epi64(res[0], res[1]);
__m128i in23 = _mm_unpacklo_epi64(res[2], res[3]);
__m128i in45 = _mm_unpacklo_epi64(res[4], res[5]);
__m128i in67 = _mm_unpacklo_epi64(res[6], res[7]);
__m128i out01 = _mm_add_epi16(in01, kOne);
__m128i out23 = _mm_add_epi16(in23, kOne);
__m128i out45 = _mm_add_epi16(in45, kOne);
__m128i out67 = _mm_add_epi16(in67, kOne);
out01 = _mm_srai_epi16(out01, 2);
out23 = _mm_srai_epi16(out23, 2);
out45 = _mm_srai_epi16(out45, 2);
out67 = _mm_srai_epi16(out67, 2);
in01 = _mm_srai_epi16(in01, 2);
in23 = _mm_srai_epi16(in23, 2);
in45 = _mm_srai_epi16(in45, 2);
in67 = _mm_srai_epi16(in67, 2);
store_output(&out01, (output + 0 * 8));
store_output(&out23, (output + 1 * 8));
store_output(&out45, (output + 2 * 8));
store_output(&out67, (output + 3 * 8));
store_output(&in01, (output + 0 * 8));
store_output(&in23, (output + 1 * 8));
store_output(&in45, (output + 2 * 8));
store_output(&in67, (output + 3 * 8));
}
void av1_fht4x8_sse2(const int16_t *input, tran_low_t *output, int stride,
......@@ -2975,16 +3003,10 @@ static INLINE void load_buffer_8x4(const int16_t *input, __m128i *in,
}
static INLINE void write_buffer_8x4(tran_low_t *output, __m128i *res) {
const __m128i kOne = _mm_set1_epi16(1);
__m128i out0 = _mm_add_epi16(res[0], kOne);
__m128i out1 = _mm_add_epi16(res[1], kOne);
__m128i out2 = _mm_add_epi16(res[2], kOne);
__m128i out3 = _mm_add_epi16(res[3], kOne);
out0 = _mm_srai_epi16(out0, 2);
out1 = _mm_srai_epi16(out1, 2);
out2 = _mm_srai_epi16(out2, 2);
out3 = _mm_srai_epi16(out3, 2);
const __m128i out0 = _mm_srai_epi16(res[0], 2);
const __m128i out1 = _mm_srai_epi16(res[1], 2);
const __m128i out2 = _mm_srai_epi16(res[2], 2);
const __m128i out3 = _mm_srai_epi16(res[3], 2);
store_output(&out0, (output + 0 * 8));
store_output(&out1, (output + 1 * 8));
......@@ -3118,6 +3140,14 @@ static INLINE void load_buffer_8x16(const int16_t *input, __m128i *in,
scale_sqrt2_8x8_signed(in + 8);
}
static INLINE void right_shift(__m128i *in, int size, int bit) {