Commit 918fe698 authored by Lester Lu's avatar Lester Lu Committed by Debargha Mukherjee

Refactor lgt

Change get_lgt in order to integrate a later experiment
lgt_from_pred with lgt. There are two main changes.

The main purpose for this change is to unify get_fwd_lgt and
get_inv_lgt functions into a get_lgt function so the lgt basis
functions can always be selected through the same function in
both forward and inverse transform paths. The structure of those
functions will also be consistent with the get_lgt_from_pred
functions that will be added in the lgt-from-pred experiment.

These changes have no impact on the bitstream.

Change-Id: Ifd3dfc1a9e1a250495830ddbf42c201e80aa913e
parent 8f661605
......@@ -31,13 +31,13 @@ typedef struct txfm_param {
int is_inter;
int stride;
uint8_t *dst;
#if CONFIG_LGT
int mode;
#endif
#if CONFIG_MRC_TX
int *valid_mask;
#endif // CONFIG_MRC_TX
#endif // CONFIG_MRC_TX || CONFIG_LGT
#if CONFIG_LGT
int mode;
#endif
// for inverse transforms only
#if CONFIG_ADAPT_SCAN
const int16_t *eob_threshold;
......@@ -97,9 +97,10 @@ static INLINE tran_high_t fdct_round_shift(tran_high_t input) {
}
#if CONFIG_LGT
// The Line Graph Transforms (LGTs) matrices are written as follows.
// Each 2D array is 16384 times an LGT matrix, which is the matrix of
// eigenvectors of the graph Laplacian matrices for the line graph.
/* The Line Graph Transforms (LGTs) matrices are defined as follows.
* Each 2D array is sqrt(2)*16384 times an LGT matrix, which is the
* matrix of eigenvectors of the graph Laplacian matrix of the associated
* line graph. */
// LGT4 name: lgt4_140
// Self loops: 1.400, 0.000, 0.000, 0.000
......
......@@ -1235,19 +1235,6 @@ static INLINE int av1_raster_order_to_block_index(TX_SIZE tx_size,
return (tx_size == TX_4X4) ? raster_order : (raster_order > 0) ? 2 : 0;
}
#if CONFIG_LGT
static INLINE PREDICTION_MODE get_prediction_mode(const MODE_INFO *mi,
int plane, TX_SIZE tx_size,
int block_idx) {
const MB_MODE_INFO *const mbmi = &mi->mbmi;
if (is_inter_block(mbmi)) return mbmi->mode;
int block_raster_idx = av1_block_index_to_raster_order(tx_size, block_idx);
return (plane == PLANE_TYPE_Y) ? get_y_mode(mi, block_raster_idx)
: get_uv_mode(mbmi->uv_mode);
}
#endif // CONFIG_LGT
static INLINE TX_TYPE get_default_tx_type(PLANE_TYPE plane_type,
const MACROBLOCKD *xd, int block_idx,
TX_SIZE tx_size) {
......
......@@ -265,12 +265,8 @@ static void highbd_inv_idtx_add_c(const tran_low_t *input, uint8_t *dest8,
#if CONFIG_LGT
void ilgt4(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
if (!(input[0] | input[1] | input[2] | input[3])) {
output[0] = output[1] = output[2] = output[3] = 0;
return;
}
// evaluate s[j] = sum of all lgtmtx[i][j]*input[i] over i=1,...,4
if (!lgtmtx) assert(0);
// evaluate s[j] = sum of all lgtmtx[j]*input[i] over i=1,...,4
tran_high_t s[4] = { 0 };
for (int i = 0; i < 4; ++i)
for (int j = 0; j < 4; ++j) s[j] += lgtmtx[i * 4 + j] * input[i];
......@@ -280,7 +276,8 @@ void ilgt4(const tran_low_t *input, tran_low_t *output,
void ilgt8(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
// evaluate s[j] = sum of all lgtmtx[i][j]*input[i] over i=1,...,8
if (!lgtmtx) assert(0);
// evaluate s[j] = sum of all lgtmtx[j]*input[i] over i=1,...,8
tran_high_t s[8] = { 0 };
for (int i = 0; i < 8; ++i)
for (int j = 0; j < 8; ++j) s[j] += lgtmtx[i * 8 + j] * input[i];
......@@ -288,26 +285,35 @@ void ilgt8(const tran_low_t *input, tran_low_t *output,
for (int i = 0; i < 8; ++i) output[i] = WRAPLOW(dct_const_round_shift(s[i]));
}
// The get_inv_lgt functions return 1 if LGT is chosen to apply, and 0 otherwise
int get_inv_lgt4(transform_1d tx_orig, const TxfmParam *txfm_param,
const tran_high_t *lgtmtx[], int ntx) {
// inter/intra split
if (tx_orig == &aom_iadst4_c) {
for (int i = 0; i < ntx; ++i)
lgtmtx[i] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
// get_lgt4 and get_lgt8 return 1 and pick a lgt matrix if LGT is chosen to
// apply. Otherwise they return 0
int get_lgt4(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx) {
if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
return 1;
} else if (!is_col && (htx_tab[txfm_param->tx_type] == ADST_1D ||
htx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
return 1;
}
lgtmtx[0] = NULL;
return 0;
}
int get_inv_lgt8(transform_1d tx_orig, const TxfmParam *txfm_param,
const tran_high_t *lgtmtx[], int ntx) {
// inter/intra split
if (tx_orig == &aom_iadst8_c) {
for (int i = 0; i < ntx; ++i)
lgtmtx[i] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
int get_lgt8(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx) {
if (is_col && (vtx_tab[txfm_param->tx_type] == ADST_1D ||
vtx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
return 1;
} else if (!is_col && (htx_tab[txfm_param->tx_type] == ADST_1D ||
htx_tab[txfm_param->tx_type] == FLIPADST_1D)) {
lgtmtx[0] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
return 1;
}
lgtmtx[0] = NULL;
return 0;
}
#endif // CONFIG_LGT
......@@ -356,12 +362,10 @@ void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
#endif
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[4];
const tran_high_t *lgtmtx_row[4];
int use_lgt_col =
get_inv_lgt4(IHT_4[tx_type].cols, txfm_param, lgtmtx_col, 4);
int use_lgt_row =
get_inv_lgt4(IHT_4[tx_type].rows, txfm_param, lgtmtx_row, 4);
const tran_high_t *lgtmtx_col[1];
const tran_high_t *lgtmtx_row[1];
int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// inverse transform row vectors
......@@ -373,7 +377,7 @@ void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
#else
#if CONFIG_LGT
if (use_lgt_row)
ilgt4(input, out[i], lgtmtx_row[i]);
ilgt4(input, out[i], lgtmtx_row[0]);
else
#endif
IHT_4[tx_type].rows(input, out[i]);
......@@ -392,7 +396,7 @@ void av1_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (i = 0; i < 4; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
ilgt4(tmp[i], out[i], lgtmtx_col[i]);
ilgt4(tmp[i], out[i], lgtmtx_col[0]);
else
#endif
IHT_4[tx_type].cols(tmp[i], out[i]);
......@@ -454,19 +458,17 @@ void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n2;
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[4];
const tran_high_t *lgtmtx_row[8];
int use_lgt_col =
get_inv_lgt8(IHT_4x8[tx_type].cols, txfm_param, lgtmtx_col, 4);
int use_lgt_row =
get_inv_lgt4(IHT_4x8[tx_type].rows, txfm_param, lgtmtx_row, 8);
const tran_high_t *lgtmtx_col[1];
const tran_high_t *lgtmtx_row[1];
int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// inverse transform row vectors and transpose
for (i = 0; i < n2; ++i) {
#if CONFIG_LGT
if (use_lgt_row)
ilgt4(input, outtmp, lgtmtx_row[i]);
ilgt4(input, outtmp, lgtmtx_row[0]);
else
#endif
IHT_4x8[tx_type].rows(input, outtmp);
......@@ -479,7 +481,7 @@ void av1_iht4x8_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (i = 0; i < n; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
ilgt8(tmp[i], out[i], lgtmtx_col[i]);
ilgt8(tmp[i], out[i], lgtmtx_col[0]);
else
#endif
IHT_4x8[tx_type].cols(tmp[i], out[i]);
......@@ -538,19 +540,17 @@ void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n;
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[8];
const tran_high_t *lgtmtx_row[4];
int use_lgt_col =
get_inv_lgt4(IHT_8x4[tx_type].cols, txfm_param, lgtmtx_col, 8);
int use_lgt_row =
get_inv_lgt8(IHT_8x4[tx_type].rows, txfm_param, lgtmtx_row, 4);
const tran_high_t *lgtmtx_col[1];
const tran_high_t *lgtmtx_row[1];
int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
#endif
// inverse transform row vectors and transpose
for (i = 0; i < n; ++i) {
#if CONFIG_LGT
if (use_lgt_row)
ilgt8(input, outtmp, lgtmtx_row[i]);
ilgt8(input, outtmp, lgtmtx_row[0]);
else
#endif
IHT_8x4[tx_type].rows(input, outtmp);
......@@ -563,7 +563,7 @@ void av1_iht8x4_32_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (i = 0; i < n2; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
ilgt4(tmp[i], out[i], lgtmtx_col[i]);
ilgt4(tmp[i], out[i], lgtmtx_col[0]);
else
#endif
IHT_8x4[tx_type].cols(tmp[i], out[i]);
......@@ -621,16 +621,15 @@ void av1_iht4x16_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n4;
#if CONFIG_LGT
const tran_high_t *lgtmtx_row[16];
int use_lgt_row =
get_inv_lgt4(IHT_4x16[tx_type].rows, txfm_param, lgtmtx_row, 16);
const tran_high_t *lgtmtx_row[1];
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// inverse transform row vectors and transpose
for (i = 0; i < n4; ++i) {
#if CONFIG_LGT
if (use_lgt_row)
ilgt4(input, outtmp, lgtmtx_row[i]);
ilgt4(input, outtmp, lgtmtx_row[0]);
else
#endif
IHT_4x16[tx_type].rows(input, outtmp);
......@@ -696,9 +695,8 @@ void av1_iht16x4_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n;
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[16];
int use_lgt_col =
get_inv_lgt4(IHT_16x4[tx_type].cols, txfm_param, lgtmtx_col, 16);
const tran_high_t *lgtmtx_col[1];
int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
#endif
// inverse transform row vectors and transpose
......@@ -712,7 +710,7 @@ void av1_iht16x4_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (i = 0; i < n4; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
ilgt4(tmp[i], out[i], lgtmtx_col[i]);
ilgt4(tmp[i], out[i], lgtmtx_col[0]);
else
#endif
IHT_16x4[tx_type].cols(tmp[i], out[i]);
......@@ -770,16 +768,15 @@ void av1_iht8x16_128_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n2;
#if CONFIG_LGT
const tran_high_t *lgtmtx_row[16];
int use_lgt_row =
get_inv_lgt8(IHT_8x16[tx_type].rows, txfm_param, lgtmtx_row, 16);
const tran_high_t *lgtmtx_row[1];
int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
#endif
// inverse transform row vectors and transpose
for (i = 0; i < n2; ++i) {
#if CONFIG_LGT
if (use_lgt_row)
ilgt8(input, outtmp, lgtmtx_row[i]);
ilgt8(input, outtmp, lgtmtx_row[0]);
else
#endif
IHT_8x16[tx_type].rows(input, outtmp);
......@@ -846,9 +843,8 @@ void av1_iht16x8_128_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n;
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[16];
int use_lgt_col =
get_inv_lgt8(IHT_16x8[tx_type].cols, txfm_param, lgtmtx_col, 16);
const tran_high_t *lgtmtx_col[1];
int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
#endif
// inverse transform row vectors and transpose
......@@ -863,7 +859,7 @@ void av1_iht16x8_128_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (i = 0; i < n2; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
ilgt8(tmp[i], out[i], lgtmtx_col[i]);
ilgt8(tmp[i], out[i], lgtmtx_col[0]);
else
#endif
IHT_16x8[tx_type].cols(tmp[i], out[i]);
......@@ -921,16 +917,15 @@ void av1_iht8x32_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n4;
#if CONFIG_LGT
const tran_high_t *lgtmtx_row[32];
int use_lgt_row =
get_inv_lgt8(IHT_8x32[tx_type].rows, txfm_param, lgtmtx_row, 32);
const tran_high_t *lgtmtx_row[1];
int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
#endif
// inverse transform row vectors and transpose
for (i = 0; i < n4; ++i) {
#if CONFIG_LGT
if (use_lgt_row)
ilgt8(input, outtmp, lgtmtx_row[i]);
ilgt8(input, outtmp, lgtmtx_row[0]);
else
#endif
IHT_8x32[tx_type].rows(input, outtmp);
......@@ -996,9 +991,8 @@ void av1_iht32x8_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = n;
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[32];
int use_lgt_col =
get_inv_lgt4(IHT_32x8[tx_type].cols, txfm_param, lgtmtx_col, 32);
const tran_high_t *lgtmtx_col[1];
int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
#endif
// inverse transform row vectors and transpose
......@@ -1012,7 +1006,7 @@ void av1_iht32x8_256_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (i = 0; i < n4; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
ilgt8(tmp[i], out[i], lgtmtx_col[i]);
ilgt8(tmp[i], out[i], lgtmtx_col[0]);
else
#endif
IHT_32x8[tx_type].cols(tmp[i], out[i]);
......@@ -1193,12 +1187,10 @@ void av1_iht8x8_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
int outstride = 8;
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[8];
const tran_high_t *lgtmtx_row[8];
int use_lgt_col =
get_inv_lgt8(IHT_8[tx_type].cols, txfm_param, lgtmtx_col, 8);
int use_lgt_row =
get_inv_lgt8(IHT_8[tx_type].rows, txfm_param, lgtmtx_row, 8);
const tran_high_t *lgtmtx_col[1];
const tran_high_t *lgtmtx_row[1];
int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
#endif
// inverse transform row vectors
......@@ -1210,7 +1202,7 @@ void av1_iht8x8_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
#else
#if CONFIG_LGT
if (use_lgt_row)
ilgt8(input, out[i], lgtmtx_row[i]);
ilgt8(input, out[i], lgtmtx_row[0]);
else
#endif
IHT_8[tx_type].rows(input, out[i]);
......@@ -1229,7 +1221,7 @@ void av1_iht8x8_64_add_c(const tran_low_t *input, uint8_t *dest, int stride,
for (i = 0; i < 8; ++i) {
#if CONFIG_LGT
if (use_lgt_col)
ilgt8(tmp[i], out[i], lgtmtx_col[i]);
ilgt8(tmp[i], out[i], lgtmtx_col[0]);
else
#endif
IHT_8[tx_type].cols(tmp[i], out[i]);
......@@ -2294,9 +2286,6 @@ static InvTxfmFunc inv_txfm_func[2] = { av1_inv_txfm_add,
av1_highbd_inv_txfm_add };
#endif
// TODO(kslu) Change input arguments to TxfmParam, which contains mode,
// tx_type, tx_size, dst, stride, eob. Thus, the additional argument when LGT
// is on will no longer be needed.
void av1_inverse_transform_block(const MACROBLOCKD *xd,
const tran_low_t *dqcoeff,
#if CONFIG_LGT
......@@ -2321,13 +2310,13 @@ void av1_inverse_transform_block(const MACROBLOCKD *xd,
TxfmParam txfm_param;
init_txfm_param(xd, tx_size, tx_type, eob, &txfm_param);
#if CONFIG_LGT || CONFIG_MRC_TX
txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
txfm_param.dst = dst;
txfm_param.stride = stride;
txfm_param.is_inter = is_inter_block(&xd->mi[0]->mbmi);
#endif // CONFIG_LGT || CONFIG_MRC_TX
#if CONFIG_LGT
txfm_param.mode = mode;
#endif
#endif // CONFIG_LGT
#endif // CONFIG_LGT || CONFIG_MRC_TX
const int is_hbd = get_bitdepth_data_path_index(xd);
#if CONFIG_TXMG
......@@ -2369,14 +2358,11 @@ void av1_inverse_transform_block_facade(MACROBLOCKD *xd, int plane, int block,
const int dst_stride = pd->dst.stride;
uint8_t *dst =
&pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
av1_inverse_transform_block(xd, dqcoeff,
#if CONFIG_LGT
PREDICTION_MODE mode = get_prediction_mode(xd->mi[0], plane, tx_size, block);
av1_inverse_transform_block(xd, dqcoeff, mode, tx_type, tx_size, dst,
dst_stride, eob);
#else
av1_inverse_transform_block(xd, dqcoeff, tx_type, tx_size, dst, dst_stride,
eob);
xd->mi[0]->mbmi.mode,
#endif // CONFIG_LGT
tx_type, tx_size, dst, dst_stride, eob);
}
void av1_highbd_inv_txfm_add(const tran_low_t *input, uint8_t *dest, int stride,
......
......@@ -26,13 +26,19 @@
extern "C" {
#endif
// TODO(kslu) move the common stuff in idct.h to av1_txfm.h or txfm_common.h
typedef void (*transform_1d)(const tran_low_t *, tran_low_t *);
typedef struct {
transform_1d cols, rows; // vertical and horizontal
} transform_2d;
#if CONFIG_LGT
int get_lgt4(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx);
int get_lgt8(const TxfmParam *txfm_param, int is_col,
const tran_high_t **lgtmtx);
#endif // CONFIG_LGT
#if CONFIG_HIGHBITDEPTH
typedef void (*highbd_transform_1d)(const tran_low_t *, tran_low_t *, int bd);
......
......@@ -504,13 +504,9 @@ static void predict_and_reconstruct_intra_block(
if (eob) {
uint8_t *dst =
&pd->dst.buf[(row * pd->dst.stride + col) << tx_size_wide_log2[0]];
#if CONFIG_LGT
const PREDICTION_MODE mode =
get_prediction_mode(xd->mi[0], plane, tx_size, block_idx);
#endif // CONFIG_LGT
inverse_transform_block(xd, plane,
#if CONFIG_LGT
mode,
mbmi->mode,
#endif
tx_type, tx_size, dst, pd->dst.stride,
max_scan_line, eob);
......
......@@ -1179,10 +1179,7 @@ static void get_masked_residual32(const int16_t **input, int *input_stride,
#if CONFIG_LGT
static void flgt4(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
if (!(input[0] | input[1] | input[2] | input[3])) {
output[0] = output[1] = output[2] = output[3] = 0;
return;
}
if (!lgtmtx) assert(0);
// evaluate s[j] = sum of all lgtmtx[j][i]*input[i] over i=1,...,4
tran_high_t s[4] = { 0 };
......@@ -1194,6 +1191,8 @@ static void flgt4(const tran_low_t *input, tran_low_t *output,
static void flgt8(const tran_low_t *input, tran_low_t *output,
const tran_high_t *lgtmtx) {
if (!lgtmtx) assert(0);
// evaluate s[j] = sum of all lgtmtx[j][i]*input[i] over i=1,...,8
tran_high_t s[8] = { 0 };
for (int i = 0; i < 8; ++i)
......@@ -1201,29 +1200,6 @@ static void flgt8(const tran_low_t *input, tran_low_t *output,
for (int i = 0; i < 8; ++i) output[i] = (tran_low_t)fdct_round_shift(s[i]);
}
// The get_fwd_lgt functions return 1 if LGT is chosen to apply, and 0 otherwise
int get_fwd_lgt4(transform_1d tx_orig, TxfmParam *txfm_param,
const tran_high_t *lgtmtx[], int ntx) {
// inter/intra split
if (tx_orig == &fadst4) {
for (int i = 0; i < ntx; ++i)
lgtmtx[i] = txfm_param->is_inter ? &lgt4_170[0][0] : &lgt4_140[0][0];
return 1;
}
return 0;
}
int get_fwd_lgt8(transform_1d tx_orig, TxfmParam *txfm_param,
const tran_high_t *lgtmtx[], int ntx) {
// inter/intra split
if (tx_orig == &fadst8) {
for (int i = 0; i < ntx; ++i)
lgtmtx[i] = txfm_param->is_inter ? &lgt8_170[0][0] : &lgt8_150[0][0];
return 1;
}
return 0;
}
#endif // CONFIG_LGT
#if CONFIG_EXT_TX
......@@ -1422,10 +1398,10 @@ void av1_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
#if CONFIG_LGT
// Choose LGT adaptive to the prediction. We may apply different LGTs for
// different rows/columns, indicated by the pointers to 2D arrays
const tran_high_t *lgtmtx_col[4];
const tran_high_t *lgtmtx_row[4];
int use_lgt_col = get_fwd_lgt4(ht.cols, txfm_param, lgtmtx_col, 4);
int use_lgt_row = get_fwd_lgt4(ht.rows, txfm_param, lgtmtx_row, 4);
const tran_high_t *lgtmtx_col[1];
const tran_high_t *lgtmtx_row[1];
int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// Columns
......@@ -1437,7 +1413,7 @@ void av1_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
#endif
#if CONFIG_LGT
if (use_lgt_col)
flgt4(temp_in, temp_out, lgtmtx_col[i]);
flgt4(temp_in, temp_out, lgtmtx_col[0]);
else
#endif
ht.cols(temp_in, temp_out);
......@@ -1449,7 +1425,7 @@ void av1_fht4x4_c(const int16_t *input, tran_low_t *output, int stride,
for (j = 0; j < 4; ++j) temp_in[j] = out[j + i * 4];
#if CONFIG_LGT
if (use_lgt_row)
flgt4(temp_in, temp_out, lgtmtx_row[i]);
flgt4(temp_in, temp_out, lgtmtx_row[0]);
else
#endif
ht.rows(temp_in, temp_out);
......@@ -1505,10 +1481,10 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
#endif
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[4];
const tran_high_t *lgtmtx_row[8];
int use_lgt_col = get_fwd_lgt8(ht.cols, txfm_param, lgtmtx_col, 4);
int use_lgt_row = get_fwd_lgt4(ht.rows, txfm_param, lgtmtx_row, 8);
const tran_high_t *lgtmtx_col[1];
const tran_high_t *lgtmtx_row[1];
int use_lgt_col = get_lgt8(txfm_param, 1, lgtmtx_col);
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// Rows
......@@ -1518,7 +1494,7 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
(tran_low_t)fdct_round_shift(input[i * stride + j] * 4 * Sqrt2);
#if CONFIG_LGT
if (use_lgt_row)
flgt4(temp_in, temp_out, lgtmtx_row[i]);
flgt4(temp_in, temp_out, lgtmtx_row[0]);
else
#endif
ht.rows(temp_in, temp_out);
......@@ -1530,7 +1506,7 @@ void av1_fht4x8_c(const int16_t *input, tran_low_t *output, int stride,
for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
#if CONFIG_LGT
if (use_lgt_col)
flgt8(temp_in, temp_out, lgtmtx_col[i]);
flgt8(temp_in, temp_out, lgtmtx_col[0]);
else
#endif
ht.cols(temp_in, temp_out);
......@@ -1581,10 +1557,10 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
#endif
#if CONFIG_LGT
const tran_high_t *lgtmtx_col[8];
const tran_high_t *lgtmtx_row[4];
int use_lgt_col = get_fwd_lgt4(ht.cols, txfm_param, lgtmtx_col, 8);
int use_lgt_row = get_fwd_lgt8(ht.rows, txfm_param, lgtmtx_row, 4);
const tran_high_t *lgtmtx_col[1];
const tran_high_t *lgtmtx_row[1];
int use_lgt_col = get_lgt4(txfm_param, 1, lgtmtx_col);
int use_lgt_row = get_lgt8(txfm_param, 0, lgtmtx_row);
#endif
// Columns
......@@ -1594,7 +1570,7 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
(tran_low_t)fdct_round_shift(input[j * stride + i] * 4 * Sqrt2);
#if CONFIG_LGT
if (use_lgt_col)
flgt4(temp_in, temp_out, lgtmtx_col[i]);
flgt4(temp_in, temp_out, lgtmtx_col[0]);
else
#endif
ht.cols(temp_in, temp_out);
......@@ -1606,7 +1582,7 @@ void av1_fht8x4_c(const int16_t *input, tran_low_t *output, int stride,
for (j = 0; j < n2; ++j) temp_in[j] = out[j + i * n2];
#if CONFIG_LGT
if (use_lgt_row)
flgt8(temp_in, temp_out, lgtmtx_row[i]);
flgt8(temp_in, temp_out, lgtmtx_row[0]);
else
#endif
ht.rows(temp_in, temp_out);
......@@ -1657,8 +1633,8 @@ void av1_fht4x16_c(const int16_t *input, tran_low_t *output, int stride,
#endif
#if CONFIG_LGT
const tran_high_t *lgtmtx_row[16];
int use_lgt_row = get_fwd_lgt4(ht.rows, txfm_param, lgtmtx_row, 16);
const tran_high_t *lgtmtx_row[1];
int use_lgt_row = get_lgt4(txfm_param, 0, lgtmtx_row);
#endif
// Rows
...