Commit 90024e44 authored by Sarah Parker's avatar Sarah Parker

Use tx_size 1 level down for transform type search

This addresses an inconsistency between the set used
to decode the tx_type in the bitstream and the set used
for the tx_type search. Previously, the set used to
read/write the tx_type was based on the smallest tx_size
in the vartx partitioning, but the search uses a set
based on the largest possible tx_size. This patch
changes the tx_type search to use the transform type
set associated with the tx_size 1 recursive level down from
the max square tx_size to make the search more consistent
with the bitstream syntax. If a tx_size is selected for an
invalid tx_type, DCT_DCT is used for that partition instead.

This patch also adds assertions to all exposed transform
functions to ensure that no illegal transform type/size
combinations occur.

This currently gets a 0.1% drop in performance on lowres.
The drop is due to the reduction of the tx_types available
for 32x16 and 16x32 transform sizes. Before this patch,
32x16 and 16x32 transforms were getting assigned a
set of 12 tx_types, some of which we did not intend to
support for these sizes.

Change-Id: I44aca4876b261c345623cd04ad6235bca4532701
parent 670b660d
......@@ -28,6 +28,9 @@ typedef struct txfm_param {
TX_SIZE tx_size;
int lossless;
int bd;
#if CONFIG_EXT_TX
TxSetType tx_set_type;
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX || CONFIG_LGT
int is_inter;
#endif // CONFIG_MRC_TX || CONFIG_LGT
......
......@@ -1144,6 +1144,11 @@ static INLINE TX_TYPE get_default_tx_type(PLANE_TYPE plane_type,
: get_uv_mode(mbmi->uv_mode)];
}
static INLINE BLOCK_SIZE
get_plane_block_size(BLOCK_SIZE bsize, const struct macroblockd_plane *pd) {
return ss_size_lookup[bsize][pd->subsampling_x][pd->subsampling_y];
}
static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
const MACROBLOCKD *xd, int blk_row,
int blk_col, int block, TX_SIZE tx_size) {
......@@ -1156,6 +1161,18 @@ static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
if (is_intrabc_block(mbmi)) return DCT_DCT;
#endif // CONFIG_INTRABC && (!CONFIG_EXT_TX || CONFIG_TXK_SEL)
#if CONFIG_EXT_TX
const struct macroblockd_plane *const pd = &xd->plane[plane_type];
const BLOCK_SIZE plane_bsize = get_plane_block_size(mbmi->sb_type, pd);
// TODO(sarahparker) This assumes reduced_tx_set_used == 0. I will do a
// follow up refactor to make the actual value of reduced_tx_set_used
// within this function.
const TxSetType tx_set_type =
get_ext_tx_set_type(tx_size, plane_bsize, is_inter_block(mbmi), 0);
if (is_inter_block(mbmi) && !av1_ext_tx_used[tx_set_type][mbmi->tx_type])
return DCT_DCT;
#endif // CONFIG_EXT_TX
#if CONFIG_TXK_SEL
TX_TYPE tx_type;
if (xd->lossless[mbmi->segment_id] || txsize_sqr_map[tx_size] >= TX_32X32) {
......@@ -1169,6 +1186,10 @@ static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
tx_type = intra_mode_to_tx_type_context[mbmi->uv_mode];
}
assert(tx_type >= DCT_DCT && tx_type < TX_TYPES);
#if CONFIG_EXT_TX
if (is_inter_block(mbmi) && !av1_ext_tx_used[tx_set_type][tx_type])
return DCT_DCT;
#endif // CONFIG_EXT_TX
return tx_type;
#endif // CONFIG_TXK_SEL
......@@ -1211,8 +1232,12 @@ static INLINE TX_TYPE av1_get_tx_type(PLANE_TYPE plane_type,
: mbmi->tx_type;
}
// UV Intra only
(void)block;
return intra_mode_to_tx_type_context[get_uv_mode(mbmi->uv_mode)];
TX_TYPE intra_type =
intra_mode_to_tx_type_context[get_uv_mode(mbmi->uv_mode)];
if (!av1_ext_tx_used[tx_set_type][intra_type]) return DCT_DCT;
return intra_type;
#else // CONFIG_EXT_TX
(void)block;
#if CONFIG_MRC_TX
......@@ -1257,11 +1282,6 @@ static INLINE TX_SIZE av1_get_tx_size(int plane, const MACROBLOCKD *xd) {
return av1_get_uv_tx_size(mbmi, pd);
}
static INLINE BLOCK_SIZE
get_plane_block_size(BLOCK_SIZE bsize, const struct macroblockd_plane *pd) {
return ss_size_lookup[bsize][pd->subsampling_x][pd->subsampling_y];
}
void av1_reset_skip_context(MACROBLOCKD *xd, int mi_row, int mi_col,
BLOCK_SIZE bsize);
......
......@@ -256,6 +256,9 @@ void ilgt8(const tran_low_t *input, tran_low_t *output,
// apply. Otherwise they return 0
int get_lgt4(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
......@@ -271,6 +274,9 @@ int get_lgt4(const TxfmParam *txfm_param, int is_col,
int get_lgt8(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
......@@ -388,6 +394,9 @@ int idx_selfloop_wrt_mode(PREDICTION_MODE mode, int is_col) {
void get_lgt4_from_pred(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx, int ntx) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
PREDICTION_MODE mode = txfm_param->mode;
int stride = txfm_param->stride;
uint8_t *dst = txfm_param->dst;
......@@ -469,6 +478,9 @@ void get_lgt4_from_pred(const TxfmParam *txfm_param, int is_col,
void get_lgt8_from_pred(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx, int ntx) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
PREDICTION_MODE mode = txfm_param->mode;
int stride = txfm_param->stride;
uint8_t *dst = txfm_param->dst;
......@@ -538,6 +550,9 @@ void get_lgt8_from_pred(const TxfmParam *txfm_param, int is_col,
// will just call DCT or ADST
void get_lgt16up_from_pred(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx, int ntx) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
int tx_length = is_col ? tx_size_high[txfm_param->tx_size]
: tx_size_wide[txfm_param->tx_size];
assert(tx_length == 16 || tx_length == 32);
......@@ -2414,6 +2429,9 @@ void av1_iht32x64_2048_add_c(const tran_low_t *input, uint8_t *dest, int stride,
// idct
void av1_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
const int eob = txfm_param->eob;
if (eob > 1)
av1_iht4x4_16_add(input, dest, stride, txfm_param);
......@@ -2423,6 +2441,9 @@ void av1_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
void av1_iwht4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
const int eob = txfm_param->eob;
if (eob > 1)
aom_iwht4x4_16_add(input, dest, stride);
......@@ -2897,6 +2918,9 @@ static const int32_t *cast_to_int32(const tran_low_t *input) {
void av1_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
int eob = txfm_param->eob;
int bd = txfm_param->bd;
int lossless = txfm_param->lossless;
......@@ -2942,6 +2966,9 @@ void av1_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
void av1_highbd_inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
const int32_t *src = cast_to_int32(input);
av1_inv_txfm2d_add_4x8_c(src, CONVERT_TO_SHORTPTR(dest), stride,
txfm_param->tx_type, txfm_param->bd);
......@@ -2949,6 +2976,9 @@ void av1_highbd_inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest,
void av1_highbd_inv_txfm_add_8x4(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
const int32_t *src = cast_to_int32(input);
av1_inv_txfm2d_add_8x4_c(src, CONVERT_TO_SHORTPTR(dest), stride,
txfm_param->tx_type, txfm_param->bd);
......@@ -3158,6 +3188,9 @@ static void highbd_inv_txfm_add_64x64(const tran_low_t *input, uint8_t *dest,
void av1_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
const TX_SIZE tx_size = txfm_param->tx_size;
#if CONFIG_LGT_FROM_PRED
if (txfm_param->use_lgt) {
......@@ -3199,13 +3232,27 @@ void av1_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
}
}
static void init_txfm_param(const MACROBLOCKD *xd, TX_SIZE tx_size,
TX_TYPE tx_type, int eob, TxfmParam *txfm_param) {
static void init_txfm_param(const MACROBLOCKD *xd,
#if CONFIG_EXT_TX
int plane,
#endif // CONFIG_EXT_TX
TX_SIZE tx_size, TX_TYPE tx_type, int eob,
TxfmParam *txfm_param) {
txfm_param->tx_type = tx_type;
txfm_param->tx_size = tx_size;
txfm_param->eob = eob;
txfm_param->lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
txfm_param->bd = xd->bd;
#if CONFIG_EXT_TX
const struct macroblockd_plane *const pd = &xd->plane[plane];
const BLOCK_SIZE plane_bsize =
get_plane_block_size(xd->mi[0]->mbmi.sb_type, pd);
// TODO(sarahparker) This assumes reduced_tx_set_used == 0. I will do a
// follow up refactor to make the actual value of reduced_tx_set_used
// within this function.
txfm_param->tx_set_type = get_ext_tx_set_type(
txfm_param->tx_size, plane_bsize, is_inter_block(&xd->mi[0]->mbmi), 0);
#endif // CONFIG_EXT_TX
#if CONFIG_LGT
txfm_param->is_inter = is_inter_block(&xd->mi[0]->mbmi);
#endif
......@@ -3234,12 +3281,19 @@ void av1_inverse_transform_block(const MACROBLOCKD *xd,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
uint8_t *mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
int plane,
#endif // CONFIG_EXT_TX
TX_TYPE tx_type, TX_SIZE tx_size, uint8_t *dst,
int stride, int eob) {
if (!eob) return;
TxfmParam txfm_param;
init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
init_txfm_param(xd,
#if CONFIG_EXT_TX
plane,
#endif // CONFIG_EXT_TX
tx_size, tx_type, eob, &txfm_param);
#if CONFIG_LGT || CONFIG_MRC_TX
txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
#endif // CONFIG_LGT || CONFIG_MRC_TX
......@@ -3253,6 +3307,9 @@ void av1_inverse_transform_block(const MACROBLOCKD *xd,
txfm_param.mode = mode;
#endif // CONFIG_LGT_FROM_PRED
#endif // CONFIG_LGT_FROM_PRED || CONFIG_MRC_TX
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param.tx_set_type][txfm_param.tx_type]);
#endif // CONFIG_EXT_TX
const int is_hbd = get_bitdepth_data_path_index(xd);
#if CONFIG_TXMG
......@@ -3304,12 +3361,18 @@ void av1_inverse_transform_block_facade(MACROBLOCKD *xd, int plane, int block,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
plane,
#endif // CONFIG_EXT_TX
tx_type, tx_size, dst, dst_stride, eob);
}
void av1_highbd_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
TxfmParam *txfm_param) {
const TX_SIZE tx_size = txfm_param->tx_size;
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
switch (tx_size) {
#if CONFIG_TX64X64
case TX_64X64:
......
......@@ -74,6 +74,9 @@ void av1_inverse_transform_block(const MACROBLOCKD *xd,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
uint8_t *mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
int plane,
#endif // CONFIG_EXT_TX
TX_TYPE tx_type, TX_SIZE tx_size, uint8_t *dst,
int stride, int eob);
void av1_inverse_transform_block_facade(MACROBLOCKD *xd, int plane, int block,
......@@ -88,8 +91,7 @@ void av1_highbd_inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest,
void av1_highbd_inv_txfm_add_8x4(const tran_low_t *input, uint8_t *dest,
int stride, const TxfmParam *param);
void av1_highbd_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
TxfmParam *txfm_param);
TxfmParam *param);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -248,6 +248,9 @@ static void inverse_transform_block(MACROBLOCKD *xd, int plane,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
xd->mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
plane,
#endif // CONFIG_EXT_TX
tx_type, tx_size, dst, stride, eob);
memset(dqcoeff, 0, (scan_line + 1) * sizeof(dqcoeff[0]));
}
......
......@@ -948,7 +948,11 @@ void av1_read_tx_type(const AV1_COMMON *const cm, MACROBLOCKD *xd,
MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
const int inter_block = is_inter_block(mbmi);
#if !CONFIG_TXK_SEL
const TX_SIZE tx_size = inter_block ? mbmi->min_tx_size : mbmi->tx_size;
const TX_SIZE sqr_up_tx_size =
txsize_sqr_up_map[max_txsize_rect_lookup[xd->mi[0]->mbmi.sb_type]];
const TX_SIZE tx_size =
inter_block ? AOMMAX(sub_tx_size_map[sqr_up_tx_size], mbmi->min_tx_size)
: mbmi->tx_size;
#endif // !CONFIG_TXK_SEL
FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
......
......@@ -1301,7 +1301,11 @@ void av1_write_tx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
const int is_inter = is_inter_block(mbmi);
#if !CONFIG_TXK_SEL
const TX_SIZE tx_size = is_inter ? mbmi->min_tx_size : mbmi->tx_size;
const TX_SIZE sqr_up_tx_size =
txsize_sqr_up_map[max_txsize_rect_lookup[xd->mi[0]->mbmi.sb_type]];
const TX_SIZE tx_size =
is_inter ? AOMMAX(sub_tx_size_map[sqr_up_tx_size], mbmi->min_tx_size)
: mbmi->tx_size;
#endif // !CONFIG_TXK_SEL
FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
......
......@@ -539,6 +539,11 @@ void av1_xform_quant(const AV1_COMMON *cm, MACROBLOCK *x, int plane, int block,
txfm_param.tx_type = tx_type;
txfm_param.tx_size = tx_size;
txfm_param.lossless = xd->lossless[mbmi->segment_id];
#if CONFIG_EXT_TX
txfm_param.tx_set_type =
get_ext_tx_set_type(txfm_param.tx_size, plane_bsize, is_inter_block(mbmi),
cm->reduced_tx_set_used);
#endif // CONFIG_EXT_TX
#if CONFIG_MRC_TX || CONFIG_LGT
txfm_param.is_inter = is_inter_block(mbmi);
#endif
......@@ -637,6 +642,9 @@ static void encode_block(int plane, int block, int blk_row, int blk_col,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
plane,
#endif // CONFIG_EXT_TX
tx_type, tx_size, dst, pd->dst.stride,
p->eobs[block]);
}
......@@ -732,8 +740,14 @@ static void encode_block_pass1(int plane, int block, int blk_row, int blk_col,
if (p->eobs[block] > 0) {
txfm_param.bd = xd->bd;
txfm_param.tx_type = DCT_DCT;
txfm_param.tx_size = tx_size;
txfm_param.eob = p->eobs[block];
txfm_param.lossless = xd->lossless[xd->mi[0]->mbmi.segment_id];
#if CONFIG_EXT_TX
txfm_param.tx_set_type = get_ext_tx_set_type(
txfm_param.tx_size, plane_bsize, is_inter_block(&xd->mi[0]->mbmi),
cm->reduced_tx_set_used);
#endif // CONFIG_EXT_TX
#if CONFIG_HIGHBITDEPTH
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
av1_highbd_inv_txfm_add_4x4(dqcoeff, dst, pd->dst.stride, &txfm_param);
......@@ -900,6 +914,9 @@ void av1_encode_block_intra(int plane, int block, int blk_row, int blk_col,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
plane,
#endif // CONFIG_EXT_TX
tx_type, tx_size, dst, dst_stride, *eob);
if (*eob) *(args->skip) = 0;
......
......@@ -508,6 +508,9 @@ static void highbd_fwd_txfm_64x64(const int16_t *src_diff, tran_low_t *coeff,
void av1_fwd_txfm(const int16_t *src_diff, tran_low_t *coeff, int diff_stride,
TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
const TX_SIZE tx_size = txfm_param->tx_size;
#if CONFIG_LGT_FROM_PRED
if (txfm_param->use_lgt) {
......@@ -571,6 +574,9 @@ void av1_fwd_txfm(const int16_t *src_diff, tran_low_t *coeff, int diff_stride,
void av1_highbd_fwd_txfm(const int16_t *src_diff, tran_low_t *coeff,
int diff_stride, TxfmParam *txfm_param) {
#if CONFIG_EXT_TX
assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]);
#endif // CONFIG_EXT_TX
const TX_SIZE tx_size = txfm_param->tx_size;
switch (tx_size) {
#if CONFIG_TX64X64
......
......@@ -1976,6 +1976,9 @@ void av1_dist_block(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
plane,
#endif // CONFIG_EXT_TX
tx_type, tx_size, recon, MAX_TX_SIZE, eob);
#if CONFIG_DIST_8X8
......@@ -4006,6 +4009,9 @@ void av1_tx_block_rd_b(const AV1_COMP *cpi, MACROBLOCK *x, TX_SIZE tx_size,
#if CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
mrc_mask,
#endif // CONFIG_MRC_TX && SIGNAL_ANY_MRC_MASK
#if CONFIG_EXT_TX
plane,
#endif // CONFIG_EXT_TX
tx_type, tx_size, rec_buffer, MAX_TX_SIZE, eob);
if (eob > 0) {
#if CONFIG_DIST_8X8
......@@ -4897,6 +4903,17 @@ static int predict_skip_flag_8bit(const MACROBLOCK *x, BLOCK_SIZE bsize) {
param.tx_size = max_txsize_rect_lookup[bsize];
param.bd = 8;
param.lossless = 0;
#if CONFIG_EXT_TX
const MACROBLOCKD *xd = &x->e_mbd;
const struct macroblockd_plane *const pd = &xd->plane[0];
const BLOCK_SIZE plane_bsize =
get_plane_block_size(xd->mi[0]->mbmi.sb_type, pd);
// TODO(sarahparker) This assumes reduced_tx_set_used == 0. I will do a
// follow up refactor to make the actual value of reduced_tx_set_used
// within this function.
param.tx_set_type = get_ext_tx_set_type(param.tx_size, plane_bsize,
is_inter_block(&xd->mi[0]->mbmi), 0);
#endif // CONFIG_EXT_TX
#if CONFIG_TXMG
av1_highbd_fwd_txfm(p->src_diff, DCT_coefs, bw, &param);
......@@ -5001,8 +5018,12 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
int idx, idy;
int prune = 0;
#if CONFIG_EXT_TX
const TX_SIZE sqr_up_tx_size =
txsize_sqr_up_map[max_txsize_rect_lookup[bsize]];
// Get the tx_size 1 level down
TX_SIZE min_tx_size = sub_tx_size_map[sqr_up_tx_size];
const TxSetType tx_set_type = get_ext_tx_set_type(
max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
min_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
#endif // CONFIG_EXT_TX
int within_border = (mi_row + mi_size_high[bsize] <= cm->mi_rows) &&
(mi_col + mi_size_wide[bsize] <= cm->mi_cols);
......@@ -5070,6 +5091,11 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
#endif // CONFIG_EXT_TX && CONFIG_MRC_TX
#if CONFIG_EXT_TX
if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
(void)prune;
// TODO(sarahparker) This speed feature has been temporarily disabled
// with ext-tx because it is not compatible with the current
// search method. It will be fixed in a followup.
/*
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
if (!do_tx_type_search(tx_type, prune,
......@@ -5081,6 +5107,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
if (tx_type != intra_mode_to_tx_type_context[mbmi->mode]) continue;
}
}
*/
#else // CONFIG_EXT_TX
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
......@@ -5095,6 +5122,15 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col,
ref_best_rd, tx_type);
#if CONFIG_EXT_TX
// If the current tx_type is not included in the tx_set for the smallest
// tx size found, then all vartx partitions were actually transformed with
// DCT_DCT and we should avoid picking it.
const TxSetType min_tx_set_type = get_ext_tx_set_type(
mbmi->min_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
if (!av1_ext_tx_used[min_tx_set_type][tx_type]) continue;
#endif // CONFIG_EXT_TX
ref_best_rd = AOMMIN(rd, ref_best_rd);
if (rd < best_rd) {
best_rd = rd;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment