pickrst.c 53 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

Yaowu Xu's avatar
Yaowu Xu committed
19
#include "aom_dsp/aom_dsp_common.h"
20
21
#include "aom_dsp/binary_codes_writer.h"
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
22
#include "aom_mem/aom_mem.h"
23
#include "aom_ports/mem.h"
24
#include "aom_ports/system_state.h"
25

26
27
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
28
#include "av1/common/restoration.h"
29

30
#include "av1/encoder/av1_quantize.h"
31
32
33
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
34
#include "av1/encoder/mathutils.h"
35

36
37
38
// When set to RESTORE_WIENER or RESTORE_SGRPROJ only those are allowed.
// When set to RESTORE_NONE (0) we allow switchable.
const RestorationType force_restore_type = RESTORE_NONE;
39
40

// Number of Wiener iterations
41
#define NUM_WIENER_ITERS 5
42

43
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
44
                                      AV1_COMP *cpi, int partial_frame,
45
                                      int plane, RestorationInfo *info,
46
                                      RestorationType *rest_level,
47
48
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
49

50
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
51
52

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
53
54
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
55
                                    int width, int v_start, int height,
56
57
                                    int components_pattern) {
  int64_t filt_err = 0;
58
59
60
61
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
62
#if CONFIG_HIGHBITDEPTH
63
  if (cm->use_highbitdepth) {
64
65
66
67
68
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
69
70
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
71
72
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
73
74
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
75
76
    }
    return filt_err;
77
  }
78
#endif  // CONFIG_HIGHBITDEPTH
79
80
81
82
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
83
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
84
85
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
86
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
87
  }
88
89
90
  return filt_err;
}

91
92
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
93
94
95
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
96
#if CONFIG_HIGHBITDEPTH
97
98
99
100
101
102
103
104
105
106
107
108
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
109
110
#else
  (void)cm;
111
#endif  // CONFIG_HIGHBITDEPTH
112
113
114
115
116
117
118
119
120
121
122
123
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

124
125
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
126
127
128
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
129
                                    YV12_BUFFER_CONFIG *dst_frame) {
130
131
132
133
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
134
135
136
137
138
139
140
141
142
143
144
145
146
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
147
148
149
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
150
151
  (void)ntiles;

152
153
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
154
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
155
156
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
157
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
158
                                  v_start, v_end - v_start, components_pattern);
159
160
161
162
163

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
164
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
165
                                     int components_pattern, int partial_frame,
166
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
167
  AV1_COMMON *const cm = &cpi->common;
168
  int64_t filt_err;
169
170
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
171
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
172
173
174
  return filt_err;
}

175
176
177
178
179
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
180
181
182
183
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
184
185
186
187
188
189
190
191
192
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
193
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
209
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
210
211
212
213
214
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
215
216
217
218
219
    }
  }
  return err;
}

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
#define USE_SGRPROJ_REFINEMENT_SEARCH 1
static int64_t finer_search_pixel_proj_error(
    uint8_t *src8, int width, int height, int src_stride, uint8_t *dat8,
    int dat_stride, int bit_depth, int32_t *flt1, int flt1_stride,
    int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
  int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
                                     dat_stride, bit_depth, flt1, flt1_stride,
                                     flt2, flt2_stride, xqd);
  (void)start_step;
#if USE_SGRPROJ_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MIN1 };
  int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
  for (int s = start_step; s >= 1; s >>= 1) {
    for (int p = 0; p < 2; ++p) {
      int skip = 0;
      do {
        if (xqd[p] - s >= tap_min[p]) {
          xqd[p] -= s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
                                      dat_stride, bit_depth, flt1, flt1_stride,
                                      flt2, flt2_stride, xqd);
          if (err2 > err) {
            xqd[p] += s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (xqd[p] + s <= tap_max[p]) {
          xqd[p] += s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
                                      dat_stride, bit_depth, flt1, flt1_stride,
                                      flt2, flt2_stride, xqd);
          if (err2 > err) {
            xqd[p] -= s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
    }
  }
#endif  // USE_SGRPROJ_REFINEMENT_SEARCH
  return err;
}

276
277
278
279
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
280
281
282
283
284
285
286
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

287
288
  aom_clear_system_state();

289
290
291
  // Default
  xq[0] = 0;
  xq[1] = 0;
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
342
  xqd[0] = xq[0];
343
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
344
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
345
346
347
348
349
350
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
351
352
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
353
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
354
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
355
  int ep, bestep = 0;
356
357
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
358

359
360
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
361
#if CONFIG_HIGHBITDEPTH
362
363
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
364
365
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter_highbd(dat, width, height, dat_stride, flt1, width,
366
                                 sgr_params[ep].corner, sgr_params[ep].edge);
367
#else
368
369
370
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
371
#endif  // USE_HIGHPASS_IN_SGRPROJ
372
373
374
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
375
    } else {
376
#endif
377
378
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter(dat8, width, height, dat_stride, flt1, width,
379
                          sgr_params[ep].corner, sgr_params[ep].edge);
380
381
382
383
#else
    av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
#endif  // USE_HIGHPASS_IN_SGRPROJ
384
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
385
                                 sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
386
#if CONFIG_HIGHBITDEPTH
387
    }
388
#endif
389
    aom_clear_system_state();
390
391
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
392
    aom_clear_system_state();
393
    encode_xq(exq, exqd);
394
395
396
    err = finer_search_pixel_proj_error(src8, width, height, src_stride, dat8,
                                        dat_stride, bit_depth, flt1, width,
                                        flt2, width, 2, exqd);
397
398
399
400
401
402
403
404
405
406
407
408
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

409
410
411
412
413
414
415
416
417
418
419
420
421
422
static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                              SgrprojInfo *ref_sgrproj_info) {
  int bits = SGRPROJ_PARAMS_BITS;
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
      sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
      sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
  return bits;
}

423
424
425
426
427
static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int partial_frame, int plane,
                             RestorationInfo *info, RestorationType *type,
                             double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
428
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
429
  double err, cost_norestore, cost_sgrproj;
430
431
432
433
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
434
  RestorationInfo *rsi = &cpi->rst_search[0];
435
436
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
437
438
439
  int width, height, src_stride, dgd_stride;
  uint8_t *dgd_buffer, *src_buffer;
  if (plane == AOM_PLANE_Y) {
440
441
    width = src->y_crop_width;
    height = src->y_crop_height;
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    src_buffer = src->y_buffer;
    src_stride = src->y_stride;
    dgd_buffer = dgd->y_buffer;
    dgd_stride = dgd->y_stride;
    assert(width == dgd->y_crop_width);
    assert(height == dgd->y_crop_height);
    assert(width == src->y_crop_width);
    assert(height == src->y_crop_height);
  } else {
    width = src->uv_crop_width;
    height = src->uv_crop_height;
    src_stride = src->uv_stride;
    dgd_stride = dgd->uv_stride;
    src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
    dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
    assert(width == dgd->uv_crop_width);
    assert(height == dgd->uv_crop_height);
  }
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
463
464
465
466
467
468
469
470
471
472
473
474
475
476
  SgrprojInfo ref_sgrproj_info;
  set_default_sgrproj(&ref_sgrproj_info);

  rsi[plane].frame_restoration_type = RESTORE_SGRPROJ;

  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
477
478
                               h_end - h_start, v_start, v_end - v_start,
                               (1 << plane));
479
480
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
481
    cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
482
483
484
485
486
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd_buffer + v_start * dgd_stride + h_start, h_end - h_start,
        v_end - v_start, dgd_stride,
        src_buffer + v_start * src_stride + h_start, src_stride,
487
#if CONFIG_HIGHBITDEPTH
488
        cm->bit_depth,
489
#else
490
        8,
491
#endif  // CONFIG_HIGHBITDEPTH
492
493
        &rsi[plane].sgrproj_info[tile_idx].ep,
        rsi[plane].sgrproj_info[tile_idx].xqd, cm->rst_internal.tmpbuf);
494
495
496
497
498
499
500
    rsi[plane].restoration_type[tile_idx] = RESTORE_SGRPROJ;
    err = try_restoration_tile(src, cpi, rsi, (1 << plane), partial_frame,
                               tile_idx, 0, 0, dst_frame);
    bits = count_sgrproj_bits(&rsi[plane].sgrproj_info[tile_idx],
                              &ref_sgrproj_info)
           << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
501
    cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
502
503
504
505
506
507
508
509
    if (cost_sgrproj >= cost_norestore) {
      type[tile_idx] = RESTORE_NONE;
    } else {
      type[tile_idx] = RESTORE_SGRPROJ;
      memcpy(&sgrproj_info[tile_idx], &rsi[plane].sgrproj_info[tile_idx],
             sizeof(sgrproj_info[tile_idx]));
      memcpy(&ref_sgrproj_info, &sgrproj_info[tile_idx],
             sizeof(ref_sgrproj_info));
510
      best_tile_cost[tile_idx] = err;
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    }
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Cost for Sgrproj filtering
  set_default_sgrproj(&ref_sgrproj_info);
  bits = frame_level_restore_bits[rsi[plane].frame_restoration_type]
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi[plane].sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
           sizeof(sgrproj_info[tile_idx]));
    if (type[tile_idx] == RESTORE_SGRPROJ) {
      bits += count_sgrproj_bits(&rsi[plane].sgrproj_info[tile_idx],
                                 &ref_sgrproj_info)
              << AV1_PROB_COST_SHIFT;
      memcpy(&ref_sgrproj_info, &rsi[plane].sgrproj_info[tile_idx],
             sizeof(ref_sgrproj_info));
    }
    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
  }
  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
                              dst_frame);
534
  cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
535
536
537
538

  return cost_sgrproj;
}

539
540
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
541
542
543
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
544
  aom_clear_system_state();
545
546
547
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
548
549
550
  return avg;
}

551
552
553
554
static void compute_stats(int wiener_win, uint8_t *dgd, uint8_t *src,
                          int h_start, int h_end, int v_start, int v_end,
                          int dgd_stride, int src_stride, double *M,
                          double *H) {
555
  int i, j, k, l;
556
  double Y[WIENER_WIN2];
557
558
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
559
560
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
561

562
563
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
564
565
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
566
567
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
568
569
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
570
571
572
573
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
574
      for (k = 0; k < wiener_win2; ++k) {
575
        M[k] += Y[k] * X;
576
577
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
578
579
580
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
581
          H[k * wiener_win2 + l] += Y[k] * Y[l];
582
583
584
585
        }
      }
    }
  }
586
587
588
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
589
590
    }
  }
591
592
}

593
#if CONFIG_HIGHBITDEPTH
594
595
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
596
597
598
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
599
  aom_clear_system_state();
600
601
602
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
603
604
605
  return avg;
}

606
607
static void compute_stats_highbd(int wiener_win, uint8_t *dgd8, uint8_t *src8,
                                 int h_start, int h_end, int v_start, int v_end,
608
609
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
610
  int i, j, k, l;
611
  double Y[WIENER_WIN2];
612
613
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
614
615
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
616
617
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
618

619
620
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
621
622
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
623
624
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
625
626
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
627
628
629
630
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
631
      for (k = 0; k < wiener_win2; ++k) {
632
        M[k] += Y[k] * X;
633
634
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
635
636
637
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
638
          H[k * wiener_win2 + l] += Y[k] * Y[l];
639
640
641
642
        }
      }
    }
  }
643
644
645
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
646
647
    }
  }
648
}
649
#endif  // CONFIG_HIGHBITDEPTH
650

651
652
653
static INLINE int wrap_index(int i, int wiener_win) {
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
  return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
654
655
656
}

// Fix vector b, update vector a
657
658
static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
659
  int i, j;
660
  double S[WIENER_WIN];
661
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
662
663
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
664
665
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
666
667
668
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; ++j) {
      const int jj = wrap_index(j, wiener_win);
669
670
671
      A[jj] += Mc[i][j] * b[i];
    }
  }
672
673
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
674
      int k, l;
675
676
677
678
679
680
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l) {
          const int kk = wrap_index(k, wiener_win);
          const int ll = wrap_index(l, wiener_win);
          B[ll * wiener_halfwin1 + kk] +=
              Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
681
682
683
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
684
  // Normalization enforcement in the system of equations itself
685
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
686
    A[i] -=
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
702
    }
703
    memcpy(a, S, wiener_win * sizeof(*a));
704
705
706
707
  }
}

// Fix vector a, update vector b
708
709
static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
710
  int i, j;
711
  double S[WIENER_WIN];
712
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
713
714
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
715
716
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
717
718
719
  for (i = 0; i < wiener_win; i++) {
    const int ii = wrap_index(i, wiener_win);
    for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
720
721
  }

722
723
724
725
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
      const int ii = wrap_index(i, wiener_win);
      const int jj = wrap_index(j, wiener_win);
726
      int k, l;
727
728
729
730
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l)
          B[jj * wiener_halfwin1 + ii] +=
              Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
731
732
    }
  }
Aamir Anis's avatar
Aamir Anis committed
733
  // Normalization enforcement in the system of equations itself
734
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
735
    A[i] -=
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
751
    }
752
    memcpy(b, S, wiener_win * sizeof(*b));
753
754
755
  }
}

756
757
static int wiener_decompose_sep_sym(int wiener_win, double *M, double *H,
                                    double *a, double *b) {
758
759
760
761
  static const int init_filt[WIENER_WIN] = {
    WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
    WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
    WIENER_FILT_TAP0_MIDV,
762
  };
763
764
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
765
766
767
768
769
  int i, j, iter;
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
  for (i = 0; i < wiener_win; i++) {
    a[i] = b[i] = (double)init_filt[i + plane_off] / WIENER_FILT_STEP;
770
  }
771
772
773
774
775
776
  for (i = 0; i < wiener_win; i++) {
    Mc[i] = M + i * wiener_win;
    for (j = 0; j < wiener_win; j++) {
      Hc[i * wiener_win + j] =
          H + i * wiener_win * wiener_win2 + j * wiener_win;
    }
777
  }
778
779

  iter = 1;
780
  while (iter < NUM_WIENER_ITERS) {
781
782
    update_a_sep_sym(wiener_win, Mc, Hc, a, b);
    update_b_sep_sym(wiener_win, Mc, Hc, a, b);
783
784
    iter++;
  }
785
  return 1;
786
787
}

788
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
789
790
// against identity filters; Final score is defined as the difference between
// the function values
791
792
static double compute_score(int wiener_win, double *M, double *H,
                            InterpKernel vfilt, InterpKernel hfilt) {
793
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
794
795
796
797
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
798
  double a[WIENER_WIN], b[WIENER_WIN];
799
800
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
801
802
803

  aom_clear_system_state();

804
805
806
807
808
809
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
810
  }
811
812
813
  for (k = 0; k < wiener_win; ++k) {
    for (l = 0; l < wiener_win; ++l)
      ab[k * wiener_win + l] = a[l + plane_off] * b[k + plane_off];
Aamir Anis's avatar
Aamir Anis committed
814
  }
815
  for (k = 0; k < wiener_win2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
816
    P += ab[k] * M[k];
817
818
    for (l = 0; l < wiener_win2; ++l)
      Q += ab[k] * H[k * wiener_win2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
819
820
821
  }
  Score = Q - 2 * P;

822
823
  iP = M[wiener_win2 >> 1];
  iQ = H[(wiener_win2 >> 1) * wiener_win2 + (wiener_win2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
824
825
826
827
828
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

829
static void quantize_sym_filter(int wiener_win, double *f, InterpKernel fi) {
830
  int i;
831
832
  const int wiener_halfwin = (wiener_win >> 1);
  for (i = 0; i < wiener_halfwin; ++i) {
833
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
834
835
  }
  // Specialize for 7-tap filter
836
837
838
839
840
841
842
843
844
  if (wiener_win == WIENER_WIN) {
    fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
    fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
  } else {
    fi[2] = CLIP(fi[1], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
    fi[1] = CLIP(fi[0], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[0] = 0;
  }
845
846
847
848
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
849
850
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
851
852
}

853
static int count_wiener_bits(int wiener_win, WienerInfo *wiener_info,
854
855
                             WienerInfo *ref_wiener_info) {
  int bits = 0;
856
857
858
859
860
861
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
862
863
864
865
866
867
868
869
870
871
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
872
873
874
875
876
877
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
878
879
880
881
882
883
884
885
886
887
888
889
890
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV);
  return bits;
}

891
892
893
#define USE_WIENER_REFINEMENT_SEARCH 1
static int64_t finer_tile_search_wiener(const YV12_BUFFER_CONFIG *src,
                                        AV1_COMP *cpi, RestorationInfo *rsi,
894
895
                                        int start_step, int plane,
                                        int wiener_win, int tile_idx,
896
897
                                        int partial_frame,
                                        YV12_BUFFER_CONFIG *dst_frame) {
898
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
899
900
  int64_t err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                     tile_idx, 0, 0, dst_frame);
901
  (void)start_step;
902
903
904
905
906
907
908
#if USE_WIENER_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP1_MINV,
                    WIENER_FILT_TAP2_MINV };
  int tap_max[] = { WIENER_FILT_TAP0_MAXV, WIENER_FILT_TAP1_MAXV,
                    WIENER_FILT_TAP2_MAXV };
  // printf("err  pre = %"PRId64"\n", err);
909
  for (int s = start_step; s >= 1; s >>= 1) {
910
    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
      int skip = 0;
      do {
        if (rsi[plane].wiener_info[tile_idx].hfilter[p] - s >= tap_min[p]) {
          rsi[plane].wiener_info[tile_idx].hfilter[p] -= s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] -= s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] += 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].hfilter[p] += s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] += s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] -= 2 * s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (rsi[plane].wiener_info[tile_idx].hfilter[p] + s <= tap_max[p]) {
          rsi[plane].wiener_info[tile_idx].hfilter[p] += s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] += s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] -= 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].hfilter[p] -= s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] -= s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] += 2 * s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
952
    }
953
    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
      int skip = 0;
      do {
        if (rsi[plane].wiener_info[tile_idx].vfilter[p] - s >= tap_min[p]) {
          rsi[plane].wiener_info[tile_idx].vfilter[p] -= s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] -= s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] += 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].vfilter[p] += s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] += s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] -= 2 * s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (rsi[plane].wiener_info[tile_idx].vfilter[p] + s <= tap_max[p]) {
          rsi[plane].wiener_info[tile_idx].vfilter[p] += s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] += s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] -= 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].vfilter[p] -= s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] -= s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] += 2 * s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
995
996
997
998
999
1000
1001
    }
  }
// printf("err post = %"PRId64"\n", err);
#endif  // USE_WIENER_REFINEMENT_SEARCH
  return err;
}

1002
1003
1004
1005
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                            int partial_frame, int plane, RestorationInfo *info,
                            RestorationType *type, double *best_tile_cost,
                            YV12_BUFFER_CONFIG *dst_frame) {
1006
1007
1008
1009
1010
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
1011
  double cost_wiener, cost_norestore;
1012
1013
1014
1015
1016
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
1017
  int width, height, src_stride, dgd_stride, wiener_win;
1018
1019
  uint8_t *dgd_buffer, *src_buffer;
  if (plane == AOM_PLANE_Y) {
1020
1021
    width = src->y_crop_width;
    height = src->y_crop_height;
1022
1023
1024
1025
1026
1027
1028
1029
    src_buffer = src->y_buffer;
    src_stride = src->y_stride;
    dgd_buffer = dgd->y_buffer;
    dgd_stride = dgd->y_stride;
    assert(width == dgd->y_crop_width);
    assert(height == dgd->y_crop_height);
    assert(width == src->y_crop_width);
    assert(height == src->y_crop_height);
1030
    wiener_win = WIENER_WIN;
1031
1032
1033
1034
1035
1036
1037
1038
1039
  } else {
    width = src->uv_crop_width;
    height = src->uv_crop_height;
    src_stride = src->uv_stride;
    dgd_stride = dgd->uv_stride;
    src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
    dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
    assert(width == dgd->uv_crop_width);
    assert(height == dgd->uv_crop_height);
1040
    wiener_win = WIENER_WIN_CHROMA;
1041
  }
1042
1043
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
1044
  int h_start, h_end, v_start, v_end;
1045
1046
1047
  const int ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[plane].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
1048
1049
  WienerInfo ref_wiener_info;
  set_default_wiener(&ref_wiener_info);
1050
1051

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
1052

1053
1054
1055
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
1056

1057
1058
1059
// Construct a (WIENER_HALFWIN)-pixel border around the frame
#if CONFIG_HIGHBITDEPTH
  if (cm->use_highbitdepth)
1060
1061
    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd_buffer), width, height,
                        dgd_stride);
1062
1063
  else
#endif
1064
    extend_frame(dgd_buffer, width, height, dgd_stride);
1065
1066
1067
1068
1069
1070
1071
1072

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
                               h_end - h_start, v_start, v_end - v_start,
1073
                               (1 << plane));
1074
1075
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
1076
    cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
1077
    best_tile_cost[tile_idx] = DBL_MAX;
1078
1079

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
1080
1081
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
1082
#if CONFIG_HIGHBITDEPTH
1083
    if (cm->use_highbitdepth)
1084
1085
      compute_stats_highbd(wiener_win, dgd_buffer, src_buffer, h_start, h_end,
                           v_start, v_end, dgd_stride, src_stride, M, H);
1086
    else
1087
#endif  // CONFIG_HIGHBITDEPTH
1088
1089
      compute_stats(wiener_win, dgd_buffer, src_buffer, h_start, h_end, v_start,
                    v_end, dgd_stride, src_stride, M, H);
1090
1091
1092

    type[tile_idx] = RESTORE_WIENER;

1093
    if (!wiener_decompose_sep_sym(wiener_win, M, H, vfilterd, hfilterd)) {
1094
1095
      type[tile_idx] = RESTORE_NONE;
      continue;
1096
    }
1097
1098