pickrst.c 42.5 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20
21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24
25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
#include "av1/encoder/av1_quantize.h"
29
30
31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
32

33
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
34
35
                                      AV1_COMP *cpi, int partial_frame,
                                      RestorationInfo *info,
36
                                      RestorationType *rest_level,
37
38
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
39

40
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
41
42

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
43
44
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
45
                                    int width, int v_start, int height,
46
47
                                    int components_pattern) {
  int64_t filt_err = 0;
48
49
50
51
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
52
53
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
54
55
56
57
58
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
59
60
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
61
62
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
63
64
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
65
66
    }
    return filt_err;
67
68
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
69
70
71
72
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
73
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
74
75
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
76
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
77
  }
78
79
80
  return filt_err;
}

81
82
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
99
100
#else
  (void)cm;
101
102
103
104
105
106
107
108
109
110
111
112
113
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

114
115
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
116
117
118
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
119
                                    YV12_BUFFER_CONFIG *dst_frame) {
120
121
122
123
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
124
125
126
127
128
129
130
131
132
133
134
135
136
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
137
138
139
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
140
141
  (void)ntiles;

142
143
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
144
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
145
146
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
147
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
148
                                  v_start, v_end - v_start, components_pattern);
149
150
151
152
153

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
154
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
155
                                     int components_pattern, int partial_frame,
156
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
157
  AV1_COMMON *const cm = &cpi->common;
158
  int64_t filt_err;
159
160
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
161
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
162
163
164
  return filt_err;
}

165
166
167
168
169
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
170
171
172
173
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
174
175
176
177
178
179
180
181
182
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
183
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
199
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
200
201
202
203
204
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
205
206
207
208
209
    }
  }
  return err;
}

210
211
212
213
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
214
215
216
217
218
219
220
221
222
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
  xqd[0] = -xq[0];
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
282
283
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
284
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
285
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
286
  int ep, bestep = 0;
287
288
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
289

290
291
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
292
#if CONFIG_AOM_HIGHBITDEPTH
293
294
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
295
296
297
298
299
300
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
301
    } else {
302
303
#endif
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
304
                                 sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
305
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
306
                                 sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
307
#if CONFIG_AOM_HIGHBITDEPTH
308
    }
309
#endif
310
311
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
312
    encode_xq(exq, exqd);
313
314
315
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
316
317
318
319
320
321
322
323
324
325
326
327
328
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
329
330
                             int partial_frame, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
331
                             YV12_BUFFER_CONFIG *dst_frame) {
332
333
334
335
336
337
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
338
  RestorationInfo *rsi = &cpi->rst_search[0];
339
340
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
341
  // Allocate for the src buffer at high precision
342
343
344
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
345
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
346

347
348
349
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
350
351
352
353
354
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
355
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
356
                               h_end - h_start, v_start, v_end - v_start, 1);
357
358
359
360
361
362
363
364
365
366
367
368
369
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
370
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
371
        cm->rst_internal.tmpbuf);
372
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
373
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
374
                               dst_frame);
375
376
377
378
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
379
      type[tile_idx] = RESTORE_NONE;
380
    } else {
381
      type[tile_idx] = RESTORE_SGRPROJ;
382
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
383
384
385
386
387
388
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
389
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
390
391
  }
  // Cost for Sgrproj filtering
392
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
393
394
395
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
396
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
397
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
398
           sizeof(sgrproj_info[tile_idx]));
399
    if (type[tile_idx] == RESTORE_SGRPROJ) {
400
401
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
402
    rsi->restoration_type[tile_idx] = type[tile_idx];
403
  }
404
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
405
406
407
408
409
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_sgrproj;
}

410
411
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
412
413
414
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
415
416
417
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
418
419
420
  return avg;
}

421
422
423
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
424
  int i, j, k, l;
425
  double Y[WIENER_WIN2];
426
427
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
428

429
430
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
431
432
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
433
434
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
435
436
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
437
438
439
440
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
441
      for (k = 0; k < WIENER_WIN2; ++k) {
442
        M[k] += Y[k] * X;
443
444
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
445
446
447
448
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
449
450
451
452
        }
      }
    }
  }
453
454
455
456
457
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
458
459
}

Yaowu Xu's avatar
Yaowu Xu committed
460
#if CONFIG_AOM_HIGHBITDEPTH
461
462
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
463
464
465
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
466
467
468
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
469
470
471
  return avg;
}

472
473
474
475
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
476
  int i, j, k, l;
477
  double Y[WIENER_WIN2];
478
479
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
480
481
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
482

483
484
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
485
486
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
487
488
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
489
490
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
491
492
493
494
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
495
      for (k = 0; k < WIENER_WIN2; ++k) {
496
        M[k] += Y[k] * X;
497
498
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
499
500
501
502
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
503
504
505
506
        }
      }
    }
  }
507
508
509
510
511
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
512
}
Yaowu Xu's avatar
Yaowu Xu committed
513
#endif  // CONFIG_AOM_HIGHBITDEPTH
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
536
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
537
538
539
540
541
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
542
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
543
    c = 0;
544
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
545
546
547
548
549
550
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
551
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
552
553
554
555
556
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
557
558
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
559
  int w, w2;
560
561
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
562
563
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
564
565
566
567
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
568
569
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
570
      int k, l;
571
572
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
573
574
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
575
576
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
577
578
579
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
580
  // Normalization enforcement in the system of equations itself
581
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
582
583
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
584
585
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
586
587
588
589
590
591
592
593
594
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
595
    }
Aamir Anis's avatar
Aamir Anis committed
596
    memcpy(a, S, w * sizeof(*a));
597
598
599
600
601
602
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
603
604
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
605
  int w, w2;
606
607
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
608
  for (i = 0; i < WIENER_WIN; i++) {
609
    const int ii = wrap_index(i);
610
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
611
612
  }

613
614
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
615
616
617
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
618
619
620
621
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
622
623
    }
  }
Aamir Anis's avatar
Aamir Anis committed
624
  // Normalization enforcement in the system of equations itself
625
626
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
627
  for (i = 0; i < w2 - 1; ++i)
628
629
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
630
631
632
633
634
635
636
637
638
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
639
    }
Aamir Anis's avatar
Aamir Anis committed
640
    memcpy(b, S, w * sizeof(*b));
641
642
643
  }
}

644
645
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
646
  static const double init_filt[WIENER_WIN] = {
647
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
648
649
  };
  int i, j, iter;
650
651
652
653
654
655
656
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
657
658
    }
  }
659
660
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
661
662
663
664
665
666
667

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
668
  return 1;
669
670
}

671
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
672
673
// against identity filters; Final score is defined as the difference between
// the function values
674
675
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
676
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
677
678
679
680
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
681
682
683
684
685
686
687
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
688
  }
689
690
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
691
  }
692
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
693
    P += ab[k] * M[k];
694
695
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
696
697
698
  }
  Score = Q - 2 * P;

699
700
  iP = M[WIENER_WIN2 >> 1];
  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
701
702
703
704
705
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

706
static void quantize_sym_filter(double *f, InterpKernel fi) {
707
  int i;
708
709
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
710
711
712
713
714
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
715
716
717
718
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
719
720
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
721
722
723
}

static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
724
                               int partial_frame, int plane,
725
                               RestorationInfo *info, RestorationType *type,
726
727
728
729
730
731
                               YV12_BUFFER_CONFIG *dst_frame) {
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
732
  double cost_wiener, cost_norestore, cost_wiener_frame, cost_norestore_frame;
733
734
735
736
737
738
739
740
741
742
743
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = src->uv_crop_width;
  const int height = src->uv_crop_height;
  const int src_stride = src->uv_stride;
  const int dgd_stride = dgd->uv_stride;
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
744
  int h_start, h_end, v_start, v_end;
745
746
747
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[1].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
748
749
750
751
  assert(width == dgd->uv_crop_width);
  assert(height == dgd->uv_crop_height);

  rsi[plane].frame_restoration_type = RESTORE_NONE;
752
  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
753
  bits = 0;
754
  cost_norestore_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
755
756

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
757

758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
                               h_end - h_start, v_start, v_end - v_start,
                               1 << plane);
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    // best_tile_cost[tile_idx] = DBL_MAX;

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, WIENER_HALFWIN,
                             WIENER_HALFWIN, &h_start, &h_end, &v_start,
                             &v_end);
779
    if (plane == AOM_PLANE_U) {
780
781
782
783
784
785
786
787
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
788
    } else if (plane == AOM_PLANE_V) {
789
790
791
792
793
794
795
796
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
797
798
799
    } else {
      assert(0);
    }
800
801
802
803
804
805

    type[tile_idx] = RESTORE_WIENER;

    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
      type[tile_idx] = RESTORE_NONE;
      continue;
806
    }
807
808
    quantize_sym_filter(vfilterd, rsi[plane].wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi[plane].wiener_info[tile_idx].hfilter);
809

810
811
812
813
814
815
816
817
818
    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
    score = compute_score(M, H, rsi[plane].wiener_info[tile_idx].vfilter,
                          rsi[plane].wiener_info[tile_idx].hfilter);
    if (score > 0.0) {
      type[tile_idx] = RESTORE_NONE;
      continue;
    }
819

820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
    rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
    err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                               tile_idx, 0, 0, dst_frame);
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_wiener >= cost_norestore) {
      type[tile_idx] = RESTORE_NONE;
    } else {
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
    }
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Cost for Wiener filtering
  bits = 0;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
    }
    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
846
  }
847
  err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
848
                              dst_frame);
849
850
851
852
853
  cost_wiener_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  if (cost_wiener_frame < cost_norestore_frame) {
    info->frame_restoration_type = RESTORE_WIENER;
  } else {
854
855
856
    info->frame_restoration_type = RESTORE_NONE;
  }

857
858
  return info->frame_restoration_type == RESTORE_WIENER ? cost_wiener_frame
                                                        : cost_norestore_frame;
859
860
}

861
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
862
863
                            int partial_frame, RestorationInfo *info,
                            RestorationType *type, double *best_tile_cost,
864
                            YV12_BUFFER_CONFIG *dst_frame) {
865
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
866
  AV1_COMMON *const cm = &cpi->common;
867
  RestorationInfo *rsi = cpi->rst_search;
868
869
  int64_t err;
  int bits;
870
  double cost_wiener, cost_norestore;
871
  MACROBLOCK *x = &cpi->td.mb;
872
873
874
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
875
876
877
878
879
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
880
  double score;
881
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
882
  int h_start, h_end, v_start, v_end;
883
884
885
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
886
887
888
889
890
  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

891
  rsi->frame_restoration_type = RESTORE_WIENER;
892

893
894
895
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
896

897
898
899
900
901
902
903
904
905
// Construct a (WIENER_HALFWIN)-pixel border around the frame
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->y_buffer), width, height,
                        dgd_stride);
  else
#endif
    extend_frame(dgd->y_buffer, width, height, dgd_stride);

906
907
  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
908
909
910
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
911
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
912
                               h_end - h_start, v_start, v_end - v_start, 1);
913
914
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
915
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
916
    best_tile_cost[tile_idx] = DBL_MAX;
917
918

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
919
920
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
Yaowu Xu's avatar
Yaowu Xu committed
921
#if CONFIG_AOM_HIGHBITDEPTH
922
923
924
925
    if (cm->use_highbitdepth)
      compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
                           v_start, v_end, dgd_stride, src_stride, M, H);
    else
Yaowu Xu's avatar
Yaowu Xu committed
926
#endif  // CONFIG_AOM_HIGHBITDEPTH
927
928
929
      compute_stats(dgd->y_buffer, src->y_buffer, h_start, h_end, v_start,
                    v_end, dgd_stride, src_stride, M, H);

930
931
    type[tile_idx] = RESTORE_WIENER;

932
    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
933
      type[tile_idx] = RESTORE_NONE;
934
935
      continue;
    }
936
937
    quantize_sym_filter(vfilterd, rsi->wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi->wiener_info[tile_idx].hfilter);
938
939
940
941

    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
942
943
    score = compute_score(M, H, rsi->wiener_info[tile_idx].vfilter,
                          rsi->wiener_info[tile_idx].hfilter);
944
    if (score > 0.0) {
945
      type[tile_idx] = RESTORE_NONE;
946
947
      continue;
    }
948

949
    rsi->restoration_type[tile_idx] = RESTORE_WIENER;
950
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
951
                               dst_frame);
952
953
954
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
955
    if (cost_wiener >= cost_norestore) {
956
      type[tile_idx] = RESTORE_NONE;
957
    } else {
958
959
960
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi->wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
961
962
963
964
965
      bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_WIENER]) >> 4, err);
    }
966
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
967
  }
968
  // Cost for Wiener filtering
969
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
970
         << AV1_PROB_COST_SHIFT;
971
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
972
973
974
975
976
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi->wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
977
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
978
    }
979
    rsi->restoration_type[tile_idx] = type[tile_idx];
Aamir Anis's avatar
Aamir Anis committed
980
  }
981
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
982
  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
983

984
  return cost_wiener;
985
986
}

987
static double search_norestore(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
988
989
                               int partial_frame, RestorationInfo *info,
                               RestorationType *type, double *best_tile_cost,
990
                               YV12_BUFFER_CONFIG *dst_frame) {
991
992
  double err, cost_norestore;
  int bits;
993
  MACROBLOCK *x = &cpi->td.mb;
994
995
  AV1_COMMON *const cm = &cpi->common;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
996
  int h_start, h_end, v_start, v_end;
997
998
999
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
1000
  (void)info;
1001
  (void)dst_frame;
1002
  (void)partial_frame;
1003
1004
1005
1006
1007

  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);