Commit d8b1ddce authored by Lester Lu's avatar Lester Lu

Signature changes for the LGT experiment

The input arguments of av1_fht* and av1_iht* functions (and their
HBD versions) are slightly changed. Input arguments tx_type and
bd are carried by a struct fwd_txfm_param/inv_txfm_param. This
struct is meant to later on carry other prediction information,
such as intra top/left boundaries to the transform level, so
that the choice of transforms can be more adaptive to the
prediction mode and local video content.

Change-Id: Ia42544248a51845be64b72855b642ef1fe5910a9
parent cff9171e
......@@ -16,6 +16,7 @@
#include "./av1_rtcd.h"
#include "aom_dsp/txfm_common.h"
#include "av1/common/common.h"
#include "av1/common/idct.h"
static INLINE void TRANSPOSE4X4(int16x8_t *q8s16, int16x8_t *q9s16) {
int32x4_t q8s32, q9s32;
......@@ -134,7 +135,7 @@ static INLINE void IADST4x4_1D(int16x4_t *d3s16, int16x4_t *d4s16,
}
void av1_iht4x4_16_add_neon(const tran_low_t *input, uint8_t *dest,
int dest_stride, int tx_type) {
int dest_stride, const INV_TXFM_PARAM *param) {
uint8x8_t d26u8, d27u8;
int16x4_t d0s16, d1s16, d2s16, d3s16, d4s16, d5s16;
uint32x2_t d26u32, d27u32;
......@@ -148,9 +149,10 @@ void av1_iht4x4_16_add_neon(const tran_low_t *input, uint8_t *dest,
TRANSPOSE4X4(&q8s16, &q9s16);
int tx_type = param->tx_type;
switch (tx_type) {
case 0: // idct_idct is not supported. Fall back to C
av1_iht4x4_16_add_c(input, dest, dest_stride, tx_type);
av1_iht4x4_16_add_c(input, dest, dest_stride, param);
return;
break;
case 1: // iadst_idct
......
......@@ -16,6 +16,7 @@
#include "./av1_rtcd.h"
#include "aom_dsp/txfm_common.h"
#include "av1/common/common.h"
#include "av1/common/idct.h"
static INLINE void TRANSPOSE8X8(int16x8_t *q8s16, int16x8_t *q9s16,
int16x8_t *q10s16, int16x8_t *q11s16,
......@@ -458,7 +459,7 @@ static INLINE void IADST8X8_1D(int16x8_t *q8s16, int16x8_t *q9s16,
}
void av1_iht8x8_64_add_neon(const tran_low_t *input, uint8_t *dest,
int dest_stride, int tx_type) {
int dest_stride, const INV_TXFM_PARAM *param) {
int i;
uint8_t *d1, *d2;
uint8x8_t d0u8, d1u8, d2u8, d3u8;
......@@ -478,9 +479,10 @@ void av1_iht8x8_64_add_neon(const tran_low_t *input, uint8_t *dest,
TRANSPOSE8X8(&q8s16, &q9s16, &q10s16, &q11s16, &q12s16, &q13s16, &q14s16,
&q15s16);
int tx_type = param->tx_type;
switch (tx_type) {
case 0: // idct_idct is not supported. Fall back to C
av1_iht8x8_64_add_c(input, dest, dest_stride, tx_type);
av1_iht8x8_64_add_c(input, dest, dest_stride, param);
return;
break;
case 1: // iadst_idct
......
This diff is collapsed.
This diff is collapsed.
......@@ -26,7 +26,16 @@
extern "C" {
#endif
typedef struct INV_TXFM_PARAM {
// TODO(kslu) Combine FWD_TXFM_PARAM and INV_TXFM_PARAM into a common struct.
// and move the common stuff in idct.h to av1_txfm.h or txfm_common.h
typedef struct fwd_txfm_param {
TX_TYPE tx_type;
TX_SIZE tx_size;
int lossless;
int bd;
} FWD_TXFM_PARAM;
typedef struct inv_txfm_param {
#if CONFIG_ADAPT_SCAN
const int16_t *eob_threshold;
#endif
......@@ -71,12 +80,11 @@ void av1_inverse_transform_block_facade(MACROBLOCKD *xd, int plane, int block,
void av1_highbd_iwht4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
int eob, int bd);
void av1_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
int stride, int eob, int bd, TX_TYPE tx_type,
int lossless);
int stride, const INV_TXFM_PARAM *param);
void av1_highbd_inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest,
int stride, int eob, int bd, TX_TYPE tx_type);
int stride, const INV_TXFM_PARAM *param);
void av1_highbd_inv_txfm_add_8x4(const tran_low_t *input, uint8_t *dest,
int stride, int eob, int bd, TX_TYPE tx_type);
int stride, const INV_TXFM_PARAM *param);
void av1_highbd_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
INV_TXFM_PARAM *inv_txfm_param);
......
......@@ -23,12 +23,13 @@
#if HAVE_DSPR2
void av1_iht16x16_256_add_dspr2(const int16_t *input, uint8_t *dest, int pitch,
int tx_type) {
FWD_TXFM_PARAM *param) {
int i, j;
DECLARE_ALIGNED(32, int16_t, out[16 * 16]);
int16_t *outptr = out;
int16_t temp_out[16];
uint32_t pos = 45;
int tx_type = param->tx_type;
/* bit positon for extract from acc */
__asm__ __volatile__("wrdsp %[pos], 1 \n\t" : : [pos] "r"(pos));
......
......@@ -23,12 +23,13 @@
#if HAVE_DSPR2
void av1_iht4x4_16_add_dspr2(const int16_t *input, uint8_t *dest,
int dest_stride, int tx_type) {
int dest_stride, FWD_TXFM_PARAM *param) {
int i, j;
DECLARE_ALIGNED(32, int16_t, out[4 * 4]);
int16_t *outptr = out;
int16_t temp_in[4 * 4], temp_out[4];
uint32_t pos = 45;
int tx_type = param->tx_type;
/* bit positon for extract from acc */
__asm__ __volatile__("wrdsp %[pos], 1 \n\t"
......
......@@ -16,18 +16,20 @@
#include "./av1_rtcd.h"
#include "av1/common/common.h"
#include "av1/common/blockd.h"
#include "av1/common/idct.h"
#include "aom_dsp/mips/inv_txfm_dspr2.h"
#include "aom_dsp/txfm_common.h"
#include "aom_ports/mem.h"
#if HAVE_DSPR2
void av1_iht8x8_64_add_dspr2(const int16_t *input, uint8_t *dest,
int dest_stride, int tx_type) {
int dest_stride, FWD_TXFM_PARAM *param) {
int i, j;
DECLARE_ALIGNED(32, int16_t, out[8 * 8]);
int16_t *outptr = out;
int16_t temp_in[8 * 8], temp_out[8];
uint32_t pos = 45;
int tx_type = param->tx_type;
/* bit positon for extract from acc */
__asm__ __volatile__("wrdsp %[pos], 1 \n\t" : : [pos] "r"(pos));
......
......@@ -12,13 +12,15 @@
#include <assert.h>
#include "av1/common/enums.h"
#include "av1/common/idct.h"
#include "aom_dsp/mips/inv_txfm_msa.h"
void av1_iht16x16_256_add_msa(const int16_t *input, uint8_t *dst,
int32_t dst_stride, int32_t tx_type) {
int32_t dst_stride, FWD_TXFM_PARAM *param) {
int32_t i;
DECLARE_ALIGNED(32, int16_t, out[16 * 16]);
int16_t *out_ptr = &out[0];
int32_t tx_type = param->tx_type;
switch (tx_type) {
case DCT_DCT:
......
......@@ -12,11 +12,13 @@
#include <assert.h>
#include "av1/common/enums.h"
#include "av1/common/idct.h"
#include "aom_dsp/mips/inv_txfm_msa.h"
void av1_iht4x4_16_add_msa(const int16_t *input, uint8_t *dst,
int32_t dst_stride, int32_t tx_type) {
int32_t dst_stride, FWD_TXFM_PARAM *param) {
v8i16 in0, in1, in2, in3;
int32_t tx_type = param->tx_type;
/* load vector elements of 4x4 block */
LD4x4_SH(input, in0, in1, in2, in3);
......
......@@ -12,11 +12,13 @@
#include <assert.h>
#include "av1/common/enums.h"
#include "av1/common/idct.h"
#include "aom_dsp/mips/inv_txfm_msa.h"
void av1_iht8x8_64_add_msa(const int16_t *input, uint8_t *dst,
int32_t dst_stride, int32_t tx_type) {
int32_t dst_stride, FWD_TXFM_PARAM *param) {
v8i16 in0, in1, in2, in3, in4, in5, in6, in7;
int32_t tx_type = param->tx_type;
/* load vector elements of 8x8 block */
LD_SH8(input, 8, in0, in1, in2, in3, in4, in5, in6, in7);
......
......@@ -14,6 +14,7 @@
#include "./aom_config.h"
#include "./av1_rtcd.h"
#include "av1/common/idct.h"
#include "aom_dsp/x86/inv_txfm_common_avx2.h"
void av1_idct16_avx2(__m256i *in) {
......@@ -364,8 +365,10 @@ static void iidtx16(__m256i *in) {
#endif
void av1_iht16x16_256_add_avx2(const tran_low_t *input, uint8_t *dest,
int stride, int tx_type) {
int stride,
const INV_TXFM_PARAM *inv_txfm_param) {
__m256i in[16];
int tx_type = inv_txfm_param->tx_type;
load_buffer_16x16(input, in);
switch (tx_type) {
......
......@@ -15,6 +15,7 @@
#include "aom_dsp/x86/txfm_common_sse2.h"
#include "aom_ports/mem.h"
#include "av1/common/enums.h"
#include "av1/common/idct.h"
#if CONFIG_EXT_TX
static INLINE void fliplr_4x4(__m128i *in /*in[2]*/) {
......@@ -59,10 +60,11 @@ static INLINE void fliplr_16x8(__m128i *in /*in[16]*/) {
#endif
void av1_iht4x4_16_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
int tx_type) {
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in[2];
const __m128i zero = _mm_setzero_si128();
const __m128i eight = _mm_set1_epi16(8);
int tx_type = inv_txfm_param->tx_type;
in[0] = load_input_data(input);
in[1] = load_input_data(input + 8);
......@@ -150,10 +152,11 @@ void av1_iht4x4_16_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
}
void av1_iht8x8_64_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
int tx_type) {
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in[8];
const __m128i zero = _mm_setzero_si128();
const __m128i final_rounding = _mm_set1_epi16(1 << 4);
int tx_type = inv_txfm_param->tx_type;
// load input data
in[0] = load_input_data(input);
......@@ -251,10 +254,12 @@ static void iidtx16_sse2(__m128i *in0, __m128i *in1) {
#endif // CONFIG_EXT_TX
void av1_iht16x16_256_add_sse2(const tran_low_t *input, uint8_t *dest,
int stride, int tx_type) {
int stride,
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in[32];
__m128i *in0 = &in[0];
__m128i *in1 = &in[16];
int tx_type = inv_txfm_param->tx_type;
load_buffer_8x16(input, in0);
input += 8;
......@@ -388,8 +393,10 @@ static INLINE void flip_buffer_lr_8x8(__m128i *in) {
#endif // CONFIG_EXT_TX
void av1_iht8x16_128_add_sse2(const tran_low_t *input, uint8_t *dest,
int stride, int tx_type) {
int stride,
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in[16];
int tx_type = inv_txfm_param->tx_type;
in[0] = load_input_data(input + 0 * 8);
in[1] = load_input_data(input + 1 * 8);
......@@ -553,8 +560,10 @@ static INLINE void write_buffer_8x8_round6(uint8_t *dest, __m128i *in,
}
void av1_iht16x8_128_add_sse2(const tran_low_t *input, uint8_t *dest,
int stride, int tx_type) {
int stride,
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in[16];
int tx_type = inv_txfm_param->tx_type;
// Transpose 16x8 input into in[]
in[0] = load_input_data(input + 0 * 16);
......@@ -713,8 +722,9 @@ static INLINE void write_buffer_8x4_round5(uint8_t *dest, __m128i *in,
}
void av1_iht8x4_32_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
int tx_type) {
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in[8];
int tx_type = inv_txfm_param->tx_type;
in[0] = load_input_data(input + 0 * 8);
in[1] = load_input_data(input + 1 * 8);
......@@ -897,8 +907,9 @@ static INLINE void write_buffer_4x8_round5(uint8_t *dest, __m128i *in,
}
void av1_iht4x8_32_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
int tx_type) {
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in[8];
int tx_type = inv_txfm_param->tx_type;
// Load rows, packed two per element of 'in'.
// We pack into the bottom half of 'in' so that the
......@@ -1119,8 +1130,10 @@ static INLINE void write_buffer_16x32_round6(uint8_t *dest, __m128i *intl,
}
void av1_iht16x32_512_add_sse2(const tran_low_t *input, uint8_t *dest,
int stride, int tx_type) {
int stride,
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i intl[16], intr[16], inbl[16], inbr[16];
int tx_type = inv_txfm_param->tx_type;
int i;
for (i = 0; i < 16; ++i) {
......@@ -1272,8 +1285,10 @@ static INLINE void write_buffer_32x16_round6(uint8_t *dest, __m128i *in0,
}
void av1_iht32x16_512_add_sse2(const tran_low_t *input, uint8_t *dest,
int stride, int tx_type) {
int stride,
const INV_TXFM_PARAM *inv_txfm_param) {
__m128i in0[16], in1[16], in2[16], in3[16];
int tx_type = inv_txfm_param->tx_type;
int i;
for (i = 0; i < 16; ++i) {
......
......@@ -1175,7 +1175,8 @@ static void maybe_flip_input(const int16_t **src, int *src_stride, int l, int w,
#endif // CONFIG_EXT_TX
void av1_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
if (tx_type == DCT_DCT) {
aom_fdct4x4_c(input, output, stride);
} else {
......@@ -1227,7 +1228,8 @@ void av1_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct8, fdct4 }, // DCT_DCT
{ fadst8, fdct4 }, // ADST_DCT
......@@ -1279,7 +1281,8 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct4, fdct8 }, // DCT_DCT
{ fadst4, fdct8 }, // ADST_DCT
......@@ -1331,7 +1334,8 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht4x16_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct16, fdct4 }, // DCT_DCT
{ fadst16, fdct4 }, // ADST_DCT
......@@ -1381,7 +1385,8 @@ void av1_fht4x16_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht16x4_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct4, fdct16 }, // DCT_DCT
{ fadst4, fdct16 }, // ADST_DCT
......@@ -1431,7 +1436,8 @@ void av1_fht16x4_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht8x16_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct16, fdct8 }, // DCT_DCT
{ fadst16, fdct8 }, // ADST_DCT
......@@ -1483,7 +1489,8 @@ void av1_fht8x16_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht16x8_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct8, fdct16 }, // DCT_DCT
{ fadst8, fdct16 }, // ADST_DCT
......@@ -1535,7 +1542,8 @@ void av1_fht16x8_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht8x32_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct32, fdct8 }, // DCT_DCT
{ fhalfright32, fdct8 }, // ADST_DCT
......@@ -1585,7 +1593,8 @@ void av1_fht8x32_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht32x8_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct8, fdct32 }, // DCT_DCT
{ fadst8, fdct32 }, // ADST_DCT
......@@ -1635,7 +1644,8 @@ void av1_fht32x8_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht16x32_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct32, fdct16 }, // DCT_DCT
{ fhalfright32, fdct16 }, // ADST_DCT
......@@ -1687,7 +1697,8 @@ void av1_fht16x32_c(const int16_t *input, tran_low_t *output, int stride,
}
void av1_fht32x16_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct16, fdct32 }, // DCT_DCT
{ fadst16, fdct32 }, // ADST_DCT
......@@ -1864,7 +1875,8 @@ void av1_fdct8x8_quant_c(const int16_t *input, int stride,
}
void av1_fht8x8_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
if (tx_type == DCT_DCT) {
aom_fdct8x8_c(input, output, stride);
} else {
......@@ -1972,7 +1984,8 @@ void av1_fwht4x4_c(const int16_t *input, tran_low_t *output, int stride) {
}
void av1_fht16x16_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct16, fdct16 }, // DCT_DCT
{ fadst16, fdct16 }, // ADST_DCT
......@@ -2026,7 +2039,8 @@ void av1_highbd_fwht4x4_c(const int16_t *input, tran_low_t *output,
}
void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct32, fdct32 }, // DCT_DCT
#if CONFIG_EXT_TX
......@@ -2114,7 +2128,8 @@ static void fdct64_row(const tran_low_t *input, tran_low_t *output) {
}
void av1_fht64x64_c(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
int tx_type = fwd_txfm_param->tx_type;
static const transform_2d FHT[] = {
{ fdct64_col, fdct64_row }, // DCT_DCT
#if CONFIG_EXT_TX
......
......@@ -846,10 +846,14 @@ static void encode_block_pass1(int plane, int block, int blk_row, int blk_col,
}
#endif // !CONFIG_PVQ
#if CONFIG_HIGHBITDEPTH
INV_TXFM_PARAM inv_txfm_param;
inv_txfm_param.bd = xd->bd;
inv_txfm_param.tx_type = DCT_DCT;
inv_txfm_param.eob = p->eobs[block];
inv_txfm_param.lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
av1_highbd_inv_txfm_add_4x4(dqcoeff, dst, pd->dst.stride, p->eobs[block],
xd->bd, DCT_DCT,
xd->lossless[xd->mi[0]->mbmi.segment_id]);
av1_highbd_inv_txfm_add_4x4(dqcoeff, dst, pd->dst.stride,
&inv_txfm_param);
return;
}
#endif // CONFIG_HIGHBITDEPTH
......
This diff is collapsed.
......@@ -14,15 +14,6 @@
#include "./aom_config.h"
typedef enum FWD_TXFM_OPT { FWD_TXFM_OPT_NORMAL } FWD_TXFM_OPT;
typedef struct FWD_TXFM_PARAM {
TX_TYPE tx_type;
TX_SIZE tx_size;
int lossless;
int bd;
} FWD_TXFM_PARAM;
#ifdef __cplusplus
extern "C" {
#endif
......
......@@ -14,6 +14,7 @@
#include "./aom_dsp_rtcd.h"
#include "./av1_rtcd.h"
#include "av1/common/idct.h"
#include "aom_dsp/txfm_common.h"
#include "aom_dsp/x86/fwd_txfm_sse2.h"
#include "aom_dsp/x86/synonyms.h"
......@@ -203,8 +204,9 @@ static void fidtx4_sse2(__m128i *in) {
#endif // CONFIG_EXT_TX
void av1_fht4x4_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in[4];
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT: aom_fdct4x4_sse2(input, output, stride); break;
......@@ -1301,8 +1303,9 @@ static void fidtx8_sse2(__m128i *in) {
#endif // CONFIG_EXT_TX
void av1_fht8x8_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in[8];
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT: aom_fdct8x8_sse2(input, output, stride); break;
......@@ -2334,8 +2337,9 @@ static void fidtx16_sse2(__m128i *in0, __m128i *in1) {
#endif // CONFIG_EXT_TX
void av1_fht16x16_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in0[16], in1[16];
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT:
......@@ -2550,8 +2554,9 @@ static INLINE void write_buffer_4x8(tran_low_t *output, __m128i *res) {
}
void av1_fht4x8_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in[8];
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT:
......@@ -2724,8 +2729,9 @@ static INLINE void write_buffer_8x4(tran_low_t *output, __m128i *res) {
}
void av1_fht8x4_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in[8];
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT:
......@@ -2864,8 +2870,9 @@ static void row_8x16_rounding(__m128i *in, int bits) {
}
void av1_fht8x16_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in[16];
int tx_type = fwd_txfm_param->tx_type;
__m128i *const t = in; // Alias to top 8x8 sub block
__m128i *const b = in + 8; // Alias to bottom 8x8 sub block
......@@ -3045,8 +3052,9 @@ static INLINE void load_buffer_16x8(const int16_t *input, __m128i *in,
#define col_16x8_rounding row_8x16_rounding
void av1_fht16x8_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in[16];
int tx_type = fwd_txfm_param->tx_type;
__m128i *const l = in; // Alias to left 8x8 sub block
__m128i *const r = in + 8; // Alias to right 8x8 sub block, which we store
......@@ -3355,8 +3363,9 @@ static INLINE void fhalfright32_16col(__m128i *tl, __m128i *tr, __m128i *bl,
// For 16x32, this means the input is a 2x2 grid of such blocks.
// For 32x16, it means the input is a 4x1 grid.
void av1_fht16x32_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i intl[16], intr[16], inbl[16], inbr[16];
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT:
......@@ -3544,8 +3553,9 @@ static INLINE void write_buffer_32x16(tran_low_t *output, __m128i *res0,
}
void av1_fht32x16_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in0[16], in1[16], in2[16], in3[16];
int tx_type = fwd_txfm_param->tx_type;
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0);
switch (tx_type) {
......@@ -3784,8 +3794,9 @@ static INLINE void write_buffer_32x32(__m128i *in0, __m128i *in1, __m128i *in2,
}
void av1_fht32x32_sse2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m128i in0[32], in1[32], in2[32], in3[32];
int tx_type = fwd_txfm_param->tx_type;
load_buffer_32x32(input, in0, in1, in2, in3, stride, 0, 0);
switch (tx_type) {
......
......@@ -14,6 +14,7 @@
#include "./av1_rtcd.h"
#include "./aom_dsp_rtcd.h"
#include "av1/common/idct.h"
#include "aom_dsp/x86/fwd_txfm_avx2.h"
#include "aom_dsp/txfm_common.h"
#include "aom_dsp/x86/txfm_common_avx2.h"
......@@ -914,8 +915,9 @@ static void fidtx16_avx2(__m256i *in) {
#endif
void av1_fht16x16_avx2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m256i in[16];
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT:
......@@ -1509,9 +1511,10 @@ static void fidtx32_avx2(__m256i *in0, __m256i *in1) {
#endif
void av1_fht32x32_avx2(const int16_t *input, tran_low_t *output, int stride,
int tx_type) {
FWD_TXFM_PARAM *fwd_txfm_param) {
__m256i in0[32]; // left 32 columns
__m256i in1[32]; // right 32 columns
int tx_type = fwd_txfm_param->tx_type;
switch (tx_type) {
case DCT_DCT:
......
......@@ -25,18 +25,19 @@ using libaom_test::ACMRandom;
namespace {
typedef void (*IhtFunc)(const tran_low_t *in, uint8_t *out, int stride,
int tx_type);
const INV_TXFM_PARAM *inv_txfm_param);
using std::tr1::tuple;
using libaom_test::FhtFunc;
typedef tuple<FhtFunc, IhtFunc, int, aom_bit_depth_t, int> Ht16x16Param;
void fht16x16_ref(const int16_t *in, tran_low_t *out, int stride, int tx_type) {
av1_fht16x16_c(in, out, stride, tx_type);
void fht16x16_ref(const int16_t *in, tran_low_t *out, int stride,
FWD_TXFM_PARAM *fwd_txfm_param) {
av1_fht16x16_c(in, out, stride, fwd_txfm_param);
}
void iht16x16_ref(const tran_low_t *in, uint8_t *dest, int stride,
int tx_type) {
av1_iht16x16_256_add_c(in, dest, stride, tx_type);
const INV_TXFM_PARAM *inv_txfm_param) {
av1_iht16x16_256_add_c(in, dest, stride, inv_txfm_param);
}
#if CONFIG_HIGHBITDEPTH
......@@ -62,7 +63,6 @@ class AV1Trans16x16HT : public libaom_test::TransformTestBase,
virtual void SetUp() {
fwd_txfm_ = GET_PARAM(0);
inv_txfm_ = GET_PARAM(1);
tx_type_ = GET_PARAM(2);
pitch_ = 16;
height_ = 16;
fwd_txfm_ref = fht16x16_ref;
......@@ -70,16 +70,18 @@ class AV1Trans16x16HT : public libaom_test::TransformTestBase,
bit_depth_ = GET_PARAM(3);
mask_ = (1 << bit_depth_) - 1;
num_coeffs_ = GET_PARAM(4);
fwd_txfm_param_.tx_type = (TX_TYPE)GET_PARAM(2);
inv_txfm_param_.tx_type = (TX_TYPE)GET_PARAM(2);
}
virtual void TearDown() { libaom_test::ClearSystemState(); }
protected:
void RunFwdTxfm(const int16_t *in, tran_low_t *out, int stride) {
fwd_txfm_(in, out, stride, tx_type_);
fwd_txfm_(in, out, stride, &fwd_txfm_param_);
}
void RunInvTxfm(const tran_low_t *out, uint8_t *dst, int stride) {
inv_txfm_(out, dst, stride, tx_type_);
inv_txfm_(out, dst, stride, &inv_txfm_param_);
}
FhtFunc fwd_txfm_;
......
......@@ -25,17 +25,19 @@ using libaom_test::ACMRandom;