Commit 157e45a4 authored by Yi Luo's avatar Yi Luo
Browse files

Fix the overflow of av1_fht32x32() in 2D DCT_DCT

- Use range check function to avoid DCT_DCT overflow.
  We need to re-develop the column txfm side scaling/rounding. Now,
  we prefer to maintain the current BDRate level.
- Encoder user level time reduction <1% owing to av1_fht32x32_avx2.
- Add MemCheck unit test and fdct32() unit test.

Change-Id: I1e67030f67bc637859798ebe2f6698afffb8531c
parent cfc5ac50
...@@ -391,6 +391,9 @@ specialize qw/av1_fht8x8 sse2/; ...@@ -391,6 +391,9 @@ specialize qw/av1_fht8x8 sse2/;
add_proto qw/void av1_fht16x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; add_proto qw/void av1_fht16x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht16x16 sse2 avx2/; specialize qw/av1_fht16x16 sse2 avx2/;
add_proto qw/void av1_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht32x32 avx2/;
if (aom_config("CONFIG_EXT_TX") eq "yes") { if (aom_config("CONFIG_EXT_TX") eq "yes") {
add_proto qw/void av1_fht4x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; add_proto qw/void av1_fht4x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht4x8 sse2/; specialize qw/av1_fht4x8 sse2/;
...@@ -409,9 +412,6 @@ if (aom_config("CONFIG_EXT_TX") eq "yes") { ...@@ -409,9 +412,6 @@ if (aom_config("CONFIG_EXT_TX") eq "yes") {
add_proto qw/void av1_fht32x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; add_proto qw/void av1_fht32x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht32x16 sse2/; specialize qw/av1_fht32x16 sse2/;
add_proto qw/void av1_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht32x32 avx2/;
} }
if (aom_config("CONFIG_EMULATE_HARDWARE") eq "yes") { if (aom_config("CONFIG_EMULATE_HARDWARE") eq "yes") {
......
...@@ -325,7 +325,6 @@ static void fdct16(const tran_low_t *input, tran_low_t *output) { ...@@ -325,7 +325,6 @@ static void fdct16(const tran_low_t *input, tran_low_t *output) {
range_check(output, 16, 16); range_check(output, 16, 16);
} }
#if CONFIG_EXT_TX
static void fdct32(const tran_low_t *input, tran_low_t *output) { static void fdct32(const tran_low_t *input, tran_low_t *output) {
tran_high_t temp; tran_high_t temp;
tran_low_t step[32]; tran_low_t step[32];
...@@ -723,7 +722,6 @@ static void fdct32(const tran_low_t *input, tran_low_t *output) { ...@@ -723,7 +722,6 @@ static void fdct32(const tran_low_t *input, tran_low_t *output) {
range_check(output, 32, 18); range_check(output, 32, 18);
} }
#endif // CONFIG_EXT_TX
static void fadst4(const tran_low_t *input, tran_low_t *output) { static void fadst4(const tran_low_t *input, tran_low_t *output) {
tran_high_t x0, x1, x2, x3; tran_high_t x0, x1, x2, x3;
...@@ -1809,57 +1807,74 @@ void av1_highbd_fht16x16_c(const int16_t *input, tran_low_t *output, int stride, ...@@ -1809,57 +1807,74 @@ void av1_highbd_fht16x16_c(const int16_t *input, tran_low_t *output, int stride,
} }
#endif // CONFIG_AOM_HIGHBITDEPTH #endif // CONFIG_AOM_HIGHBITDEPTH
#if CONFIG_EXT_TX // TODO(luoyi): Adding this function to avoid DCT_DCT overflow.
// Remove this function after we scale the column txfm output correctly.
static INLINE int range_check_dct32x32(const int16_t *input, int16_t bound,
int size) {
int i;
for (i = 0; i < size; ++i) {
if (abs(input[i]) > bound) return 1;
}
return 0;
}
void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride, void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) { int tx_type) {
if (tx_type == DCT_DCT) { static const transform_2d FHT[] = {
aom_fdct32x32_c(input, output, stride); { fdct32, fdct32 }, // DCT_DCT
} else { #if CONFIG_EXT_TX
static const transform_2d FHT[] = { { fhalfright32, fdct32 }, // ADST_DCT
{ fdct32, fdct32 }, // DCT_DCT { fdct32, fhalfright32 }, // DCT_ADST
{ fhalfright32, fdct32 }, // ADST_DCT { fhalfright32, fhalfright32 }, // ADST_ADST
{ fdct32, fhalfright32 }, // DCT_ADST { fhalfright32, fdct32 }, // FLIPADST_DCT
{ fhalfright32, fhalfright32 }, // ADST_ADST { fdct32, fhalfright32 }, // DCT_FLIPADST
{ fhalfright32, fdct32 }, // FLIPADST_DCT { fhalfright32, fhalfright32 }, // FLIPADST_FLIPADST
{ fdct32, fhalfright32 }, // DCT_FLIPADST { fhalfright32, fhalfright32 }, // ADST_FLIPADST
{ fhalfright32, fhalfright32 }, // FLIPADST_FLIPADST { fhalfright32, fhalfright32 }, // FLIPADST_ADST
{ fhalfright32, fhalfright32 }, // ADST_FLIPADST { fidtx32, fidtx32 }, // IDTX
{ fhalfright32, fhalfright32 }, // FLIPADST_ADST { fdct32, fidtx32 }, // V_DCT
{ fidtx32, fidtx32 }, // IDTX { fidtx32, fdct32 }, // H_DCT
{ fdct32, fidtx32 }, // V_DCT { fhalfright32, fidtx32 }, // V_ADST
{ fidtx32, fdct32 }, // H_DCT { fidtx32, fhalfright32 }, // H_ADST
{ fhalfright32, fidtx32 }, // V_ADST { fhalfright32, fidtx32 }, // V_FLIPADST
{ fidtx32, fhalfright32 }, // H_ADST { fidtx32, fhalfright32 }, // H_FLIPADST
{ fhalfright32, fidtx32 }, // V_FLIPADST #endif
{ fidtx32, fhalfright32 }, // H_FLIPADST };
}; const transform_2d ht = FHT[tx_type];
const transform_2d ht = FHT[tx_type]; tran_low_t out[1024];
tran_low_t out[1024]; int i, j;
int i, j; tran_low_t temp_in[32], temp_out[32];
tran_low_t temp_in[32], temp_out[32];
int16_t flipped_input[32 * 32]; #if CONFIG_EXT_TX
maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type); int16_t flipped_input[32 * 32];
maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type);
#endif
// Columns if (DCT_DCT == tx_type) {
for (i = 0; i < 32; ++i) { if (range_check_dct32x32(input, (1 << 6) - 1, 1 << 10)) {
for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4; aom_fdct32x32_c(input, output, stride);
ht.cols(temp_in, temp_out); return;
for (j = 0; j < 32; ++j)
out[j * 32 + i] = (temp_out[j] + 1 + (temp_out[j] > 0)) >> 2;
} }
}
// Columns
for (i = 0; i < 32; ++i) {
for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4;
ht.cols(temp_in, temp_out);
for (j = 0; j < 32; ++j)
out[j * 32 + i] = (temp_out[j] + 1 + (temp_out[j] > 0)) >> 2;
}
// Rows // Rows
for (i = 0; i < 32; ++i) { for (i = 0; i < 32; ++i) {
for (j = 0; j < 32; ++j) temp_in[j] = out[j + i * 32]; for (j = 0; j < 32; ++j) temp_in[j] = out[j + i * 32];
ht.rows(temp_in, temp_out); ht.rows(temp_in, temp_out);
for (j = 0; j < 32; ++j) for (j = 0; j < 32; ++j)
output[j + i * 32] = output[j + i * 32] =
(tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2); (tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2);
}
} }
} }
#if CONFIG_EXT_TX
// Forward identity transform. // Forward identity transform.
void av1_fwd_idtx_c(const int16_t *src_diff, tran_low_t *coeff, int stride, void av1_fwd_idtx_c(const int16_t *src_diff, tran_low_t *coeff, int stride,
int bs, int tx_type) { int bs, int tx_type) {
......
...@@ -21,7 +21,7 @@ static INLINE void fdct32x32(int rd_transform, const int16_t *src, ...@@ -21,7 +21,7 @@ static INLINE void fdct32x32(int rd_transform, const int16_t *src,
if (rd_transform) if (rd_transform)
aom_fdct32x32_rd(src, dst, src_stride); aom_fdct32x32_rd(src, dst, src_stride);
else else
aom_fdct32x32(src, dst, src_stride); av1_fht32x32(src, dst, src_stride, DCT_DCT);
} }
static void fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff, static void fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff,
......
...@@ -198,8 +198,8 @@ static void mm256_transpose_16x16(__m256i *in) { ...@@ -198,8 +198,8 @@ static void mm256_transpose_16x16(__m256i *in) {
in[15] = _mm256_permute2x128_si256(tr0_7, tr0_f, 0x31); in[15] = _mm256_permute2x128_si256(tr0_7, tr0_f, 0x31);
} }
static void load_buffer_16x16(const int16_t *input, int stride, int flipud, static INLINE void load_buffer_16x16(const int16_t *input, int stride,
int fliplr, __m256i *in) { int flipud, int fliplr, __m256i *in) {
if (!flipud) { if (!flipud) {
in[0] = _mm256_loadu_si256((const __m256i *)(input + 0 * stride)); in[0] = _mm256_loadu_si256((const __m256i *)(input + 0 * stride));
in[1] = _mm256_loadu_si256((const __m256i *)(input + 1 * stride)); in[1] = _mm256_loadu_si256((const __m256i *)(input + 1 * stride));
...@@ -1273,7 +1273,6 @@ void aom_fdct32x32_1_avx2(const int16_t *input, tran_low_t *output, ...@@ -1273,7 +1273,6 @@ void aom_fdct32x32_1_avx2(const int16_t *input, tran_low_t *output,
_mm256_zeroupper(); _mm256_zeroupper();
} }
#if CONFIG_EXT_TX
static void mm256_vectors_swap(__m256i *a0, __m256i *a1, const int size) { static void mm256_vectors_swap(__m256i *a0, __m256i *a1, const int size) {
int i = 0; int i = 0;
__m256i temp; __m256i temp;
...@@ -1622,7 +1621,6 @@ static void fdct32_avx2(__m256i *in0, __m256i *in1) { ...@@ -1622,7 +1621,6 @@ static void fdct32_avx2(__m256i *in0, __m256i *in1) {
mm256_transpose_32x32(in0, in1); mm256_transpose_32x32(in0, in1);
} }
#endif // CONFIG_EXT_TX
static INLINE void write_buffer_32x32(const __m256i *in0, const __m256i *in1, static INLINE void write_buffer_32x32(const __m256i *in0, const __m256i *in1,
int stride, tran_low_t *output) { int stride, tran_low_t *output) {
...@@ -1667,9 +1665,11 @@ static void fhalfright32_avx2(__m256i *in0, __m256i *in1) { ...@@ -1667,9 +1665,11 @@ static void fhalfright32_avx2(__m256i *in0, __m256i *in1) {
mm256_vectors_swap(in1, &in1[16], 16); mm256_vectors_swap(in1, &in1[16], 16);
mm256_transpose_32x32(in0, in1); mm256_transpose_32x32(in0, in1);
} }
#endif // CONFIG_EXT_TX
static void load_buffer_32x32(const int16_t *input, int stride, int flipud, static INLINE void load_buffer_32x32(const int16_t *input, int stride,
int fliplr, __m256i *in0, __m256i *in1) { int flipud, int fliplr, __m256i *in0,
__m256i *in1) {
// Load 4 16x16 blocks // Load 4 16x16 blocks
const int16_t *topL = input; const int16_t *topL = input;
const int16_t *topR = input + 16; const int16_t *topR = input + 16;
...@@ -1708,7 +1708,6 @@ static void load_buffer_32x32(const int16_t *input, int stride, int flipud, ...@@ -1708,7 +1708,6 @@ static void load_buffer_32x32(const int16_t *input, int stride, int flipud,
load_buffer_16x16(topR, stride, flipud, fliplr, in1); load_buffer_16x16(topR, stride, flipud, fliplr, in1);
load_buffer_16x16(botR, stride, flipud, fliplr, in1 + 16); load_buffer_16x16(botR, stride, flipud, fliplr, in1 + 16);
} }
#endif // CONFIG_EXT_TX
static void nr_right_shift_32x32_16col(__m256i *in) { static void nr_right_shift_32x32_16col(__m256i *in) {
int i = 0; int i = 0;
...@@ -1729,8 +1728,7 @@ static void nr_right_shift_32x32(__m256i *in0, __m256i *in1) { ...@@ -1729,8 +1728,7 @@ static void nr_right_shift_32x32(__m256i *in0, __m256i *in1) {
nr_right_shift_32x32_16col(in1); nr_right_shift_32x32_16col(in1);
} }
#if CONFIG_EXT_TX static INLINE void pr_right_shift_32x32_16col(__m256i *in) {
static void pr_right_shift_32x32_16col(__m256i *in) {
int i = 0; int i = 0;
const __m256i zero = _mm256_setzero_si256(); const __m256i zero = _mm256_setzero_si256();
const __m256i one = _mm256_set1_epi16(1); const __m256i one = _mm256_set1_epi16(1);
...@@ -1745,11 +1743,12 @@ static void pr_right_shift_32x32_16col(__m256i *in) { ...@@ -1745,11 +1743,12 @@ static void pr_right_shift_32x32_16col(__m256i *in) {
} }
// Positive rounding // Positive rounding
static void pr_right_shift_32x32(__m256i *in0, __m256i *in1) { static INLINE void pr_right_shift_32x32(__m256i *in0, __m256i *in1) {
pr_right_shift_32x32_16col(in0); pr_right_shift_32x32_16col(in0);
pr_right_shift_32x32_16col(in1); pr_right_shift_32x32_16col(in1);
} }
#if CONFIG_EXT_TX
static void fidtx32_avx2(__m256i *in0, __m256i *in1) { static void fidtx32_avx2(__m256i *in0, __m256i *in1) {
int i = 0; int i = 0;
while (i < 32) { while (i < 32) {
...@@ -1761,23 +1760,42 @@ static void fidtx32_avx2(__m256i *in0, __m256i *in1) { ...@@ -1761,23 +1760,42 @@ static void fidtx32_avx2(__m256i *in0, __m256i *in1) {
} }
#endif #endif
static INLINE int range_check_dct32x32(const __m256i *in0, const __m256i *in1,
int row) {
__m256i value, bits0, bits1;
const __m256i bound = _mm256_set1_epi16((1 << 6) - 1);
int flag;
int i = 0;
while (i < row) {
value = _mm256_abs_epi16(in0[i]);
bits0 = _mm256_cmpgt_epi16(value, bound);
value = _mm256_abs_epi16(in1[i]);
bits1 = _mm256_cmpgt_epi16(value, bound);
bits0 = _mm256_or_si256(bits0, bits1);
flag = _mm256_movemask_epi8(bits0);
if (flag) return 1;
i++;
}
return 0;
}
void av1_fht32x32_avx2(const int16_t *input, tran_low_t *output, int stride, void av1_fht32x32_avx2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) { int tx_type) {
__m256i in0[32]; // left 32 columns __m256i in0[32]; // left 32 columns
__m256i in1[32]; // right 32 columns __m256i in1[32]; // right 32 columns
(void)input;
(void)stride;
switch (tx_type) { switch (tx_type) {
// TODO(luoyi): For DCT_DCT, fwd_txfm_32x32() uses aom set. But this case DCT_DCT:
// function has better speed. The replacement must work with the load_buffer_32x32(input, stride, 0, 0, in0, in1);
// corresponding inverse transform. if (range_check_dct32x32(in0, in1, 32)) {
// case DCT_DCT: aom_fdct32x32_avx2(input, output, stride);
// load_buffer_32x32(input, stride, 0, 0, in0, in1); return;
// fdct32_avx2(in0, in1); }
// pr_right_shift_32x32(in0, in1); fdct32_avx2(in0, in1);
// fdct32_avx2(in0, in1); pr_right_shift_32x32(in0, in1);
// break; fdct32_avx2(in0, in1);
break;
#if CONFIG_EXT_TX #if CONFIG_EXT_TX
case ADST_DCT: case ADST_DCT:
load_buffer_32x32(input, stride, 0, 0, in0, in1); load_buffer_32x32(input, stride, 0, 0, in0, in1);
......
...@@ -102,5 +102,6 @@ INSTANTIATE_TEST_CASE_P( ...@@ -102,5 +102,6 @@ INSTANTIATE_TEST_CASE_P(
C, AV1FwdTxfm, C, AV1FwdTxfm,
::testing::Values(FdctParam(&fdct4, &reference_dct_1d, 4, 1), ::testing::Values(FdctParam(&fdct4, &reference_dct_1d, 4, 1),
FdctParam(&fdct8, &reference_dct_1d, 8, 1), FdctParam(&fdct8, &reference_dct_1d, 8, 1),
FdctParam(&fdct16, &reference_dct_1d, 16, 2))); FdctParam(&fdct16, &reference_dct_1d, 16, 2),
FdctParam(&fdct32, &reference_dct_1d, 32, 3)));
} // namespace } // namespace
...@@ -69,6 +69,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase, ...@@ -69,6 +69,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase,
inv_txfm_ = GET_PARAM(1); inv_txfm_ = GET_PARAM(1);
tx_type_ = GET_PARAM(2); tx_type_ = GET_PARAM(2);
pitch_ = 32; pitch_ = 32;
height_ = 32;
fwd_txfm_ref = fht32x32_ref; fwd_txfm_ref = fht32x32_ref;
bit_depth_ = GET_PARAM(3); bit_depth_ = GET_PARAM(3);
mask_ = (1 << bit_depth_) - 1; mask_ = (1 << bit_depth_) - 1;
...@@ -90,6 +91,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase, ...@@ -90,6 +91,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase,
}; };
TEST_P(AV1Trans32x32HT, CoeffCheck) { RunCoeffCheck(); } TEST_P(AV1Trans32x32HT, CoeffCheck) { RunCoeffCheck(); }
TEST_P(AV1Trans32x32HT, MemCheck) { RunMemCheck(); }
#if CONFIG_AOM_HIGHBITDEPTH #if CONFIG_AOM_HIGHBITDEPTH
class AV1HighbdTrans32x32HT class AV1HighbdTrans32x32HT
...@@ -164,8 +166,7 @@ using std::tr1::make_tuple; ...@@ -164,8 +166,7 @@ using std::tr1::make_tuple;
#if HAVE_AVX2 #if HAVE_AVX2
const Ht32x32Param kArrayHt32x32Param_avx2[] = { const Ht32x32Param kArrayHt32x32Param_avx2[] = {
// TODO(luoyi): DCT_DCT tx_type is not enabled in av1_fht32x32_c(avx2) yet. make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 0, AOM_BITS_8, 1024),
// make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 0, AOM_BITS_8, 1024),
make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 1, AOM_BITS_8, 1024), make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 1, AOM_BITS_8, 1024),
make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 2, AOM_BITS_8, 1024), make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 2, AOM_BITS_8, 1024),
make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 3, AOM_BITS_8, 1024), make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 3, AOM_BITS_8, 1024),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment