Commit 6bbd9321 authored by Hui Su's avatar Hui Su
Browse files

Move the tx pruning flags into MACROBLOCK

So they can be generated at prediction block, and then easily
accessed by transform block.

Change-Id: I376042e8d57e00586d3cf90e237544e705b77e8b
parent c30c18f5
......@@ -365,6 +365,8 @@ struct macroblock {
int comp_idx_cost[COMP_INDEX_CONTEXTS][2];
int comp_group_idx_cost[COMP_GROUP_IDX_CONTEXTS][2];
#endif // CONFIG_JNT_COMP
// Bit flags for pruning tx type search, tx split, etc.
int tx_search_prune[EXT_TX_SET_TYPES];
};
static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) {
......
......@@ -1255,20 +1255,15 @@ static int prune_tx_split(BLOCK_SIZE bsize, const int16_t *diff, float hcorr,
return (score > av1_prune_tx_split_thresholds[bidx]);
}
static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type,
int tx_type_pruning_aggressiveness,
int use_tx_split_prune) {
if (bsize >= BLOCK_32X32) return 0;
static void prune_tx_2D(BLOCK_SIZE bsize, MACROBLOCK *x,
TX_TYPE_PRUNE_MODE prune_mode, int use_tx_split_prune) {
if (bsize >= BLOCK_32X32) return;
aom_clear_system_state();
const struct macroblock_plane *const p = &x->plane[0];
const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
const float score_thresh =
av1_prune_2D_adaptive_thresholds[bidx]
[tx_type_pruning_aggressiveness - 1];
float hfeatures[16], vfeatures[16];
float hscores[4], vscores[4];
float scores_2D[16];
int tx_type_table_2D[16] = {
const int tx_type_table_2D[16] = {
DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
......@@ -1284,7 +1279,7 @@ static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type,
get_horver_correlation_full(p->src_diff, bw, bw, bh,
&hfeatures[hfeatures_num - 1],
&vfeatures[vfeatures_num - 1]);
const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
const float *fc1_hor = av1_prune_2D_learned_weights_hor[bidx];
const float *b1_hor =
fc1_hor + av1_prune_2D_num_hidden_units_hor[bidx] * hfeatures_num;
......@@ -1314,22 +1309,43 @@ static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type,
score_2D_average /= 16;
score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
// Always keep the TX type with the highest score, prune all others with
// score below score_thresh.
int max_score_i = 0;
float max_score = 0.0f;
for (int i = 0; i < 16; i++) {
if (scores_2D[i] > max_score &&
av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
max_score = scores_2D[i];
max_score_i = i;
// TODO(huisu@google.com): support more tx set types.
const int tx_set_types[2] = { EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT };
for (int tx_set_idx = 0; tx_set_idx < 2; ++tx_set_idx) {
const int tx_set_type = tx_set_types[tx_set_idx];
// Always keep the TX type with the highest score, prune all others with
// score below score_thresh.
int max_score_i = 0;
float max_score = 0.0f;
for (int i = 0; i < 16; i++) {
if (scores_2D[i] > max_score &&
av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
max_score = scores_2D[i];
max_score_i = i;
}
}
}
int prune_bitmask = 0;
for (int i = 0; i < 16; i++) {
if (scores_2D[i] < score_thresh && i != max_score_i)
prune_bitmask |= (1 << tx_type_table_2D[i]);
int pruning_aggressiveness = 0;
if (prune_mode == PRUNE_2D_ACCURATE) {
if (tx_set_type == EXT_TX_SET_ALL16)
pruning_aggressiveness = 6;
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
pruning_aggressiveness = 4;
} else if (prune_mode == PRUNE_2D_FAST) {
if (tx_set_type == EXT_TX_SET_ALL16)
pruning_aggressiveness = 10;
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
pruning_aggressiveness = 7;
}
const float score_thresh =
av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1];
int prune_bitmask = 0;
for (int i = 0; i < 16; i++) {
if (scores_2D[i] < score_thresh && i != max_score_i)
prune_bitmask |= (1 << tx_type_table_2D[i]);
}
x->tx_search_prune[tx_set_type] = prune_bitmask;
}
// Also apply TX size pruning if it's turned on. The value
......@@ -1342,51 +1358,42 @@ static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type,
prune_tx_split(bsize, p->src_diff, hfeatures[hfeatures_num - 1],
vfeatures[vfeatures_num - 1]);
}
prune_bitmask |= (prune_tx_split_flag << TX_TYPES);
return prune_bitmask;
x->tx_search_prune[0] |= (prune_tx_split_flag << TX_TYPES);
}
static int prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
const MACROBLOCKD *const xd, int tx_set_type,
int use_tx_split_prune) {
static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
const MACROBLOCKD *const xd, int tx_set_type,
int use_tx_split_prune) {
int tx_set = ext_tx_set_index[1][tx_set_type];
assert(tx_set >= 0);
av1_zero(x->tx_search_prune);
const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
switch (cpi->sf.tx_type_search.prune_mode) {
case NO_PRUNE: return 0; break;
case NO_PRUNE: return;
case PRUNE_ONE:
if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return 0;
return prune_one_for_sby(cpi, bsize, x, xd);
if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return;
x->tx_search_prune[tx_set_type] = prune_one_for_sby(cpi, bsize, x, xd);
break;
case PRUNE_TWO:
if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return 0;
return prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return;
x->tx_search_prune[tx_set_type] =
prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
}
if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
return prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
return prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) {
x->tx_search_prune[tx_set_type] =
prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
}
x->tx_search_prune[tx_set_type] =
prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
break;
case PRUNE_2D_ACCURATE:
if (tx_set_type == EXT_TX_SET_ALL16)
return prune_tx_2D(bsize, x, tx_set_type, 6, use_tx_split_prune);
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
return prune_tx_2D(bsize, x, tx_set_type, 4, use_tx_split_prune);
else
return 0;
break;
case PRUNE_2D_FAST:
if (tx_set_type == EXT_TX_SET_ALL16)
return prune_tx_2D(bsize, x, tx_set_type, 10, use_tx_split_prune);
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
return prune_tx_2D(bsize, x, tx_set_type, 7, use_tx_split_prune);
else
return 0;
prune_tx_2D(bsize, x, cpi->sf.tx_type_search.prune_mode,
use_tx_split_prune);
break;
default: assert(0);
}
assert(0);
return 0;
}
static int do_tx_type_search(TX_TYPE tx_type, int prune,
......@@ -2319,7 +2326,7 @@ static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
}
static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
TX_TYPE tx_type, TX_SIZE tx_size, int prune) {
TX_TYPE tx_type, TX_SIZE tx_size) {
const MACROBLOCKD *const xd = &x->e_mbd;
const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
const int is_inter = is_inter_block(mbmi);
......@@ -2337,7 +2344,8 @@ static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1;
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
cpi->sf.tx_type_search.prune_mode))
return 1;
}
}
......@@ -2371,7 +2379,6 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
int s0 = x->skip_cost[skip_ctx][0];
int s1 = x->skip_cost[skip_ctx][1];
const int is_inter = is_inter_block(mbmi);
int prune = 0;
av1_invalid_rd_stats(rd_stats);
mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode, is_inter);
......@@ -2381,7 +2388,7 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
prune = prune_tx(cpi, bs, x, xd, tx_set_type, 0);
prune_tx(cpi, bs, x, xd, tx_set_type, 0);
}
#if CONFIG_FILTER_INTRA
if (skip_invalid_tx_size_for_filter_intra(mbmi, AOM_PLANE_Y, rd_stats)) {
......@@ -2399,7 +2406,7 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
tx_type != get_default_tx_type(0, xd, mbmi->tx_size))
continue;
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
if (!do_tx_type_search(tx_type, prune,
if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
cpi->sf.tx_type_search.prune_mode))
continue;
}
......@@ -2440,6 +2447,8 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
mbmi->tx_size, cpi->sf.use_fast_coef_costing);
}
mbmi->tx_type = best_tx_type;
// Reset the pruning flags.
av1_zero(x->tx_search_prune);
}
static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
......@@ -2507,10 +2516,9 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
depth = MAX_TX_DEPTH;
}
int prune = 0;
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
prune = prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0);
prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0);
}
last_rd = INT64_MAX;
......@@ -2526,7 +2534,7 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
TX_TYPE tx_type;
for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
RD_STATS this_rd_stats;
if (skip_txfm_search(cpi, x, bs, tx_type, n, prune)) continue;
if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue;
if (mbmi->ref_mv_idx > 0) x->rd_model = LOW_TXFM_RD;
rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n);
......@@ -2578,6 +2586,8 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
memcpy(x->blk_skip[0], best_blk_skip, sizeof(best_blk_skip[0]) * n4);
mbmi->min_tx_size = mbmi->tx_size;
// Reset the pruning flags.
av1_zero(x->tx_search_prune);
}
static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
......@@ -3820,7 +3830,6 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
RD_STATS *rd_stats, int64_t ref_best_rd,
int *is_cost_valid, int fast,
int tx_split_prune_flag,
TX_SIZE_RD_INFO_NODE *rd_info_node) {
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
......@@ -3920,6 +3929,9 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
#endif
}
int tx_split_prune_flag = 0;
if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
tx_split_prune_flag = ((x->tx_search_prune[0] >> TX_TYPES) & 1);
if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH && tx_split_prune_flag == 0) {
const TX_SIZE sub_txs = sub_tx_size_map[1][tx_size];
const int bsw = tx_size_wide_unit[sub_txs];
......@@ -3947,7 +3959,7 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
select_tx_block(
cpi, x, offsetr, offsetc, plane, block, sub_txs, depth + 1,
plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats,
ref_best_rd - tmp_rd, &this_cost_valid, fast, 0,
ref_best_rd - tmp_rd, &this_cost_valid, fast,
(rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
#if CONFIG_DIST_8X8
......@@ -4099,7 +4111,6 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int64_t ref_best_rd, int fast,
int tx_split_prune_flag,
TX_SIZE_RD_INFO_NODE *rd_info_tree) {
MACROBLOCKD *const xd = &x->e_mbd;
int is_cost_valid = 1;
......@@ -4138,7 +4149,7 @@ static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
select_tx_block(cpi, x, idy, idx, 0, block, max_tx_size, init_depth,
plane_bsize, ctxa, ctxl, tx_above, tx_left,
&pn_rd_stats, ref_best_rd - this_rd, &is_cost_valid,
fast, tx_split_prune_flag, rd_info_tree);
fast, rd_info_tree);
if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
av1_invalid_rd_stats(rd_stats);
return;
......@@ -4172,7 +4183,6 @@ static int64_t select_tx_size_fix_type(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int mi_row, int mi_col,
int64_t ref_best_rd, TX_TYPE tx_type,
int tx_split_prune_flag,
TX_SIZE_RD_INFO_NODE *rd_info_tree) {
const int fast = cpi->sf.tx_size_search_method > USE_FULL_RD;
const AV1_COMMON *const cm = &cpi->common;
......@@ -4199,7 +4209,7 @@ static int64_t select_tx_size_fix_type(const AV1_COMP *cpi, MACROBLOCK *x,
mbmi->tx_type = tx_type;
select_inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, fast,
tx_split_prune_flag, rd_info_tree);
rd_info_tree);
if (rd_stats->rate == INT_MAX) return INT64_MAX;
mbmi->min_tx_size = mbmi->inter_tx_size[0][0];
......@@ -4800,7 +4810,6 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
#endif
const int n4 = bsize_to_num_blk(bsize);
int idx, idy;
int prune = 0;
// Get the tx_size 1 level down
const TX_SIZE min_tx_size =
sub_tx_size_map[1][max_txsize_rect_lookup[1][bsize]];
......@@ -4851,16 +4860,12 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type && !xd->lossless[mbmi->segment_id]) {
prune = prune_tx(cpi, bsize, x, xd, tx_set_type,
cpi->sf.tx_type_search.use_tx_size_pruning);
prune_tx(cpi, bsize, x, xd, tx_set_type,
cpi->sf.tx_type_search.use_tx_size_pruning);
}
int found = 0;
int tx_split_prune_flag = 0;
if (is_inter && cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
tx_split_prune_flag = ((prune >> TX_TYPES) & 1);
for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
RD_STATS this_rd_stats;
av1_init_rd_stats(&this_rd_stats);
......@@ -4868,7 +4873,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
#if !CONFIG_TXK_SEL
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
if (!do_tx_type_search(tx_type, prune,
if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type],
cpi->sf.tx_type_search.prune_mode))
continue;
}
......@@ -4882,7 +4887,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
if (tx_type != DCT_DCT) continue;
rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col,
ref_best_rd, tx_type, tx_split_prune_flag,
ref_best_rd, tx_type,
found_rd_info ? matched_rd_info : NULL);
#if !CONFIG_TXK_SEL
// If the current tx_type is not included in the tx_set for the smallest
......@@ -4928,6 +4933,9 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
#endif
}
// Reset the pruning flags.
av1_zero(x->tx_search_prune);
// We should always find at least one candidate unless ref_best_rd is less
// than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
// might have failed to find something better)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment