Commit 53f93dbd authored by Sarah Parker's avatar Sarah Parker

Add new MRC_DCT tx type

This adds the new transform to the list of possible transforms.
The impact on performance is in the noise range because the transform
implementation currently performs DCT as a placeholder. This transform
will initially only have an implementation for TX_32X32 and it is
skipped in the tx search for smaller transform sizes.

Change-Id: Iab2faddc525b478ca06972a753428a4f4ef53ac6
parent c2502b55
......@@ -797,18 +797,50 @@ typedef enum {
// DCT only
EXT_TX_SET_DCTONLY = 0,
// DCT + Identity only
EXT_TX_SET_DCT_IDTX = 1,
EXT_TX_SET_DCT_IDTX,
#if CONFIG_MRC_TX
// DCT + MRC_DCT
EXT_TX_SET_MRC_DCT,
// DCT + MRC_DCT + IDTX
EXT_TX_SET_MRC_DCT_IDTX,
#endif // CONFIG_MRC_TX
// Discrete Trig transforms w/o flip (4) + Identity (1)
EXT_TX_SET_DTT4_IDTX = 2,
EXT_TX_SET_DTT4_IDTX,
// Discrete Trig transforms w/o flip (4) + Identity (1) + 1D Hor/vert DCT (2)
EXT_TX_SET_DTT4_IDTX_1DDCT = 3,
EXT_TX_SET_DTT4_IDTX_1DDCT,
// Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver DCT (2)
EXT_TX_SET_DTT9_IDTX_1DDCT = 4,
EXT_TX_SET_DTT9_IDTX_1DDCT,
// Discrete Trig transforms w/ flip (9) + Identity (1) + 1D Hor/Ver (6)
EXT_TX_SET_ALL16 = 5,
EXT_TX_SET_ALL16,
EXT_TX_SET_TYPES
} TxSetType;
#if CONFIG_MRC_TX
// Number of transform types in each set type
static const int num_ext_tx_set[EXT_TX_SET_TYPES] = {
1, 2, 2, 3, 5, 7, 12, 16
};
// Maps intra set index to the set type
static const int ext_tx_set_type_intra[EXT_TX_SETS_INTRA] = {
EXT_TX_SET_DCTONLY, EXT_TX_SET_DTT4_IDTX_1DDCT, EXT_TX_SET_DTT4_IDTX,
EXT_TX_SET_MRC_DCT
};
// Maps inter set index to the set type
static const int ext_tx_set_type_inter[EXT_TX_SETS_INTER] = {
EXT_TX_SET_DCTONLY, EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT,
EXT_TX_SET_DCT_IDTX, EXT_TX_SET_MRC_DCT_IDTX
};
// Maps set types above to the indices used for intra
static const int ext_tx_set_index_intra[EXT_TX_SET_TYPES] = { 0, -1, 3, -1,
2, 1, -1, -1 };
// Maps set types above to the indices used for inter
static const int ext_tx_set_index_inter[EXT_TX_SET_TYPES] = { 0, 3, -1, 4,
-1, -1, 2, 1 };
#else // CONFIG_MRC_TX
// Number of transform types in each set type
static const int num_ext_tx_set[EXT_TX_SET_TYPES] = { 1, 2, 5, 7, 12, 16 };
......@@ -831,6 +863,7 @@ static const int ext_tx_set_index_intra[EXT_TX_SET_TYPES] = { 0, -1, 2,
static const int ext_tx_set_index_inter[EXT_TX_SET_TYPES] = {
0, 3, -1, -1, 2, 1
};
#endif // CONFIG_MRC_TX
static INLINE TxSetType get_ext_tx_set_type(TX_SIZE tx_size, BLOCK_SIZE bs,
int is_inter, int use_reduced_set) {
......@@ -844,6 +877,10 @@ static INLINE TxSetType get_ext_tx_set_type(TX_SIZE tx_size, BLOCK_SIZE bs,
#endif
if (use_reduced_set)
return is_inter ? EXT_TX_SET_DCT_IDTX : EXT_TX_SET_DTT4_IDTX;
#if CONFIG_MRC_TX
if (tx_size == TX_32X32)
return is_inter ? EXT_TX_SET_MRC_DCT_IDTX : EXT_TX_SET_MRC_DCT;
#endif // CONFIG_MRC_TX
if (tx_size_sqr_up == TX_32X32)
return is_inter ? EXT_TX_SET_DCT_IDTX : EXT_TX_SET_DCTONLY;
if (is_inter)
......@@ -862,6 +899,63 @@ static INLINE int get_ext_tx_set(TX_SIZE tx_size, BLOCK_SIZE bs, int is_inter,
: ext_tx_set_index_intra[set_type];
}
#if CONFIG_MRC_TX
static const int use_intra_ext_tx_for_txsize[EXT_TX_SETS_INTRA][EXT_TX_SIZES] =
{
#if CONFIG_CHROMA_2X2
{ 1, 1, 1, 1, 1 }, // unused
{ 0, 1, 1, 0, 0 },
{ 0, 0, 0, 1, 0 },
{ 0, 0, 0, 0, 1 },
#else
{ 1, 1, 1, 1 }, // unused
{ 1, 1, 0, 0 },
{ 0, 0, 1, 0 },
{ 0, 0, 0, 1 },
#endif // CONFIG_CHROMA_2X2
};
static const int use_inter_ext_tx_for_txsize[EXT_TX_SETS_INTER][EXT_TX_SIZES] =
{
#if CONFIG_CHROMA_2X2
{ 1, 1, 1, 1, 1 }, // unused
{ 0, 1, 1, 0, 0 }, { 0, 0, 0, 1, 0 },
{ 0, 0, 0, 0, 1 }, { 0, 0, 0, 0, 1 },
#else
{ 1, 1, 1, 1 }, // unused
{ 1, 1, 0, 0 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { 0, 0, 0, 1 },
#endif // CONFIG_CHROMA_2X2
};
// Transform types used in each intra set
static const int ext_tx_used_intra[EXT_TX_SETS_INTRA][TX_TYPES] = {
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0 },
{ 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
{ 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1 },
};
// Numbers of transform types used in each intra set
static const int ext_tx_cnt_intra[EXT_TX_SETS_INTRA] = { 1, 7, 5, 2 };
// Transform types used in each inter set
static const int ext_tx_used_inter[EXT_TX_SETS_INTER][TX_TYPES] = {
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0 },
{ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0 },
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 },
{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1 },
};
// Numbers of transform types used in each inter set
static const int ext_tx_cnt_inter[EXT_TX_SETS_INTER] = { 1, 16, 12, 2, 3 };
// 1D Transforms used in inter set, this needs to be changed if
// ext_tx_used_inter is changed
static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
{ 1, 0, 0, 0 }, { 1, 1, 1, 1 }, { 1, 1, 1, 1 }, { 1, 0, 0, 1 }, { 1, 0, 0, 1 }
};
#else // CONFIG_MRC_TX
static const int use_intra_ext_tx_for_txsize[EXT_TX_SETS_INTRA][EXT_TX_SIZES] =
{
#if CONFIG_CHROMA_2X2
......@@ -916,6 +1010,7 @@ static const int ext_tx_cnt_inter[EXT_TX_SETS_INTER] = { 1, 16, 12, 2 };
static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
{ 1, 0, 0, 0 }, { 1, 1, 1, 1 }, { 1, 1, 1, 1 }, { 1, 0, 0, 1 },
};
#endif // CONFIG_MRC_TX
static INLINE int get_ext_tx_types(TX_SIZE tx_size, BLOCK_SIZE bs, int is_inter,
int use_reduced_set) {
......@@ -1149,6 +1244,15 @@ static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
#endif // FIXED_TX_TYPE
#if CONFIG_EXT_TX
#if CONFIG_MRC_TX
if (mbmi->tx_type == MRC_DCT) {
if (plane_type == PLANE_TYPE_Y) {
assert(tx_size == TX_32X32);
return mbmi->tx_type;
}
return DCT_DCT;
}
#endif // CONFIG_MRC_TX
if (xd->lossless[mbmi->segment_id] || txsize_sqr_map[tx_size] > TX_32X32 ||
(txsize_sqr_map[tx_size] >= TX_32X32 && !is_inter_block(mbmi)))
return DCT_DCT;
......@@ -1193,6 +1297,15 @@ static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
#endif // CONFIG_CB4X4
#else // CONFIG_EXT_TX
(void)block;
#if CONFIG_MRC_TX
if (mbmi->tx_type == MRC_DCT) {
if (plane_type == PLANE_TYPE_Y && !xd->lossless[mbmi->segment_id]) {
assert(tx_size == TX_32X32);
return mbmi->tx_type;
}
return DCT_DCT;
}
#endif // CONFIG_MRC_TX
if (plane_type != PLANE_TYPE_Y || xd->lossless[mbmi->segment_id] ||
txsize_sqr_map[tx_size] >= TX_32X32)
return DCT_DCT;
......
This diff is collapsed.
......@@ -227,6 +227,8 @@ typedef enum {
ADST_1D = 1,
FLIPADST_1D = 2,
IDTX_1D = 3,
// TODO(sarahparker) need to eventually put something here for the
// mrc experiment to make this work with the ext-tx pruning functions
TX_TYPES_1D = 4,
} TX_TYPE_1D;
......@@ -249,6 +251,9 @@ typedef enum {
V_FLIPADST = 14,
H_FLIPADST = 15,
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX
MRC_DCT, // DCT in both directions with mrc based bitmask
#endif // CONFIG_MRC_TX
TX_TYPES,
} TX_TYPE;
......@@ -273,10 +278,15 @@ typedef enum {
#if CONFIG_CHROMA_2X2
#define EXT_TX_SIZES 5 // number of sizes that use extended transforms
#else
#define EXT_TX_SIZES 4 // number of sizes that use extended transforms
#endif // CONFIG_CHROMA_2X2
#define EXT_TX_SIZES 4 // number of sizes that use extended transforms
#endif // CONFIG_CHROMA_2X2
#if CONFIG_MRC_TX
#define EXT_TX_SETS_INTER 5 // Sets of transform selections for INTER
#define EXT_TX_SETS_INTRA 4 // Sets of transform selections for INTRA
#else // CONFIG_MRC_TX
#define EXT_TX_SETS_INTER 4 // Sets of transform selections for INTER
#define EXT_TX_SETS_INTRA 3 // Sets of transform selections for INTRA
#endif // CONFIG_MRC_TX
#else
#if CONFIG_CHROMA_2X2
#define EXT_TX_SIZES 4 // number of sizes that use extended transforms
......
......@@ -254,6 +254,9 @@ int get_inv_lgt8(transform_1d tx_orig, const TxfmParam *txfm_param,
void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if !CONFIG_DAALA_DCT4
if (tx_type == DCT_DCT) {
aom_idct4x4_16_add(input, dest, stride);
......@@ -355,6 +358,9 @@ void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -435,6 +441,9 @@ void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -516,6 +525,9 @@ void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht4x16_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -587,6 +599,9 @@ void av1_iht4x16_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht16x4_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -659,6 +674,9 @@ void av1_iht16x4_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht8x16_128_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -731,6 +749,9 @@ void av1_iht8x16_128_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht16x8_128_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -804,6 +825,9 @@ void av1_iht16x8_128_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht8x32_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -875,6 +899,9 @@ void av1_iht8x32_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht32x8_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -947,6 +974,9 @@ void av1_iht32x8_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht16x32_512_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1006,6 +1036,9 @@ void av1_iht16x32_512_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht32x16_512_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1065,6 +1098,9 @@ void av1_iht32x16_512_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht8x8_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1149,6 +1185,9 @@ void av1_iht8x8_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht16x16_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1274,6 +1313,9 @@ void av1_iht32x32_1024_add_c(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iht64x64_4096_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1426,6 +1468,24 @@ static void idct32x32_add(const tran_low_t *input, uint8_t *dest, int stride,
aom_idct32x32_1024_add(input, dest, stride);
}
#if CONFIG_MRC_TX
static void get_masked_residual32_inv(const tran_low_t *input, uint8_t *dest,
tran_low_t *output) {
// placeholder for bitmask creation, in the future it
// will likely be made based on dest
(void)dest;
memcpy(output, input, 32 * 32 * sizeof(*input));
}
static void imrc32x32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *param) {
// placeholder mrc tx function
tran_low_t masked_input[32 * 32];
get_masked_residual32_inv(input, dest, masked_input);
idct32x32_add(input, dest, stride, param);
}
#endif // CONFIG_MRC_TX
#if CONFIG_TX64X64
static void idct64x64_add(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
......@@ -1669,6 +1729,9 @@ static void inv_txfm_add_16x16(const tran_low_t *input, uint8_t *dest,
break;
case IDTX: inv_idtx_add_c(input, dest, stride, 16, tx_type); break;
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX
case MRC_DCT: assert(0 && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
default: assert(0); break;
}
}
......@@ -1697,6 +1760,9 @@ static void inv_txfm_add_32x32(const tran_low_t *input, uint8_t *dest,
break;
case IDTX: inv_idtx_add_c(input, dest, stride, 32, tx_type); break;
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX
case MRC_DCT: imrc32x32_add_c(input, dest, stride, txfm_param); break;
#endif // CONFIG_MRC_TX
default: assert(0); break;
}
}
......@@ -1726,6 +1792,9 @@ static void inv_txfm_add_64x64(const tran_low_t *input, uint8_t *dest,
break;
case IDTX: inv_idtx_add_c(input, dest, stride, 64, tx_type); break;
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX
case MRC_DCT: assert(0 && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
default: assert(0); break;
}
}
......
......@@ -80,6 +80,10 @@ static INLINE const SCAN_ORDER *get_default_scan(TX_SIZE tx_size,
static INLINE const SCAN_ORDER *get_scan(const AV1_COMMON *cm, TX_SIZE tx_size,
TX_TYPE tx_type,
const MB_MODE_INFO *mbmi) {
#if CONFIG_MRC_TX
// use the DCT_DCT scan order for MRC_DCT for now
if (tx_type == MRC_DCT) tx_type = DCT_DCT;
#endif // CONFIG_MRC_TX
#if CONFIG_ADAPT_SCAN
(void)mbmi;
return &cm->fc->sc[tx_size][tx_type];
......
......@@ -1037,6 +1037,21 @@ static void fhalfright32(const tran_low_t *input, tran_low_t *output) {
// Note overall scaling factor is 4 times orthogonal
}
#if CONFIG_MRC_TX
static void get_masked_residual32_fwd(const tran_low_t *input,
tran_low_t *output) {
// placeholder for future bitmask creation
memcpy(output, input, 32 * 32 * sizeof(*input));
}
static void fmrc32(const tran_low_t *input, tran_low_t *output) {
// placeholder for mrc_dct, this just performs regular dct
tran_low_t masked_input[32 * 32];
get_masked_residual32_fwd(input, masked_input);
fdct32(masked_input, output);
}
#endif // CONFIG_MRC_TX
#if CONFIG_LGT
static void flgt4(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
......@@ -1181,6 +1196,9 @@ static void copy_fliplrud(const int16_t *src, int src_stride, int l, int w,
static void maybe_flip_input(const int16_t **src, int *src_stride, int l, int w,
int16_t *buff, int tx_type) {
switch (tx_type) {
#if CONFIG_MRC_TX
case MRC_DCT:
#endif // CONFIG_MRC_TX
case DCT_DCT:
case ADST_DCT:
case DCT_ADST:
......@@ -1217,6 +1235,9 @@ static void maybe_flip_input(const int16_t **src, int *src_stride, int l, int w,
void av1_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1305,6 +1326,9 @@ void av1_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1378,6 +1402,9 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1451,6 +1478,9 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht4x16_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1515,6 +1545,9 @@ void av1_fht4x16_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht16x4_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1579,6 +1612,9 @@ void av1_fht16x4_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht8x16_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1645,6 +1681,9 @@ void av1_fht8x16_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht16x8_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1711,6 +1750,9 @@ void av1_fht16x8_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht8x32_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1775,6 +1817,9 @@ void av1_fht8x32_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht32x8_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1839,6 +1884,9 @@ void av1_fht32x8_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht16x32_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -1895,6 +1943,9 @@ void av1_fht16x32_c(const int16_t *input, tran_low_t *output, int stride,
void av1_fht32x16_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -2076,6 +2127,9 @@ void av1_fdct8x8_quant_c(const int16_t *input, int stride,
void av1_fht8x8_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -2205,6 +2259,9 @@ void av1_fwht4x4_c(const int16_t *input, tran_low_t *output, int stride) {
void av1_fht16x16_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......@@ -2284,6 +2341,9 @@ void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride,
{ fhalfright32, fidtx32 }, // V_FLIPADST
{ fidtx32, fhalfright32 }, // H_FLIPADST
#endif
#if CONFIG_MRC_TX
{ fmrc32, fmrc32 }, // MRC_TX
#endif // CONFIG_MRC_TX
};
const transform_2d ht = FHT[tx_type];
tran_low_t out[1024];
......@@ -2354,6 +2414,9 @@ static void fdct64_row(const tran_low_t *input, tran_low_t *output) {
void av1_fht64x64_c(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif // CONFIG_MRC_TX
#if CONFIG_DCT_ONLY
assert(tx_type == DCT_DCT);
#endif
......
......@@ -121,6 +121,13 @@ static void fwd_txfm_16x16(const int16_t *src_diff, tran_low_t *coeff,
static void fwd_txfm_32x32(const int16_t *src_diff, tran_low_t *coeff,
int diff_stride, TxfmParam *txfm_param) {
#if CONFIG_MRC_TX
// MRC_DCT currently only has a C implementation
if (txfm_param->tx_type == MRC_DCT) {
av1_fht32x32_c(src_diff, coeff, diff_stride, txfm_param);
return;
}
#endif // CONFIG_MRC_TX
av1_fht32x32(src_diff, coeff, diff_stride, txfm_param);
}
......
......@@ -2274,6 +2274,11 @@ static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
// transforms should be considered for pruning
prune = prune_tx_types(cpi, bs, x, xd, -1);
#if CONFIG_MRC_TX
// MRC_DCT only implemented for TX_32X32 so only include this tx in
// the search for TX_32X32
if (tx_type == MRC_DCT && tx_size != TX_32X32) return 1;
#endif // CONFIG_MRC_TX
if (mbmi->ref_mv_idx > 0 && tx_type != DCT_DCT) return 1;
if (FIXED_TX_TYPE && tx_type != get_default_tx_type(0, xd, 0, tx_size))
return 1;
......@@ -4503,7 +4508,13 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
#endif
}
#if CONFIG_MRC_TX
// If the tx type we are trying is MRC_DCT, we cannot partition the transform
// into anything smaller than TX_32X32
if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH && mbmi->tx_type != MRC_DCT) {
#else
if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH) {
#endif
const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
const int bsl = tx_size_wide_unit[sub_txs];
int sub_step = tx_size_wide_unit[sub_txs] * tx_size_high_unit[sub_txs];
......@@ -4841,6 +4852,11 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
RD_STATS this_rd_stats;
av1_init_rd_stats(&this_rd_stats);
#if CONFIG_MRC_TX
// MRC_DCT only implemented for TX_32X32 so only include this tx in
// the search for TX_32X32
if (tx_type == MRC_DCT && max_tx_size != TX_32X32) continue;
#endif // CONFIG_MRC_TX
#if CONFIG_EXT_TX
if (is_inter) {
if (!ext_tx_used_inter[ext_tx_set][tx_type]) continue;
......
......@@ -206,6 +206,9 @@ void av1_fht4x4_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in[4];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
switch (tx_type) {
case DCT_DCT: aom_fdct4x4_sse2(input, output, stride); break;
......@@ -1305,6 +1308,9 @@ void av1_fht8x8_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in[8];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
switch (tx_type) {
case DCT_DCT: aom_fdct8x8_sse2(input, output, stride); break;
......@@ -2339,6 +2345,9 @@ void av1_fht16x16_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in0[16], in1[16];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
switch (tx_type) {
case DCT_DCT:
......@@ -2556,6 +2565,9 @@ void av1_fht4x8_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in[8];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
switch (tx_type) {
case DCT_DCT:
......@@ -2731,6 +2743,9 @@ void av1_fht8x4_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in[8];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
switch (tx_type) {
case DCT_DCT:
......@@ -2872,6 +2887,9 @@ void av1_fht8x16_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in[16];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
__m128i *const t = in; // Alias to top 8x8 sub block
__m128i *const b = in + 8; // Alias to bottom 8x8 sub block
......@@ -3054,6 +3072,9 @@ void av1_fht16x8_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in[16];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
__m128i *const l = in; // Alias to left 8x8 sub block
__m128i *const r = in + 8; // Alias to right 8x8 sub block, which we store
......@@ -3365,6 +3386,9 @@ void av1_fht16x32_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i intl[16], intr[16], inbl[16], inbr[16];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
switch (tx_type) {
case DCT_DCT:
......@@ -3555,6 +3579,9 @@ void av1_fht32x16_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in0[16], in1[16], in2[16], in3[16];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "Invalid tx type for tx size");
#endif
load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0);
switch (tx_type) {
......@@ -3796,6 +3823,9 @@ void av1_fht32x32_sse2(const int16_t *input, tran_low_t *output, int stride,
TxfmParam *txfm_param) {
__m128i in0[32], in1[32], in2[32], in3[32];
int tx_type = txfm_param->tx_type;
#if CONFIG_MRC_TX
assert(tx_type != MRC_DCT && "No 32x32 sse2 MRC_DCT implementation");
#endif