Commit 3b0b5f17 authored by Yi Luo's avatar Yi Luo
Browse files

Fix 16x32, 32x16 rectangular transform SSE2 to match C

- Turn on SSE2 unit tests

Change-Id: I285771b04c0dec0501210fde570b9ac3cb9c4be0
parent ab9ecbab
......@@ -92,14 +92,14 @@ static void fwd_txfm_16x32(const int16_t *src_diff, tran_low_t *coeff,
int diff_stride, TX_TYPE tx_type,
FWD_TXFM_OPT fwd_txfm_opt) {
(void)fwd_txfm_opt;
av1_fht16x32_c(src_diff, coeff, diff_stride, tx_type);
av1_fht16x32(src_diff, coeff, diff_stride, tx_type);
}
static void fwd_txfm_32x16(const int16_t *src_diff, tran_low_t *coeff,
int diff_stride, TX_TYPE tx_type,
FWD_TXFM_OPT fwd_txfm_opt) {
(void)fwd_txfm_opt;
av1_fht32x16_c(src_diff, coeff, diff_stride, tx_type);
av1_fht32x16(src_diff, coeff, diff_stride, tx_type);
}
static void fwd_txfm_8x8(const int16_t *src_diff, tran_low_t *coeff,
......
......@@ -3453,36 +3453,6 @@ static INLINE void fdct32_16col(__m128i *tl, __m128i *tr, __m128i *bl,
array_transpose_16x16(bl, br);
}
static INLINE void fhalfright32_16col(__m128i *tl, __m128i *tr, __m128i *bl,
__m128i *br) {
__m128i tmpl[16], tmpr[16];
int i;
// Copy the bottom half of the input to temporary storage
for (i = 0; i < 16; ++i) {
tmpl[i] = bl[i];
tmpr[i] = br[i];
}
// Generate the bottom half of the output
for (i = 0; i < 16; ++i) {
bl[i] = _mm_slli_epi16(tl[i], 2);
br[i] = _mm_slli_epi16(tr[i], 2);
}
array_transpose_16x16(bl, br);
// Copy the temporary storage back to the top half of the input
for (i = 0; i < 16; ++i) {
tl[i] = tmpl[i];
tr[i] = tmpr[i];
}
// Generate the top half of the output
scale_sqrt2_8x16(tl);
scale_sqrt2_8x16(tr);
fdct16_sse2(tl, tr);
}
#if CONFIG_EXT_TX
static INLINE void fidtx32_16col(__m128i *tl, __m128i *tr, __m128i *bl,
__m128i *br) {
......@@ -3541,8 +3511,6 @@ static INLINE void write_buffer_16x32(tran_low_t *output, __m128i *restl,
__m128i *restr, __m128i *resbl,
__m128i *resbr) {
int i;
right_shift_16x16(restl, restr);
right_shift_16x16(resbl, resbr);
for (i = 0; i < 16; ++i) {
store_output(&restl[i], output + i * 16 + 0);
store_output(&restr[i], output + i * 16 + 8);
......@@ -3551,6 +3519,104 @@ static INLINE void write_buffer_16x32(tran_low_t *output, __m128i *restl,
}
}
static INLINE void round_signed_8x8(__m128i *in, const int bit) {
const __m128i rounding = _mm_set1_epi16((1 << bit) >> 1);
__m128i sign0 = _mm_srai_epi16(in[0], 15);
__m128i sign1 = _mm_srai_epi16(in[1], 15);
__m128i sign2 = _mm_srai_epi16(in[2], 15);
__m128i sign3 = _mm_srai_epi16(in[3], 15);
__m128i sign4 = _mm_srai_epi16(in[4], 15);
__m128i sign5 = _mm_srai_epi16(in[5], 15);
__m128i sign6 = _mm_srai_epi16(in[6], 15);
__m128i sign7 = _mm_srai_epi16(in[7], 15);
in[0] = _mm_add_epi16(_mm_add_epi16(in[0], rounding), sign0);
in[1] = _mm_add_epi16(_mm_add_epi16(in[1], rounding), sign1);
in[2] = _mm_add_epi16(_mm_add_epi16(in[2], rounding), sign2);
in[3] = _mm_add_epi16(_mm_add_epi16(in[3], rounding), sign3);
in[4] = _mm_add_epi16(_mm_add_epi16(in[4], rounding), sign4);
in[5] = _mm_add_epi16(_mm_add_epi16(in[5], rounding), sign5);
in[6] = _mm_add_epi16(_mm_add_epi16(in[6], rounding), sign6);
in[7] = _mm_add_epi16(_mm_add_epi16(in[7], rounding), sign7);
in[0] = _mm_srai_epi16(in[0], bit);
in[1] = _mm_srai_epi16(in[1], bit);
in[2] = _mm_srai_epi16(in[2], bit);
in[3] = _mm_srai_epi16(in[3], bit);
in[4] = _mm_srai_epi16(in[4], bit);
in[5] = _mm_srai_epi16(in[5], bit);
in[6] = _mm_srai_epi16(in[6], bit);
in[7] = _mm_srai_epi16(in[7], bit);
}
static INLINE void round_signed_16x16(__m128i *in0, __m128i *in1) {
const int bit = 4;
round_signed_8x8(in0, bit);
round_signed_8x8(in0 + 8, bit);
round_signed_8x8(in1, bit);
round_signed_8x8(in1 + 8, bit);
}
// Note:
// suffix "t" indicates the transpose operation comes first
static void fdct16t_sse2(__m128i *in0, __m128i *in1) {
array_transpose_16x16(in0, in1);
fdct16_8col(in0);
fdct16_8col(in1);
}
static void fadst16t_sse2(__m128i *in0, __m128i *in1) {
array_transpose_16x16(in0, in1);
fadst16_8col(in0);
fadst16_8col(in1);
}
static INLINE void fdct32t_16col(__m128i *tl, __m128i *tr, __m128i *bl,
__m128i *br) {
array_transpose_16x16(tl, tr);
array_transpose_16x16(bl, br);
fdct32_8col(tl, bl);
fdct32_8col(tr, br);
}
typedef enum transpose_indicator_ {
transpose,
no_transpose,
} transpose_indicator;
static INLINE void fhalfright32_16col(__m128i *tl, __m128i *tr, __m128i *bl,
__m128i *br, transpose_indicator t) {
__m128i tmpl[16], tmpr[16];
int i;
// Copy the bottom half of the input to temporary storage
for (i = 0; i < 16; ++i) {
tmpl[i] = bl[i];
tmpr[i] = br[i];
}
// Generate the bottom half of the output
for (i = 0; i < 16; ++i) {
bl[i] = _mm_slli_epi16(tl[i], 2);
br[i] = _mm_slli_epi16(tr[i], 2);
}
array_transpose_16x16(bl, br);
// Copy the temporary storage back to the top half of the input
for (i = 0; i < 16; ++i) {
tl[i] = tmpl[i];
tr[i] = tmpr[i];
}
// Generate the top half of the output
scale_sqrt2_8x16(tl);
scale_sqrt2_8x16(tr);
if (t == transpose)
fdct16t_sse2(tl, tr);
else
fdct16_sse2(tl, tr);
}
// Note on data layout, for both this and the 32x16 transforms:
// So that we can reuse the 16-element transforms easily,
// we want to split the input into 8x16 blocks.
......@@ -3563,132 +3629,132 @@ void av1_fht16x32_sse2(const int16_t *input, tran_low_t *output, int stride,
switch (tx_type) {
case DCT_DCT:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fdct32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fdct16_sse2(intl, intr);
fdct16_sse2(inbl, inbr);
fdct16t_sse2(intl, intr);
fdct16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fdct32t_16col(intl, intr, inbl, inbr);
break;
case ADST_DCT:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fdct16_sse2(intl, intr);
fdct16_sse2(inbl, inbr);
fdct16t_sse2(intl, intr);
fdct16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
case DCT_ADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fdct32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fdct32t_16col(intl, intr, inbl, inbr);
break;
case ADST_ADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
#if CONFIG_EXT_TX
case FLIPADST_DCT:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 0);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fdct16_sse2(intl, intr);
fdct16_sse2(inbl, inbr);
fdct16t_sse2(intl, intr);
fdct16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
case DCT_FLIPADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 1);
fdct32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fdct32t_16col(intl, intr, inbl, inbr);
break;
case FLIPADST_FLIPADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 1);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
case ADST_FLIPADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 1);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
case FLIPADST_ADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 0);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
case IDTX:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fidtx32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fidtx16_sse2(intl, intr);
fidtx16_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fidtx32_16col(intl, intr, inbl, inbr);
break;
case V_DCT:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fdct32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fidtx16_sse2(intl, intr);
fidtx16_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fdct32t_16col(intl, intr, inbl, inbr);
break;
case H_DCT:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fdct16t_sse2(intl, intr);
fdct16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fidtx32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fdct16_sse2(intl, intr);
fdct16_sse2(inbl, inbr);
break;
case V_ADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fidtx16_sse2(intl, intr);
fidtx16_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
case H_ADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fidtx32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
break;
case V_FLIPADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 0);
fhalfright32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fidtx16_sse2(intl, intr);
fidtx16_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fhalfright32_16col(intl, intr, inbl, inbr, transpose);
break;
case H_FLIPADST:
load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 1);
fadst16t_sse2(intl, intr);
fadst16t_sse2(inbl, inbr);
round_signed_16x16(intl, intr);
round_signed_16x16(inbl, inbr);
fidtx32_16col(intl, intr, inbl, inbr);
right_shift_16x16(intl, intr);
right_shift_16x16(inbl, inbr);
fadst16_sse2(intl, intr);
fadst16_sse2(inbl, inbr);
break;
#endif
default: assert(0); break;
......@@ -3737,8 +3803,6 @@ static INLINE void write_buffer_32x16(tran_low_t *output, __m128i *res0,
__m128i *res1, __m128i *res2,
__m128i *res3) {
int i;
right_shift_16x16(res0, res1);
right_shift_16x16(res2, res3);
for (i = 0; i < 16; ++i) {
store_output(&res0[i], output + i * 32 + 0);
store_output(&res1[i], output + i * 32 + 8);
......@@ -3756,127 +3820,127 @@ void av1_fht32x16_sse2(const int16_t *input, tran_low_t *output, int stride,
case DCT_DCT:
fdct16_sse2(in0, in1);
fdct16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fdct32_16col(in0, in1, in2, in3);
break;
case ADST_DCT:
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fdct32_16col(in0, in1, in2, in3);
break;
case DCT_ADST:
fdct16_sse2(in0, in1);
fdct16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
case ADST_ADST:
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
#if CONFIG_EXT_TX
case FLIPADST_DCT:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 0);
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fdct32_16col(in0, in1, in2, in3);
break;
case DCT_FLIPADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 1);
fdct16_sse2(in0, in1);
fdct16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
case FLIPADST_FLIPADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 1);
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
case ADST_FLIPADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 1);
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
case FLIPADST_ADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 0);
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
case IDTX:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0);
fidtx16_sse2(in0, in1);
fidtx16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fidtx32_16col(in0, in1, in2, in3);
break;
case V_DCT:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0);
fdct16_sse2(in0, in1);
fdct16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fidtx32_16col(in0, in1, in2, in3);
break;
case H_DCT:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0);
fidtx16_sse2(in0, in1);
fidtx16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fdct32_16col(in0, in1, in2, in3);
break;
case V_ADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0);
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fidtx32_16col(in0, in1, in2, in3);
break;
case H_ADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0);
fidtx16_sse2(in0, in1);
fidtx16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
case V_FLIPADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 0);
fadst16_sse2(in0, in1);
fadst16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fidtx32_16col(in0, in1, in2, in3);
break;
case H_FLIPADST:
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 1);
fidtx16_sse2(in0, in1);
fidtx16_sse2(in2, in3);
right_shift_16x16(in0, in1);
right_shift_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3);
round_signed_16x16(in0, in1);
round_signed_16x16(in2, in3);
fhalfright32_16col(in0, in1, in2, in3, no_transpose);
break;
#endif
default: assert(0); break;
......
......@@ -137,7 +137,7 @@ const Ht16x32Param kArrayHt16x32Param_sse2[] = {
512)
#endif // CONFIG_EXT_TX
};
INSTANTIATE_TEST_CASE_P(DISABLED_SSE2, AV1Trans16x32HT,
INSTANTIATE_TEST_CASE_P(SSE2, AV1Trans16x32HT,
::testing::ValuesIn(kArrayHt16x32Param_sse2));
#endif // HAVE_SSE2
......
......@@ -137,7 +137,7 @@ const Ht32x16Param kArrayHt32x16Param_sse2[] = {
512)
#endif // CONFIG_EXT_TX
};
INSTANTIATE_TEST_CASE_P(DISABLED_SSE2, AV1Trans32x16HT,
INSTANTIATE_TEST_CASE_P(SSE2, AV1Trans32x16HT,
::testing::ValuesIn(kArrayHt32x16Param_sse2));
#endif // HAVE_SSE2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment