Commit 79a37242 authored by Alexander Bokov's avatar Alexander Bokov Committed by Hui Su

Introducing a model for pruning the TX size search

Use a neural-network-based binary classifier to predict the first split
decision on the highest level of the TX size RD search tree. Depending
on how confident we are in the prediction we either keep full unmodified
TX size search or use the largest possible TX size and stop any further
search.

Average speed-up: 3-4%
Quality loss (lowres): 0.062%
Quality loss (midres): 0.018%

Change-Id: I64c0317db74cbeddfbdf772147c43e99e275891f
parent 7ac01f8f
......@@ -1191,6 +1191,34 @@ static void get_energy_distribution_finer(const int16_t *diff, int stride,
for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
}
// Instead of 1D projections of the block energy distribution computed by
// get_energy_distribution_finer() this function computes a full
// two-dimensional energy distribution of the input block.
static void get_2D_energy_distribution(const int16_t *diff, int stride, int bw,
int bh, float *edist) {
unsigned int esq[256] = { 0 };
const int esq_w = bw >> 2;
const int esq_h = bh >> 2;
const int esq_sz = esq_w * esq_h;
uint64_t total = 0;
for (int i = 0; i < bh; i += 4) {
for (int j = 0; j < bw; j += 4) {
unsigned int cur_sum_energy = 0;
for (int k = 0; k < 4; k++) {
const int16_t *cur_diff = diff + (i + k) * stride + j;
cur_sum_energy += cur_diff[0] * cur_diff[0] +
cur_diff[1] * cur_diff[1] +
cur_diff[2] * cur_diff[2] + cur_diff[3] * cur_diff[3];
}
esq[(i >> 2) * esq_w + (j >> 2)] = cur_sum_energy;
total += cur_sum_energy;
}
}
const float e_recip = 1.0f / (float)total;
for (int i = 0; i < esq_sz - 1; i++) edist[i] = esq[i] * e_recip;
}
// Similar to get_horver_correlation, but also takes into account first
// row/column, when computing horizontal/vertical correlation.
static void get_horver_correlation_full(const int16_t *diff, int stride, int w,
......@@ -1302,13 +1330,65 @@ static void score_2D_transform_pow8(float *scores_2D, float shift) {
for (i = 0; i < 16; i++) scores_2D[i] /= sum;
}
static int prune_tx_types_2D(BLOCK_SIZE bsize, const MACROBLOCK *x,
int tx_set_type, int pruning_aggressiveness) {
// Similarly to compute_1D_scores() performs a forward pass through a
// neural network with two fully-connected layers. The only difference
// is that it assumes 1 output neuron, as required by the classifier used
// for TX size pruning.
static float compute_tx_split_prune_score(float *features, int num_features,
const float *fc1, const float *b1,
const float *fc2, float b2,
int num_hidden_units) {
assert(num_hidden_units <= 64);
float hidden_layer[64];
for (int i = 0; i < num_hidden_units; i++) {
const float *cur_coef = fc1 + i * num_features;
hidden_layer[i] = 0.0f;
for (int j = 0; j < num_features; j++)
hidden_layer[i] += cur_coef[j] * features[j];
hidden_layer[i] = AOMMAX(hidden_layer[i] + b1[i], 0.0f);
}
float dst_score = 0.0f;
for (int j = 0; j < num_hidden_units; j++)
dst_score += fc2[j] * hidden_layer[j];
dst_score += b2;
return dst_score;
}
static int prune_tx_split(BLOCK_SIZE bsize, const int16_t *diff, float hcorr,
float vcorr) {
if (bsize <= BLOCK_4X4 || bsize > BLOCK_16X16) return 0;
float features[17];
const int bw = block_size_wide[bsize], bh = block_size_high[bsize];
const int feature_num = (bw / 4) * (bh / 4) + 1;
assert(feature_num <= 17);
get_2D_energy_distribution(diff, bw, bw, bh, features);
features[feature_num - 2] = hcorr;
features[feature_num - 1] = vcorr;
const int bidx = bsize - BLOCK_4X4 - 1;
const float *fc1 = av1_prune_tx_split_learned_weights[bidx];
const float *b1 =
fc1 + av1_prune_tx_split_num_hidden_units[bidx] * feature_num;
const float *fc2 = b1 + av1_prune_tx_split_num_hidden_units[bidx];
float b2 = *(fc2 + av1_prune_tx_split_num_hidden_units[bidx]);
float score =
compute_tx_split_prune_score(features, feature_num, fc1, b1, fc2, b2,
av1_prune_tx_split_num_hidden_units[bidx]);
return (score > av1_prune_tx_split_thresholds[bidx]);
}
static int prune_tx_2D(BLOCK_SIZE bsize, const MACROBLOCK *x, int tx_set_type,
int tx_type_pruning_aggressiveness,
int use_tx_split_prune) {
if (bsize >= BLOCK_32X32) return 0;
const struct macroblock_plane *const p = &x->plane[0];
const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
const float score_thresh =
av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1];
av1_prune_2D_adaptive_thresholds[bidx]
[tx_type_pruning_aggressiveness - 1];
float hfeatures[16], vfeatures[16];
float hscores[4], vscores[4];
......@@ -1377,11 +1457,23 @@ static int prune_tx_types_2D(BLOCK_SIZE bsize, const MACROBLOCK *x,
prune_bitmask |= (1 << tx_type_table_2D[i]);
}
// Also apply TX size pruning if it's turned on. The value
// of prune_tx_split_flag indicates whether we should do
// full TX size search (flag=0) or use the largest available
// TX size without performing any further search (flag=1).
int prune_tx_split_flag = 0;
if (use_tx_split_prune) {
prune_tx_split_flag =
prune_tx_split(bsize, p->src_diff, hfeatures[hfeatures_num - 1],
vfeatures[vfeatures_num - 1]);
}
prune_bitmask |= (prune_tx_split_flag << TX_TYPES);
return prune_bitmask;
}
static int prune_tx_types(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
const MACROBLOCKD *const xd, int tx_set_type) {
static int prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
const MACROBLOCKD *const xd, int tx_set_type,
int use_tx_split_prune) {
int tx_set = ext_tx_set_index[1][tx_set_type];
assert(tx_set >= 0);
const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
......@@ -1403,17 +1495,17 @@ static int prune_tx_types(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
break;
case PRUNE_2D_ACCURATE:
if (tx_set_type == EXT_TX_SET_ALL16)
return prune_tx_types_2D(bsize, x, tx_set_type, 6);
return prune_tx_2D(bsize, x, tx_set_type, 6, use_tx_split_prune);
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
return prune_tx_types_2D(bsize, x, tx_set_type, 4);
return prune_tx_2D(bsize, x, tx_set_type, 4, use_tx_split_prune);
else
return 0;
break;
case PRUNE_2D_FAST:
if (tx_set_type == EXT_TX_SET_ALL16)
return prune_tx_types_2D(bsize, x, tx_set_type, 10);
return prune_tx_2D(bsize, x, tx_set_type, 10, use_tx_split_prune);
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
return prune_tx_types_2D(bsize, x, tx_set_type, 7);
return prune_tx_2D(bsize, x, tx_set_type, 7, use_tx_split_prune);
else
return 0;
break;
......@@ -2499,7 +2591,7 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
prune = prune_tx_types(cpi, bs, x, xd, tx_set_type);
prune = prune_tx(cpi, bs, x, xd, tx_set_type, 0);
}
if (get_ext_tx_types(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used) >
1 &&
......@@ -2772,7 +2864,7 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
int prune = 0;
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
prune = prune_tx_types(cpi, bs, x, xd, EXT_TX_SET_ALL16);
prune = prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16, 0);
}
last_rd = INT64_MAX;
......@@ -3952,7 +4044,8 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
ENTROPY_CONTEXT *ta, ENTROPY_CONTEXT *tl,
TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
RD_STATS *rd_stats, int64_t ref_best_rd,
int *is_cost_valid, int fast) {
int *is_cost_valid, int fast,
int tx_split_prune_flag) {
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
struct macroblock_plane *const p = &x->plane[plane];
......@@ -4172,7 +4265,7 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
#endif
}
if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH
if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH && tx_split_prune_flag == 0
#if CONFIG_MRC_TX
// If the tx type we are trying is MRC_DCT, we cannot partition the
// transform into anything smaller than TX_32X32
......@@ -4203,7 +4296,7 @@ static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
select_tx_block(cpi, x, offsetr, offsetc, plane, block, sub_txs,
depth + 1, plane_bsize, ta, tl, tx_above, tx_left,
&this_rd_stats, ref_best_rd - tmp_rd, &this_cost_valid,
fast);
fast, 0);
#if CONFIG_DIST_8X8
if (x->using_dist_8x8 && plane == 0 && tx_size == TX_8X8) {
sub8x8_eob[i] = p->eobs[block];
......@@ -4388,7 +4481,8 @@ static int get_search_init_depth(int mi_width, int mi_height,
static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int64_t ref_best_rd, int fast) {
int64_t ref_best_rd, int fast,
int tx_split_prune_flag) {
MACROBLOCKD *const xd = &x->e_mbd;
int is_cost_valid = 1;
int64_t this_rd = 0;
......@@ -4426,7 +4520,7 @@ static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
select_tx_block(cpi, x, idy, idx, 0, block, max_tx_size, init_depth,
plane_bsize, ctxa, ctxl, tx_above, tx_left,
&pn_rd_stats, ref_best_rd - this_rd, &is_cost_valid,
fast);
fast, tx_split_prune_flag);
if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
av1_invalid_rd_stats(rd_stats);
return;
......@@ -4452,7 +4546,8 @@ static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
static int64_t select_tx_size_fix_type(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS *rd_stats, BLOCK_SIZE bsize,
int mi_row, int mi_col,
int64_t ref_best_rd, TX_TYPE tx_type) {
int64_t ref_best_rd, TX_TYPE tx_type,
int tx_split_prune_flag) {
const int fast = cpi->sf.tx_size_search_method > USE_FULL_RD;
const AV1_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
......@@ -4477,7 +4572,8 @@ static int64_t select_tx_size_fix_type(const AV1_COMP *cpi, MACROBLOCK *x,
(void)mi_col;
mbmi->tx_type = tx_type;
select_inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, fast);
select_inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, fast,
tx_split_prune_flag);
if (rd_stats->rate == INT_MAX) return INT64_MAX;
mbmi->min_tx_size = get_min_tx_size(mbmi->inter_tx_size[0][0]);
......@@ -4950,11 +5046,16 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type && !xd->lossless[mbmi->segment_id]) {
prune = prune_tx_types(cpi, bsize, x, xd, tx_set_type);
prune = prune_tx(cpi, bsize, x, xd, tx_set_type,
cpi->sf.tx_type_search.use_tx_size_pruning);
}
int found = 0;
int tx_split_prune_flag = 0;
if (is_inter && cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE)
tx_split_prune_flag = ((prune >> TX_TYPES) & 1);
for (tx_type = txk_start; tx_type < txk_end; ++tx_type) {
RD_STATS this_rd_stats;
av1_init_rd_stats(&this_rd_stats);
......@@ -4992,7 +5093,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
if (tx_type != DCT_DCT) continue;
rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col,
ref_best_rd, tx_type);
ref_best_rd, tx_type, tx_split_prune_flag);
// If the current tx_type is not included in the tx_set for the smallest
// tx size found, then all vartx partitions were actually transformed with
// DCT_DCT and we should avoid picking it.
......@@ -5027,7 +5128,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
RD_STATS this_rd_stats;
mbmi->use_lgt = 1;
rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, mi_row, mi_col,
ref_best_rd, 0);
ref_best_rd, 0, 0);
if (rd < best_rd) {
best_rd = rd;
*rd_stats = this_rd_stats;
......
......@@ -397,6 +397,7 @@ void av1_set_speed_features_framesize_independent(AV1_COMP *cpi) {
sf->alt_ref_search_fp = 0;
sf->partition_search_type = SEARCH_PARTITION;
sf->tx_type_search.prune_mode = PRUNE_2D_ACCURATE;
sf->tx_type_search.use_tx_size_pruning = 1;
sf->tx_type_search.use_skip_flag_prediction = 1;
sf->tx_type_search.fast_intra_tx_type_search = 0;
sf->tx_type_search.fast_inter_tx_type_search = 0;
......
......@@ -207,6 +207,12 @@ typedef struct {
// Use a skip flag prediction model to detect blocks with skip = 1 early
// and avoid doing full TX type search for such blocks.
int use_skip_flag_prediction;
// Use a model to predict TX block split decisions on the highest level
// of TX partition tree and apply adaptive pruning based on that to speed up
// RD search (currently works only when prune_mode equals to PRUNE_2D_ACCURATE
// or PRUNE_2D_FAST).
int use_tx_size_pruning;
} TX_TYPE_SEARCH;
typedef enum {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment