pickrst.c 42.8 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20
21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24
25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
#include "av1/encoder/av1_quantize.h"
29
30
31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
32

33
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
34
35
                                      AV1_COMP *cpi, int partial_frame,
                                      RestorationInfo *info,
36
                                      RestorationType *rest_level,
37
38
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
39

40
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
41
42

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
43
44
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
45
                                    int width, int v_start, int height,
46
47
                                    int components_pattern) {
  int64_t filt_err = 0;
48
49
50
51
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
52
53
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
54
55
56
57
58
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
59
60
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
61
62
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
63
64
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
65
66
    }
    return filt_err;
67
68
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
69
70
71
72
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
73
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
74
75
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
76
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
77
  }
78
79
80
  return filt_err;
}

81
82
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
99
100
#else
  (void)cm;
101
102
103
104
105
106
107
108
109
110
111
112
113
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

114
115
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
116
117
118
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
119
                                    YV12_BUFFER_CONFIG *dst_frame) {
120
121
122
123
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
124
125
126
127
128
129
130
131
132
133
134
135
136
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
137
138
139
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
140
141
  (void)ntiles;

142
143
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
144
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
145
146
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
147
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
148
                                  v_start, v_end - v_start, components_pattern);
149
150
151
152
153

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
154
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
155
                                     int components_pattern, int partial_frame,
156
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
157
  AV1_COMMON *const cm = &cpi->common;
158
  int64_t filt_err;
159
160
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
161
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
162
163
164
  return filt_err;
}

165
166
167
168
169
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
170
171
172
173
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
174
175
176
177
178
179
180
181
182
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
183
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
199
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
200
201
202
203
204
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
205
206
207
208
209
    }
  }
  return err;
}

210
211
212
213
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
214
215
216
217
218
219
220
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

221
222
223
  // Default
  xq[0] = 0;
  xq[1] = 0;
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
274
  xqd[0] = xq[0];
275
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
276
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
277
278
279
280
281
282
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
283
284
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
285
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
286
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
287
  int ep, bestep = 0;
288
289
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
290

291
292
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
293
#if CONFIG_AOM_HIGHBITDEPTH
294
295
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
296
297
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter_highbd(dat, width, height, dat_stride, flt1, width,
298
                                 sgr_params[ep].corner, sgr_params[ep].edge);
299
#else
300
301
302
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
303
#endif  // USE_HIGHPASS_IN_SGRPROJ
304
305
306
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
307
    } else {
308
#endif
309
310
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter(dat8, width, height, dat_stride, flt1, width,
311
                          sgr_params[ep].corner, sgr_params[ep].edge);
312
313
314
315
#else
    av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
#endif  // USE_HIGHPASS_IN_SGRPROJ
316
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
317
                                 sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
318
#if CONFIG_AOM_HIGHBITDEPTH
319
    }
320
#endif
321
322
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
323
    encode_xq(exq, exqd);
324
325
326
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
327
328
329
330
331
332
333
334
335
336
337
338
339
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
340
341
                             int partial_frame, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
342
                             YV12_BUFFER_CONFIG *dst_frame) {
343
344
345
346
347
348
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
349
  RestorationInfo *rsi = &cpi->rst_search[0];
350
351
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
352
  // Allocate for the src buffer at high precision
353
354
355
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
356
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
357

358
359
360
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
361
362
363
364
365
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
366
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
367
                               h_end - h_start, v_start, v_end - v_start, 1);
368
369
370
371
372
373
374
375
376
377
378
379
380
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
381
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
382
        cm->rst_internal.tmpbuf);
383
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
384
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
385
                               dst_frame);
386
387
388
389
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
390
      type[tile_idx] = RESTORE_NONE;
391
    } else {
392
      type[tile_idx] = RESTORE_SGRPROJ;
393
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
394
395
396
397
398
399
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
400
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
401
402
  }
  // Cost for Sgrproj filtering
403
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
404
405
406
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
407
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
408
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
409
           sizeof(sgrproj_info[tile_idx]));
410
    if (type[tile_idx] == RESTORE_SGRPROJ) {
411
412
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
413
    rsi->restoration_type[tile_idx] = type[tile_idx];
414
  }
415
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
416
417
418
419
420
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_sgrproj;
}

421
422
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
423
424
425
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
426
427
428
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
429
430
431
  return avg;
}

432
433
434
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
435
  int i, j, k, l;
436
  double Y[WIENER_WIN2];
437
438
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
439

440
441
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
442
443
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
444
445
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
446
447
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
448
449
450
451
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
452
      for (k = 0; k < WIENER_WIN2; ++k) {
453
        M[k] += Y[k] * X;
454
455
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
456
457
458
459
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
460
461
462
463
        }
      }
    }
  }
464
465
466
467
468
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
469
470
}

Yaowu Xu's avatar
Yaowu Xu committed
471
#if CONFIG_AOM_HIGHBITDEPTH
472
473
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
474
475
476
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
477
478
479
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
480
481
482
  return avg;
}

483
484
485
486
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
487
  int i, j, k, l;
488
  double Y[WIENER_WIN2];
489
490
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
491
492
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
493

494
495
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
496
497
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
498
499
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
500
501
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
502
503
504
505
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
506
      for (k = 0; k < WIENER_WIN2; ++k) {
507
        M[k] += Y[k] * X;
508
509
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
510
511
512
513
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
514
515
516
517
        }
      }
    }
  }
518
519
520
521
522
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
523
}
Yaowu Xu's avatar
Yaowu Xu committed
524
#endif  // CONFIG_AOM_HIGHBITDEPTH
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
547
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
548
549
550
551
552
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
553
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
554
    c = 0;
555
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
556
557
558
559
560
561
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
562
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
563
564
565
566
567
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
568
569
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
570
  int w, w2;
571
572
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
573
574
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
575
576
577
578
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
579
580
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
581
      int k, l;
582
583
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
584
585
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
586
587
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
588
589
590
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
591
  // Normalization enforcement in the system of equations itself
592
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
593
594
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
595
596
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
597
598
599
600
601
602
603
604
605
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
606
    }
Aamir Anis's avatar
Aamir Anis committed
607
    memcpy(a, S, w * sizeof(*a));
608
609
610
611
612
613
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
614
615
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
616
  int w, w2;
617
618
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
619
  for (i = 0; i < WIENER_WIN; i++) {
620
    const int ii = wrap_index(i);
621
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
622
623
  }

624
625
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
626
627
628
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
629
630
631
632
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
633
634
    }
  }
Aamir Anis's avatar
Aamir Anis committed
635
  // Normalization enforcement in the system of equations itself
636
637
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
638
  for (i = 0; i < w2 - 1; ++i)
639
640
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
641
642
643
644
645
646
647
648
649
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
650
    }
Aamir Anis's avatar
Aamir Anis committed
651
    memcpy(b, S, w * sizeof(*b));
652
653
654
  }
}

655
656
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
657
  static const double init_filt[WIENER_WIN] = {
658
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
659
660
  };
  int i, j, iter;
661
662
663
664
665
666
667
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
668
669
    }
  }
670
671
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
672
673
674
675
676
677
678

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
679
  return 1;
680
681
}

682
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
683
684
// against identity filters; Final score is defined as the difference between
// the function values
685
686
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
687
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
688
689
690
691
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
692
693
694
695
696
697
698
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
699
  }
700
701
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
702
  }
703
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
704
    P += ab[k] * M[k];
705
706
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
707
708
709
  }
  Score = Q - 2 * P;

710
711
  iP = M[WIENER_WIN2 >> 1];
  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
712
713
714
715
716
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

717
static void quantize_sym_filter(double *f, InterpKernel fi) {
718
  int i;
719
720
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
721
722
723
724
725
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
726
727
728
729
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
730
731
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
732
733
734
}

static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
735
                               int partial_frame, int plane,
736
                               RestorationInfo *info, RestorationType *type,
737
738
739
740
741
742
                               YV12_BUFFER_CONFIG *dst_frame) {
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
743
  double cost_wiener, cost_norestore, cost_wiener_frame, cost_norestore_frame;
744
745
746
747
748
749
750
751
752
753
754
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = src->uv_crop_width;
  const int height = src->uv_crop_height;
  const int src_stride = src->uv_stride;
  const int dgd_stride = dgd->uv_stride;
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
755
  int h_start, h_end, v_start, v_end;
756
757
758
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[1].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
759
760
761
762
  assert(width == dgd->uv_crop_width);
  assert(height == dgd->uv_crop_height);

  rsi[plane].frame_restoration_type = RESTORE_NONE;
763
  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
764
  bits = 0;
765
  cost_norestore_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
766
767

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
768

769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
                               h_end - h_start, v_start, v_end - v_start,
                               1 << plane);
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    // best_tile_cost[tile_idx] = DBL_MAX;

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, WIENER_HALFWIN,
                             WIENER_HALFWIN, &h_start, &h_end, &v_start,
                             &v_end);
790
    if (plane == AOM_PLANE_U) {
791
792
793
794
795
796
797
798
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
799
    } else if (plane == AOM_PLANE_V) {
800
801
802
803
804
805
806
807
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
808
809
810
    } else {
      assert(0);
    }
811
812
813
814
815
816

    type[tile_idx] = RESTORE_WIENER;

    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
      type[tile_idx] = RESTORE_NONE;
      continue;
817
    }
818
819
    quantize_sym_filter(vfilterd, rsi[plane].wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi[plane].wiener_info[tile_idx].hfilter);
820

821
822
823
824
825
826
827
828
829
    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
    score = compute_score(M, H, rsi[plane].wiener_info[tile_idx].vfilter,
                          rsi[plane].wiener_info[tile_idx].hfilter);
    if (score > 0.0) {
      type[tile_idx] = RESTORE_NONE;
      continue;
    }
830

831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
    rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
    err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                               tile_idx, 0, 0, dst_frame);
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_wiener >= cost_norestore) {
      type[tile_idx] = RESTORE_NONE;
    } else {
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
    }
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Cost for Wiener filtering
  bits = 0;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
    }
    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
857
  }
858
  err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
859
                              dst_frame);
860
861
862
863
864
  cost_wiener_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  if (cost_wiener_frame < cost_norestore_frame) {
    info->frame_restoration_type = RESTORE_WIENER;
  } else {
865
866
867
    info->frame_restoration_type = RESTORE_NONE;
  }

868
869
  return info->frame_restoration_type == RESTORE_WIENER ? cost_wiener_frame
                                                        : cost_norestore_frame;
870
871
}

872
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
873
874
                            int partial_frame, RestorationInfo *info,
                            RestorationType *type, double *best_tile_cost,
875
                            YV12_BUFFER_CONFIG *dst_frame) {
876
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
877
  AV1_COMMON *const cm = &cpi->common;
878
  RestorationInfo *rsi = cpi->rst_search;
879
880
  int64_t err;
  int bits;
881
  double cost_wiener, cost_norestore;
882
  MACROBLOCK *x = &cpi->td.mb;
883
884
885
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
886
887
888
889
890
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
891
  double score;
892
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
893
  int h_start, h_end, v_start, v_end;
894
895
896
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
897
898
899
900
901
  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

902
  rsi->frame_restoration_type = RESTORE_WIENER;
903

904
905
906
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
907

908
909
910
911
912
913
914
915
916
// Construct a (WIENER_HALFWIN)-pixel border around the frame
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->y_buffer), width, height,
                        dgd_stride);
  else
#endif
    extend_frame(dgd->y_buffer, width, height, dgd_stride);

917
918
  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
919
920
921
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
922
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
923
                               h_end - h_start, v_start, v_end - v_start, 1);
924
925
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
926
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
927
    best_tile_cost[tile_idx] = DBL_MAX;
928
929

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
930
931
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
Yaowu Xu's avatar
Yaowu Xu committed
932
#if CONFIG_AOM_HIGHBITDEPTH
933
934
935
936
    if (cm->use_highbitdepth)
      compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
                           v_start, v_end, dgd_stride, src_stride, M, H);
    else
Yaowu Xu's avatar
Yaowu Xu committed
937
#endif  // CONFIG_AOM_HIGHBITDEPTH
938
939
940
      compute_stats(dgd->y_buffer, src->y_buffer, h_start, h_end, v_start,
                    v_end, dgd_stride, src_stride, M, H);

941
942
    type[tile_idx] = RESTORE_WIENER;

943
    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
944
      type[tile_idx] = RESTORE_NONE;
945
946
      continue;
    }
947
948
    quantize_sym_filter(vfilterd, rsi->wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi->wiener_info[tile_idx].hfilter);
949
950
951
952

    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
953
954
    score = compute_score(M, H, rsi->wiener_info[tile_idx].vfilter,
                          rsi->wiener_info[tile_idx].hfilter);
955
    if (score > 0.0) {
956
      type[tile_idx] = RESTORE_NONE;
957
958
      continue;
    }
959

960
    rsi->restoration_type[tile_idx] = RESTORE_WIENER;
961
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
962
                               dst_frame);
963
964
965
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
966
    if (cost_wiener >= cost_norestore) {
967
      type[tile_idx] = RESTORE_NONE;
968
    } else {
969
970
971
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi->wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
972
973
974
975
976
      bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_WIENER]) >> 4, err);
    }
977
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
978
  }
979
  // Cost for Wiener filtering
980
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
981
         << AV1_PROB_COST_SHIFT;
982
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
983
984
985
986
987
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi->wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
988
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
989
    }
990
    rsi->restoration_type[tile_idx] = type[tile_idx];
Aamir Anis's avatar
Aamir Anis committed
991
  }
992
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
993
  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
994

995
  return cost_wiener;
996
997
}

998
static double search_norestore(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
999
1000
                               int partial_frame, RestorationInfo *info,
                               RestorationType *type, double *best_tile_cost,
1001
                               YV12_B