Commit c5ccd4ca authored by Sarah Parker's avatar Sarah Parker

Avoid using MRC_DCT when the mask produced is invalid

If the mask is invalid, do not allow the encoder to select MRC_DCT.
Currently the mask is invalid if it is all 1 or all 0, but these
criteria will likely expand in a future patch.

Change-Id: I77230ea8357bfdb2bf1e6338903d44bbf1db22d1
parent 2b5907eb
......@@ -31,6 +31,9 @@ typedef struct txfm_param {
int is_inter;
int stride;
uint8_t *dst;
#if CONFIG_MRC_TX
int *valid_mask;
#endif // CONFIG_MRC_TX
#endif // CONFIG_MRC_TX || CONFIG_LGT
#if CONFIG_LGT
int mode;
......
......@@ -210,12 +210,20 @@ static INLINE void set_flip_cfg(int tx_type, TXFM_2D_FLIP_CFG *cfg) {
}
#if CONFIG_MRC_TX
static INLINE void get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
int mask_stride, int width, int height) {
static INLINE int get_mrc_mask(const uint8_t *pred, int pred_stride, int *mask,
int mask_stride, int width, int height) {
int n_masked_vals = 0;
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j)
for (int j = 0; j < width; ++j) {
mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
n_masked_vals += mask[i * mask_stride + j];
}
}
return n_masked_vals;
}
static INLINE int is_valid_mrc_mask(int n_masked_vals, int width, int height) {
return !(n_masked_vals == 0 || n_masked_vals == (width * height));
}
#endif // CONFIG_MRC_TX
......
......@@ -377,6 +377,10 @@ typedef struct MB_MODE_INFO {
#endif // CONFIG_SUPERTX
int8_t seg_id_predicted; // valid only when temporal_update is enabled
#if CONFIG_MRC_TX
int valid_mrc_mask;
#endif // CONFIG_MRC_TX
// Only for INTRA blocks
UV_PREDICTION_MODE uv_mode;
#if CONFIG_PALETTE
......
......@@ -1495,8 +1495,12 @@ static void imrc32x32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
if (eob == 1) {
aom_idct32x32_1_add_c(input, dest, stride);
} else {
tran_low_t mask[32 * 32];
get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
int mask[32 * 32];
int n_masked_vals =
get_mrc_mask(txfm_param->dst, txfm_param->stride, mask, 32, 32, 32);
if (!is_valid_mrc_mask(n_masked_vals, 32, 32))
assert(0 && "Invalid MRC mask");
if (eob <= quarter)
// non-zero coeff only in upper-left 8x8
aom_imrc32x32_34_add_c(input, dest, stride, mask);
......
......@@ -1642,6 +1642,11 @@ void av1_write_tx_type(const AV1_COMMON *const cm, const MACROBLOCKD *xd,
!supertx_enabled &&
#endif // CONFIG_SUPERTX
!segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
#if CONFIG_MRC_TX
if (tx_type == MRC_DCT)
assert(mbmi->valid_mrc_mask && "Invalid MRC mask");
#endif // CONFIG_MRC_TX
const int eset =
get_ext_tx_set(tx_size, bsize, is_inter, cm->reduced_tx_set_used);
// eset == 0 should correspond to a set with only DCT_DCT and there
......
......@@ -1092,9 +1092,14 @@ static void fhalfright32(const tran_low_t *input, tran_low_t *output) {
#if CONFIG_MRC_TX
static void get_masked_residual32(const int16_t **input, int *input_stride,
const uint8_t *pred, int pred_stride,
int16_t *masked_input) {
int16_t *masked_input, int *valid_mask) {
int mrc_mask[32 * 32];
get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
int n_masked_vals = get_mrc_mask(pred, pred_stride, mrc_mask, 32, 32, 32);
// Do not use MRC_DCT if mask is invalid. DCT_DCT will be used instead.
if (!is_valid_mrc_mask(n_masked_vals, 32, 32)) {
*valid_mask = 0;
return;
}
int32_t sum = 0;
int16_t avg;
// Get the masked average of the prediction
......@@ -1103,7 +1108,7 @@ static void get_masked_residual32(const int16_t **input, int *input_stride,
sum += mrc_mask[i * 32 + j] * (*input)[i * (*input_stride) + j];
}
}
avg = ROUND_POWER_OF_TWO_SIGNED(sum, 10);
avg = sum / n_masked_vals;
// Replace all of the unmasked pixels in the prediction with the average
// of the masked pixels
for (int i = 0; i < 32; ++i) {
......@@ -1113,6 +1118,7 @@ static void get_masked_residual32(const int16_t **input, int *input_stride,
}
*input = masked_input;
*input_stride = 32;
*valid_mask = 1;
}
#endif // CONFIG_MRC_TX
......@@ -2464,7 +2470,7 @@ void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride,
if (tx_type == MRC_DCT) {
int16_t masked_input[32 * 32];
get_masked_residual32(&input, &stride, txfm_param->dst, txfm_param->stride,
masked_input);
masked_input, txfm_param->valid_mask);
}
#endif // CONFIG_MRC_TX
......
......@@ -600,6 +600,9 @@ void av1_xform_quant(const AV1_COMMON *cm, MACROBLOCK *x, int plane, int block,
txfm_param.is_inter = is_inter_block(mbmi);
txfm_param.dst = dst;
txfm_param.stride = dst_stride;
#if CONFIG_MRC_TX
txfm_param.valid_mask = &mbmi->valid_mrc_mask;
#endif // CONFIG_MRC_TX
#endif // CONFIG_MRC_TX || CONFIG_LGT
#if CONFIG_LGT
txfm_param.mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
......
......@@ -2159,6 +2159,13 @@ static void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
}
#endif // DISABLE_TRELLISQ_SEARCH
#if CONFIG_MRC_TX
if (mbmi->tx_type == MRC_DCT && !mbmi->valid_mrc_mask) {
args->exit_early = 1;
return;
}
#endif // CONFIG_MRC_TX
if (!is_inter_block(mbmi)) {
struct macroblock_plane *const p = &x->plane[plane];
av1_inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment