pickrst.c 28.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
/*
 *  Copyright (c) 2010 The WebM project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include <assert.h>
12
#include <float.h>
13
14
15
16
17
#include <limits.h>
#include <math.h>

#include "./vpx_scale_rtcd.h"

18
19
20
21
#include "aom_dsp/psnr.h"
#include "aom_dsp/vpx_dsp_common.h"
#include "aom_mem/vpx_mem.h"
#include "aom_ports/mem.h"
22

23
24
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
25

26
27
28
29
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
#include "av1/encoder/quantize.h"
30

31
static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *sd,
32
                                     VP10_COMP *const cpi, RestorationInfo *rsi,
33
                                     int partial_frame) {
34
  VP10_COMMON *const cm = &cpi->common;
35
  int64_t filt_err;
36
  vp10_loop_restoration_frame(cm->frame_to_show, cm, rsi, 1, partial_frame);
37
38
#if CONFIG_VP9_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
39
    filt_err = vpx_highbd_get_y_sse(sd, cm->frame_to_show);
40
  } else {
41
    filt_err = vpx_get_y_sse(sd, cm->frame_to_show);
42
43
  }
#else
44
  filt_err = vpx_get_y_sse(sd, cm->frame_to_show);
45
46
47
48
49
50
51
#endif  // CONFIG_VP9_HIGHBITDEPTH

  // Re-instate the unfiltered frame
  vpx_yv12_copy_y(&cpi->last_frame_db, cm->frame_to_show);
  return filt_err;
}

52
static int search_bilateral_level(const YV12_BUFFER_CONFIG *sd, VP10_COMP *cpi,
53
                                  int filter_level, int partial_frame,
54
                                  int *bilateral_level, double *best_cost_ret) {
55
  VP10_COMMON *const cm = &cpi->common;
56
  int i, j, tile_idx;
57
  int64_t err;
58
  int bits;
59
60
61
62
  double cost, best_cost, cost_norestore, cost_bilateral;
  const int bilateral_level_bits = vp10_bilateral_level_bits(&cpi->common);
  const int bilateral_levels = 1 << bilateral_level_bits;
  MACROBLOCK *x = &cpi->td.mb;
63
  RestorationInfo rsi;
64
65
  const int ntiles =
      vp10_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
66
67
68
69
70
71
72

  //  Make a copy of the unfiltered / processed recon buffer
  vpx_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  vp10_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                         1, partial_frame);
  vpx_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

73
  // RD cost associated with no restoration
74
75
76
  rsi.restoration_type = RESTORE_NONE;
  err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
  bits = 0;
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv,
                              (bits << (VP10_PROB_COST_SHIFT - 4)), err);
  best_cost = cost_norestore;

  // RD cost associated with bilateral filtering
  rsi.restoration_type = RESTORE_BILATERAL;
  rsi.bilateral_level =
      (int *)vpx_malloc(sizeof(*rsi.bilateral_level) * ntiles);
  assert(rsi.bilateral_level != NULL);

  for (j = 0; j < ntiles; ++j) bilateral_level[j] = -1;

  // Find best filter for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    for (j = 0; j < ntiles; ++j) rsi.bilateral_level[j] = -1;
    best_cost = cost_norestore;
    for (i = 0; i < bilateral_levels; ++i) {
      rsi.bilateral_level[tile_idx] = i;
      err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
      bits = bilateral_level_bits + 1;
      // Normally the rate is rate in bits * 256 and dist is sum sq err * 64
      // when RDCOST is used.  However below we just scale both in the correct
      // ratios appropriately but not exactly by these values.
      cost = RDCOST_DBL(x->rdmult, x->rddiv,
                        (bits << (VP10_PROB_COST_SHIFT - 4)), err);
      if (cost < best_cost) {
        bilateral_level[tile_idx] = i;
        best_cost = cost;
      }
106
107
    }
  }
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  // Find cost for combined configuration
  bits = 0;
  for (j = 0; j < ntiles; ++j) {
    rsi.bilateral_level[j] = bilateral_level[j];
    if (rsi.bilateral_level[j] >= 0) {
      bits += (bilateral_level_bits + 1);
    } else {
      bits += 1;
    }
  }
  err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
  cost_bilateral = RDCOST_DBL(x->rdmult, x->rddiv,
                              (bits << (VP10_PROB_COST_SHIFT - 4)), err);

  vpx_free(rsi.bilateral_level);

124
  vpx_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
125
126
127
128
129
130
131
  if (cost_bilateral < cost_norestore) {
    if (best_cost_ret) *best_cost_ret = cost_bilateral;
    return 1;
  } else {
    if (best_cost_ret) *best_cost_ret = cost_norestore;
    return 0;
  }
132
133
134
}

static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
135
                                         VP10_COMP *cpi, int partial_frame,
136
                                         int *filter_best, int *bilateral_level,
137
138
139
140
141
142
                                         double *best_cost_ret) {
  const VP10_COMMON *const cm = &cpi->common;
  const struct loopfilter *const lf = &cm->lf;
  const int min_filter_level = 0;
  const int max_filter_level = vp10_get_max_filter_level(cpi);
  int filt_direction = 0;
143
  int filt_best;
144
  double best_err;
145
146
147
148
149
150
  int i, j;
  int *tmp_level;
  int bilateral_success[MAX_LOOP_FILTER + 1];

  const int ntiles =
      vp10_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
151
152
153
154
155
156
157

  // Start the search at the previous frame filter level unless it is now out of
  // range.
  int filt_mid = clamp(lf->filter_level, min_filter_level, max_filter_level);
  int filter_step = filt_mid < 16 ? 4 : filt_mid / 4;
  double ss_err[MAX_LOOP_FILTER + 1];
  // Set each entry to -1
158
  for (i = 0; i <= MAX_LOOP_FILTER; ++i) ss_err[i] = -1.0;
159

160
161
162
163
  tmp_level = (int *)vpx_malloc(sizeof(*tmp_level) * ntiles);

  bilateral_success[filt_mid] = search_bilateral_level(
      sd, cpi, filt_mid, partial_frame, tmp_level, &best_err);
164
165
  filt_best = filt_mid;
  ss_err[filt_mid] = best_err;
166
167
168
  for (j = 0; j < ntiles; ++j) {
    bilateral_level[j] = tmp_level[j];
  }
169
170
171
172
173
174
175
176
177
178
179
180

  while (filter_step > 0) {
    const int filt_high = VPXMIN(filt_mid + filter_step, max_filter_level);
    const int filt_low = VPXMAX(filt_mid - filter_step, min_filter_level);

    // Bias against raising loop filter in favor of lowering it.
    double bias = (best_err / (1 << (15 - (filt_mid / 8)))) * filter_step;

    if ((cpi->oxcf.pass == 2) && (cpi->twopass.section_intra_rating < 20))
      bias = (bias * cpi->twopass.section_intra_rating) / 20;

    // yx, bias less for large block size
181
    if (cm->tx_mode != ONLY_4X4) bias /= 2;
182
183
184
185

    if (filt_direction <= 0 && filt_low != filt_mid) {
      // Get Low filter error score
      if (ss_err[filt_low] < 0) {
186
187
        bilateral_success[filt_low] = search_bilateral_level(
            sd, cpi, filt_low, partial_frame, tmp_level, &ss_err[filt_low]);
188
189
190
      }
      // If value is close to the best so far then bias towards a lower loop
      // filter value.
191
      if (ss_err[filt_low] < (best_err + bias)) {
192
193
194
195
196
        // Was it actually better than the previous best?
        if (ss_err[filt_low] < best_err) {
          best_err = ss_err[filt_low];
        }
        filt_best = filt_low;
197
198
199
        for (j = 0; j < ntiles; ++j) {
          bilateral_level[j] = tmp_level[j];
        }
200
201
202
203
204
205
      }
    }

    // Now look at filt_high
    if (filt_direction >= 0 && filt_high != filt_mid) {
      if (ss_err[filt_high] < 0) {
206
207
        bilateral_success[filt_high] = search_bilateral_level(
            sd, cpi, filt_high, partial_frame, tmp_level, &ss_err[filt_high]);
208
      }
209
210
      // If value is significantly better than previous best, bias added against
      // raising filter value
211
212
213
      if (ss_err[filt_high] < (best_err - bias)) {
        best_err = ss_err[filt_high];
        filt_best = filt_high;
214
215
216
        for (j = 0; j < ntiles; ++j) {
          bilateral_level[j] = tmp_level[j];
        }
217
218
219
220
221
222
223
224
225
226
227
228
      }
    }

    // Half the step distance if the best filter value was the same as last time
    if (filt_best == filt_mid) {
      filter_step /= 2;
      filt_direction = 0;
    } else {
      filt_direction = (filt_best < filt_mid) ? -1 : 1;
      filt_mid = filt_best;
    }
  }
229

230
231
  vpx_free(tmp_level);

232
233
234
  // Update best error
  best_err = ss_err[filt_best];

235
  if (best_cost_ret) *best_cost_ret = best_err;
236
237
238
  if (filter_best) *filter_best = filt_best;

  return bilateral_success[filt_best];
239
240
}

241
242
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
243
244
245
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
246
247
248
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
249
250
251
  return avg;
}

252
253
254
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
255
256
  int i, j, k, l;
  double Y[RESTORATION_WIN2];
257
258
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
259
260
261

  memset(M, 0, sizeof(*M) * RESTORATION_WIN2);
  memset(H, 0, sizeof(*H) * RESTORATION_WIN2 * RESTORATION_WIN2);
262
263
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
      for (k = -RESTORATION_HALFWIN; k <= RESTORATION_HALFWIN; k++) {
        for (l = -RESTORATION_HALFWIN; l <= RESTORATION_HALFWIN; l++) {
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
      for (k = 0; k < RESTORATION_WIN2; ++k) {
        M[k] += Y[k] * X;
        H[k * RESTORATION_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < RESTORATION_WIN2; ++l) {
          double value = Y[k] * Y[l];
          H[k * RESTORATION_WIN2 + l] += value;
          H[l * RESTORATION_WIN2 + k] += value;
        }
      }
    }
  }
}

#if CONFIG_VP9_HIGHBITDEPTH
286
287
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
288
289
290
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
291
292
293
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
294
295
296
  return avg;
}

297
298
299
300
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
301
302
303
304
  int i, j, k, l;
  double Y[RESTORATION_WIN2];
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
305
306
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
307
308
309

  memset(M, 0, sizeof(*M) * RESTORATION_WIN2);
  memset(H, 0, sizeof(*H) * RESTORATION_WIN2 * RESTORATION_WIN2);
310
311
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
      for (k = -RESTORATION_HALFWIN; k <= RESTORATION_HALFWIN; k++) {
        for (l = -RESTORATION_HALFWIN; l <= RESTORATION_HALFWIN; l++) {
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
      for (k = 0; k < RESTORATION_WIN2; ++k) {
        M[k] += Y[k] * X;
        H[k * RESTORATION_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < RESTORATION_WIN2; ++l) {
          double value = Y[k] * Y[l];
          H[k * RESTORATION_WIN2 + l] += value;
          H[l * RESTORATION_WIN2 + k] += value;
        }
      }
    }
  }
}
#endif  // CONFIG_VP9_HIGHBITDEPTH

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
355
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
356
357
358
359
360
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
361
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
362
    c = 0;
363
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
  return (i >= RESTORATION_HALFWIN1 ? RESTORATION_WIN - 1 - i : i);
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
  double S[RESTORATION_WIN];
  double A[RESTORATION_WIN], B[RESTORATION_WIN2];
Aamir Anis's avatar
Aamir Anis committed
378
  int w, w2;
379
380
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
381
  for (i = 0; i < RESTORATION_WIN; i++) {
382
383
384
385
386
387
    int j;
    for (j = 0; j < RESTORATION_WIN; ++j) {
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
388
389
  for (i = 0; i < RESTORATION_WIN; i++) {
    for (j = 0; j < RESTORATION_WIN; j++) {
390
391
392
393
394
395
      int k, l;
      for (k = 0; k < RESTORATION_WIN; ++k)
        for (l = 0; l < RESTORATION_WIN; ++l) {
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
          B[ll * RESTORATION_HALFWIN1 + kk] +=
396
397
              Hc[j * RESTORATION_WIN + i][k * RESTORATION_WIN2 + l] * b[i] *
              b[j];
398
399
400
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
401
402
403
404
  // Normalization enforcement in the system of equations itself
  w = RESTORATION_WIN;
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
405
406
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
407
408
409
410
411
412
413
414
415
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
416
    }
Aamir Anis's avatar
Aamir Anis committed
417
    memcpy(a, S, w * sizeof(*a));
418
419
420
421
422
423
424
425
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
  double S[RESTORATION_WIN];
  double A[RESTORATION_WIN], B[RESTORATION_WIN2];
Aamir Anis's avatar
Aamir Anis committed
426
  int w, w2;
427
428
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
429
  for (i = 0; i < RESTORATION_WIN; i++) {
430
431
    int j;
    const int ii = wrap_index(i);
432
    for (j = 0; j < RESTORATION_WIN; j++) A[ii] += Mc[i][j] * a[j];
433
434
435
436
437
438
439
440
441
442
  }

  for (i = 0; i < RESTORATION_WIN; i++) {
    for (j = 0; j < RESTORATION_WIN; j++) {
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
      for (k = 0; k < RESTORATION_WIN; ++k)
        for (l = 0; l < RESTORATION_WIN; ++l)
          B[jj * RESTORATION_HALFWIN1 + ii] +=
443
444
              Hc[i * RESTORATION_WIN + j][k * RESTORATION_WIN2 + l] * a[k] *
              a[l];
445
446
    }
  }
Aamir Anis's avatar
Aamir Anis committed
447
448
449
450
  // Normalization enforcement in the system of equations itself
  w = RESTORATION_WIN;
  w2 = RESTORATION_HALFWIN1;
  for (i = 0; i < w2 - 1; ++i)
451
452
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
453
454
455
456
457
458
459
460
461
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
462
    }
Aamir Anis's avatar
Aamir Anis committed
463
    memcpy(b, S, w * sizeof(*b));
464
465
466
  }
}

467
468
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
469
  static const double init_filt[RESTORATION_WIN] = {
470
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
  };
  int i, j, iter;
  double *Hc[RESTORATION_WIN2];
  double *Mc[RESTORATION_WIN];
  for (i = 0; i < RESTORATION_WIN; i++) {
    Mc[i] = M + i * RESTORATION_WIN;
    for (j = 0; j < RESTORATION_WIN; j++) {
      Hc[i * RESTORATION_WIN + j] =
          H + i * RESTORATION_WIN * RESTORATION_WIN2 + j * RESTORATION_WIN;
    }
  }
  memcpy(a, init_filt, sizeof(*a) * RESTORATION_WIN);
  memcpy(b, init_filt, sizeof(*b) * RESTORATION_WIN);

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
491
  return 1;
492
493
}

Aamir Anis's avatar
Aamir Anis committed
494
495
496
// Computes the function x'*A*x - x'*b for the learned filters, and compares
// against identity filters; Final score is defined as the difference between
// the function values
497
static double compute_score(double *M, double *H, int *vfilt, int *hfilt) {
Aamir Anis's avatar
Aamir Anis committed
498
499
500
501
502
503
504
505
506
507
  double ab[RESTORATION_WIN * RESTORATION_WIN];
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
  int w;
  double a[RESTORATION_WIN], b[RESTORATION_WIN];
  w = RESTORATION_WIN;
  a[RESTORATION_HALFWIN] = b[RESTORATION_HALFWIN] = 1.0;
  for (i = 0; i < RESTORATION_HALFWIN; ++i) {
508
509
510
511
    a[i] = a[RESTORATION_WIN - i - 1] =
        (double)vfilt[i] / RESTORATION_FILT_STEP;
    b[i] = b[RESTORATION_WIN - i - 1] =
        (double)hfilt[i] / RESTORATION_FILT_STEP;
Aamir Anis's avatar
Aamir Anis committed
512
513
514
515
516
517
518
519
520
521
    a[RESTORATION_HALFWIN] -= 2 * a[i];
    b[RESTORATION_HALFWIN] -= 2 * b[i];
  }
  for (k = 0; k < w; ++k) {
    for (l = 0; l < w; ++l) {
      ab[k * w + l] = a[l] * b[k];
    }
  }
  for (k = 0; k < w * w; ++k) {
    P += ab[k] * M[k];
522
    for (l = 0; l < w * w; ++l) Q += ab[k] * H[k * w * w + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
523
524
525
526
527
528
529
530
531
532
  }
  Score = Q - 2 * P;

  iP = M[(w * w) >> 1];
  iQ = H[((w * w) >> 1) * w * w + ((w * w) >> 1)];
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

533
#define CLIP(x, lo, hi) ((x) < (lo) ? (lo) : (x) > (hi) ? (hi) : (x))
534
#define RINT(x) ((x) < 0 ? (int)((x)-0.5) : (int)((x) + 0.5))
535
536
537
538
539
540
541
542
543
544
545
546

static void quantize_sym_filter(double *f, int *fi) {
  int i;
  for (i = 0; i < RESTORATION_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * RESTORATION_FILT_STEP);
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
}

547
548
static int search_wiener_filter(const YV12_BUFFER_CONFIG *src, VP10_COMP *cpi,
                                int filter_level, int partial_frame,
549
550
551
                                int (*vfilter)[RESTORATION_HALFWIN],
                                int (*hfilter)[RESTORATION_HALFWIN],
                                int *process_tile, double *best_cost_ret) {
552
553
  VP10_COMMON *const cm = &cpi->common;
  RestorationInfo rsi;
554
555
  int64_t err;
  int bits;
556
557
558
559
560
561
562
563
564
565
  double cost_wiener, cost_norestore;
  MACROBLOCK *x = &cpi->td.mb;
  double M[RESTORATION_WIN2];
  double H[RESTORATION_WIN2 * RESTORATION_WIN2];
  double vfilterd[RESTORATION_WIN], hfilterd[RESTORATION_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
566
  double score;
567
568
569
570
571
572
  int tile_idx, htile_idx, vtile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
  int i, j;

  const int tilesize = WIENER_TILESIZE;
  const int ntiles = vp10_get_restoration_ntiles(tilesize, width, height);
573
574
575
576
577
578

  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

579
580
581
  vp10_get_restoration_tile_size(tilesize, width, height, &tile_width,
                                 &tile_height, &nhtiles, &nvtiles);

582
583
584
585
586
587
588
589
590
  //  Make a copy of the unfiltered / processed recon buffer
  vpx_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  vp10_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                         1, partial_frame);
  vpx_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

  rsi.restoration_type = RESTORE_NONE;
  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
  bits = 0;
591
592
  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv,
                              (bits << (VP10_PROB_COST_SHIFT - 4)), err);
593

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
  rsi.restoration_type = RESTORE_WIENER;
  rsi.vfilter =
      (int(*)[RESTORATION_HALFWIN])vpx_malloc(sizeof(*rsi.vfilter) * ntiles);
  assert(rsi.vfilter != NULL);
  rsi.hfilter =
      (int(*)[RESTORATION_HALFWIN])vpx_malloc(sizeof(*rsi.hfilter) * ntiles);
  assert(rsi.hfilter != NULL);
  rsi.wiener_level = (int *)vpx_malloc(sizeof(*rsi.wiener_level) * ntiles);
  assert(rsi.wiener_level != NULL);

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    htile_idx = tile_idx % nhtiles;
    vtile_idx = tile_idx / nhtiles;
    h_start =
        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
    h_end = (htile_idx < nhtiles - 1) ? ((htile_idx + 1) * tile_width)
                                      : (width - RESTORATION_HALFWIN);
    v_start =
        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
    v_end = (vtile_idx < nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
                                      : (height - RESTORATION_HALFWIN);

617
#if CONFIG_VP9_HIGHBITDEPTH
618
619
620
621
    if (cm->use_highbitdepth)
      compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
                           v_start, v_end, dgd_stride, src_stride, M, H);
    else
622
#endif  // CONFIG_VP9_HIGHBITDEPTH
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
      compute_stats(dgd->y_buffer, src->y_buffer, h_start, h_end, v_start,
                    v_end, dgd_stride, src_stride, M, H);

    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
      for (i = 0; i < RESTORATION_HALFWIN; ++i)
        rsi.vfilter[tile_idx][i] = rsi.hfilter[tile_idx][i] = 0;
      process_tile[tile_idx] = 0;
      continue;
    }
    quantize_sym_filter(vfilterd, rsi.vfilter[tile_idx]);
    quantize_sym_filter(hfilterd, rsi.hfilter[tile_idx]);
    process_tile[tile_idx] = 1;

    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
    score = compute_score(M, H, rsi.vfilter[tile_idx], rsi.hfilter[tile_idx]);
    if (score > 0.0) {
      for (i = 0; i < RESTORATION_HALFWIN; ++i)
        rsi.vfilter[tile_idx][i] = rsi.hfilter[tile_idx][i] = 0;
      process_tile[tile_idx] = 0;
      continue;
    }
646

647
648
649
650
651
652
653
654
    for (j = 0; j < ntiles; ++j) rsi.wiener_level[j] = 0;
    rsi.wiener_level[tile_idx] = 1;

    err = try_restoration_frame(src, cpi, &rsi, partial_frame);
    bits = 1 + WIENER_FILT_BITS;
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv,
                             (bits << (VP10_PROB_COST_SHIFT - 4)), err);
    if (cost_wiener >= cost_norestore) process_tile[tile_idx] = 0;
655
  }
656
657
658
659
660
  // Cost for Wiener filtering
  bits = 0;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits += (process_tile[tile_idx] ? (WIENER_FILT_BITS + 1) : 1);
    rsi.wiener_level[tile_idx] = process_tile[tile_idx];
Aamir Anis's avatar
Aamir Anis committed
661
  }
662
  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
663
664
  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv,
                           (bits << (VP10_PROB_COST_SHIFT - 4)), err);
665

666
667
668
669
670
671
672
673
674
675
676
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    if (process_tile[tile_idx] == 0) continue;
    for (i = 0; i < RESTORATION_HALFWIN; ++i) {
      vfilter[tile_idx][i] = rsi.vfilter[tile_idx][i];
      hfilter[tile_idx][i] = rsi.hfilter[tile_idx][i];
    }
  }

  vpx_free(rsi.vfilter);
  vpx_free(rsi.hfilter);
  vpx_free(rsi.wiener_level);
677

678
  vpx_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
679
680
681
682
683
684
685
686
687
  if (cost_wiener < cost_norestore) {
    if (best_cost_ret) *best_cost_ret = cost_wiener;
    return 1;
  } else {
    if (best_cost_ret) *best_cost_ret = cost_norestore;
    return 0;
  }
}

688
689
void vp10_pick_filter_restoration(const YV12_BUFFER_CONFIG *sd, VP10_COMP *cpi,
                                  LPF_PICK_METHOD method) {
690
691
  VP10_COMMON *const cm = &cpi->common;
  struct loopfilter *const lf = &cm->lf;
692
  int wiener_success = 0;
693
  int bilateral_success = 0;
694
695
696
  double cost_bilateral = DBL_MAX;
  double cost_wiener = DBL_MAX;
  double cost_norestore = DBL_MAX;
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
  int ntiles;

  ntiles =
      vp10_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
  cm->rst_info.bilateral_level =
      (int *)vpx_realloc(cm->rst_info.bilateral_level,
                         sizeof(*cm->rst_info.bilateral_level) * ntiles);
  assert(cm->rst_info.bilateral_level != NULL);

  ntiles = vp10_get_restoration_ntiles(WIENER_TILESIZE, cm->width, cm->height);
  cm->rst_info.wiener_level = (int *)vpx_realloc(
      cm->rst_info.wiener_level, sizeof(*cm->rst_info.wiener_level) * ntiles);
  assert(cm->rst_info.wiener_level != NULL);
  cm->rst_info.vfilter = (int(*)[RESTORATION_HALFWIN])vpx_realloc(
      cm->rst_info.vfilter, sizeof(*cm->rst_info.vfilter) * ntiles);
  assert(cm->rst_info.vfilter != NULL);
  cm->rst_info.hfilter = (int(*)[RESTORATION_HALFWIN])vpx_realloc(
      cm->rst_info.hfilter, sizeof(*cm->rst_info.hfilter) * ntiles);
  assert(cm->rst_info.hfilter != NULL);
716

717
  lf->sharpness_level = cm->frame_type == KEY_FRAME ? 0 : cpi->oxcf.sharpness;
718
719

  if (method == LPF_PICK_MINIMAL_LPF && lf->filter_level) {
720
721
    lf->filter_level = 0;
    cm->rst_info.restoration_type = RESTORE_NONE;
722
723
724
725
  } else if (method >= LPF_PICK_FROM_Q) {
    const int min_filter_level = 0;
    const int max_filter_level = vp10_get_max_filter_level(cpi);
    const int q = vp10_ac_quant(cm->base_qindex, 0, cm->bit_depth);
726
727
// These values were determined by linear fitting the result of the
// searched level, filt_guess = q * 0.316206 + 3.87252
728
729
730
731
732
733
734
735
736
737
738
739
740
#if CONFIG_VP9_HIGHBITDEPTH
    int filt_guess;
    switch (cm->bit_depth) {
      case VPX_BITS_8:
        filt_guess = ROUND_POWER_OF_TWO(q * 20723 + 1015158, 18);
        break;
      case VPX_BITS_10:
        filt_guess = ROUND_POWER_OF_TWO(q * 20723 + 4060632, 20);
        break;
      case VPX_BITS_12:
        filt_guess = ROUND_POWER_OF_TWO(q * 20723 + 16242526, 22);
        break;
      default:
741
742
743
        assert(0 &&
               "bit_depth should be VPX_BITS_8, VPX_BITS_10 "
               "or VPX_BITS_12");
744
745
746
747
748
        return;
    }
#else
    int filt_guess = ROUND_POWER_OF_TWO(q * 20723 + 1015158, 18);
#endif  // CONFIG_VP9_HIGHBITDEPTH
749
    if (cm->frame_type == KEY_FRAME) filt_guess -= 4;
750
    lf->filter_level = clamp(filt_guess, min_filter_level, max_filter_level);
751
    bilateral_success = search_bilateral_level(
752
        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
753
        cm->rst_info.bilateral_level, &cost_bilateral);
754
755
    wiener_success = search_wiener_filter(
        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
756
757
        cm->rst_info.vfilter, cm->rst_info.hfilter, cm->rst_info.wiener_level,
        &cost_wiener);
758
    if (cost_bilateral < cost_wiener) {
759
      if (bilateral_success)
760
761
762
763
764
765
766
767
768
769
770
        cm->rst_info.restoration_type = RESTORE_BILATERAL;
      else
        cm->rst_info.restoration_type = RESTORE_NONE;
    } else {
      if (wiener_success)
        cm->rst_info.restoration_type = RESTORE_WIENER;
      else
        cm->rst_info.restoration_type = RESTORE_NONE;
    }
  } else {
    int blf_filter_level = -1;
771
772
773
    bilateral_success = search_filter_bilateral_level(
        sd, cpi, method == LPF_PICK_FROM_SUBIMAGE, &blf_filter_level,
        cm->rst_info.bilateral_level, &cost_bilateral);
774
775
776
777
    lf->filter_level = vp10_search_filter_level(
        sd, cpi, method == LPF_PICK_FROM_SUBIMAGE, &cost_norestore);
    wiener_success = search_wiener_filter(
        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
778
779
        cm->rst_info.vfilter, cm->rst_info.hfilter, cm->rst_info.wiener_level,
        &cost_wiener);
780
781
    if (cost_bilateral < cost_wiener) {
      lf->filter_level = blf_filter_level;
782
      if (bilateral_success)
783
784
785
786
787
788
789
790
791
        cm->rst_info.restoration_type = RESTORE_BILATERAL;
      else
        cm->rst_info.restoration_type = RESTORE_NONE;
    } else {
      if (wiener_success)
        cm->rst_info.restoration_type = RESTORE_WIENER;
      else
        cm->rst_info.restoration_type = RESTORE_NONE;
    }
792
    // printf("[%d] Costs %g %g (%d) %g (%d)\n", cm->rst_info.restoration_type,
793
794
795
796
797
798
799
800
801
802
803
804
805
806
    //        cost_norestore, cost_bilateral, lf->filter_level, cost_wiener,
    //        wiener_success);
  }
  if (cm->rst_info.restoration_type != RESTORE_BILATERAL) {
    vpx_free(cm->rst_info.bilateral_level);
    cm->rst_info.bilateral_level = NULL;
  }
  if (cm->rst_info.restoration_type != RESTORE_WIENER) {
    vpx_free(cm->rst_info.vfilter);
    cm->rst_info.vfilter = NULL;
    vpx_free(cm->rst_info.hfilter);
    cm->rst_info.hfilter = NULL;
    vpx_free(cm->rst_info.wiener_level);
    cm->rst_info.wiener_level = NULL;
807
808
  }
}