pickrst.c 47.4 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

Yaowu Xu's avatar
Yaowu Xu committed
19
#include "aom_dsp/aom_dsp_common.h"
20
21
#include "aom_dsp/binary_codes_writer.h"
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
22
#include "aom_mem/aom_mem.h"
23
#include "aom_ports/mem.h"
24
#include "aom_ports/system_state.h"
25

26
27
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
28
#include "av1/common/restoration.h"
29

30
#include "av1/encoder/av1_quantize.h"
31
32
33
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
34
#include "av1/encoder/mathutils.h"
35

36
37
38
// When set to RESTORE_WIENER or RESTORE_SGRPROJ only those are allowed.
// When set to RESTORE_NONE (0) we allow switchable.
const RestorationType force_restore_type = RESTORE_NONE;
39
40
41
42

// Number of Wiener iterations
#define NUM_WIENER_ITERS 10

43
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
44
45
                                      AV1_COMP *cpi, int partial_frame,
                                      RestorationInfo *info,
46
                                      RestorationType *rest_level,
47
48
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
49

50
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
51
52

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
53
54
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
55
                                    int width, int v_start, int height,
56
57
                                    int components_pattern) {
  int64_t filt_err = 0;
58
59
60
61
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
62
#if CONFIG_HIGHBITDEPTH
63
  if (cm->use_highbitdepth) {
64
65
66
67
68
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
69
70
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
71
72
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
73
74
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
75
76
    }
    return filt_err;
77
  }
78
#endif  // CONFIG_HIGHBITDEPTH
79
80
81
82
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
83
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
84
85
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
86
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
87
  }
88
89
90
  return filt_err;
}

91
92
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
93
94
95
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
96
#if CONFIG_HIGHBITDEPTH
97
98
99
100
101
102
103
104
105
106
107
108
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
109
110
#else
  (void)cm;
111
#endif  // CONFIG_HIGHBITDEPTH
112
113
114
115
116
117
118
119
120
121
122
123
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

124
125
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
126
127
128
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
129
                                    YV12_BUFFER_CONFIG *dst_frame) {
130
131
132
133
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
134
135
136
137
138
139
140
141
142
143
144
145
146
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
147
148
149
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
150
151
  (void)ntiles;

152
153
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
154
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
155
156
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
157
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
158
                                  v_start, v_end - v_start, components_pattern);
159
160
161
162
163

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
164
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
165
                                     int components_pattern, int partial_frame,
166
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
167
  AV1_COMMON *const cm = &cpi->common;
168
  int64_t filt_err;
169
170
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
171
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
172
173
174
  return filt_err;
}

175
176
177
178
179
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
180
181
182
183
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
184
185
186
187
188
189
190
191
192
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
193
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
209
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
210
211
212
213
214
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
215
216
217
218
219
    }
  }
  return err;
}

220
221
222
223
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
224
225
226
227
228
229
230
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

231
232
  aom_clear_system_state();

233
234
235
  // Default
  xq[0] = 0;
  xq[1] = 0;
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
286
  xqd[0] = xq[0];
287
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
288
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
289
290
291
292
293
294
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
295
296
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
297
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
298
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
299
  int ep, bestep = 0;
300
301
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
302

303
304
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
305
#if CONFIG_HIGHBITDEPTH
306
307
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
308
309
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter_highbd(dat, width, height, dat_stride, flt1, width,
310
                                 sgr_params[ep].corner, sgr_params[ep].edge);
311
#else
312
313
314
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
315
#endif  // USE_HIGHPASS_IN_SGRPROJ
316
317
318
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
319
    } else {
320
#endif
321
322
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter(dat8, width, height, dat_stride, flt1, width,
323
                          sgr_params[ep].corner, sgr_params[ep].edge);
324
325
326
327
#else
    av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
#endif  // USE_HIGHPASS_IN_SGRPROJ
328
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
329
                                 sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
330
#if CONFIG_HIGHBITDEPTH
331
    }
332
#endif
333
    aom_clear_system_state();
334
335
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
336
    aom_clear_system_state();
337
    encode_xq(exq, exqd);
338
339
340
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
341
342
343
344
345
346
347
348
349
350
351
352
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

353
354
355
356
357
358
359
360
361
362
363
364
365
366
static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                              SgrprojInfo *ref_sgrproj_info) {
  int bits = SGRPROJ_PARAMS_BITS;
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
      sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
      sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
  return bits;
}

367
static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
368
369
                             int partial_frame, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
370
                             YV12_BUFFER_CONFIG *dst_frame) {
371
372
373
374
375
376
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
377
  RestorationInfo *rsi = &cpi->rst_search[0];
378
379
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
380
  // Allocate for the src buffer at high precision
381
382
383
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
384
385
386
  SgrprojInfo ref_sgrproj_info;
  set_default_sgrproj(&ref_sgrproj_info);

387
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
388

389
390
391
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
392
393
394
395
396
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
397
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
398
                               h_end - h_start, v_start, v_end - v_start, 1);
399
400
401
402
403
404
405
406
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
407
#if CONFIG_HIGHBITDEPTH
408
409
410
        cm->bit_depth,
#else
        8,
411
#endif  // CONFIG_HIGHBITDEPTH
412
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
413
        cm->rst_internal.tmpbuf);
414
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
415
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
416
                               dst_frame);
417
418
    bits = count_sgrproj_bits(&rsi->sgrproj_info[tile_idx], &ref_sgrproj_info)
           << AV1_PROB_COST_SHIFT;
419
420
421
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
422
      type[tile_idx] = RESTORE_NONE;
423
    } else {
424
      type[tile_idx] = RESTORE_SGRPROJ;
425
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
426
             sizeof(sgrproj_info[tile_idx]));
427
428
429
430
431
      bits = count_sgrproj_bits(&rsi->sgrproj_info[tile_idx], &ref_sgrproj_info)
             << AV1_PROB_COST_SHIFT;
      memcpy(&ref_sgrproj_info, &sgrproj_info[tile_idx],
             sizeof(ref_sgrproj_info));
      best_tile_cost[tile_idx] = err;
432
    }
433
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
434
435
  }
  // Cost for Sgrproj filtering
436
  set_default_sgrproj(&ref_sgrproj_info);
437
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
438
439
440
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
441
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
442
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
443
           sizeof(sgrproj_info[tile_idx]));
444
    if (type[tile_idx] == RESTORE_SGRPROJ) {
445
446
447
448
449
      bits +=
          count_sgrproj_bits(&rsi->sgrproj_info[tile_idx], &ref_sgrproj_info)
          << AV1_PROB_COST_SHIFT;
      memcpy(&ref_sgrproj_info, &rsi->sgrproj_info[tile_idx],
             sizeof(ref_sgrproj_info));
450
    }
451
    rsi->restoration_type[tile_idx] = type[tile_idx];
452
  }
453
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
454
455
456
457
458
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_sgrproj;
}

459
460
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
461
462
463
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
464
  aom_clear_system_state();
465
466
467
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
468
469
470
  return avg;
}

471
472
473
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
474
  int i, j, k, l;
475
  double Y[WIENER_WIN2];
476
477
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
478

479
480
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
481
482
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
483
484
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
485
486
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
487
488
489
490
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
491
      for (k = 0; k < WIENER_WIN2; ++k) {
492
        M[k] += Y[k] * X;
493
494
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
495
496
497
498
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
499
500
501
502
        }
      }
    }
  }
503
504
505
506
507
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
508
509
}

510
#if CONFIG_HIGHBITDEPTH
511
512
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
513
514
515
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
516
  aom_clear_system_state();
517
518
519
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
520
521
522
  return avg;
}

523
524
525
526
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
527
  int i, j, k, l;
528
  double Y[WIENER_WIN2];
529
530
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
531
532
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
533

534
535
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
536
537
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
538
539
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
540
541
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
542
543
544
545
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
546
      for (k = 0; k < WIENER_WIN2; ++k) {
547
        M[k] += Y[k] * X;
548
549
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
550
551
552
553
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
554
555
556
557
        }
      }
    }
  }
558
559
560
561
562
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
563
}
564
#endif  // CONFIG_HIGHBITDEPTH
565
566

static INLINE int wrap_index(int i) {
567
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
568
569
570
571
572
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
573
  double S[WIENER_WIN];
574
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
Aamir Anis's avatar
Aamir Anis committed
575
  int w, w2;
576
577
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
578
579
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
580
581
582
583
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
584
585
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
586
      int k, l;
587
588
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
589
590
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
591
592
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
593
594
595
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
596
  // Normalization enforcement in the system of equations itself
597
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
598
599
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
600
601
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
602
603
604
605
606
607
608
609
610
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
611
    }
Aamir Anis's avatar
Aamir Anis committed
612
    memcpy(a, S, w * sizeof(*a));
613
614
615
616
617
618
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
619
  double S[WIENER_WIN];
620
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
Aamir Anis's avatar
Aamir Anis committed
621
  int w, w2;
622
623
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
624
  for (i = 0; i < WIENER_WIN; i++) {
625
    const int ii = wrap_index(i);
626
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
627
628
  }

629
630
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
631
632
633
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
634
635
636
637
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
638
639
    }
  }
Aamir Anis's avatar
Aamir Anis committed
640
  // Normalization enforcement in the system of equations itself
641
642
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
643
  for (i = 0; i < w2 - 1; ++i)
644
645
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
646
647
648
649
650
651
652
653
654
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
655
    }
Aamir Anis's avatar
Aamir Anis committed
656
    memcpy(b, S, w * sizeof(*b));
657
658
659
  }
}

660
661
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
662
  static const double init_filt[WIENER_WIN] = {
663
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
664
665
  };
  int i, j, iter;
666
667
668
669
670
671
672
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
673
674
    }
  }
675
676
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
677
678

  iter = 1;
679
  while (iter < NUM_WIENER_ITERS) {
680
681
682
683
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
684
  return 1;
685
686
}

687
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
688
689
// against identity filters; Final score is defined as the difference between
// the function values
690
691
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
692
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
693
694
695
696
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
697
  double a[WIENER_WIN], b[WIENER_WIN];
698
699
700

  aom_clear_system_state();

701
702
703
704
705
706
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
707
  }
708
709
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
710
  }
711
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
712
    P += ab[k] * M[k];
713
714
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
715
716
717
  }
  Score = Q - 2 * P;

718
719
  iP = M[WIENER_WIN2 >> 1];
  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
720
721
722
723
724
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

725
static void quantize_sym_filter(double *f, InterpKernel fi) {
726
  int i;
727
728
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
729
730
731
732
733
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
734
735
736
737
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
738
739
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
740
741
}

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
static int count_wiener_bits(WienerInfo *wiener_info,
                             WienerInfo *ref_wiener_info) {
  int bits = 0;
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
      WIENER_FILT_TAP0_SUBEXP_K,
      ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
      wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
      WIENER_FILT_TAP0_SUBEXP_K,
      ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
      wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV);
  return bits;
}

778
static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
779
                               int partial_frame, int plane,
780
                               RestorationInfo *info, RestorationType *type,
781
782
783
784
785
786
                               YV12_BUFFER_CONFIG *dst_frame) {
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
787
  double cost_wiener, cost_norestore, cost_wiener_frame, cost_norestore_frame;
788
789
790
791
792
793
794
795
796
797
798
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = src->uv_crop_width;
  const int height = src->uv_crop_height;
  const int src_stride = src->uv_stride;
  const int dgd_stride = dgd->uv_stride;
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
799
  int h_start, h_end, v_start, v_end;
800
801
802
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[1].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
803
804
  WienerInfo ref_wiener_info;
  set_default_wiener(&ref_wiener_info);
805
806
807
808
  assert(width == dgd->uv_crop_width);
  assert(height == dgd->uv_crop_height);

  rsi[plane].frame_restoration_type = RESTORE_NONE;
809
  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
810
  bits = 0;
811
  cost_norestore_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
812
813

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
814

815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
                               h_end - h_start, v_start, v_end - v_start,
                               1 << plane);
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    // best_tile_cost[tile_idx] = DBL_MAX;

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, WIENER_HALFWIN,
                             WIENER_HALFWIN, &h_start, &h_end, &v_start,
                             &v_end);
836
    if (plane == AOM_PLANE_U) {
837
#if CONFIG_HIGHBITDEPTH
838
839
840
841
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
842
#endif  // CONFIG_HIGHBITDEPTH
843
844
        compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
845
    } else if (plane == AOM_PLANE_V) {
846
#if CONFIG_HIGHBITDEPTH
847
848
849
850
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
851
#endif  // CONFIG_HIGHBITDEPTH
852
853
        compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
854
855
856
    } else {
      assert(0);
    }
857
858
859
860
861
862

    type[tile_idx] = RESTORE_WIENER;

    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
      type[tile_idx] = RESTORE_NONE;
      continue;
863
    }
864
865
    quantize_sym_filter(vfilterd, rsi[plane].wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi[plane].wiener_info[tile_idx].hfilter);
866

867
868
869
870
871
872
873
874
875
    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
    score = compute_score(M, H, rsi[plane].wiener_info[tile_idx].vfilter,
                          rsi[plane].wiener_info[tile_idx].hfilter);
    if (score > 0.0) {
      type[tile_idx] = RESTORE_NONE;
      continue;
    }
876
    aom_clear_system_state();
877

878
879
880
    rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
    err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                               tile_idx, 0, 0, dst_frame);
881
882
883
884
    bits =
        count_wiener_bits(&rsi[plane].wiener_info[tile_idx], &ref_wiener_info)
        << AV1_PROB_COST_SHIFT;
    // bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
885
886
887
888
889
890
891
892
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_wiener >= cost_norestore) {
      type[tile_idx] = RESTORE_NONE;
    } else {
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
893
894
      memcpy(&ref_wiener_info, &rsi[plane].wiener_info[tile_idx],
             sizeof(ref_wiener_info));
895
896
897
898
    }
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Cost for Wiener filtering
899
  set_default_wiener(&ref_wiener_info);
900
901
902
903
904
905
906
  bits = 0;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
907
908
909
910
911
      bits +=
          count_wiener_bits(&rsi[plane].wiener_info[tile_idx], &ref_wiener_info)
          << AV1_PROB_COST_SHIFT;
      memcpy(&ref_wiener_info, &rsi[plane].wiener_info[tile_idx],
             sizeof(ref_wiener_info));
912
913
    }
    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
914
  }
915
  err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
916
                              dst_frame);
917
918
919
920
921
  cost_wiener_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  if (cost_wiener_frame < cost_norestore_frame) {
    info->frame_restoration_type = RESTORE_WIENER;
  } else {
922
923
924
    info->frame_restoration_type = RESTORE_NONE;
  }

925
926
  return info->frame_restoration_type == RESTORE_WIENER ? cost_wiener_frame
                                                        : cost_norestore_frame;
927
928
}

929
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
930
931
                            int partial_frame, RestorationInfo *info,
                            RestorationType *type, double *best_tile_cost,
932
                            YV12_BUFFER_CONFIG *dst_frame) {
933
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
934
  AV1_COMMON *const cm = &cpi->common;
935
  RestorationInfo *rsi = cpi->rst_search;
936
937
  int64_t err;
  int bits;
938
  double cost_wiener, cost_norestore;
939
  MACROBLOCK *x = &cpi->td.mb;
940
941
942
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
943
944
945
946
947
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
948
  double score;
949
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
950
  int h_start, h_end, v_start, v_end;
951
952
953
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
954
955
956
  WienerInfo ref_wiener_info;
  set_default_wiener(&ref_wiener_info);

957
958
959
960
961
  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

962
  rsi->frame_restoration_type = RESTORE_WIENER;
963

964
965
966
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
967

968
// Construct a (WIENER_HALFWIN)-pixel border around the frame
969
#if CONFIG_HIGHBITDEPTH
970
971
972
973
974
975
976
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->y_buffer), width, height,
                        dgd_stride);
  else
#endif
    extend_frame(dgd->y_buffer, width, height, dgd_stride);

977
978
  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
979
980
981
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
982
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
983
                               h_end - h_start, v_start, v_end - v_start, 1);
984
985
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
986
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
987
    best_tile_cost[tile_idx] = DBL_MAX;
988
989

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
990
991
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
992
#if CONFIG_HIGHBITDEPTH
993
994
995
996
    if (cm->use_highbitdepth)
      compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
                           v_start, v_end, dgd_stride, src_stride, M, H);
    else
997
#endif  // CONFIG_HIGHBITDEPTH
998
999
1000
      compute_stats(dgd->y_buffer, src->y_buffer, h_start, h_end, v_start,
                    v_end, dgd_stride, src_stride, M, H);

1001
1002
    type[tile_idx] = RESTORE_WIENER;