Commit e4abb97b authored by Yi Luo's avatar Yi Luo Committed by Gerrit Code Review

Merge "Fix the overflow of av1_fht32x32() in 2D DCT_DCT" into nextgenv2

parents b97c3a13 157e45a4
......@@ -391,6 +391,9 @@ specialize qw/av1_fht8x8 sse2/;
add_proto qw/void av1_fht16x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht16x16 sse2 avx2/;
add_proto qw/void av1_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht32x32 avx2/;
if (aom_config("CONFIG_EXT_TX") eq "yes") {
add_proto qw/void av1_fht4x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht4x8 sse2/;
......@@ -409,9 +412,6 @@ if (aom_config("CONFIG_EXT_TX") eq "yes") {
add_proto qw/void av1_fht32x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht32x16 sse2/;
add_proto qw/void av1_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type";
specialize qw/av1_fht32x32 avx2/;
}
if (aom_config("CONFIG_EMULATE_HARDWARE") eq "yes") {
......
......@@ -325,7 +325,6 @@ static void fdct16(const tran_low_t *input, tran_low_t *output) {
range_check(output, 16, 16);
}
#if CONFIG_EXT_TX
static void fdct32(const tran_low_t *input, tran_low_t *output) {
tran_high_t temp;
tran_low_t step[32];
......@@ -723,7 +722,6 @@ static void fdct32(const tran_low_t *input, tran_low_t *output) {
range_check(output, 32, 18);
}
#endif // CONFIG_EXT_TX
static void fadst4(const tran_low_t *input, tran_low_t *output) {
tran_high_t x0, x1, x2, x3;
......@@ -1809,57 +1807,74 @@ void av1_highbd_fht16x16_c(const int16_t *input, tran_low_t *output, int stride,
}
#endif // CONFIG_AOM_HIGHBITDEPTH
#if CONFIG_EXT_TX
// TODO(luoyi): Adding this function to avoid DCT_DCT overflow.
// Remove this function after we scale the column txfm output correctly.
static INLINE int range_check_dct32x32(const int16_t *input, int16_t bound,
int size) {
int i;
for (i = 0; i < size; ++i) {
if (abs(input[i]) > bound) return 1;
}
return 0;
}
void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
if (tx_type == DCT_DCT) {
aom_fdct32x32_c(input, output, stride);
} else {
static const transform_2d FHT[] = {
{ fdct32, fdct32 }, // DCT_DCT
{ fhalfright32, fdct32 }, // ADST_DCT
{ fdct32, fhalfright32 }, // DCT_ADST
{ fhalfright32, fhalfright32 }, // ADST_ADST
{ fhalfright32, fdct32 }, // FLIPADST_DCT
{ fdct32, fhalfright32 }, // DCT_FLIPADST
{ fhalfright32, fhalfright32 }, // FLIPADST_FLIPADST
{ fhalfright32, fhalfright32 }, // ADST_FLIPADST
{ fhalfright32, fhalfright32 }, // FLIPADST_ADST
{ fidtx32, fidtx32 }, // IDTX
{ fdct32, fidtx32 }, // V_DCT
{ fidtx32, fdct32 }, // H_DCT
{ fhalfright32, fidtx32 }, // V_ADST
{ fidtx32, fhalfright32 }, // H_ADST
{ fhalfright32, fidtx32 }, // V_FLIPADST
{ fidtx32, fhalfright32 }, // H_FLIPADST
};
const transform_2d ht = FHT[tx_type];
tran_low_t out[1024];
int i, j;
tran_low_t temp_in[32], temp_out[32];
static const transform_2d FHT[] = {
{ fdct32, fdct32 }, // DCT_DCT
#if CONFIG_EXT_TX
{ fhalfright32, fdct32 }, // ADST_DCT
{ fdct32, fhalfright32 }, // DCT_ADST
{ fhalfright32, fhalfright32 }, // ADST_ADST
{ fhalfright32, fdct32 }, // FLIPADST_DCT
{ fdct32, fhalfright32 }, // DCT_FLIPADST
{ fhalfright32, fhalfright32 }, // FLIPADST_FLIPADST
{ fhalfright32, fhalfright32 }, // ADST_FLIPADST
{ fhalfright32, fhalfright32 }, // FLIPADST_ADST
{ fidtx32, fidtx32 }, // IDTX
{ fdct32, fidtx32 }, // V_DCT
{ fidtx32, fdct32 }, // H_DCT
{ fhalfright32, fidtx32 }, // V_ADST
{ fidtx32, fhalfright32 }, // H_ADST
{ fhalfright32, fidtx32 }, // V_FLIPADST
{ fidtx32, fhalfright32 }, // H_FLIPADST
#endif
};
const transform_2d ht = FHT[tx_type];
tran_low_t out[1024];
int i, j;
tran_low_t temp_in[32], temp_out[32];
int16_t flipped_input[32 * 32];
maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type);
#if CONFIG_EXT_TX
int16_t flipped_input[32 * 32];
maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type);
#endif
// Columns
for (i = 0; i < 32; ++i) {
for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4;
ht.cols(temp_in, temp_out);
for (j = 0; j < 32; ++j)
out[j * 32 + i] = (temp_out[j] + 1 + (temp_out[j] > 0)) >> 2;
if (DCT_DCT == tx_type) {
if (range_check_dct32x32(input, (1 << 6) - 1, 1 << 10)) {
aom_fdct32x32_c(input, output, stride);
return;
}
}
// Columns
for (i = 0; i < 32; ++i) {
for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4;
ht.cols(temp_in, temp_out);
for (j = 0; j < 32; ++j)
out[j * 32 + i] = (temp_out[j] + 1 + (temp_out[j] > 0)) >> 2;
}
// Rows
for (i = 0; i < 32; ++i) {
for (j = 0; j < 32; ++j) temp_in[j] = out[j + i * 32];
ht.rows(temp_in, temp_out);
for (j = 0; j < 32; ++j)
output[j + i * 32] =
(tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2);
}
// Rows
for (i = 0; i < 32; ++i) {
for (j = 0; j < 32; ++j) temp_in[j] = out[j + i * 32];
ht.rows(temp_in, temp_out);
for (j = 0; j < 32; ++j)
output[j + i * 32] =
(tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2);
}
}
#if CONFIG_EXT_TX
// Forward identity transform.
void av1_fwd_idtx_c(const int16_t *src_diff, tran_low_t *coeff, int stride,
int bs, int tx_type) {
......
......@@ -21,7 +21,7 @@ static INLINE void fdct32x32(int rd_transform, const int16_t *src,
if (rd_transform)
aom_fdct32x32_rd(src, dst, src_stride);
else
aom_fdct32x32(src, dst, src_stride);
av1_fht32x32(src, dst, src_stride, DCT_DCT);
}
static void fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff,
......
......@@ -198,8 +198,8 @@ static void mm256_transpose_16x16(__m256i *in) {
in[15] = _mm256_permute2x128_si256(tr0_7, tr0_f, 0x31);
}
static void load_buffer_16x16(const int16_t *input, int stride, int flipud,
int fliplr, __m256i *in) {
static INLINE void load_buffer_16x16(const int16_t *input, int stride,
int flipud, int fliplr, __m256i *in) {
if (!flipud) {
in[0] = _mm256_loadu_si256((const __m256i *)(input + 0 * stride));
in[1] = _mm256_loadu_si256((const __m256i *)(input + 1 * stride));
......@@ -1273,7 +1273,6 @@ void aom_fdct32x32_1_avx2(const int16_t *input, tran_low_t *output,
_mm256_zeroupper();
}
#if CONFIG_EXT_TX
static void mm256_vectors_swap(__m256i *a0, __m256i *a1, const int size) {
int i = 0;
__m256i temp;
......@@ -1622,7 +1621,6 @@ static void fdct32_avx2(__m256i *in0, __m256i *in1) {
mm256_transpose_32x32(in0, in1);
}
#endif // CONFIG_EXT_TX
static INLINE void write_buffer_32x32(const __m256i *in0, const __m256i *in1,
int stride, tran_low_t *output) {
......@@ -1667,9 +1665,11 @@ static void fhalfright32_avx2(__m256i *in0, __m256i *in1) {
mm256_vectors_swap(in1, &in1[16], 16);
mm256_transpose_32x32(in0, in1);
}
#endif // CONFIG_EXT_TX
static void load_buffer_32x32(const int16_t *input, int stride, int flipud,
int fliplr, __m256i *in0, __m256i *in1) {
static INLINE void load_buffer_32x32(const int16_t *input, int stride,
int flipud, int fliplr, __m256i *in0,
__m256i *in1) {
// Load 4 16x16 blocks
const int16_t *topL = input;
const int16_t *topR = input + 16;
......@@ -1708,7 +1708,6 @@ static void load_buffer_32x32(const int16_t *input, int stride, int flipud,
load_buffer_16x16(topR, stride, flipud, fliplr, in1);
load_buffer_16x16(botR, stride, flipud, fliplr, in1 + 16);
}
#endif // CONFIG_EXT_TX
static void nr_right_shift_32x32_16col(__m256i *in) {
int i = 0;
......@@ -1729,8 +1728,7 @@ static void nr_right_shift_32x32(__m256i *in0, __m256i *in1) {
nr_right_shift_32x32_16col(in1);
}
#if CONFIG_EXT_TX
static void pr_right_shift_32x32_16col(__m256i *in) {
static INLINE void pr_right_shift_32x32_16col(__m256i *in) {
int i = 0;
const __m256i zero = _mm256_setzero_si256();
const __m256i one = _mm256_set1_epi16(1);
......@@ -1745,11 +1743,12 @@ static void pr_right_shift_32x32_16col(__m256i *in) {
}
// Positive rounding
static void pr_right_shift_32x32(__m256i *in0, __m256i *in1) {
static INLINE void pr_right_shift_32x32(__m256i *in0, __m256i *in1) {
pr_right_shift_32x32_16col(in0);
pr_right_shift_32x32_16col(in1);
}
#if CONFIG_EXT_TX
static void fidtx32_avx2(__m256i *in0, __m256i *in1) {
int i = 0;
while (i < 32) {
......@@ -1761,23 +1760,42 @@ static void fidtx32_avx2(__m256i *in0, __m256i *in1) {
}
#endif
static INLINE int range_check_dct32x32(const __m256i *in0, const __m256i *in1,
int row) {
__m256i value, bits0, bits1;
const __m256i bound = _mm256_set1_epi16((1 << 6) - 1);
int flag;
int i = 0;
while (i < row) {
value = _mm256_abs_epi16(in0[i]);
bits0 = _mm256_cmpgt_epi16(value, bound);
value = _mm256_abs_epi16(in1[i]);
bits1 = _mm256_cmpgt_epi16(value, bound);
bits0 = _mm256_or_si256(bits0, bits1);
flag = _mm256_movemask_epi8(bits0);
if (flag) return 1;
i++;
}
return 0;
}
void av1_fht32x32_avx2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
__m256i in0[32]; // left 32 columns
__m256i in1[32]; // right 32 columns
(void)input;
(void)stride;
switch (tx_type) {
// TODO(luoyi): For DCT_DCT, fwd_txfm_32x32() uses aom set. But this
// function has better speed. The replacement must work with the
// corresponding inverse transform.
// case DCT_DCT:
// load_buffer_32x32(input, stride, 0, 0, in0, in1);
// fdct32_avx2(in0, in1);
// pr_right_shift_32x32(in0, in1);
// fdct32_avx2(in0, in1);
// break;
case DCT_DCT:
load_buffer_32x32(input, stride, 0, 0, in0, in1);
if (range_check_dct32x32(in0, in1, 32)) {
aom_fdct32x32_avx2(input, output, stride);
return;
}
fdct32_avx2(in0, in1);
pr_right_shift_32x32(in0, in1);
fdct32_avx2(in0, in1);
break;
#if CONFIG_EXT_TX
case ADST_DCT:
load_buffer_32x32(input, stride, 0, 0, in0, in1);
......
......@@ -102,5 +102,6 @@ INSTANTIATE_TEST_CASE_P(
C, AV1FwdTxfm,
::testing::Values(FdctParam(&fdct4, &reference_dct_1d, 4, 1),
FdctParam(&fdct8, &reference_dct_1d, 8, 1),
FdctParam(&fdct16, &reference_dct_1d, 16, 2)));
FdctParam(&fdct16, &reference_dct_1d, 16, 2),
FdctParam(&fdct32, &reference_dct_1d, 32, 3)));
} // namespace
......@@ -69,6 +69,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase,
inv_txfm_ = GET_PARAM(1);
tx_type_ = GET_PARAM(2);
pitch_ = 32;
height_ = 32;
fwd_txfm_ref = fht32x32_ref;
bit_depth_ = GET_PARAM(3);
mask_ = (1 << bit_depth_) - 1;
......@@ -90,6 +91,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase,
};
TEST_P(AV1Trans32x32HT, CoeffCheck) { RunCoeffCheck(); }
TEST_P(AV1Trans32x32HT, MemCheck) { RunMemCheck(); }
#if CONFIG_AOM_HIGHBITDEPTH
class AV1HighbdTrans32x32HT
......@@ -164,8 +166,7 @@ using std::tr1::make_tuple;
#if HAVE_AVX2
const Ht32x32Param kArrayHt32x32Param_avx2[] = {
// TODO(luoyi): DCT_DCT tx_type is not enabled in av1_fht32x32_c(avx2) yet.
// make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 0, AOM_BITS_8, 1024),
make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 0, AOM_BITS_8, 1024),
make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 1, AOM_BITS_8, 1024),
make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 2, AOM_BITS_8, 1024),
make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 3, AOM_BITS_8, 1024),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment