diff --git a/vp10/common/idct.c b/vp10/common/idct.c index a941f64c86ed4aabacd9fa3939a9ddff43649200..f621ec61b5bfbeaeacda8df512f97d29f4d51d1f 100644 --- a/vp10/common/idct.c +++ b/vp10/common/idct.c @@ -700,11 +700,78 @@ void highbd_idst16_c(const tran_low_t *input, tran_low_t *output, int bd) { } static void highbd_inv_idtx_add_c(const tran_low_t *input, uint8_t *dest8, - int stride, int bs, int bd) { + int stride, int bs, int tx_type, int bd) { int r, c; const int shift = bs < 32 ? 3 : 2; uint16_t *dest = CONVERT_TO_SHORTPTR(dest8); + tran_low_t temp_in[32], temp_out[32]; + highbd_transform_2d ht = {vpx_highbd_idct4_c, vpx_highbd_idct4_c}; + int out_scale = 1; + int coeff_stride = 0; + + switch (bs) { + case 4: + ht.cols = vpx_highbd_idct4_c; + ht.rows = vpx_highbd_idct4_c; + out_scale = cospi_16_64 >> 3; + coeff_stride = 4; + break; + case 8: + ht.cols = vpx_highbd_idct8_c; + ht.rows = vpx_highbd_idct8_c; + out_scale = (1 << (DCT_CONST_BITS - 4)); + coeff_stride = 8; + break; + case 16: + ht.cols = vpx_highbd_idct16_c; + ht.rows = vpx_highbd_idct16_c; + out_scale = cospi_16_64 >> 4; + coeff_stride = 16; + break; + case 32: + ht.cols = vpx_highbd_idct32_c; + ht.rows = vpx_highbd_idct32_c; + out_scale = (1 << (DCT_CONST_BITS - 4)); + coeff_stride = 32; + break; + default: + assert(0); + } + + // Columns + if (tx_type == V_DCT) { + for (c = 0; c < bs; ++c) { + for (r = 0; r < bs; ++r) + temp_in[r] = input[r * coeff_stride + c]; + ht.cols(temp_in, temp_out, bd); + + for (r = 0; r < bs; ++r) { + tran_high_t temp = (tran_high_t)temp_out[r] * out_scale; + temp >>= DCT_CONST_BITS; + dest[r * stride + c] = highbd_clip_pixel_add(dest[r * stride + c], + (tran_low_t)temp, bd); + } + } + return; + } + + if (tx_type == H_DCT) { + for (r = 0; r < bs; ++r) { + for (c = 0; c < bs; ++c) + temp_in[c] = input[r * coeff_stride + c]; + ht.rows(temp_in, temp_out, bd); + + for (c = 0; c < bs; ++c) { + tran_high_t temp = (tran_high_t)temp_out[c] * out_scale; + temp >>= DCT_CONST_BITS; + dest[r * stride + c] = highbd_clip_pixel_add(dest[r * stride + c], + (tran_low_t)temp, bd); + } + } + return; + } + for (r = 0; r < bs; ++r) { for (c = 0; c < bs; ++c) dest[c] = highbd_clip_pixel_add(dest[c], input[c] >> shift, bd); @@ -1593,8 +1660,10 @@ void vp10_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest, // Use C version since DST only exists in C code vp10_highbd_iht4x4_16_add_c(input, dest, stride, tx_type, bd); break; + case H_DCT: + case V_DCT: case IDTX: - highbd_inv_idtx_add_c(input, dest, stride, 4, bd); + highbd_inv_idtx_add_c(input, dest, stride, 4, tx_type, bd); break; #endif // CONFIG_EXT_TX default: @@ -1633,8 +1702,10 @@ void vp10_highbd_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest, // Use C version since DST only exists in C code vp10_highbd_iht8x8_64_add_c(input, dest, stride, tx_type, bd); break; + case H_DCT: + case V_DCT: case IDTX: - highbd_inv_idtx_add_c(input, dest, stride, 8, bd); + highbd_inv_idtx_add_c(input, dest, stride, 8, tx_type, bd); break; #endif // CONFIG_EXT_TX default: @@ -1673,8 +1744,10 @@ void vp10_highbd_inv_txfm_add_16x16(const tran_low_t *input, uint8_t *dest, // Use C version since DST only exists in C code vp10_highbd_iht16x16_256_add_c(input, dest, stride, tx_type, bd); break; + case H_DCT: + case V_DCT: case IDTX: - highbd_inv_idtx_add_c(input, dest, stride, 16, bd); + highbd_inv_idtx_add_c(input, dest, stride, 16, tx_type, bd); break; #endif // CONFIG_EXT_TX default: @@ -1708,8 +1781,10 @@ void vp10_highbd_inv_txfm_add_32x32(const tran_low_t *input, uint8_t *dest, case DST_FLIPADST: vp10_highbd_iht32x32_1024_add_c(input, dest, stride, tx_type, bd); break; + case H_DCT: + case V_DCT: case IDTX: - highbd_inv_idtx_add_c(input, dest, stride, 32, bd); + highbd_inv_idtx_add_c(input, dest, stride, 32, tx_type, bd); break; #endif // CONFIG_EXT_TX default: diff --git a/vp10/encoder/hybrid_fwd_txfm.c b/vp10/encoder/hybrid_fwd_txfm.c index 029240f7166d44b4146f616199c3696f6e781c67..ee0ca8c55c4aba0d861b2a66d220b124e02d74d8 100644 --- a/vp10/encoder/hybrid_fwd_txfm.c +++ b/vp10/encoder/hybrid_fwd_txfm.c @@ -232,6 +232,8 @@ void vp10_highbd_fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff, // Use C version since DST exists only in C vp10_highbd_fht4x4_c(src_diff, coeff, diff_stride, tx_type); break; + case H_DCT: + case V_DCT: case IDTX: vp10_fwd_idtx_c(src_diff, coeff, diff_stride, 4, tx_type); break; @@ -274,6 +276,8 @@ static void highbd_fwd_txfm_8x8(const int16_t *src_diff, tran_low_t *coeff, // Use C version since DST exists only in C vp10_highbd_fht8x8_c(src_diff, coeff, diff_stride, tx_type); break; + case H_DCT: + case V_DCT: case IDTX: vp10_fwd_idtx_c(src_diff, coeff, diff_stride, 8, tx_type); break; @@ -316,6 +320,8 @@ static void highbd_fwd_txfm_16x16(const int16_t *src_diff, tran_low_t *coeff, // Use C version since DST exists only in C vp10_highbd_fht16x16_c(src_diff, coeff, diff_stride, tx_type); break; + case H_DCT: + case V_DCT: case IDTX: vp10_fwd_idtx_c(src_diff, coeff, diff_stride, 16, tx_type); break; @@ -354,6 +360,8 @@ static void highbd_fwd_txfm_32x32(int rd_transform, const int16_t *src_diff, case FLIPADST_DST: vp10_highbd_fht32x32_c(src_diff, coeff, diff_stride, tx_type); break; + case H_DCT: + case V_DCT: case IDTX: vp10_fwd_idtx_c(src_diff, coeff, diff_stride, 32, tx_type); break;