pickrst.c 42.5 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20
21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24
25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
#include "av1/encoder/av1_quantize.h"
29
30
31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
32

33
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
34
35
                                      AV1_COMP *cpi, int partial_frame,
                                      RestorationInfo *info,
36
                                      RestorationType *rest_level,
37
38
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
39

40
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
41
42

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
43
44
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
45
                                    int width, int v_start, int height,
46
47
                                    int components_pattern) {
  int64_t filt_err = 0;
48
49
50
51
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
52
53
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
54
55
56
57
58
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
59
60
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
61
62
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
63
64
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
65
66
    }
    return filt_err;
67
68
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
69
70
71
72
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
73
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
74
75
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
76
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
77
  }
78
79
80
  return filt_err;
}

81
82
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
99
100
#else
  (void)cm;
101
102
103
104
105
106
107
108
109
110
111
112
113
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

114
115
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
116
117
118
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
119
                                    YV12_BUFFER_CONFIG *dst_frame) {
120
121
122
123
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
124
125
126
127
128
129
130
131
132
133
134
135
136
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
137
138
139
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
140
141
  (void)ntiles;

142
143
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
144
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
145
146
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
147
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
148
                                  v_start, v_end - v_start, components_pattern);
149
150
151
152
153

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
154
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
155
                                     int components_pattern, int partial_frame,
156
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
157
  AV1_COMMON *const cm = &cpi->common;
158
  int64_t filt_err;
159
160
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
161
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
162
163
164
  return filt_err;
}

165
166
167
168
169
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
170
171
172
173
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
174
175
176
177
178
179
180
181
182
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
183
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
199
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
200
201
202
203
204
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
205
206
207
208
209
    }
  }
  return err;
}

210
211
212
213
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
214
215
216
217
218
219
220
221
222
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
  xqd[0] = -xq[0];
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
282
283
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
284
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
285
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
286
  int ep, bestep = 0;
287
288
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
289

290
291
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
292
#if CONFIG_AOM_HIGHBITDEPTH
293
294
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
295
296
297
298
299
300
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
301
    } else {
302
303
304
305
306
307
308
309
#endif
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                                 bit_depth, sgr_params[ep].r1,
                                 sgr_params[ep].e1, tmpbuf2);
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
                                 bit_depth, sgr_params[ep].r2,
                                 sgr_params[ep].e2, tmpbuf2);
#if CONFIG_AOM_HIGHBITDEPTH
310
    }
311
#endif
312
313
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
314
    encode_xq(exq, exqd);
315
316
317
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
318
319
320
321
322
323
324
325
326
327
328
329
330
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
331
332
                             int partial_frame, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
333
                             YV12_BUFFER_CONFIG *dst_frame) {
334
335
336
337
338
339
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
340
  RestorationInfo *rsi = &cpi->rst_search[0];
341
342
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
343
  // Allocate for the src buffer at high precision
344
345
346
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
347
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
348

349
350
351
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
352
353
354
355
356
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
357
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
358
                               h_end - h_start, v_start, v_end - v_start, 1);
359
360
361
362
363
364
365
366
367
368
369
370
371
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
372
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
373
        cm->rst_internal.tmpbuf);
374
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
375
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
376
                               dst_frame);
377
378
379
380
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
381
      type[tile_idx] = RESTORE_NONE;
382
    } else {
383
      type[tile_idx] = RESTORE_SGRPROJ;
384
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
385
386
387
388
389
390
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
391
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
392
393
  }
  // Cost for Sgrproj filtering
394
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
395
396
397
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
398
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
399
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
400
           sizeof(sgrproj_info[tile_idx]));
401
    if (type[tile_idx] == RESTORE_SGRPROJ) {
402
403
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
404
    rsi->restoration_type[tile_idx] = type[tile_idx];
405
  }
406
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
407
408
409
410
411
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_sgrproj;
}

412
413
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
414
415
416
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
417
418
419
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
420
421
422
  return avg;
}

423
424
425
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
426
  int i, j, k, l;
427
  double Y[WIENER_WIN2];
428
429
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
430

431
432
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
433
434
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
435
436
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
437
438
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
439
440
441
442
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
443
      for (k = 0; k < WIENER_WIN2; ++k) {
444
        M[k] += Y[k] * X;
445
446
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
447
448
449
450
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
451
452
453
454
        }
      }
    }
  }
455
456
457
458
459
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
460
461
}

Yaowu Xu's avatar
Yaowu Xu committed
462
#if CONFIG_AOM_HIGHBITDEPTH
463
464
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
465
466
467
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
468
469
470
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
471
472
473
  return avg;
}

474
475
476
477
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
478
  int i, j, k, l;
479
  double Y[WIENER_WIN2];
480
481
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
482
483
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
484

485
486
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
487
488
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
489
490
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
491
492
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
493
494
495
496
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
497
      for (k = 0; k < WIENER_WIN2; ++k) {
498
        M[k] += Y[k] * X;
499
500
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
501
502
503
504
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
505
506
507
508
        }
      }
    }
  }
509
510
511
512
513
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
514
}
Yaowu Xu's avatar
Yaowu Xu committed
515
#endif  // CONFIG_AOM_HIGHBITDEPTH
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
538
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
539
540
541
542
543
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
544
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
545
    c = 0;
546
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
547
548
549
550
551
552
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
553
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
554
555
556
557
558
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
559
560
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
561
  int w, w2;
562
563
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
564
565
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
566
567
568
569
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
570
571
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
572
      int k, l;
573
574
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
575
576
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
577
578
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
579
580
581
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
582
  // Normalization enforcement in the system of equations itself
583
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
584
585
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
586
587
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
588
589
590
591
592
593
594
595
596
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
597
    }
Aamir Anis's avatar
Aamir Anis committed
598
    memcpy(a, S, w * sizeof(*a));
599
600
601
602
603
604
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
605
606
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
607
  int w, w2;
608
609
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
610
  for (i = 0; i < WIENER_WIN; i++) {
611
    const int ii = wrap_index(i);
612
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
613
614
  }

615
616
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
617
618
619
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
620
621
622
623
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
624
625
    }
  }
Aamir Anis's avatar
Aamir Anis committed
626
  // Normalization enforcement in the system of equations itself
627
628
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
629
  for (i = 0; i < w2 - 1; ++i)
630
631
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
632
633
634
635
636
637
638
639
640
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
641
    }
Aamir Anis's avatar
Aamir Anis committed
642
    memcpy(b, S, w * sizeof(*b));
643
644
645
  }
}

646
647
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
648
  static const double init_filt[WIENER_WIN] = {
649
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
650
651
  };
  int i, j, iter;
652
653
654
655
656
657
658
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
659
660
    }
  }
661
662
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
663
664
665
666
667
668
669

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
670
  return 1;
671
672
}

673
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
674
675
// against identity filters; Final score is defined as the difference between
// the function values
676
677
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
678
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
679
680
681
682
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
683
684
685
686
687
688
689
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
690
  }
691
692
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
693
  }
694
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
695
    P += ab[k] * M[k];
696
697
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
698
699
700
  }
  Score = Q - 2 * P;

701
702
  iP = M[WIENER_WIN2 >> 1];
  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
703
704
705
706
707
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

708
static void quantize_sym_filter(double *f, InterpKernel fi) {
709
  int i;
710
711
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
712
713
714
715
716
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
717
718
719
720
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
721
722
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
723
724
725
}

static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
726
                               int partial_frame, int plane,
727
                               RestorationInfo *info, RestorationType *type,
728
729
730
731
732
733
                               YV12_BUFFER_CONFIG *dst_frame) {
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
734
  double cost_wiener, cost_norestore, cost_wiener_frame, cost_norestore_frame;
735
736
737
738
739
740
741
742
743
744
745
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = src->uv_crop_width;
  const int height = src->uv_crop_height;
  const int src_stride = src->uv_stride;
  const int dgd_stride = dgd->uv_stride;
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
746
  int h_start, h_end, v_start, v_end;
747
748
749
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[1].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
750
751
752
753
  assert(width == dgd->uv_crop_width);
  assert(height == dgd->uv_crop_height);

  rsi[plane].frame_restoration_type = RESTORE_NONE;
754
  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
755
  bits = 0;
756
  cost_norestore_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
757
758

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
759

760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
                               h_end - h_start, v_start, v_end - v_start,
                               1 << plane);
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    // best_tile_cost[tile_idx] = DBL_MAX;

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, WIENER_HALFWIN,
                             WIENER_HALFWIN, &h_start, &h_end, &v_start,
                             &v_end);
781
    if (plane == AOM_PLANE_U) {
782
783
784
785
786
787
788
789
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
790
    } else if (plane == AOM_PLANE_V) {
791
792
793
794
795
796
797
798
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
799
800
801
    } else {
      assert(0);
    }
802
803
804
805
806
807

    type[tile_idx] = RESTORE_WIENER;

    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
      type[tile_idx] = RESTORE_NONE;
      continue;
808
    }
809
810
    quantize_sym_filter(vfilterd, rsi[plane].wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi[plane].wiener_info[tile_idx].hfilter);
811

812
813
814
815
816
817
818
819
820
    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
    score = compute_score(M, H, rsi[plane].wiener_info[tile_idx].vfilter,
                          rsi[plane].wiener_info[tile_idx].hfilter);
    if (score > 0.0) {
      type[tile_idx] = RESTORE_NONE;
      continue;
    }
821

822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
    rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
    err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                               tile_idx, 0, 0, dst_frame);
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_wiener >= cost_norestore) {
      type[tile_idx] = RESTORE_NONE;
    } else {
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
    }
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Cost for Wiener filtering
  bits = 0;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
    }
    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
848
  }
849
  err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
850
                              dst_frame);
851
852
853
854
855
  cost_wiener_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  if (cost_wiener_frame < cost_norestore_frame) {
    info->frame_restoration_type = RESTORE_WIENER;
  } else {
856
857
858
    info->frame_restoration_type = RESTORE_NONE;
  }

859
860
  return info->frame_restoration_type == RESTORE_WIENER ? cost_wiener_frame
                                                        : cost_norestore_frame;
861
862
}

863
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
864
865
                            int partial_frame, RestorationInfo *info,
                            RestorationType *type, double *best_tile_cost,
866
                            YV12_BUFFER_CONFIG *dst_frame) {
867
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
868
  AV1_COMMON *const cm = &cpi->common;
869
  RestorationInfo *rsi = cpi->rst_search;
870
871
  int64_t err;
  int bits;
872
  double cost_wiener, cost_norestore;
873
  MACROBLOCK *x = &cpi->td.mb;
874
875
876
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
877
878
879
880
881
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
882
  double score;
883
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
884
  int h_start, h_end, v_start, v_end;
885
886
887
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
888
889
890
891
892
  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

893
  rsi->frame_restoration_type = RESTORE_WIENER;
894

895
896
897
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
898

899
900
901
902
903
904
905
906
907
// Construct a (WIENER_HALFWIN)-pixel border around the frame
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->y_buffer), width, height,
                        dgd_stride);
  else
#endif
    extend_frame(dgd->y_buffer, width, height, dgd_stride);

908
909
  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
910
911
912
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
913
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
914
                               h_end - h_start, v_start, v_end - v_start, 1);
915
916
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
917
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
918
    best_tile_cost[tile_idx] = DBL_MAX;
919
920

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
921
922
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
Yaowu Xu's avatar
Yaowu Xu committed
923
#if CONFIG_AOM_HIGHBITDEPTH
924
925
926
927
    if (cm->use_highbitdepth)
      compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
                           v_start, v_end, dgd_stride, src_stride, M, H);
    else
Yaowu Xu's avatar
Yaowu Xu committed
928
#endif  // CONFIG_AOM_HIGHBITDEPTH
929
930
931
      compute_stats(dgd->y_buffer, src->y_buffer, h_start, h_end, v_start,
                    v_end, dgd_stride, src_stride, M, H);

932
933
    type[tile_idx] = RESTORE_WIENER;

934
    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
935
      type[tile_idx] = RESTORE_NONE;
936
937
      continue;
    }
938
939
    quantize_sym_filter(vfilterd, rsi->wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi->wiener_info[tile_idx].hfilter);
940
941
942
943

    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
944
945
    score = compute_score(M, H, rsi->wiener_info[tile_idx].vfilter,
                          rsi->wiener_info[tile_idx].hfilter);
946
    if (score > 0.0) {
947
      type[tile_idx] = RESTORE_NONE;
948
949
      continue;
    }
950

951
    rsi->restoration_type[tile_idx] = RESTORE_WIENER;
952
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
953
                               dst_frame);
954
955
956
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
957
    if (cost_wiener >= cost_norestore) {
958
      type[tile_idx] = RESTORE_NONE;
959
    } else {
960
961
962
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi->wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
963
964
965
966
967
      bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_WIENER]) >> 4, err);
    }
968
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
969
  }
970
  // Cost for Wiener filtering
971
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
972
         << AV1_PROB_COST_SHIFT;
973
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
974
975
976
977
978
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi->wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
979
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
980
    }
981
    rsi->restoration_type[tile_idx] = type[tile_idx];
Aamir Anis's avatar
Aamir Anis committed
982
  }
983
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
984
  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
985

986
  return cost_wiener;
987
988
}

989
static double search_norestore(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
990
991
                               int partial_frame, RestorationInfo *info,
                               RestorationType *type, double *best_tile_cost,
992
                               YV12_BUFFER_CONFIG *dst_frame) {
993
994
  double err, cost_norestore;
  int bits;
995
  MACROBLOCK *x = &cpi->td.mb;
996
997
  AV1_COMMON *const cm = &cpi->common;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
998
  int h_start, h_end, v_start, v_end;
999
1000
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,