Commit abd94510 authored by Monty Montgomery's avatar Monty Montgomery Committed by Christopher Montgomery

Add Daala TX to 4x8 and 8x4 transforms

Rectangular 4x8 and 8x4 will now use Daala TX when CONFIG_DAALA_TX4 and
CONFIG_DAALA_TX8 are both enabled.

Change-Id: I56659c3e98e4bbd5bd3591404f9ff72120b33d6f
parent 1eac584f
......@@ -839,6 +839,26 @@ void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
assert(tx_type == DCT_DCT);
#endif
static const transform_2d IHT_4x8[] = {
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
{ daala_idct8, daala_idct4 }, // DCT_DCT = 0
{ daala_idst8, daala_idct4 }, // ADST_DCT = 1
{ daala_idct8, daala_idst4 }, // DCT_ADST = 2
{ daala_idst8, daala_idst4 }, // ADST_ADST = 3
#if CONFIG_EXT_TX
{ daala_idst8, daala_idct4 }, // FLIPADST_DCT
{ daala_idct8, daala_idst4 }, // DCT_FLIPADST
{ daala_idst8, daala_idst4 }, // FLIPADST_FLIPADST
{ daala_idst8, daala_idst4 }, // ADST_FLIPADST
{ daala_idst8, daala_idst4 }, // FLIPADST_ADST
{ daala_idtx8, daala_idtx4 }, // IDTX
{ daala_idct8, daala_idtx4 }, // V_DCT
{ daala_idtx8, daala_idct4 }, // H_DCT
{ daala_idst8, daala_idtx4 }, // V_ADST
{ daala_idtx8, daala_idst4 }, // H_ADST
{ daala_idst8, daala_idtx4 }, // V_FLIPADST
{ daala_idtx8, daala_idst4 }, // H_FLIPADST
#endif
#else
{ aom_idct8_c, aom_idct4_c }, // DCT_DCT
{ aom_iadst8_c, aom_idct4_c }, // ADST_DCT
{ aom_idct8_c, aom_iadst4_c }, // DCT_ADST
......@@ -856,6 +876,7 @@ void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
{ iidtx8_c, aom_iadst4_c }, // H_ADST
{ aom_iadst8_c, iidtx4_c }, // V_FLIPADST
{ iidtx8_c, aom_iadst4_c }, // H_FLIPADST
#endif
#endif
};
......@@ -873,20 +894,50 @@ void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// Multi-way scaling matrix (bits):
// LGT/AV1 row,col input+0, rowTX+.5, mid+.5, colTX+1, out-5 == -3
// LGT row, Daala col input+0, rowTX+.5, mid+.5, colTX+0, out-4 == -3
// Daala row, LGT col input+1, rowTX+0, mid+0, colTX+1, out-5 == -3
// Daala row,col input+1, rowTX+0, mid+0, colTX+0, out-4 == -3
// inverse transform row vectors and transpose
for (i = 0; i < n2; ++i) {
#if CONFIG_LGT
if (use_lgt_row)
if (use_lgt_row) {
// Scaling cases 1 and 2 above
// No input scaling
// Row transform (LGT; scales up .5 bits)
ilgt4(input, outtmp, lgtmtx_row[0]);
else
// Transpose and mid scaling up by .5 bit
for (j = 0; j < n; ++j)
tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
} else {
#endif
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
// Daala row transform; Scaling cases 3 and 4 above
tran_low_t temp_in[4];
// Input scaling up by 1 bit
for (j = 0; j < n; j++) temp_in[j] = input[j] * 2;
// Row transform; Daala does not scale
IHT_4x8[tx_type].rows(temp_in, outtmp);
// Transpose; no mid scaling
for (j = 0; j < n; ++j) tmp[j][i] = outtmp[j];
#else
// AV1 row transform; Scaling case 1 only
// Row transform (AV1 scales up .5 bits)
IHT_4x8[tx_type].rows(input, outtmp);
// Transpose and mid scaling up by .5 bit
for (j = 0; j < n; ++j)
tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
#endif
#if CONFIG_LGT
}
#endif
input += n;
}
// inverse transform column vectors
// AV1/LGT column TX scales up by 1 bit, Daala does not scale
for (i = 0; i < n; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
......@@ -905,7 +956,19 @@ void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (j = 0; j < n; ++j) {
int d = i * stride + j;
int s = j * outstride + i;
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
#if CONFIG_LGT
if (use_lgt_col)
// Output Scaling cases 1, 3
dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
else
#endif
// Output scaling cases 2, 4
dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 4));
#else
// Output scaling case 1 only
dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
#endif
}
}
}
......@@ -920,6 +983,26 @@ void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
assert(tx_type == DCT_DCT);
#endif
static const transform_2d IHT_8x4[] = {
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
{ daala_idct4, daala_idct8 }, // DCT_DCT = 0
{ daala_idst4, daala_idct8 }, // ADST_DCT = 1
{ daala_idct4, daala_idst8 }, // DCT_ADST = 2
{ daala_idst4, daala_idst8 }, // ADST_ADST = 3
#if CONFIG_EXT_TX
{ daala_idst4, daala_idct8 }, // FLIPADST_DCT
{ daala_idct4, daala_idst8 }, // DCT_FLIPADST
{ daala_idst4, daala_idst8 }, // FLIPADST_FLIPADST
{ daala_idst4, daala_idst8 }, // ADST_FLIPADST
{ daala_idst4, daala_idst8 }, // FLIPADST_ADST
{ daala_idtx4, daala_idtx8 }, // IDTX
{ daala_idct4, daala_idtx8 }, // V_DCT
{ daala_idtx4, daala_idct8 }, // H_DCT
{ daala_idst4, daala_idtx8 }, // V_ADST
{ daala_idtx4, daala_idst8 }, // H_ADST
{ daala_idst4, daala_idtx8 }, // V_FLIPADST
{ daala_idtx4, daala_idst8 }, // H_FLIPADST
#endif
#else
{ aom_idct4_c, aom_idct8_c }, // DCT_DCT
{ aom_iadst4_c, aom_idct8_c }, // ADST_DCT
{ aom_idct4_c, aom_iadst8_c }, // DCT_ADST
......@@ -937,6 +1020,7 @@ void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
{ iidtx4_c, aom_iadst8_c }, // H_ADST
{ aom_iadst4_c, iidtx8_c }, // V_FLIPADST
{ iidtx4_c, aom_iadst8_c }, // H_FLIPADST
#endif
#endif
};
......@@ -955,20 +1039,50 @@ void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
#endif
// Multi-way scaling matrix (bits):
// LGT/AV1 row,col input+0, rowTX+1, mid+.5, colTX+.5, out-5 == -3
// LGT row, Daala col input+0, rowTX+1, mid+.5, colTX+.5, out-4 == -3
// Daala row, LGT col input+1, rowTX+0, mid+0, colTX+1, out-5 == -3
// Daala row,col input+1, rowTX+0, mid+0, colTX+0, out-4 == -3
// inverse transform row vectors and transpose
for (i = 0; i < n; ++i) {
#if CONFIG_LGT
if (use_lgt_row)
if (use_lgt_row) {
// Scaling cases 1 and 2 above
// No input scaling
// Row transform (LGT; scales up 1 bit)
ilgt8(input, outtmp, lgtmtx_row[0]);
else
// Transpose and mid scaling up by .5 bit
for (j = 0; j < n2; ++j)
tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
} else {
#endif
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
// Daala row transform; Scaling cases 3 and 4 above
tran_low_t temp_in[8];
// Input scaling up by 1 bit
for (j = 0; j < n2; j++) temp_in[j] = input[j] * 2;
// Row transform; Daala does not scale
IHT_8x4[tx_type].rows(temp_in, outtmp);
// Transpose; no mid scaling
for (j = 0; j < n2; ++j) tmp[j][i] = outtmp[j];
#else
// AV1 row transform; Scaling case 1 only
// Row transform (AV1 scales up 1 bit)
IHT_8x4[tx_type].rows(input, outtmp);
// Transpose and mid scaling up by .5 bit
for (j = 0; j < n2; ++j)
tmp[j][i] = (tran_low_t)dct_const_round_shift(outtmp[j] * Sqrt2);
#endif
#if CONFIG_LGT
}
#endif
input += n2;
}
// inverse transform column vectors
// AV1 and LGT scale up by .5 bits; Daala does not scale
for (i = 0; i < n2; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
......@@ -987,7 +1101,19 @@ void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (j = 0; j < n2; ++j) {
int d = i * stride + j;
int s = j * outstride + i;
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
#if CONFIG_LGT
if (use_lgt_col)
// Output scaling cases 1, 3
dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
else
#endif
// Output scaling cases 2, 4
dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 4));
#else
// Output scaling case 1
dest[d] = clip_pixel_add(dest[d], ROUND_POWER_OF_TWO(outp[s], 5));
#endif
}
}
}
......@@ -2297,7 +2423,7 @@ static void inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest, int stride,
static void inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
#if CONFIG_LGT
#if CONFIG_LGT || (CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8)
av1_iht4x8_32_add_c(input, dest, stride, txfm_param);
#else
av1_iht4x8_32_add(input, dest, stride, txfm_param);
......@@ -2306,7 +2432,7 @@ static void inv_txfm_add_4x8(const tran_low_t *input, uint8_t *dest, int stride,
static void inv_txfm_add_8x4(const tran_low_t *input, uint8_t *dest, int stride,
const TxfmParam *txfm_param) {
#if CONFIG_LGT
#if CONFIG_LGT || (CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8)
av1_iht8x4_32_add_c(input, dest, stride, txfm_param);
#else
av1_iht8x4_32_add(input, dest, stride, txfm_param);
......
......@@ -1502,6 +1502,26 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
assert(tx_type == DCT_DCT);
#endif
static const transform_2d FHT[] = {
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
{ daala_fdct8, daala_fdct4 }, // DCT_DCT
{ daala_fdst8, daala_fdct4 }, // ADST_DCT
{ daala_fdct8, daala_fdst4 }, // DCT_ADST
{ daala_fdst8, daala_fdst4 }, // ADST_ADST
#if CONFIG_EXT_TX
{ daala_fdst8, daala_fdct4 }, // FLIPADST_DCT
{ daala_fdct8, daala_fdst4 }, // DCT_FLIPADST
{ daala_fdst8, daala_fdst4 }, // FLIPADST_FLIPADST
{ daala_fdst8, daala_fdst4 }, // ADST_FLIPADST
{ daala_fdst8, daala_fdst4 }, // FLIPADST_ADST
{ daala_idtx8, daala_idtx4 }, // IDTX
{ daala_fdct8, daala_idtx4 }, // V_DCT
{ daala_idtx8, daala_fdct4 }, // H_DCT
{ daala_fdst8, daala_idtx4 }, // V_ADST
{ daala_idtx8, daala_fdst4 }, // H_ADST
{ daala_fdst8, daala_idtx4 }, // V_FLIPADST
{ daala_idtx8, daala_fdst4 }, // H_FLIPADST
#endif
#else
{ fdct8, fdct4 }, // DCT_DCT
{ fadst8, fdct4 }, // ADST_DCT
{ fdct8, fadst4 }, // DCT_ADST
......@@ -1519,6 +1539,7 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
{ fidtx8, fadst4 }, // H_ADST
{ fadst8, fidtx4 }, // V_FLIPADST
{ fidtx8, fadst4 }, // H_FLIPADST
#endif
#endif
};
const transform_2d ht = FHT[tx_type];
......@@ -1539,29 +1560,55 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// Multi-way scaling matrix (bits):
// LGT/AV1 row,col input+2.5, rowTX+.5, mid+0, colTX+1, out-1 == 3
// LGT row, Daala col input+3.5, rowTX+.5, mid+0, colTX+0, out-1 == 3
// Daala row, LGT col input+3, rowTX+0, mid+0, colTX+1, out-1 == 3
// Daala row,col input+4, rowTX+0, mid+0, colTX+0, out-1 == 3
// Rows
for (i = 0; i < n2; ++i) {
for (j = 0; j < n; ++j)
// Input scaling
for (j = 0; j < n; ++j) {
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
#if CONFIG_LGT
// Input scaling when LGT might be active (1-4 above)
temp_in[j] = use_lgt_row ?
(tran_low_t)fdct_round_shift(input[i * stride + j] * Sqrt2 *
(use_lgt_col ? 4 : 8)) :
input[i * stride + j] * (use_lgt_col ? 8 : 16));
#else
// Input scaling when LGT is not possible, Daala only (4 above)
temp_in[j] = input[i * stride + j] * 16;
#endif
#else
// Input scaling when Daala is not possible, LGT/AV1 only (1 above)
temp_in[j] =
(tran_low_t)fdct_round_shift(input[i * stride + j] * 4 * Sqrt2);
#endif
}
// Row transform (AV1/LGT scale up .5 bit, Daala does not scale)
#if CONFIG_LGT
if (use_lgt_row)
flgt4(temp_in, temp_out, lgtmtx_row[0]);
else
#endif
ht.rows(temp_in, temp_out);
// No mid scaling
for (j = 0; j < n; ++j) out[j * n2 + i] = temp_out[j];
}
// Columns
for (i = 0; i < n; ++i) {
for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
// Column transform (AV1/LGT scale up 1 bit, Daala does not scale)
#if CONFIG_LGT
if (use_lgt_col)
flgt8(temp_in, temp_out, lgtmtx_col[0]);
else
#endif
ht.cols(temp_in, temp_out);
// Output scaling is always a downshift of 1
for (j = 0; j < n2; ++j)
output[i + j * n] = (temp_out[j] + (temp_out[j] < 0)) >> 1;
}
......@@ -1578,6 +1625,26 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
assert(tx_type == DCT_DCT);
#endif
static const transform_2d FHT[] = {
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
{ daala_fdct4, daala_fdct8 }, // DCT_DCT
{ daala_fdst4, daala_fdct8 }, // ADST_DCT
{ daala_fdct4, daala_fdst8 }, // DCT_ADST
{ daala_fdst4, daala_fdst8 }, // ADST_ADST
#if CONFIG_EXT_TX
{ daala_fdst4, daala_fdct8 }, // FLIPADST_DCT
{ daala_fdct4, daala_fdst8 }, // DCT_FLIPADST
{ daala_fdst4, daala_fdst8 }, // FLIPADST_FLIPADST
{ daala_fdst4, daala_fdst8 }, // ADST_FLIPADST
{ daala_fdst4, daala_fdst8 }, // FLIPADST_ADST
{ daala_idtx4, daala_idtx8 }, // IDTX
{ daala_fdct4, daala_idtx8 }, // V_DCT
{ daala_idtx4, daala_fdct8 }, // H_DCT
{ daala_fdst4, daala_idtx8 }, // V_ADST
{ daala_idtx4, daala_fdst8 }, // H_ADST
{ daala_fdst4, daala_idtx8 }, // V_FLIPADST
{ daala_idtx4, daala_fdst8 }, // H_FLIPADST
#endif
#else
{ fdct4, fdct8 }, // DCT_DCT
{ fadst4, fdct8 }, // ADST_DCT
{ fdct4, fadst8 }, // DCT_ADST
......@@ -1595,6 +1662,7 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
{ fidtx4, fadst8 }, // H_ADST
{ fadst4, fidtx8 }, // V_FLIPADST
{ fidtx4, fadst8 }, // H_FLIPADST
#endif
#endif
};
const transform_2d ht = FHT[tx_type];
......@@ -1615,29 +1683,54 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
#endif
// Multi-way scaling matrix (bits):
// LGT/AV1 row,col input+2.5, rowTX+1, mid+0, colTX+.5, out-1 == 3
// LGT row, Daala col input+3, rowTX+1, mid+0, colTX+0, out-1 == 3
// Daala row, LGT col input+3.5 rowTX+0, mid+0, colTX+.5, out-1 == 3
// Daala row,col input+4, rowTX+0, mid+0, colTX+0, out-1 == 3
// Columns
for (i = 0; i < n2; ++i) {
for (j = 0; j < n; ++j)
for (j = 0; j < n; ++j) {
#if CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8
#if CONFIG_LGT
// Input scaling when LGT might be active (1-4 above)
temp_in[j] = use_lgt_col ?
(tran_low_t)fdct_round_shift(input[j * stride + i] * Sqrt2 *
(use_lgt_row ? 4 : 8)) :
input[j * stride + i] * (use_lgt_row ? 8 : 16));
#else
// Input scaling when LGT is not possible, Daala only (4 above)
temp_in[j] = input[j * stride + i] * 16;
#endif
#else
// Input scaling when Daala is not possible, AV1/LGT only (1 above)
temp_in[j] =
(tran_low_t)fdct_round_shift(input[j * stride + i] * 4 * Sqrt2);
#endif
}
// Column transform (AV1/LGT scale up .5 bit, Daala does not scale)
#if CONFIG_LGT
if (use_lgt_col)
flgt4(temp_in, temp_out, lgtmtx_col[0]);
else
#endif
ht.cols(temp_in, temp_out);
// No scaling between transforms
for (j = 0; j < n; ++j) out[j * n2 + i] = temp_out[j];
}
// Rows
for (i = 0; i < n; ++i) {
for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
// Row transform (AV1/LGT scale up 1 bit, Daala does not scale)
#if CONFIG_LGT
if (use_lgt_row)
flgt8(temp_in, temp_out, lgtmtx_row[0]);
else
#endif
ht.rows(temp_in, temp_out);
// Output scaling is always a downshift of 1
for (j = 0; j < n2; ++j)
output[j + i * n2] = (temp_out[j] + (temp_out[j] < 0)) >> 1;
}
......
......@@ -34,7 +34,7 @@ static void fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff,
static void fwd_txfm_4x8(const int16_t *src_diff, tran_low_t *coeff,
int diff_stride, TxfmParam *txfm_param) {
#if CONFIG_LGT
#if CONFIG_LGT || (CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8)
av1_fht4x8_c(src_diff, coeff, diff_stride, txfm_param);
#else
av1_fht4x8(src_diff, coeff, diff_stride, txfm_param);
......@@ -43,7 +43,7 @@ static void fwd_txfm_4x8(const int16_t *src_diff, tran_low_t *coeff,
static void fwd_txfm_8x4(const int16_t *src_diff, tran_low_t *coeff,
int diff_stride, TxfmParam *txfm_param) {
#if CONFIG_LGT
#if CONFIG_LGT || (CONFIG_DAALA_TX4 && CONFIG_DAALA_TX8)
av1_fht8x4_c(src_diff, coeff, diff_stride, txfm_param);
#else
av1_fht8x4(src_diff, coeff, diff_stride, txfm_param);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment