Commit 1cb757cb authored by Debargha Mukherjee's avatar Debargha Mukherjee

Adds the option to use 5x5 Wiener for chroma

Change-Id: I1b789acc18f1e69fb5db069ccd8bd17815938e9d
parent d02642f2
......@@ -78,6 +78,10 @@ extern "C" {
#define WIENER_TMPBUF_SIZE (0)
#define WIENER_EXTBUF_SIZE (0)
// If WIENER_WIN_CHROMA == WIENER_WIN - 2, that implies 5x5 filters are used for
// chroma. To use 7x7 for chroma set WIENER_WIN_CHROMA to WIENER_WIN.
#define WIENER_WIN_CHROMA (WIENER_WIN - 2)
#define WIENER_FILT_PREC_BITS 7
#define WIENER_FILT_STEP (1 << WIENER_FILT_PREC_BITS)
......
......@@ -2685,14 +2685,17 @@ static void decode_restoration_mode(AV1_COMMON *cm,
cm->rst_info[2].restoration_tilesize = cm->rst_info[1].restoration_tilesize;
}
static void read_wiener_filter(WienerInfo *wiener_info,
static void read_wiener_filter(int wiener_win, WienerInfo *wiener_info,
WienerInfo *ref_wiener_info, aom_reader *rb) {
wiener_info->vfilter[0] = wiener_info->vfilter[WIENER_WIN - 1] =
aom_read_primitive_refsubexpfin(
rb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV, ACCT_STR) +
WIENER_FILT_TAP0_MINV;
if (wiener_win == WIENER_WIN)
wiener_info->vfilter[0] = wiener_info->vfilter[WIENER_WIN - 1] =
aom_read_primitive_refsubexpfin(
rb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV, ACCT_STR) +
WIENER_FILT_TAP0_MINV;
else
wiener_info->vfilter[0] = wiener_info->vfilter[WIENER_WIN - 1] = 0;
wiener_info->vfilter[1] = wiener_info->vfilter[WIENER_WIN - 2] =
aom_read_primitive_refsubexpfin(
rb, WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
......@@ -2710,12 +2713,15 @@ static void read_wiener_filter(WienerInfo *wiener_info,
-2 * (wiener_info->vfilter[0] + wiener_info->vfilter[1] +
wiener_info->vfilter[2]);
wiener_info->hfilter[0] = wiener_info->hfilter[WIENER_WIN - 1] =
aom_read_primitive_refsubexpfin(
rb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV, ACCT_STR) +
WIENER_FILT_TAP0_MINV;
if (wiener_win == WIENER_WIN)
wiener_info->hfilter[0] = wiener_info->hfilter[WIENER_WIN - 1] =
aom_read_primitive_refsubexpfin(
rb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV, ACCT_STR) +
WIENER_FILT_TAP0_MINV;
else
wiener_info->hfilter[0] = wiener_info->hfilter[WIENER_WIN - 1] = 0;
wiener_info->hfilter[1] = wiener_info->hfilter[WIENER_WIN - 2] =
aom_read_primitive_refsubexpfin(
rb, WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
......@@ -2779,7 +2785,8 @@ static void decode_restoration(AV1_COMMON *cm, aom_reader *rb) {
aom_read_tree(rb, av1_switchable_restore_tree,
cm->fc->switchable_restore_prob, ACCT_STR);
if (rsi->restoration_type[i] == RESTORE_WIENER) {
read_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, rb);
read_wiener_filter(WIENER_WIN, &rsi->wiener_info[i], &ref_wiener_info,
rb);
} else if (rsi->restoration_type[i] == RESTORE_SGRPROJ) {
read_sgrproj_filter(&rsi->sgrproj_info[i], &ref_sgrproj_info, rb);
}
......@@ -2788,7 +2795,8 @@ static void decode_restoration(AV1_COMMON *cm, aom_reader *rb) {
for (i = 0; i < ntiles; ++i) {
if (aom_read(rb, RESTORE_NONE_WIENER_PROB, ACCT_STR)) {
rsi->restoration_type[i] = RESTORE_WIENER;
read_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, rb);
read_wiener_filter(WIENER_WIN, &rsi->wiener_info[i], &ref_wiener_info,
rb);
} else {
rsi->restoration_type[i] = RESTORE_NONE;
}
......@@ -2817,7 +2825,8 @@ static void decode_restoration(AV1_COMMON *cm, aom_reader *rb) {
else
rsi->restoration_type[i] = RESTORE_WIENER;
if (rsi->restoration_type[i] == RESTORE_WIENER) {
read_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, rb);
read_wiener_filter(WIENER_WIN_CHROMA, &rsi->wiener_info[i],
&ref_wiener_info, rb);
}
}
} else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
......
......@@ -3283,13 +3283,17 @@ static void encode_restoration_mode(AV1_COMMON *cm,
}
}
static void write_wiener_filter(WienerInfo *wiener_info,
static void write_wiener_filter(int wiener_win, WienerInfo *wiener_info,
WienerInfo *ref_wiener_info, aom_writer *wb) {
aom_write_primitive_refsubexpfin(
wb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
if (wiener_win == WIENER_WIN)
aom_write_primitive_refsubexpfin(
wb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
else
assert(wiener_info->vfilter[0] == 0 &&
wiener_info->vfilter[WIENER_WIN - 1] == 0);
aom_write_primitive_refsubexpfin(
wb, WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
WIENER_FILT_TAP1_SUBEXP_K,
......@@ -3300,11 +3304,15 @@ static void write_wiener_filter(WienerInfo *wiener_info,
WIENER_FILT_TAP2_SUBEXP_K,
ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
aom_write_primitive_refsubexpfin(
wb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
if (wiener_win == WIENER_WIN)
aom_write_primitive_refsubexpfin(
wb, WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
else
assert(wiener_info->hfilter[0] == 0 &&
wiener_info->hfilter[WIENER_WIN - 1] == 0);
aom_write_primitive_refsubexpfin(
wb, WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
WIENER_FILT_TAP1_SUBEXP_K,
......@@ -3362,7 +3370,8 @@ static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
wb, av1_switchable_restore_tree, cm->fc->switchable_restore_prob,
&switchable_restore_encodings[rsi->restoration_type[i]]);
if (rsi->restoration_type[i] == RESTORE_WIENER) {
write_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, wb);
write_wiener_filter(WIENER_WIN, &rsi->wiener_info[i],
&ref_wiener_info, wb);
} else if (rsi->restoration_type[i] == RESTORE_SGRPROJ) {
write_sgrproj_filter(&rsi->sgrproj_info[i], &ref_sgrproj_info, wb);
}
......@@ -3372,7 +3381,8 @@ static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
aom_write(wb, rsi->restoration_type[i] != RESTORE_NONE,
RESTORE_NONE_WIENER_PROB);
if (rsi->restoration_type[i] != RESTORE_NONE) {
write_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, wb);
write_wiener_filter(WIENER_WIN, &rsi->wiener_info[i],
&ref_wiener_info, wb);
}
}
} else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
......@@ -3395,7 +3405,8 @@ static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
aom_write(wb, rsi->restoration_type[i] != RESTORE_NONE,
RESTORE_NONE_WIENER_PROB);
if (rsi->restoration_type[i] != RESTORE_NONE) {
write_wiener_filter(&rsi->wiener_info[i], &ref_wiener_info, wb);
write_wiener_filter(WIENER_WIN_CHROMA, &rsi->wiener_info[i],
&ref_wiener_info, wb);
}
}
} else if (rsi->frame_restoration_type == RESTORE_SGRPROJ) {
......
......@@ -548,41 +548,44 @@ static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
return avg;
}
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
int v_start, int v_end, int dgd_stride,
int src_stride, double *M, double *H) {
static void compute_stats(int wiener_win, uint8_t *dgd, uint8_t *src,
int h_start, int h_end, int v_start, int v_end,
int dgd_stride, int src_stride, double *M,
double *H) {
int i, j, k, l;
double Y[WIENER_WIN2];
const int wiener_win2 = wiener_win * wiener_win;
const int wiener_halfwin = (wiener_win >> 1);
const double avg =
find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
memset(M, 0, sizeof(*M) * WIENER_WIN2);
memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
memset(M, 0, sizeof(*M) * wiener_win2);
memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
for (i = v_start; i < v_end; i++) {
for (j = h_start; j < h_end; j++) {
const double X = (double)src[i * src_stride + j] - avg;
int idx = 0;
for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
idx++;
}
}
for (k = 0; k < WIENER_WIN2; ++k) {
for (k = 0; k < wiener_win2; ++k) {
M[k] += Y[k] * X;
H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
for (l = k + 1; l < WIENER_WIN2; ++l) {
H[k * wiener_win2 + k] += Y[k] * Y[k];
for (l = k + 1; l < wiener_win2; ++l) {
// H is a symmetric matrix, so we only need to fill out the upper
// triangle here. We can copy it down to the lower triangle outside
// the (i, j) loops.
H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
H[k * wiener_win2 + l] += Y[k] * Y[l];
}
}
}
}
for (k = 0; k < WIENER_WIN2; ++k) {
for (l = k + 1; l < WIENER_WIN2; ++l) {
H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
for (k = 0; k < wiener_win2; ++k) {
for (l = k + 1; l < wiener_win2; ++l) {
H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
}
}
}
......@@ -600,168 +603,183 @@ static double find_average_highbd(uint16_t *src, int h_start, int h_end,
return avg;
}
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
int h_end, int v_start, int v_end,
static void compute_stats_highbd(int wiener_win, uint8_t *dgd8, uint8_t *src8,
int h_start, int h_end, int v_start, int v_end,
int dgd_stride, int src_stride, double *M,
double *H) {
int i, j, k, l;
double Y[WIENER_WIN2];
const int wiener_win2 = wiener_win * wiener_win;
const int wiener_halfwin = (wiener_win >> 1);
uint16_t *src = CONVERT_TO_SHORTPTR(src8);
uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
const double avg =
find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
memset(M, 0, sizeof(*M) * WIENER_WIN2);
memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
memset(M, 0, sizeof(*M) * wiener_win2);
memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
for (i = v_start; i < v_end; i++) {
for (j = h_start; j < h_end; j++) {
const double X = (double)src[i * src_stride + j] - avg;
int idx = 0;
for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
idx++;
}
}
for (k = 0; k < WIENER_WIN2; ++k) {
for (k = 0; k < wiener_win2; ++k) {
M[k] += Y[k] * X;
H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
for (l = k + 1; l < WIENER_WIN2; ++l) {
H[k * wiener_win2 + k] += Y[k] * Y[k];
for (l = k + 1; l < wiener_win2; ++l) {
// H is a symmetric matrix, so we only need to fill out the upper
// triangle here. We can copy it down to the lower triangle outside
// the (i, j) loops.
H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
H[k * wiener_win2 + l] += Y[k] * Y[l];
}
}
}
}
for (k = 0; k < WIENER_WIN2; ++k) {
for (l = k + 1; l < WIENER_WIN2; ++l) {
H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
for (k = 0; k < wiener_win2; ++k) {
for (l = k + 1; l < wiener_win2; ++l) {
H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
}
}
}
#endif // CONFIG_HIGHBITDEPTH
static INLINE int wrap_index(int i) {
return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
static INLINE int wrap_index(int i, int wiener_win) {
const int wiener_halfwin1 = (wiener_win >> 1) + 1;
return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
}
// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
double *a, double *b) {
int i, j;
double S[WIENER_WIN];
double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
int w, w2;
const int wiener_win2 = wiener_win * wiener_win;
const int wiener_halfwin1 = (wiener_win >> 1) + 1;
memset(A, 0, sizeof(A));
memset(B, 0, sizeof(B));
for (i = 0; i < WIENER_WIN; i++) {
for (j = 0; j < WIENER_WIN; ++j) {
const int jj = wrap_index(j);
for (i = 0; i < wiener_win; i++) {
for (j = 0; j < wiener_win; ++j) {
const int jj = wrap_index(j, wiener_win);
A[jj] += Mc[i][j] * b[i];
}
}
for (i = 0; i < WIENER_WIN; i++) {
for (j = 0; j < WIENER_WIN; j++) {
for (i = 0; i < wiener_win; i++) {
for (j = 0; j < wiener_win; j++) {
int k, l;
for (k = 0; k < WIENER_WIN; ++k)
for (l = 0; l < WIENER_WIN; ++l) {
const int kk = wrap_index(k);
const int ll = wrap_index(l);
B[ll * WIENER_HALFWIN1 + kk] +=
Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
for (k = 0; k < wiener_win; ++k)
for (l = 0; l < wiener_win; ++l) {
const int kk = wrap_index(k, wiener_win);
const int ll = wrap_index(l, wiener_win);
B[ll * wiener_halfwin1 + kk] +=
Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
}
}
}
// Normalization enforcement in the system of equations itself
w = WIENER_WIN;
w2 = (w >> 1) + 1;
for (i = 0; i < w2 - 1; ++i)
for (i = 0; i < wiener_halfwin1 - 1; ++i)
A[i] -=
A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
for (i = 0; i < w2 - 1; ++i)
for (j = 0; j < w2 - 1; ++j)
B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
2 * B[(w2 - 1) * w2 + (w2 - 1)]);
if (linsolve(w2 - 1, B, w2, A, S)) {
S[w2 - 1] = 1.0;
for (i = w2; i < w; ++i) {
S[i] = S[w - 1 - i];
S[w2 - 1] -= 2 * S[i];
A[wiener_halfwin1 - 1] * 2 +
B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
for (i = 0; i < wiener_halfwin1 - 1; ++i)
for (j = 0; j < wiener_halfwin1 - 1; ++j)
B[i * wiener_halfwin1 + j] -=
2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
(wiener_halfwin1 - 1)]);
if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
S[wiener_halfwin1 - 1] = 1.0;
for (i = wiener_halfwin1; i < wiener_win; ++i) {
S[i] = S[wiener_win - 1 - i];
S[wiener_halfwin1 - 1] -= 2 * S[i];
}
memcpy(a, S, w * sizeof(*a));
memcpy(a, S, wiener_win * sizeof(*a));
}
}
// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
double *a, double *b) {
int i, j;
double S[WIENER_WIN];
double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
int w, w2;
const int wiener_win2 = wiener_win * wiener_win;
const int wiener_halfwin1 = (wiener_win >> 1) + 1;
memset(A, 0, sizeof(A));
memset(B, 0, sizeof(B));
for (i = 0; i < WIENER_WIN; i++) {
const int ii = wrap_index(i);
for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
for (i = 0; i < wiener_win; i++) {
const int ii = wrap_index(i, wiener_win);
for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
}
for (i = 0; i < WIENER_WIN; i++) {
for (j = 0; j < WIENER_WIN; j++) {
const int ii = wrap_index(i);
const int jj = wrap_index(j);
for (i = 0; i < wiener_win; i++) {
for (j = 0; j < wiener_win; j++) {
const int ii = wrap_index(i, wiener_win);
const int jj = wrap_index(j, wiener_win);
int k, l;
for (k = 0; k < WIENER_WIN; ++k)
for (l = 0; l < WIENER_WIN; ++l)
B[jj * WIENER_HALFWIN1 + ii] +=
Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
for (k = 0; k < wiener_win; ++k)
for (l = 0; l < wiener_win; ++l)
B[jj * wiener_halfwin1 + ii] +=
Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
}
}
// Normalization enforcement in the system of equations itself
w = WIENER_WIN;
w2 = WIENER_HALFWIN1;
for (i = 0; i < w2 - 1; ++i)
for (i = 0; i < wiener_halfwin1 - 1; ++i)
A[i] -=
A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
for (i = 0; i < w2 - 1; ++i)
for (j = 0; j < w2 - 1; ++j)
B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
2 * B[(w2 - 1) * w2 + (w2 - 1)]);
if (linsolve(w2 - 1, B, w2, A, S)) {
S[w2 - 1] = 1.0;
for (i = w2; i < w; ++i) {
S[i] = S[w - 1 - i];
S[w2 - 1] -= 2 * S[i];
A[wiener_halfwin1 - 1] * 2 +
B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
for (i = 0; i < wiener_halfwin1 - 1; ++i)
for (j = 0; j < wiener_halfwin1 - 1; ++j)
B[i * wiener_halfwin1 + j] -=
2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
(wiener_halfwin1 - 1)]);
if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
S[wiener_halfwin1 - 1] = 1.0;
for (i = wiener_halfwin1; i < wiener_win; ++i) {
S[i] = S[wiener_win - 1 - i];
S[wiener_halfwin1 - 1] -= 2 * S[i];
}
memcpy(b, S, w * sizeof(*b));
memcpy(b, S, wiener_win * sizeof(*b));
}
}
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
double *b) {
static int wiener_decompose_sep_sym(int wiener_win, double *M, double *H,
double *a, double *b) {
static const int init_filt[WIENER_WIN] = {
WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
WIENER_FILT_TAP0_MIDV,
};
int i, j, iter;
double *Hc[WIENER_WIN2];
double *Mc[WIENER_WIN];
for (i = 0; i < WIENER_WIN; i++) {
Mc[i] = M + i * WIENER_WIN;
for (j = 0; j < WIENER_WIN; j++) {
Hc[i * WIENER_WIN + j] =
H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
}
int i, j, iter;
const int plane_off = (WIENER_WIN - wiener_win) >> 1;
const int wiener_win2 = wiener_win * wiener_win;
for (i = 0; i < wiener_win; i++) {
a[i] = b[i] = (double)init_filt[i + plane_off] / WIENER_FILT_STEP;
}
for (i = 0; i < WIENER_WIN; i++) {
a[i] = b[i] = (double)init_filt[i] / WIENER_FILT_STEP;
for (i = 0; i < wiener_win; i++) {
Mc[i] = M + i * wiener_win;
for (j = 0; j < wiener_win; j++) {
Hc[i * wiener_win + j] =
H + i * wiener_win * wiener_win2 + j * wiener_win;
}
}
iter = 1;
while (iter < NUM_WIENER_ITERS) {
update_a_sep_sym(Mc, Hc, a, b);
update_b_sep_sym(Mc, Hc, a, b);
update_a_sep_sym(wiener_win, Mc, Hc, a, b);
update_b_sep_sym(wiener_win, Mc, Hc, a, b);
iter++;
}
return 1;
......@@ -770,14 +788,16 @@ static int wiener_decompose_sep_sym(double *M, double *H, double *a,
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
// against identity filters; Final score is defined as the difference between
// the function values
static double compute_score(double *M, double *H, InterpKernel vfilt,
InterpKernel hfilt) {
static double compute_score(int wiener_win, double *M, double *H,
InterpKernel vfilt, InterpKernel hfilt) {
double ab[WIENER_WIN * WIENER_WIN];
int i, k, l;
double P = 0, Q = 0;
double iP = 0, iQ = 0;
double Score, iScore;
double a[WIENER_WIN], b[WIENER_WIN];
const int plane_off = (WIENER_WIN - wiener_win) >> 1;
const int wiener_win2 = wiener_win * wiener_win;
aom_clear_system_state();
......@@ -788,32 +808,40 @@ static double compute_score(double *M, double *H, InterpKernel vfilt,
a[WIENER_HALFWIN] -= 2 * a[i];
b[WIENER_HALFWIN] -= 2 * b[i];
}
for (k = 0; k < WIENER_WIN; ++k) {
for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
for (k = 0; k < wiener_win; ++k) {
for (l = 0; l < wiener_win; ++l)
ab[k * wiener_win + l] = a[l + plane_off] * b[k + plane_off];
}
for (k = 0; k < WIENER_WIN2; ++k) {
for (k = 0; k < wiener_win2; ++k) {
P += ab[k] * M[k];
for (l = 0; l < WIENER_WIN2; ++l)
Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
for (l = 0; l < wiener_win2; ++l)
Q += ab[k] * H[k * wiener_win2 + l] * ab[l];
}
Score = Q - 2 * P;
iP = M[WIENER_WIN2 >> 1];
iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
iP = M[wiener_win2 >> 1];
iQ = H[(wiener_win2 >> 1) * wiener_win2 + (wiener_win2 >> 1)];
iScore = iQ - 2 * iP;
return Score - iScore;
}
static void quantize_sym_filter(double *f, InterpKernel fi) {
static void quantize_sym_filter(int wiener_win, double *f, InterpKernel fi) {
int i;
for (i = 0; i < WIENER_HALFWIN; ++i) {
const int wiener_halfwin = (wiener_win >> 1);
for (i = 0; i < wiener_halfwin; ++i) {
fi[i] = RINT(f[i] * WIENER_FILT_STEP);
}
// Specialize for 7-tap filter
fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
if (wiener_win == WIENER_WIN) {
fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
} else {
fi[2] = CLIP(fi[1], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
fi[1] = CLIP(fi[0], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
fi[0] = 0;
}
// Satisfy filter constraints
fi[WIENER_WIN - 1] = fi[0];
fi[WIENER_WIN - 2] = fi[1];
......@@ -822,14 +850,15 @@ static void quantize_sym_filter(double *f, InterpKernel fi) {
fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
}
static int count_wiener_bits(WienerInfo *wiener_info,
static int count_wiener_bits(int wiener_win, WienerInfo *wiener_info,
WienerInfo *ref_wiener_info) {
int bits = 0;
bits += aom_count_primitive_refsubexpfin(
WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
if (wiener_win == WIENER_WIN)
bits += aom_count_primitive_refsubexpfin(
WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
bits += aom_count_primitive_refsubexpfin(
WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
WIENER_FILT_TAP1_SUBEXP_K,
......@@ -840,11 +869,12 @@ static int count_wiener_bits(WienerInfo *wiener_info,
WIENER_FILT_TAP2_SUBEXP_K,
ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
bits += aom_count_primitive_refsubexpfin(
WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
if (wiener_win == WIENER_WIN)
bits += aom_count_primitive_refsubexpfin(
WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
WIENER_FILT_TAP0_SUBEXP_K,
ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
bits += aom_count_primitive_refsubexpfin(