Commit 0c7eb10d authored by Alexander Bokov's avatar Alexander Bokov Committed by Hui Su
Browse files

Improving the model for pruning the TX type search

Introduces two new TX type pruning modes that provide better
speed-quality trade-off compared to the existing ones. A shallow
neural network with one hidden layer trained separately for each
block size is used as a prediction model. The new modes differ in
thresholds applied to the output of the neural net, so that they
prune different number of TX types on average.

Owing to relatively low quality loss PRUNE_2D_ACCURATE is used
by default, regardless of speed settings. Starting with speed
setting of 3 we switch to PRUNE_2D_FAST mode to get better
speed-up.

Evaluation results:
----------------------------------------------------------
Prune mode | Avg. speed-up | Quality loss | Quality loss
           |(high bitrates)|   (lowres)   |   (midres)
----------------------------------------------------------
PRUNE_ONE  |     18.7%     |    0.396%    |    0.308%
----------------------------------------------------------
PRUNE_TWO  |     27.2%     |    0.439%    |    0.389%
----------------------------------------------------------
PRUNE_2D_  |     18.8%     |    0.032%    |    0.063%
ACCURATE   |               |              |
----------------------------------------------------------
PRUNE_2D_  |     33.3%     |    0.504%    |     ---
FAST       |               |              |

Change-Id: Ibd59f52eef493a499e529d824edad267daa65f9d
parent 0b34a79f
......@@ -173,4 +173,8 @@ ifeq ($(CONFIG_GLOBAL_MOTION),yes)
AV1_CX_SRCS-$(HAVE_SSE4_1) += encoder/x86/corner_match_sse4.c
endif
ifeq ($(CONFIG_EXT_TX),yes)
AV1_CX_SRCS-yes += encoder/tx_prune_model_weights.h
endif
AV1_CX_SRCS-yes := $(filter-out $(AV1_CX_SRCS_REMOVE-yes),$(AV1_CX_SRCS-yes))
......@@ -62,6 +62,9 @@
#include "av1/encoder/rd.h"
#include "av1/encoder/rdopt.h"
#include "av1/encoder/tokenize.h"
#if CONFIG_EXT_TX
#include "av1/encoder/tx_prune_model_weights.h"
#endif // CONFIG_EXT_TX
#if CONFIG_PVQ
#include "av1/encoder/pvq_encoder.h"
#include "av1/common/pvq.h"
......@@ -1146,50 +1149,315 @@ static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
{ 1, 0, 0, 1 },
#endif // CONFIG_MRC_TX
};
static void get_energy_distribution_finer(const int16_t *diff, int stride,
int bw, int bh, float *hordist,
float *verdist) {
// First compute downscaled block energy values (esq); downscale factors
// are defined by w_shift and h_shift.
unsigned int esq[256];
const int w_shift = bw <= 8 ? 0 : 1;
const int h_shift = bh <= 8 ? 0 : 1;
const int esq_w = bw <= 8 ? bw : bw / 2;
const int esq_h = bh <= 8 ? bh : bh / 2;
const int esq_sz = esq_w * esq_h;
int i, j;
memset(esq, 0, esq_sz * sizeof(esq[0]));
for (i = 0; i < bh; i++) {
unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
const int16_t *cur_diff_row = diff + i * stride;
for (j = 0; j < bw; j++) {
cur_esq_row[j >> w_shift] += cur_diff_row[j] * cur_diff_row[j];
}
}
uint64_t total = 0;
for (i = 0; i < esq_sz; i++) total += esq[i];
// Output hordist and verdist arrays are normalized 1D projections of esq
if (total == 0) {
float hor_val = 1.0f / esq_w;
for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
float ver_val = 1.0f / esq_h;
for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
return;
}
const float e_recip = 1.0f / (float)total;
memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
const unsigned int *cur_esq_row;
for (i = 0; i < esq_h - 1; i++) {
cur_esq_row = esq + i * esq_w;
for (j = 0; j < esq_w - 1; j++) {
hordist[j] += (float)cur_esq_row[j];
verdist[i] += (float)cur_esq_row[j];
}
verdist[i] += (float)cur_esq_row[j];
}
cur_esq_row = esq + i * esq_w;
for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
}
// Similar to get_horver_correlation, but also takes into account first
// row/column, when computing horizontal/vertical correlation.
static void get_horver_correlation_full(const int16_t *diff, int stride, int w,
int h, float *hcorr, float *vcorr) {
const float num_hor = h * (w - 1);
const float num_ver = (h - 1) * w;
int i, j;
// The following notation is used:
// x - current pixel
// y - left neighbor pixel
// z - top neighbor pixel
int64_t xy_sum = 0, xz_sum = 0;
int64_t xhor_sum = 0, xver_sum = 0, y_sum = 0, z_sum = 0;
int64_t x2hor_sum = 0, x2ver_sum = 0, y2_sum = 0, z2_sum = 0;
int16_t x, y, z;
for (j = 1; j < w; ++j) {
x = diff[j];
y = diff[j - 1];
xy_sum += x * y;
xhor_sum += x;
y_sum += y;
x2hor_sum += x * x;
y2_sum += y * y;
}
for (i = 1; i < h; ++i) {
x = diff[i * stride];
z = diff[(i - 1) * stride];
xz_sum += x * z;
xver_sum += x;
z_sum += z;
x2ver_sum += x * x;
z2_sum += z * z;
for (j = 1; j < w; ++j) {
x = diff[i * stride + j];
y = diff[i * stride + j - 1];
z = diff[(i - 1) * stride + j];
xy_sum += x * y;
xz_sum += x * z;
xhor_sum += x;
xver_sum += x;
y_sum += y;
z_sum += z;
x2hor_sum += x * x;
x2ver_sum += x * x;
y2_sum += y * y;
z2_sum += z * z;
}
}
const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor;
const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor;
const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor;
const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver;
const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver;
const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver;
*hcorr = *vcorr = 1;
if (xhor_var_n > 0 && y_var_n > 0) {
*hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n);
*hcorr = *hcorr < 0 ? 0 : *hcorr;
}
if (xver_var_n > 0 && z_var_n > 0) {
*vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n);
*vcorr = *vcorr < 0 ? 0 : *vcorr;
}
}
// Performs a forward pass through a neural network with 2 fully-connected
// layers, assuming ReLU as activation function. Number of output neurons
// is always equal to 4.
// fc1, fc2 - weight matrices of the respective layers.
// b1, b2 - bias vectors of the respective layers.
static void compute_1D_scores(float *features, int num_features,
const float *fc1, const float *b1,
const float *fc2, const float *b2,
int num_hidden_units, float *dst_scores) {
assert(num_hidden_units <= 32);
float hidden_layer[32];
for (int i = 0; i < num_hidden_units; i++) {
const float *cur_coef = fc1 + i * num_features;
hidden_layer[i] = 0.0f;
for (int j = 0; j < num_features; j++)
hidden_layer[i] += cur_coef[j] * features[j];
hidden_layer[i] = AOMMAX(hidden_layer[i] + b1[i], 0.0f);
}
for (int i = 0; i < 4; i++) {
const float *cur_coef = fc2 + i * num_hidden_units;
dst_scores[i] = 0.0f;
for (int j = 0; j < num_hidden_units; j++)
dst_scores[i] += cur_coef[j] * hidden_layer[j];
dst_scores[i] += b2[i];
}
}
// Transforms raw scores into a probability distribution across 16 TX types
static void score_2D_transform_pow8(float *scores_2D, float shift) {
float sum = 0.0f;
int i;
for (i = 0; i < 16; i++) {
float v, v2, v4;
v = AOMMAX(scores_2D[i] + shift, 0.0f);
v2 = v * v;
v4 = v2 * v2;
scores_2D[i] = v4 * v4;
sum += scores_2D[i];
}
for (i = 0; i < 16; i++) scores_2D[i] /= sum;
}
static int prune_tx_types_2D(BLOCK_SIZE bsize, const MACROBLOCK *x,
int tx_set_type, int pruning_aggressiveness) {
if (bsize >= BLOCK_32X32) return 0;
const struct macroblock_plane *const p = &x->plane[0];
const int bidx = AOMMAX(bsize - BLOCK_4X4, 0);
const float score_thresh =
av1_prune_2D_adaptive_thresholds[bidx][pruning_aggressiveness - 1];
float hfeatures[16], vfeatures[16];
float hscores[4], vscores[4];
float scores_2D[16];
int tx_type_table_2D[16] = {
DCT_DCT, DCT_ADST, DCT_FLIPADST, V_DCT,
ADST_DCT, ADST_ADST, ADST_FLIPADST, V_ADST,
FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
H_DCT, H_ADST, H_FLIPADST, IDTX
};
const int bw = block_size_wide[bsize], bh = block_size_high[bsize];
const int hfeatures_num = bw <= 8 ? bw : bw / 2;
const int vfeatures_num = bh <= 8 ? bh : bh / 2;
assert(hfeatures_num <= 16);
assert(vfeatures_num <= 16);
get_energy_distribution_finer(p->src_diff, bw, bw, bh, hfeatures, vfeatures);
get_horver_correlation_full(p->src_diff, bw, bw, bh,
&hfeatures[hfeatures_num - 1],
&vfeatures[vfeatures_num - 1]);
const float *fc1_hor = av1_prune_2D_learned_weights_hor[bidx];
const float *b1_hor =
fc1_hor + av1_prune_2D_num_hidden_units_hor[bidx] * hfeatures_num;
const float *fc2_hor = b1_hor + av1_prune_2D_num_hidden_units_hor[bidx];
const float *b2_hor = fc2_hor + av1_prune_2D_num_hidden_units_hor[bidx] * 4;
compute_1D_scores(hfeatures, hfeatures_num, fc1_hor, b1_hor, fc2_hor, b2_hor,
av1_prune_2D_num_hidden_units_hor[bidx], hscores);
const float *fc1_ver = av1_prune_2D_learned_weights_ver[bidx];
const float *b1_ver =
fc1_ver + av1_prune_2D_num_hidden_units_ver[bidx] * vfeatures_num;
const float *fc2_ver = b1_ver + av1_prune_2D_num_hidden_units_ver[bidx];
const float *b2_ver = fc2_ver + av1_prune_2D_num_hidden_units_ver[bidx] * 4;
compute_1D_scores(vfeatures, vfeatures_num, fc1_ver, b1_ver, fc2_ver, b2_ver,
av1_prune_2D_num_hidden_units_ver[bidx], vscores);
float score_2D_average = 0.0f;
for (int i = 0; i < 4; i++) {
float *cur_scores_2D = scores_2D + i * 4;
cur_scores_2D[0] = vscores[i] * hscores[0];
cur_scores_2D[1] = vscores[i] * hscores[1];
cur_scores_2D[2] = vscores[i] * hscores[2];
cur_scores_2D[3] = vscores[i] * hscores[3];
score_2D_average += cur_scores_2D[0] + cur_scores_2D[1] + cur_scores_2D[2] +
cur_scores_2D[3];
}
score_2D_average /= 16;
score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
// Always keep the TX type with the highest score, prune all others with
// score below score_thresh.
int max_score_i = 0;
float max_score = 0.0f;
for (int i = 0; i < 16; i++) {
if (scores_2D[i] > max_score &&
av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
max_score = scores_2D[i];
max_score_i = i;
}
}
int prune_bitmask = 0;
for (int i = 0; i < 16; i++) {
if (scores_2D[i] < score_thresh && i != max_score_i)
prune_bitmask |= (1 << tx_type_table_2D[i]);
}
return prune_bitmask;
}
#endif // CONFIG_EXT_TX
static int prune_tx_types(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
const MACROBLOCKD *const xd, int tx_set) {
const MACROBLOCKD *const xd, int tx_set_type) {
#if CONFIG_EXT_TX
const int *tx_set_1D = tx_set >= 0 ? ext_tx_used_inter_1D[tx_set] : NULL;
int tx_set = ext_tx_set_index[1][tx_set_type];
assert(tx_set >= 0);
const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
#else
const int tx_set_1D[TX_TYPES_1D] = { 0 };
(void)tx_set_type;
#endif // CONFIG_EXT_TX
switch (cpi->sf.tx_type_search.prune_mode) {
case NO_PRUNE: return 0; break;
case PRUNE_ONE:
if ((tx_set >= 0) && !(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D]))
return 0;
if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return 0;
return prune_one_for_sby(cpi, bsize, x, xd);
break;
#if CONFIG_EXT_TX
case PRUNE_TWO:
if ((tx_set >= 0) && !(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return 0;
return prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
}
if ((tx_set >= 0) && !(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D]))
return prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
return prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
break;
case PRUNE_2D_ACCURATE:
if (tx_set_type == EXT_TX_SET_ALL16)
return prune_tx_types_2D(bsize, x, tx_set_type, 6);
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
return prune_tx_types_2D(bsize, x, tx_set_type, 4);
else
return 0;
break;
case PRUNE_2D_FAST:
if (tx_set_type == EXT_TX_SET_ALL16)
return prune_tx_types_2D(bsize, x, tx_set_type, 10);
else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
return prune_tx_types_2D(bsize, x, tx_set_type, 7);
else
return 0;
break;
#endif // CONFIG_EXT_TX
}
assert(0);
return 0;
}
static int do_tx_type_search(TX_TYPE tx_type, int prune) {
static int do_tx_type_search(TX_TYPE tx_type, int prune,
TX_TYPE_PRUNE_MODE mode) {
// TODO(sarahparker) implement for non ext tx
#if CONFIG_EXT_TX
return !(((prune >> vtx_tab[tx_type]) & 1) |
((prune >> (htx_tab[tx_type] + 8)) & 1));
if (mode >= PRUNE_2D_ACCURATE) {
return !((prune >> tx_type) & 1);
} else {
return !(((prune >> vtx_tab[tx_type]) & 1) |
((prune >> (htx_tab[tx_type] + 8)) & 1));
}
#else
// temporary to avoid compiler warnings
(void)vtx_tab;
(void)htx_tab;
(void)tx_type;
(void)prune;
(void)mode;
return 1;
#endif // CONFIG_EXT_TX
}
......@@ -2290,16 +2558,11 @@ static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
}
static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
TX_TYPE tx_type, TX_SIZE tx_size) {
TX_TYPE tx_type, TX_SIZE tx_size, int prune) {
const MACROBLOCKD *const xd = &x->e_mbd;
const MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
const TX_SIZE max_tx_size = max_txsize_lookup[bs];
const int is_inter = is_inter_block(mbmi);
int prune = 0;
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
// passing -1 in for tx_type indicates that all 1D
// transforms should be considered for pruning
prune = prune_tx_types(cpi, bs, x, xd, -1);
#if CONFIG_MRC_TX
// MRC_DCT only implemented for TX_32X32 so only include this tx in
......@@ -2329,7 +2592,8 @@ static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
if (!av1_ext_tx_used[tx_set_type][tx_type]) return 1;
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
if (!do_tx_type_search(tx_type, prune)) return 1;
if (!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
return 1;
}
} else {
if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
......@@ -2339,7 +2603,7 @@ static int skip_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bs,
#else // CONFIG_EXT_TX
if (tx_size >= TX_32X32 && tx_type != DCT_DCT) return 1;
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!do_tx_type_search(tx_type, prune))
!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
return 1;
#endif // CONFIG_EXT_TX
return 0;
......@@ -2389,18 +2653,18 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
mbmi->min_tx_size = get_min_tx_size(mbmi->tx_size);
#endif // CONFIG_VAR_TX
#if CONFIG_EXT_TX
int ext_tx_set =
get_ext_tx_set(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
const TxSetType tx_set_type =
get_ext_tx_set_type(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used);
#endif // CONFIG_EXT_TX
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
#if CONFIG_EXT_TX
prune = prune_tx_types(cpi, bs, x, xd, ext_tx_set);
prune = prune_tx_types(cpi, bs, x, xd, tx_set_type);
#else
prune = prune_tx_types(cpi, bs, x, xd, 0);
#endif // CONFIG_EXT_TX
}
#if CONFIG_EXT_TX
if (get_ext_tx_types(mbmi->tx_size, bs, is_inter, cm->reduced_tx_set_used) >
1 &&
......@@ -2420,7 +2684,9 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
tx_type != get_default_tx_type(0, xd, 0, mbmi->tx_size))
continue;
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
if (!do_tx_type_search(tx_type, prune)) continue;
if (!do_tx_type_search(tx_type, prune,
cpi->sf.tx_type_search.prune_mode))
continue;
}
} else {
if (x->use_default_intra_tx_type &&
......@@ -2512,7 +2778,8 @@ static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
av1_tx_type_cost(cm, x, xd, bs, plane, mbmi->tx_size, tx_type);
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!do_tx_type_search(tx_type, prune))
!do_tx_type_search(tx_type, prune,
cpi->sf.tx_type_search.prune_mode))
continue;
}
if (this_rd_stats.skip)
......@@ -2733,6 +3000,16 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
end_tx = chosen_tx_size;
}
int prune = 0;
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type) {
#if CONFIG_EXT_TX
prune = prune_tx_types(cpi, bs, x, xd, EXT_TX_SET_ALL16);
#else
prune = prune_tx_types(cpi, bs, x, xd, 0);
#endif // CONFIG_EXT_TX
}
last_rd = INT64_MAX;
for (n = start_tx; n >= end_tx; --n) {
#if CONFIG_EXT_TX && CONFIG_RECT_TX
......@@ -2748,7 +3025,7 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
TX_TYPE tx_type;
for (tx_type = tx_start; tx_type < tx_end; ++tx_type) {
RD_STATS this_rd_stats;
if (skip_txfm_search(cpi, x, bs, tx_type, n)) continue;
if (skip_txfm_search(cpi, x, bs, tx_type, n, prune)) continue;
rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, tx_type, n);
#if CONFIG_PVQ
od_encode_rollback(&x->daala_enc, &buf);
......@@ -2785,8 +3062,8 @@ static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
}
#if CONFIG_LGT_FROM_PRED
mbmi->use_lgt = 1;
if (is_lgt_allowed(mbmi->mode, n) && !skip_txfm_search(cpi, x, bs, 0, n) &&
!breakout) {
if (is_lgt_allowed(mbmi->mode, n) &&
!skip_txfm_search(cpi, x, bs, 0, n, prune) && !breakout) {
RD_STATS this_rd_stats;
rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, 0, n);
if (rd < best_rd) {
......@@ -5310,8 +5587,6 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
#if CONFIG_EXT_TX
const TxSetType tx_set_type = get_ext_tx_set_type(
max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
const int ext_tx_set =
get_ext_tx_set(max_tx_size, bsize, is_inter, cm->reduced_tx_set_used);
#endif // CONFIG_EXT_TX
av1_invalid_rd_stats(rd_stats);
......@@ -5353,12 +5628,14 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
}
}
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE)
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!x->use_default_inter_tx_type && !xd->lossless[mbmi->segment_id]) {
#if CONFIG_EXT_TX
prune = prune_tx_types(cpi, bsize, x, xd, ext_tx_set);
prune = prune_tx_types(cpi, bsize, x, xd, tx_set_type);
#else
prune = prune_tx_types(cpi, bsize, x, xd, 0);
#endif // CONFIG_EXT_TX
}
int found = 0;
......@@ -5377,7 +5654,9 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
if (is_inter) {
if (cpi->sf.tx_type_search.prune_mode > NO_PRUNE) {
if (!do_tx_type_search(tx_type, prune)) continue;
if (!do_tx_type_search(tx_type, prune,
cpi->sf.tx_type_search.prune_mode))
continue;
}
} else {
if (!ALLOW_INTRA_EXT_TX && bsize >= BLOCK_8X8) {
......@@ -5386,7 +5665,7 @@ static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
}
#else // CONFIG_EXT_TX
if (is_inter && cpi->sf.tx_type_search.prune_mode > NO_PRUNE &&
!do_tx_type_search(tx_type, prune))
!do_tx_type_search(tx_type, prune, cpi->sf.tx_type_search.prune_mode))
continue;
#endif // CONFIG_EXT_TX
if (is_inter && x->use_default_inter_tx_type &&
......
......@@ -192,7 +192,7 @@ static void set_good_speed_features_framesize_independent(AV1_COMP *cpi,
sf->tx_size_search_breakout = 1;
sf->partition_search_breakout_rate_thr = 80;
sf->tx_type_search.prune_mode = PRUNE_ONE;
// Use transform domain distortion.
// Note var-tx expt always uses pixel domain distortion.
sf->use_transform_domain_distortion = 1;
......@@ -215,7 +215,7 @@ static void set_good_speed_features_framesize_independent(AV1_COMP *cpi,
sf->use_upsampled_references = 0;
sf->adaptive_rd_thresh = 2;
#if CONFIG_EXT_TX
sf->tx_type_search.prune_mode = PRUNE_TWO;
sf->tx_type_search.prune_mode = PRUNE_2D_FAST;
#endif
#if CONFIG_GLOBAL_MOTION
sf->gm_search_type = GM_DISABLE_SEARCH;
......@@ -392,7 +392,11 @@ void av1_set_speed_features_framesize_independent(AV1_COMP *cpi) {
sf->cb_partition_search = 0;
sf->alt_ref_search_fp = 0;
sf->partition_search_type = SEARCH_PARTITION;
#if CONFIG_EXT_TX
sf->tx_type_search.prune_mode = PRUNE_2D_ACCURATE;
#else
sf->tx_type_search.prune_mode = NO_PRUNE;
#endif // CONFIG_EXT_TX
sf->tx_type_search.use_skip_flag_prediction = 1;
sf->tx_type_search.fast_intra_tx_type_search = 0;
sf->tx_type_search.fast_inter_tx_type_search = 0;
......
......@@ -193,6 +193,11 @@ typedef enum {
#if CONFIG_EXT_TX
// eliminates two tx types in each direction
PRUNE_TWO = 2,
// adaptively prunes the least perspective tx types out of all 16
// (tuned to provide negligible quality loss)
PRUNE_2D_ACCURATE = 3,
// similar, but applies much more aggressive pruning to get better speed-up
PRUNE_2D_FAST = 4,
#endif
} TX_TYPE_PRUNE_MODE;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment