diff --git a/av1/encoder/block.h b/av1/encoder/block.h index 0b334cf65f61286484c4cee5d9f546165a3b054b..4e6b8eec6532882a8d50534dd72efa7a21bcb21e 100644 --- a/av1/encoder/block.h +++ b/av1/encoder/block.h @@ -365,6 +365,8 @@ struct macroblock { int comp_idx_cost[COMP_INDEX_CONTEXTS][2]; int comp_group_idx_cost[COMP_GROUP_IDX_CONTEXTS][2]; #endif // CONFIG_JNT_COMP + // Bit flags for pruning tx type search, tx split, etc. + int tx_search_prune[EXT_TX_SET_TYPES]; }; static INLINE int is_rect_tx_allowed_bsize(BLOCK_SIZE bsize) { diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c index c6b0c167c9c1a14618d81452f0c4c28fe0b8b1db..dc07334b955b637d96227e4e3a54501a9bd64416 100644 --- a/av1/encoder/rdopt.c +++ b/av1/encoder/rdopt.c @@ -1255,20 +1255,15 @@ static int prune_tx_split(BLOCK_SIZE bsize, const int16_t *diff, float hcorr, return (score > av1_prune_tx_split_thresholds[bidx]); } -static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type, - int tx_type_pruning_aggressiveness, - int use_tx_split_prune) { - if (bsize >= BLOCK_32X32) return 0; +static void prune_tx_2D(BLOCK_SIZE bsize, MACROBLOCK *x, + TX_TYPE_PRUNE_MODE prune_mode, int use_tx_split_prune) { + if (bsize >= BLOCK_32X32) return; aom_clear_system_state(); const struct macroblock_plane *const p = &x->plane[0]; - const int bidx = AOMMAX(bsize - BLOCK_4X4, 0); - const float score_thresh = - av1_prune_2D_adaptive_thresholds[bidx] - [tx_type_pruning_aggressiveness - 1]; float hfeatures[16], vfeatures[16]; float hscores[4], vscores[4]; float scores_2D[16]; - int tx_type_table_2D[16] = { + const int tx_type_table_2D[16] = { DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT, ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST, FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST, @@ -1284,7 +1279,7 @@ static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type, get_horver_correlation_full(p->src_diff, bw, bw, bh, &hfeatures[hfeatures_num - 1], &vfeatures[vfeatures_num - 1]); - + const int bidx = AOMMAX(bsize - BLOCK_4X4, 0); const float *fc1_hor = av1_prune_2D_learned_weights_hor[bidx]; const float *b1_hor = fc1_hor + av1_prune_2D_num_hidden_units_hor[bidx] * hfeatures_num; @@ -1314,22 +1309,43 @@ static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type, score_2D_average /= 16; score_2D_transform_pow8(scores_2D, (20 - score_2D_average)); - // Always keep the TX type with the highest score, prune all others with - // score below score_thresh. - int max_score_i = 0; - float max_score = 0.0f; - for (int i = 0; i < 16; i++) { - if (scores_2D[i] > max_score && - av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) { - max_score = scores_2D[i]; - max_score_i = i; + // TODO(huisu@google.com): support more tx set types. + const int tx_set_types[2] = { EXT_TX_SET_ALL16, EXT_TX_SET_DTT9_IDTX_1DDCT }; + for (int tx_set_idx = 0; tx_set_idx < 2; ++tx_set_idx) { + const int tx_set_type = tx_set_types[tx_set_idx]; + // Always keep the TX type with the highest score, prune all others with + // score below score_thresh. + int max_score_i = 0; + float max_score = 0.0f; + for (int i = 0; i < 16; i++) { + if (scores_2D[i] > max_score && + av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) { + max_score = scores_2D[i]; + max_score_i = i; + } } - } - int prune_bitmask = 0; - for (int i = 0; i < 16; i++) { - if (scores_2D[i] < score_thresh && i != max_score_i) - prune_bitmask |= (1 << tx_type_table_2D[i]); + int pruning_aggressiveness = 0; + if (prune_mode == PRUNE_2D_ACCURATE) { + if (tx_set_type == EXT_TX_SET_ALL16) + pruning_aggressiveness = 6; + else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT) + pruning_aggressiveness = 4; + } else if (prune_mode == PRUNE_2D_FAST) { + if (tx_set_type == EXT_TX_SET_ALL16) + pruning_aggressiveness = 10; + else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT) + pruning_aggressiveness = 7; + } + const float score_thresh = + av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1]; + + int prune_bitmask = 0; + for (int i = 0; i < 16; i++) { + if (scores_2D[i] < score_thresh && i != max_score_i) + prune_bitmask |= (1 << tx_type_table_2D[i]); + } + x->tx_search_prune[tx_set_type] = prune_bitmask; } // Also apply TX size pruning if it's turned on. The value @@ -1342,51 +1358,42 @@ static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type, prune_tx_split(bsize, p->src_diff, hfeatures[hfeatures_num - 1], vfeatures[vfeatures_num - 1]); } - prune_bitmask |= (prune_tx_split_flag << TX_TYPES); - return prune_bitmask; + x->tx_search_prune[0] |= (prune_tx_split_flag << TX_TYPES); } -static int prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x, - const MACROBLOCKD *const xd, int tx_set_type, - int use_tx_split_prune) { +static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x, + const MACROBLOCKD *const xd, int tx_set_type, + int use_tx_split_prune) { int tx_set = ext_tx_set_index[1][tx_set_type]; assert(tx_set >= 0); + av1_zero(x->tx_search_prune); const int *tx_set_1D = ext_tx_used_inter_1D[tx_set]; - switch (cpi->sf.tx_type_search.prune_mode) { - case NO_PRUNE: return 0; break; + case NO_PRUNE: return; case PRUNE_ONE: - if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return 0; - return prune_one_for_sby(cpi, bsize, x, xd); + if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return; + x->tx_search_prune[tx_set_type] = prune_one_for_sby(cpi, bsize, x, xd); break; case PRUNE_TWO: if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) { - if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return 0; - return prune_two_for_sby(cpi, bsize, x, xd, 0, 1); + if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return; + x->tx_search_prune[tx_set_type] = + prune_two_for_sby(cpi, bsize, x, xd, 0, 1); } - if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) - return prune_two_for_sby(cpi, bsize, x, xd, 1, 0); - return prune_two_for_sby(cpi, bsize, x, xd, 1, 1); + if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) { + x->tx_search_prune[tx_set_type] = + prune_two_for_sby(cpi, bsize, x, xd, 1, 0); + } + x->tx_search_prune[tx_set_type] = + prune_two_for_sby(cpi, bsize, x, xd, 1, 1); break; case PRUNE_2D_ACCURATE: - if (tx_set_type == EXT_TX_SET_ALL16) - return prune_tx_2D(bsize, x, tx_set_type, 6, use_tx_split_prune); - else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT) - return prune_tx_2D(bsize, x, tx_set_type, 4, use_tx_split_prune); - else - return 0; - break; case PRUNE_2D_FAST: - if (tx_set_type == EXT_TX_SET_ALL16) - return prune_tx_2D(bsize, x, tx_set_type, 10, use_tx_split_prune); - else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT) - return prune_tx_2D(bsize, x, tx_set_type, 7, use_tx_split_prune); - else - return 0; + prune_tx_2D(bsize, x, cpi->sf.tx_type_search.prune_mode, + use_tx_split_prune); break; + default: assert(0); } - assert(0); - return 0; } static int do_tx_type_search(TX_TYPE tx_type, int prune, @@ -2319,7 +2326,7 @@ static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x, } static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs, - TX_TYPE tx_type, TX_SIZE tx_size, int prune) { + TX_TYPE tx_type, TX_SIZE tx_size) { const MACROBLOCKD *const xd = &x->e_mbd; const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi; const int is_inter = is_inter_block(mbmi); @@ -2337,7 +2344,8 @@ static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs, if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1; if (is_inter) { if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) { - if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode)) + if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type], + cpi->sf.tx_type_search.prune_mode)) return 1; } } @@ -2371,7 +2379,6 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x, int s0 = x->skip_cost[skip_ctx][0]; int s1 = x->skip_cost[skip_ctx][1]; const int is_inter = is_inter_block(mbmi); - int prune = 0; av1_invalid_rd_stats(rd_stats); mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode, is_inter); @@ -2381,7 +2388,7 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x, if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE && !x->use_default_inter_tx_type) { - prune = prune_tx(cpi, bs, x, xd, tx_set_type, 0); + prune_tx(cpi, bs, x, xd, tx_set_type, 0); } #if CONFIG_FILTER_INTRA if (skip_invalid_tx_size_for_filter_intra(mbmi, AOM_PLANE_Y, rd_stats)) { @@ -2399,7 +2406,7 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x, tx_type != get_default_tx_type(0, xd, mbmi->tx_size)) continue; if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) { - if (!do_tx_type_search(tx_type, prune, + if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type], cpi->sf.tx_type_search.prune_mode)) continue; } @@ -2440,6 +2447,8 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x, mbmi->tx_size, cpi->sf.use_fast_coef_costing); } mbmi->tx_type = best_tx_type; + // Reset the pruning flags. + av1_zero(x->tx_search_prune); } static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x, @@ -2507,10 +2516,9 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi, depth = MAX_TX_DEPTH; } - int prune = 0; if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE && !x->use_default_inter_tx_type) { - prune = prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0); + prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0); } last_rd = INT64_MAX; @@ -2526,7 +2534,7 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi, TX_TYPE tx_type; for (tx_type = tx_start; tx_type < tx_end; ++tx_type) { RD_STATS this_rd_stats; - if (skip_txfm_search(cpi, x, bs, tx_type, n, prune)) continue; + if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue; if (mbmi->ref_mv_idx > 0) x->rd_model = LOW_TXFM_RD; rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n); @@ -2578,6 +2586,8 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi, memcpy(x->blk_skip[0], best_blk_skip, sizeof(best_blk_skip[0]) * n4); mbmi->min_tx_size = mbmi->tx_size; + // Reset the pruning flags. + av1_zero(x->tx_search_prune); } static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x, @@ -3820,7 +3830,6 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, RD_STATS *rd_stats, int64_t ref_best_rd, int *is_cost_valid, int fast, - int tx_split_prune_flag, TX_SIZE_RD_INFO_NODE *rd_info_node) { MACROBLOCKD *const xd = &x->e_mbd; MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi; @@ -3920,6 +3929,9 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, #endif } + int tx_split_prune_flag = 0; + if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE) + tx_split_prune_flag = ((x->tx_search_prune[0] >> TX_TYPES) & 1); if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH && tx_split_prune_flag == 0) { const TX_SIZE sub_txs = sub_tx_size_map[1][tx_size]; const int bsw = tx_size_wide_unit[sub_txs]; @@ -3947,7 +3959,7 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, select_tx_block( cpi, x, offsetr, offsetc, plane, block, sub_txs, depth + 1, plane_bsize, ta, tl, tx_above, tx_left, &this_rd_stats, - ref_best_rd - tmp_rd, &this_cost_valid, fast, 0, + ref_best_rd - tmp_rd, &this_cost_valid, fast, (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL); #if CONFIG_DIST_8X8 @@ -4099,7 +4111,6 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, RD_STATS *rd_stats, BLOCK_SIZE bsize, int64_t ref_best_rd, int fast, - int tx_split_prune_flag, TX_SIZE_RD_INFO_NODE *rd_info_tree) { MACROBLOCKD *const xd = &x->e_mbd; int is_cost_valid = 1; @@ -4138,7 +4149,7 @@ static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, select_tx_block(cpi, x, idy, idx, 0, block, max_tx_size, init_depth, plane_bsize, ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, ref_best_rd - this_rd, &is_cost_valid, - fast, tx_split_prune_flag, rd_info_tree); + fast, rd_info_tree); if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) { av1_invalid_rd_stats(rd_stats); return; @@ -4172,7 +4183,6 @@ static int64_t select_tx_size_fix_type(const AV1_COMP *cpi, MACROBLOCK *x, RD_STATS *rd_stats, BLOCK_SIZE bsize, int mi_row, int mi_col, int64_t ref_best_rd, TX_TYPE tx_type, - int tx_split_prune_flag, TX_SIZE_RD_INFO_NODE *rd_info_tree) { const int fast = cpi->sf.tx_size_search_method > USE_FULL_RD; const AV1_COMMON *const cm = &cpi->common; @@ -4199,7 +4209,7 @@ static int64_t select_tx_size_fix_type(const AV1_COMP *cpi, MACROBLOCK *x, mbmi->tx_type = tx_type; select_inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, fast, - tx_split_prune_flag, rd_info_tree); + rd_info_tree); if (rd_stats->rate == INT_MAX) return INT64_MAX; mbmi->min_tx_size = mbmi->inter_tx_size[0][0]; @@ -4800,7 +4810,6 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x, #endif const int n4 = bsize_to_num_blk(bsize); int idx, idy; - int prune = 0; // Get the tx_size 1 level down const TX_SIZE min_tx_size = sub_tx_size_map[1][max_txsize_rect_lookup[1][bsize]]; @@ -4851,16 +4860,12 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x, if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE && !x->use_default_inter_tx_type && !xd->lossless[mbmi->segment_id]) { - prune = prune_tx(cpi, bsize, x, xd, tx_set_type, - cpi->sf.tx_type_search.use_tx_size_pruning); + prune_tx(cpi, bsize, x, xd, tx_set_type, + cpi->sf.tx_type_search.use_tx_size_pruning); } int found = 0; - int tx_split_prune_flag = 0; - if (is_inter && cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE) - tx_split_prune_flag = ((prune >> TX_TYPES) & 1); - for (tx_type = txk_start; tx_type < txk_end; ++tx_type) { RD_STATS this_rd_stats; av1_init_rd_stats(&this_rd_stats); @@ -4868,7 +4873,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x, #if !CONFIG_TXK_SEL if (is_inter) { if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) { - if (!do_tx_type_search(tx_type, prune, + if (!do_tx_type_search(tx_type, x->tx_search_prune[tx_set_type], cpi->sf.tx_type_search.prune_mode)) continue; } @@ -4882,7 +4887,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x, if (tx_type != DCT_DCT) continue; rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col, - ref_best_rd, tx_type, tx_split_prune_flag, + ref_best_rd, tx_type, found_rd_info ? matched_rd_info : NULL); #if !CONFIG_TXK_SEL // If the current tx_type is not included in the tx_set for the smallest @@ -4928,6 +4933,9 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x, #endif } + // Reset the pruning flags. + av1_zero(x->tx_search_prune); + // We should always find at least one candidate unless ref_best_rd is less // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type // might have failed to find something better)