Commit 7c1f6d18 authored by Yue Chen's avatar Yue Chen

Refactor transform type-size search function

Decompose choose_tx_size_from_rd into three functions that determine
the transform coding rd at different stages. Besides the original
function, txfm_yrd() calculates the rd for fixed size and type.
choose_tx_size_fix_type() fixes the type and searches for the size.
It can enable other experiments to do restricted tx searches so as to
reduce the impact on speed.
Similar refactoring is done for select_tx_type_yrd() in VAR_TX.

Performance change in baseline is trivial:
0.014/0.001/-0.020 for lowres/midres/hdres.

Change-Id: I2ecbf6066329be088ec1bfb69013b657b14b8afe
parent cbfc15b1
......@@ -1311,6 +1311,179 @@ void vp10_txfm_rd_in_plane_supertx(MACROBLOCK *x,
}
#endif // CONFIG_SUPERTX
static int64_t txfm_yrd(VP10_COMP *cpi, MACROBLOCK *x,
int *r, int64_t *d, int *s, int64_t *sse,
int64_t ref_best_rd,
BLOCK_SIZE bs, TX_TYPE tx_type, int tx_size) {
VP10_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
int64_t rd = INT64_MAX;
vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
int s0, s1;
const TX_SIZE max_tx_size = max_txsize_lookup[bs];
const int tx_select = cm->tx_mode == TX_MODE_SELECT;
const int is_inter = is_inter_block(mbmi);
const int r_tx_size =
cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)][tx_size];
#if CONFIG_EXT_TX
int ext_tx_set;
#endif // CONFIG_EXT_TX
assert(skip_prob > 0);
s0 = vp10_cost_bit(skip_prob, 0);
s1 = vp10_cost_bit(skip_prob, 1);
mbmi->tx_type = tx_type;
mbmi->tx_size = tx_size;
txfm_rd_in_plane(x,
cpi,
r, d, s,
sse, ref_best_rd, 0, bs, tx_size,
cpi->sf.use_fast_coef_costing);
if (*r == INT_MAX)
return INT64_MAX;
#if CONFIG_EXT_TX
ext_tx_set = get_ext_tx_set(tx_size, bs, is_inter);
if (get_ext_tx_types(tx_size, bs, is_inter) > 1 &&
!xd->lossless[xd->mi[0]->mbmi.segment_id]) {
if (is_inter) {
if (ext_tx_set > 0)
*r += cpi->inter_tx_type_costs[ext_tx_set]
[mbmi->tx_size][mbmi->tx_type];
} else {
if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
*r += cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size]
[mbmi->mode][mbmi->tx_type];
}
}
#else
if (tx_size < TX_32X32 &&
!xd->lossless[xd->mi[0]->mbmi.segment_id] && !FIXED_TX_TYPE) {
if (is_inter) {
*r += cpi->inter_tx_type_costs[mbmi->tx_size][mbmi->tx_type];
} else {
*r += cpi->intra_tx_type_costs[mbmi->tx_size]
[intra_mode_to_tx_type_context[mbmi->mode]]
[mbmi->tx_type];
}
}
#endif // CONFIG_EXT_TX
if (*s) {
if (is_inter) {
rd = RDCOST(x->rdmult, x->rddiv, s1, *sse);
} else {
rd = RDCOST(x->rdmult, x->rddiv, s1 + r_tx_size * tx_select, *sse);
}
} else {
rd = RDCOST(x->rdmult, x->rddiv, *r + s0 + r_tx_size * tx_select, *d);
}
if (tx_select && !(*s && is_inter))
*r += r_tx_size;
if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !(*s))
rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, *sse));
return rd;
}
static int64_t choose_tx_size_fix_type(VP10_COMP *cpi, MACROBLOCK *x,
int *rate,
int64_t *distortion,
int *skip,
int64_t *psse,
int64_t ref_best_rd,
BLOCK_SIZE bs, TX_TYPE tx_type,
int prune) {
VP10_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
int r, s;
int64_t d, sse;
int64_t rd = INT64_MAX;
int n;
int start_tx, end_tx;
int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
const TX_SIZE max_tx_size = max_txsize_lookup[bs];
TX_SIZE best_tx = max_tx_size;
const int tx_select = cm->tx_mode == TX_MODE_SELECT;
const int is_inter = is_inter_block(mbmi);
#if CONFIG_EXT_TX
int ext_tx_set;
#endif // CONFIG_EXT_TX
if (tx_select) {
start_tx = max_tx_size;
end_tx = 0;
} else {
const TX_SIZE chosen_tx_size =
VPXMIN(max_tx_size, tx_mode_to_biggest_tx_size[cm->tx_mode]);
start_tx = chosen_tx_size;
end_tx = chosen_tx_size;
}
*distortion = INT64_MAX;
*rate = INT_MAX;
*skip = 0;
*psse = INT64_MAX;
mbmi->tx_type = tx_type;
last_rd = INT64_MAX;
for (n = start_tx; n >= end_tx; --n) {
if (FIXED_TX_TYPE && tx_type != get_default_tx_type(0, xd, 0, n))
continue;
#if CONFIG_EXT_TX
ext_tx_set = get_ext_tx_set(n, bs, is_inter);
if (is_inter) {
if (!ext_tx_used_inter[ext_tx_set][tx_type])
continue;
if (cpi->sf.tx_type_search > 0) {
if (!do_tx_type_search(tx_type, prune))
continue;
}
} else {
if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
if (tx_type != intra_mode_to_tx_type_context[mbmi->mode])
continue;
}
if (!ext_tx_used_intra[ext_tx_set][tx_type])
continue;
}
#else // CONFIG_EXT_TX
if (n >= TX_32X32 && tx_type != DCT_DCT)
continue;
if (is_inter && cpi->sf.tx_type_search > 0 &&
!do_tx_type_search(tx_type, prune))
continue;
#endif // CONFIG_EXT_TX
rd = txfm_yrd(cpi, x, &r, &d, &s, &sse, ref_best_rd, bs, tx_type, n);
// Early termination in transform size search.
if (cpi->sf.tx_size_search_breakout &&
(rd == INT64_MAX ||
(s == 1 && tx_type != DCT_DCT && n < start_tx) ||
(n < (int) max_tx_size && rd > last_rd)))
break;
last_rd = rd;
if (rd < best_rd) {
best_tx = n;
best_rd = rd;
*distortion = d;
*rate = r;
*skip = s;
*psse = sse;
}
}
mbmi->tx_size = best_tx;
return best_rd;
}
static void choose_largest_tx_size(VP10_COMP *cpi, MACROBLOCK *x,
int *rate, int64_t *distortion,
int *skip, int64_t *sse,
......@@ -1464,155 +1637,36 @@ static void choose_tx_size_from_rd(VP10_COMP *cpi, MACROBLOCK *x,
int64_t *psse,
int64_t ref_best_rd,
BLOCK_SIZE bs) {
const TX_SIZE max_tx_size = max_txsize_lookup[bs];
VP10_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
int r, s;
int64_t d, sse;
int64_t rd = INT64_MAX;
int n;
int s0, s1;
int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
TX_SIZE best_tx = max_tx_size;
int start_tx, end_tx;
const int tx_select = cm->tx_mode == TX_MODE_SELECT;
int64_t best_rd = INT64_MAX;
TX_SIZE best_tx = max_txsize_lookup[bs];
const int is_inter = is_inter_block(mbmi);
TX_TYPE tx_type, best_tx_type = DCT_DCT;
int prune = 0;
#if CONFIG_EXT_TX
int ext_tx_set;
#endif // CONFIG_EXT_TX
if (is_inter && cpi->sf.tx_type_search > 0)
prune = prune_tx_types(cpi, bs, x, xd);
assert(skip_prob > 0);
s0 = vp10_cost_bit(skip_prob, 0);
s1 = vp10_cost_bit(skip_prob, 1);
if (tx_select) {
start_tx = max_tx_size;
end_tx = 0;
} else {
const TX_SIZE chosen_tx_size =
VPXMIN(max_tx_size, tx_mode_to_biggest_tx_size[cm->tx_mode]);
start_tx = chosen_tx_size;
end_tx = chosen_tx_size;
}
*distortion = INT64_MAX;
*rate = INT_MAX;
*skip = 0;
*psse = INT64_MAX;
for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
last_rd = INT64_MAX;
for (n = start_tx; n >= end_tx; --n) {
const int r_tx_size =
cpi->tx_size_cost[max_tx_size - TX_8X8][get_tx_size_context(xd)][n];
if (FIXED_TX_TYPE && tx_type != get_default_tx_type(0, xd, 0, n))
continue;
#if CONFIG_EXT_TX
ext_tx_set = get_ext_tx_set(n, bs, is_inter);
if (is_inter) {
if (!ext_tx_used_inter[ext_tx_set][tx_type])
continue;
if (cpi->sf.tx_type_search > 0) {
if (!do_tx_type_search(tx_type, prune))
continue;
}
} else {
if (!ALLOW_INTRA_EXT_TX && bs >= BLOCK_8X8) {
if (tx_type != intra_mode_to_tx_type_context[mbmi->mode])
continue;
}
if (!ext_tx_used_intra[ext_tx_set][tx_type])
continue;
}
mbmi->tx_type = tx_type;
txfm_rd_in_plane(x,
cpi,
&r, &d, &s,
&sse, ref_best_rd, 0, bs, n,
cpi->sf.use_fast_coef_costing);
if (get_ext_tx_types(n, bs, is_inter) > 1 &&
!xd->lossless[xd->mi[0]->mbmi.segment_id] &&
r != INT_MAX) {
if (is_inter) {
if (ext_tx_set > 0)
r += cpi->inter_tx_type_costs[ext_tx_set]
[mbmi->tx_size][mbmi->tx_type];
} else {
if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
r += cpi->intra_tx_type_costs[ext_tx_set][mbmi->tx_size]
[mbmi->mode][mbmi->tx_type];
}
}
#else // CONFIG_EXT_TX
if (n >= TX_32X32 && tx_type != DCT_DCT) {
continue;
}
mbmi->tx_type = tx_type;
txfm_rd_in_plane(x,
cpi,
&r, &d, &s,
&sse, ref_best_rd, 0, bs, n,
cpi->sf.use_fast_coef_costing);
if (n < TX_32X32 &&
!xd->lossless[xd->mi[0]->mbmi.segment_id] &&
r != INT_MAX && !FIXED_TX_TYPE) {
if (is_inter) {
r += cpi->inter_tx_type_costs[mbmi->tx_size][mbmi->tx_type];
if (cpi->sf.tx_type_search > 0 && !do_tx_type_search(tx_type, prune))
continue;
} else {
r += cpi->intra_tx_type_costs[mbmi->tx_size]
[intra_mode_to_tx_type_context[mbmi->mode]]
[mbmi->tx_type];
}
}
#endif // CONFIG_EXT_TX
if (r == INT_MAX)
continue;
if (s) {
if (is_inter) {
rd = RDCOST(x->rdmult, x->rddiv, s1, sse);
} else {
rd = RDCOST(x->rdmult, x->rddiv, s1 + r_tx_size * tx_select, sse);
}
} else {
rd = RDCOST(x->rdmult, x->rddiv, r + s0 + r_tx_size * tx_select, d);
}
if (tx_select && !(s && is_inter))
r += r_tx_size;
if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !s)
rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, sse));
// Early termination in transform size search.
if (cpi->sf.tx_size_search_breakout &&
(rd == INT64_MAX ||
(s == 1 && tx_type != DCT_DCT && n < start_tx) ||
(n < (int) max_tx_size && rd > last_rd)))
break;
last_rd = rd;
if (rd <
(is_inter && best_tx_type == DCT_DCT ? ext_tx_th : 1) *
best_rd) {
best_tx = n;
best_rd = rd;
*distortion = d;
*rate = r;
*skip = s;
*psse = sse;
best_tx_type = mbmi->tx_type;
}
rd = choose_tx_size_fix_type(cpi, x, &r, &d, &s, &sse, ref_best_rd, bs,
tx_type, prune);
if (rd < (is_inter && best_tx_type == DCT_DCT ? ext_tx_th : 1) * best_rd) {
best_rd = rd;
*distortion = d;
*rate = r;
*skip = s;
*psse = sse;
best_tx_type = tx_type;
best_tx = mbmi->tx_size;
}
}
......@@ -3102,21 +3156,75 @@ static void inter_block_yrd(const VP10_COMP *cpi, MACROBLOCK *x,
}
}
static int64_t select_tx_size_fix_type(const VP10_COMP *cpi, MACROBLOCK *x,
int *rate, int64_t *dist,
int *skippable,
int64_t *sse, BLOCK_SIZE bsize,
int64_t ref_best_rd, TX_TYPE tx_type) {
const VP10_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
const int is_inter = is_inter_block(mbmi);
#if CONFIG_EXT_TX
int ext_tx_set = get_ext_tx_set(max_tx_size, bsize, is_inter);
#endif // CONFIG_EXT_TX
vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
int s0 = vp10_cost_bit(skip_prob, 0);
int s1 = vp10_cost_bit(skip_prob, 1);
int64_t rd;
mbmi->tx_type = tx_type;
inter_block_yrd(cpi, x, rate, dist, skippable, sse, bsize, ref_best_rd);
if (*rate == INT_MAX)
return INT64_MAX;
#if CONFIG_EXT_TX
if (get_ext_tx_types(max_tx_size, bsize, is_inter) > 1 &&
!xd->lossless[xd->mi[0]->mbmi.segment_id]) {
if (is_inter) {
if (ext_tx_set > 0)
*rate += cpi->inter_tx_type_costs[ext_tx_set]
[max_tx_size][mbmi->tx_type];
} else {
if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
*rate += cpi->intra_tx_type_costs[ext_tx_set][max_tx_size]
[mbmi->mode][mbmi->tx_type];
}
}
#else // CONFIG_EXT_TX
if (max_tx_size < TX_32X32 && !xd->lossless[xd->mi[0]->mbmi.segment_id]) {
if (is_inter)
*rate += cpi->inter_tx_type_costs[max_tx_size][mbmi->tx_type];
else
*rate += cpi->intra_tx_type_costs[max_tx_size]
[intra_mode_to_tx_type_context[mbmi->mode]][mbmi->tx_type];
}
#endif // CONFIG_EXT_TX
if (*skippable)
rd = RDCOST(x->rdmult, x->rddiv, s1, *sse);
else
rd = RDCOST(x->rdmult, x->rddiv, *rate + s0, *dist);
if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !(*skippable))
rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, *sse));
return rd;
}
static void select_tx_type_yrd(const VP10_COMP *cpi, MACROBLOCK *x,
int *rate, int64_t *distortion, int *skippable,
int64_t *sse, BLOCK_SIZE bsize,
int64_t ref_best_rd) {
const TX_SIZE max_tx_size = max_txsize_lookup[bsize];
const VP10_COMMON *const cm = &cpi->common;
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
int64_t rd = INT64_MAX;
int64_t best_rd = INT64_MAX;
TX_TYPE tx_type, best_tx_type = DCT_DCT;
const int is_inter = is_inter_block(mbmi);
vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
int s0 = vp10_cost_bit(skip_prob, 0);
int s1 = vp10_cost_bit(skip_prob, 1);
TX_SIZE best_tx_size[MI_BLOCK_SIZE][MI_BLOCK_SIZE];
TX_SIZE best_tx = TX_SIZES;
uint8_t best_blk_skip[256];
......@@ -3156,59 +3264,15 @@ static void select_tx_type_yrd(const VP10_COMP *cpi, MACROBLOCK *x,
if (!ext_tx_used_intra[ext_tx_set][tx_type])
continue;
}
mbmi->tx_type = tx_type;
inter_block_yrd(cpi, x, &this_rate, &this_dist, &this_skip, &this_sse,
bsize, ref_best_rd);
if (get_ext_tx_types(max_tx_size, bsize, is_inter) > 1 &&
!xd->lossless[xd->mi[0]->mbmi.segment_id] &&
this_rate != INT_MAX) {
if (is_inter) {
if (ext_tx_set > 0)
this_rate += cpi->inter_tx_type_costs[ext_tx_set]
[max_tx_size][mbmi->tx_type];
} else {
if (ext_tx_set > 0 && ALLOW_INTRA_EXT_TX)
this_rate += cpi->intra_tx_type_costs[ext_tx_set][max_tx_size]
[mbmi->mode][mbmi->tx_type];
}
}
#else // CONFIG_EXT_TX
if (max_tx_size >= TX_32X32 && tx_type != DCT_DCT)
continue;
mbmi->tx_type = tx_type;
inter_block_yrd(cpi, x, &this_rate, &this_dist, &this_skip, &this_sse,
bsize, ref_best_rd);
if (max_tx_size < TX_32X32 &&
!xd->lossless[xd->mi[0]->mbmi.segment_id] &&
this_rate != INT_MAX) {
if (is_inter) {
this_rate += cpi->inter_tx_type_costs[max_tx_size][mbmi->tx_type];
if (cpi->sf.tx_type_search > 0 && !do_tx_type_search(tx_type, prune))
continue;
} else {
this_rate += cpi->intra_tx_type_costs[max_tx_size]
[intra_mode_to_tx_type_context[mbmi->mode]]
[mbmi->tx_type];
}
}
#endif // CONFIG_EXT_TX
if (this_rate == INT_MAX)
if (max_tx_size >= TX_32X32 && tx_type != DCT_DCT)
continue;
if (this_skip)
rd = RDCOST(x->rdmult, x->rddiv, s1, this_sse);
else
rd = RDCOST(x->rdmult, x->rddiv, this_rate + s0, this_dist);
if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !this_skip)
rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, this_sse));
if (is_inter && cpi->sf.tx_type_search > 0 &&
!do_tx_type_search(tx_type, prune))
continue;
#endif // CONFIG_EXT_TX
rd = select_tx_size_fix_type(cpi, x, &this_rate, &this_dist, &this_skip,
&this_sse, bsize, ref_best_rd, tx_type);
if (rd < (is_inter && best_tx_type == DCT_DCT ? ext_tx_th : 1) * best_rd) {
best_rd = rd;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment