pickrst.c 57.3 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

Yaowu Xu's avatar
Yaowu Xu committed
19
#include "aom_dsp/aom_dsp_common.h"
20
21
#include "aom_dsp/binary_codes_writer.h"
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
22
#include "aom_mem/aom_mem.h"
23
#include "aom_ports/mem.h"
24
#include "aom_ports/system_state.h"
25

26
27
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
28
#include "av1/common/restoration.h"
29

30
#include "av1/encoder/av1_quantize.h"
31
32
33
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
34
#include "av1/encoder/mathutils.h"
35

36
37
38
// When set to RESTORE_WIENER or RESTORE_SGRPROJ only those are allowed.
// When set to RESTORE_NONE (0) we allow switchable.
const RestorationType force_restore_type = RESTORE_NONE;
39
40

// Number of Wiener iterations
41
#define NUM_WIENER_ITERS 5
42

43
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
44
                                      AV1_COMP *cpi, int partial_frame,
45
                                      int plane, RestorationInfo *info,
46
                                      RestorationType *rest_level,
47
48
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
49

50
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
51
52

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
53
54
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
55
                                    int width, int v_start, int height,
56
57
                                    int components_pattern) {
  int64_t filt_err = 0;
58
59
60
61
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
62
#if CONFIG_HIGHBITDEPTH
63
  if (cm->use_highbitdepth) {
64
65
66
67
68
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
69
70
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
71
72
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
73
74
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
75
76
    }
    return filt_err;
77
  }
78
#endif  // CONFIG_HIGHBITDEPTH
79
80
81
82
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
83
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
84
85
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
86
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
87
  }
88
89
90
  return filt_err;
}

91
92
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
93
94
95
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
96
#if CONFIG_HIGHBITDEPTH
97
98
99
100
101
102
103
104
105
106
107
108
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
109
110
#else
  (void)cm;
111
#endif  // CONFIG_HIGHBITDEPTH
112
113
114
115
116
117
118
119
120
121
122
123
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

124
125
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
126
127
128
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
129
                                    YV12_BUFFER_CONFIG *dst_frame) {
130
131
132
133
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
134
135
136
137
138
139
140
141
142
143
144
145
146
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
147
148
149
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
150
151
  (void)ntiles;

152
153
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
154
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
155
156
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
157
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
158
                                  v_start, v_end - v_start, components_pattern);
159
160
161
162
163

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
164
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
165
                                     int components_pattern, int partial_frame,
166
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
167
  AV1_COMMON *const cm = &cpi->common;
168
  int64_t filt_err;
169
170
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
171
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
172
173
174
  return filt_err;
}

175
176
static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
                                    int src_stride, const uint8_t *dat8,
177
                                    int dat_stride, int use_highbitdepth,
178
179
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
180
181
182
183
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
184
  if (!use_highbitdepth) {
185
186
187
188
189
190
191
192
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
193
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
209
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
210
211
212
213
214
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
215
216
217
218
219
    }
  }
  return err;
}

220
221
#define USE_SGRPROJ_REFINEMENT_SEARCH 1
static int64_t finer_search_pixel_proj_error(
222
    const uint8_t *src8, int width, int height, int src_stride,
223
    const uint8_t *dat8, int dat_stride, int use_highbitdepth, int32_t *flt1,
224
    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
225
  int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
226
227
                                     dat_stride, use_highbitdepth, flt1,
                                     flt1_stride, flt2, flt2_stride, xqd);
228
229
230
231
232
233
234
235
236
237
238
239
  (void)start_step;
#if USE_SGRPROJ_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MIN1 };
  int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
  for (int s = start_step; s >= 1; s >>= 1) {
    for (int p = 0; p < 2; ++p) {
      int skip = 0;
      do {
        if (xqd[p] - s >= tap_min[p]) {
          xqd[p] -= s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
240
241
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
          if (err2 > err) {
            xqd[p] += s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (xqd[p] + s <= tap_max[p]) {
          xqd[p] += s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
258
259
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
          if (err2 > err) {
            xqd[p] -= s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
    }
  }
#endif  // USE_SGRPROJ_REFINEMENT_SEARCH
  return err;
}

276
static void get_proj_subspace(const uint8_t *src8, int width, int height,
277
                              int src_stride, uint8_t *dat8, int dat_stride,
278
279
280
                              int use_highbitdepth, int32_t *flt1,
                              int flt1_stride, int32_t *flt2, int flt2_stride,
                              int *xq) {
281
282
283
284
285
286
287
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

288
289
  aom_clear_system_state();

290
291
292
  // Default
  xq[0] = 0;
  xq[1] = 0;
293
  if (!use_highbitdepth) {
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
343
  xqd[0] = xq[0];
344
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
345
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
346
347
348
349
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
350
                                          int dat_stride, const uint8_t *src8,
351
352
353
                                          int src_stride, int use_highbitdepth,
                                          int bit_depth, int *eps, int *xqd,
                                          int32_t *rstbuf) {
354
  int32_t *flt1 = rstbuf;
355
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
356
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
357
  int ep, bestep = 0;
358
359
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
360

361
362
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
363
#if CONFIG_HIGHBITDEPTH
364
    if (use_highbitdepth) {
365
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
366
367
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter_highbd(dat, width, height, dat_stride, flt1, width,
368
                                 sgr_params[ep].corner, sgr_params[ep].edge);
369
#else
370
371
372
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
373
#endif  // USE_HIGHPASS_IN_SGRPROJ
374
375
376
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
377
    } else {
378
#endif
379
380
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter(dat8, width, height, dat_stride, flt1, width,
381
                          sgr_params[ep].corner, sgr_params[ep].edge);
382
383
384
385
#else
    av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
#endif  // USE_HIGHPASS_IN_SGRPROJ
386
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
387
                                 sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
388
#if CONFIG_HIGHBITDEPTH
389
    }
390
#endif
391
    aom_clear_system_state();
392
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
393
                      use_highbitdepth, flt1, width, flt2, width, exq);
394
    aom_clear_system_state();
395
    encode_xq(exq, exqd);
396
397
398
    err = finer_search_pixel_proj_error(src8, width, height, src_stride, dat8,
                                        dat_stride, bit_depth, flt1, width,
                                        flt2, width, 2, exqd);
399
400
401
402
403
404
405
406
407
408
409
410
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

411
412
413
414
415
416
417
418
419
420
421
422
423
424
static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                              SgrprojInfo *ref_sgrproj_info) {
  int bits = SGRPROJ_PARAMS_BITS;
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
      sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
      sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
  return bits;
}

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
struct rest_search_ctxt {
  const YV12_BUFFER_CONFIG *src;
  AV1_COMP *cpi;
  uint8_t *dgd_buffer;
  const uint8_t *src_buffer;
  int dgd_stride;
  int src_stride;
  int partial_frame;
  RestorationInfo *info;
  RestorationType *type;
  double *best_tile_cost;
  int plane;
  int plane_width;
  int plane_height;
  int nrtiles_x;
  int nrtiles_y;
  YV12_BUFFER_CONFIG *dst_frame;
};

// Fill in ctxt. Returns the number of restoration tiles for this plane
static INLINE int init_rest_search_ctxt(
    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int partial_frame, int plane,
    RestorationInfo *info, RestorationType *type, double *best_tile_cost,
    YV12_BUFFER_CONFIG *dst_frame, struct rest_search_ctxt *ctxt) {
449
  AV1_COMMON *const cm = &cpi->common;
450
451
452
453
454
455
456
457
458
  ctxt->src = src;
  ctxt->cpi = cpi;
  ctxt->partial_frame = partial_frame;
  ctxt->info = info;
  ctxt->type = type;
  ctxt->best_tile_cost = best_tile_cost;
  ctxt->plane = plane;
  ctxt->dst_frame = dst_frame;

459
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
460
  if (plane == AOM_PLANE_Y) {
461
462
463
464
465
466
467
468
469
470
    ctxt->plane_width = src->y_crop_width;
    ctxt->plane_height = src->y_crop_height;
    ctxt->src_buffer = src->y_buffer;
    ctxt->src_stride = src->y_stride;
    ctxt->dgd_buffer = dgd->y_buffer;
    ctxt->dgd_stride = dgd->y_stride;
    assert(ctxt->plane_width == dgd->y_crop_width);
    assert(ctxt->plane_height == dgd->y_crop_height);
    assert(ctxt->plane_width == src->y_crop_width);
    assert(ctxt->plane_height == src->y_crop_height);
471
  } else {
472
473
474
475
476
477
478
479
    ctxt->plane_width = src->uv_crop_width;
    ctxt->plane_height = src->uv_crop_height;
    ctxt->src_stride = src->uv_stride;
    ctxt->dgd_stride = dgd->uv_stride;
    ctxt->src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
    ctxt->dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
    assert(ctxt->plane_width == dgd->uv_crop_width);
    assert(ctxt->plane_height == dgd->uv_crop_height);
480
  }
481

482
483
484
485
  return av1_get_rest_ntiles(ctxt->plane_width, ctxt->plane_height,
                             cm->rst_info[plane].restoration_tilesize, NULL,
                             NULL, &ctxt->nrtiles_x, &ctxt->nrtiles_y);
}
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
typedef void (*rtile_visitor_t)(const struct rest_search_ctxt *search_ctxt,
                                int rtile_idx, int h_start, int h_end,
                                int v_start, int v_end, void *arg);

static void foreach_rtile_in_tile(const struct rest_search_ctxt *ctxt,
                                  int tile_row, int tile_col,
                                  rtile_visitor_t fun, void *arg) {
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  const RestorationInfo *rsi = ctxt->cpi->rst_search;

  const int tile_width_y = cm->tile_width * MI_SIZE;
  const int tile_height_y = cm->tile_height * MI_SIZE;

  const int tile_width =
      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_width_y, cm->subsampling_x)
                        : tile_width_y;
  const int tile_height =
      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_height_y, cm->subsampling_y)
                        : tile_height_y;

  const int rtile_size = rsi->restoration_tilesize;
  const int rtiles_per_tile_x = tile_width * MI_SIZE / rtile_size;
  const int rtiles_per_tile_y = tile_height * MI_SIZE / rtile_size;

  const int rtile_row0 = rtiles_per_tile_y * tile_row;
  const int rtile_row1 =
      AOMMIN(rtile_row0 + rtiles_per_tile_y, ctxt->nrtiles_y);

  const int rtile_col0 = rtiles_per_tile_x * tile_col;
  const int rtile_col1 =
      AOMMIN(rtile_col0 + rtiles_per_tile_x, ctxt->nrtiles_x);

  const int rtile_width = AOMMIN(tile_width, rtile_size);
  const int rtile_height = AOMMIN(tile_height, rtile_size);

  for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
    for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
      const int rtile_idx = rtile_row * ctxt->nrtiles_x + rtile_col;
      int h_start, h_end, v_start, v_end;
      av1_get_rest_tile_limits(rtile_idx, 0, 0, ctxt->nrtiles_x,
                               ctxt->nrtiles_y, rtile_width, rtile_height,
                               ctxt->plane_width, ctxt->plane_height, 0, 0,
                               &h_start, &h_end, &v_start, &v_end);

      fun(ctxt, rtile_idx, h_start, h_end, v_start, v_end, arg);
    }
533
  }
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
}

static void search_sgrproj_for_rtile(const struct rest_search_ctxt *ctxt,
                                     int rtile_idx, int h_start, int h_end,
                                     int v_start, int v_end, void *arg) {
  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  RestorationInfo *rsi = ctxt->cpi->rst_search;
  SgrprojInfo *sgrproj_info = ctxt->info->sgrproj_info;

  SgrprojInfo *ref_sgrproj_info = (SgrprojInfo *)arg;

  int64_t err = sse_restoration_tile(ctxt->src, cm->frame_to_show, cm, h_start,
                                     h_end - h_start, v_start, v_end - v_start,
                                     (1 << ctxt->plane));
  // #bits when a tile is not restored
  int bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  ctxt->best_tile_cost[rtile_idx] = DBL_MAX;

  RestorationInfo *plane_rsi = &rsi[ctxt->plane];
  SgrprojInfo *rtile_sgrproj_info = &plane_rsi->sgrproj_info[rtile_idx];
  uint8_t *dgd_start = ctxt->dgd_buffer + v_start * ctxt->dgd_stride + h_start;
  const uint8_t *src_start =
      ctxt->src_buffer + v_start * ctxt->src_stride + h_start;

  search_selfguided_restoration(dgd_start, h_end - h_start, v_end - v_start,
                                ctxt->dgd_stride, src_start, ctxt->src_stride,
562
#if CONFIG_HIGHBITDEPTH
563
                                cm->use_highbitdepth, cm->bit_depth,
564
#else
565
                                0, 8,
566
#endif  // CONFIG_HIGHBITDEPTH
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
                                &rtile_sgrproj_info->ep,
                                rtile_sgrproj_info->xqd,
                                cm->rst_internal.tmpbuf);
  plane_rsi->restoration_type[rtile_idx] = RESTORE_SGRPROJ;
  err = try_restoration_tile(ctxt->src, ctxt->cpi, rsi, (1 << ctxt->plane),
                             ctxt->partial_frame, rtile_idx, 0, 0,
                             ctxt->dst_frame);
  bits =
      count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx], ref_sgrproj_info)
      << AV1_PROB_COST_SHIFT;
  bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
  double cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  if (cost_sgrproj >= cost_norestore) {
    ctxt->type[rtile_idx] = RESTORE_NONE;
  } else {
    ctxt->type[rtile_idx] = RESTORE_SGRPROJ;
    *ref_sgrproj_info = sgrproj_info[rtile_idx] =
        plane_rsi->sgrproj_info[rtile_idx];
    ctxt->best_tile_cost[rtile_idx] = err;
  }
  plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int partial_frame, int plane,
                             RestorationInfo *info, RestorationType *type,
                             double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
  struct rest_search_ctxt ctxt;
  const int nrtiles =
      init_rest_search_ctxt(src, cpi, partial_frame, plane, info, type,
                            best_tile_cost, dst_frame, &ctxt);

  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
  plane_rsi->frame_restoration_type = RESTORE_SGRPROJ;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
  }

  // Compute best Sgrproj filters for each rtile, one (encoder/decoder)
  // tile at a time.
  const AV1_COMMON *const cm = &cpi->common;
  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
      SgrprojInfo ref_sgrproj_info;
      set_default_sgrproj(&ref_sgrproj_info);

      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_sgrproj_for_rtile,
                            &ref_sgrproj_info);
616
617
    }
  }
618

619
  // Cost for Sgrproj filtering
620
  SgrprojInfo ref_sgrproj_info;
621
  set_default_sgrproj(&ref_sgrproj_info);
622
623
624
625
626
627
628
629
630
631
  SgrprojInfo *sgrproj_info = info->sgrproj_info;

  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
             << AV1_PROB_COST_SHIFT;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB,
                         type[rtile_idx] != RESTORE_NONE);
    plane_rsi->sgrproj_info[rtile_idx] = sgrproj_info[rtile_idx];
    if (type[rtile_idx] == RESTORE_SGRPROJ) {
      bits += count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx],
632
633
                                 &ref_sgrproj_info)
              << AV1_PROB_COST_SHIFT;
634
      ref_sgrproj_info = plane_rsi->sgrproj_info[rtile_idx];
635
    }
636
    plane_rsi->restoration_type[rtile_idx] = type[rtile_idx];
637
  }
638
639
640
  double err = try_restoration_frame(src, cpi, cpi->rst_search, (1 << plane),
                                     partial_frame, dst_frame);
  double cost_sgrproj = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
641
642
643
  return cost_sgrproj;
}

644
645
static double find_average(const uint8_t *src, int h_start, int h_end,
                           int v_start, int v_end, int stride) {
646
647
648
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
649
  aom_clear_system_state();
650
651
652
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
653
654
655
  return avg;
}

656
657
658
659
static void compute_stats(int wiener_win, const uint8_t *dgd,
                          const uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
660
  int i, j, k, l;
661
  double Y[WIENER_WIN2];
662
663
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
664
665
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
666

667
668
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
669
670
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
671
672
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
673
674
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
675
676
677
678
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
679
      for (k = 0; k < wiener_win2; ++k) {
680
        M[k] += Y[k] * X;
681
682
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
683
684
685
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
686
          H[k * wiener_win2 + l] += Y[k] * Y[l];
687
688
689
690
        }
      }
    }
  }
691
692
693
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
694
695
    }
  }
696
697
}

698
#if CONFIG_HIGHBITDEPTH
699
static double find_average_highbd(const uint16_t *src, int h_start, int h_end,
700
                                  int v_start, int v_end, int stride) {
701
702
703
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
704
  aom_clear_system_state();
705
706
707
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
708
709
710
  return avg;
}

711
712
713
714
static void compute_stats_highbd(int wiener_win, const uint8_t *dgd8,
                                 const uint8_t *src8, int h_start, int h_end,
                                 int v_start, int v_end, int dgd_stride,
                                 int src_stride, double *M, double *H) {
715
  int i, j, k, l;
716
  double Y[WIENER_WIN2];
717
718
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
719
720
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
721
722
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
723

724
725
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
726
727
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
728
729
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
730
731
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
732
733
734
735
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
736
      for (k = 0; k < wiener_win2; ++k) {
737
        M[k] += Y[k] * X;
738
739
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
740
741
742
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
743
          H[k * wiener_win2 + l] += Y[k] * Y[l];
744
745
746
747
        }
      }
    }
  }
748
749
750
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
751
752
    }
  }
753
}
754
#endif  // CONFIG_HIGHBITDEPTH
755

756
757
758
static INLINE int wrap_index(int i, int wiener_win) {
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
  return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
759
760
761
}

// Fix vector b, update vector a
762
763
static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
764
  int i, j;
765
  double S[WIENER_WIN];
766
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
767
768
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
769
770
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
771
772
773
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; ++j) {
      const int jj = wrap_index(j, wiener_win);
774
775
776
      A[jj] += Mc[i][j] * b[i];
    }
  }
777
778
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
779
      int k, l;
780
781
782
783
784
785
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l) {
          const int kk = wrap_index(k, wiener_win);
          const int ll = wrap_index(l, wiener_win);
          B[ll * wiener_halfwin1 + kk] +=
              Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
786
787
788
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
789
  // Normalization enforcement in the system of equations itself
790
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
791
    A[i] -=
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
807
    }
808
    memcpy(a, S, wiener_win * sizeof(*a));
809
810
811
812
  }
}

// Fix vector a, update vector b
813
814
static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
815
  int i, j;
816
  double S[WIENER_WIN];
817
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
818
819
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
820
821
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
822
823
824
  for (i = 0; i < wiener_win; i++) {
    const int ii = wrap_index(i, wiener_win);
    for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
825
826
  }

827
828
829
830
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
      const int ii = wrap_index(i, wiener_win);
      const int jj = wrap_index(j, wiener_win);
831
      int k, l;
832
833
834
835
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l)
          B[jj * wiener_halfwin1 + ii] +=
              Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
836
837
    }
  }
Aamir Anis's avatar
Aamir Anis committed
838
  // Normalization enforcement in the system of equations itself
839
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
840
    A[i] -=
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
856
    }
857
    memcpy(b, S, wiener_win * sizeof(*b));
858
859
860
  }
}

861
862
static int wiener_decompose_sep_sym(int wiener_win, double *M, double *H,
                                    double *a, double *b) {
863
864
865
866
  static const int init_filt[WIENER_WIN] = {
    WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
    WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
    WIENER_FILT_TAP0_MIDV,
867
  };
868
869
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
870
871
872
873
874
  int i, j, iter;
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
  for (i = 0; i < wiener_win; i++) {
    a[i] = b[i] = (double)init_filt[i + plane_off] / WIENER_FILT_STEP;
875
  }
876
877
878
879
880
881
  for (i = 0; i < wiener_win; i++) {
    Mc[i] = M + i * wiener_win;
    for (j = 0; j < wiener_win; j++) {
      Hc[i * wiener_win + j] =
          H + i * wiener_win * wiener_win2 + j * wiener_win;
    }
882
  }
883
884

  iter = 1;
885
  while (iter < NUM_WIENER_ITERS) {
886
887
    update_a_sep_sym(wiener_win, Mc, Hc, a, b);
    update_b_sep_sym(wiener_win, Mc, Hc, a, b);
888
889
    iter++;
  }
890
  return 1;
891
892
}

893
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
894
895
// against identity filters; Final score is defined as the difference between
// the function values
896
897
static double compute_score(int wiener_win, double *M, double *H,
                            InterpKernel vfilt, InterpKernel hfilt) {
898
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
899
900
901
902
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
903
  double a[WIENER_WIN], b[WIENER_WIN];
904
905
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
906
907
908

  aom_clear_system_state();

909
910
911
912
913
914
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
915
  }
916
917
918
  for (k = 0; k < wiener_win; ++k) {
    for (l = 0; l < wiener_win; ++l)
      ab[k * wiener_win + l] = a[l + plane_off] * b[k + plane_off];
Aamir Anis's avatar
Aamir Anis committed
919
  }
920
  for (k = 0; k < wiener_win2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
921
    P += ab[k] * M[k];
922
923
    for (l = 0; l < wiener_win2; ++l)
      Q += ab[k] * H[k * wiener_win2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
924
925
926
  }
  Score = Q - 2 * P;

927
928
  iP = M[wiener_win2 >> 1];
  iQ = H[(wiener_win2 >> 1) * wiener_win2 + (wiener_win2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
929
930
931
932
933
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

934
static void quantize_sym_filter(int wiener_win, double *f, InterpKernel fi) {
935
  int i;
936
937
  const int wiener_halfwin = (wiener_win >> 1);
  for (i = 0; i < wiener_halfwin; ++i) {
938
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
939
940
  }
  // Specialize for 7-tap filter
941
942
943
944
945
946
947
948
949
  if (wiener_win == WIENER_WIN) {
    fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
    fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
  } else {
    fi[2] = CLIP(fi[1], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
    fi[1] = CLIP(fi[0], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[0] = 0;
  }
950
951
952
953
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
954
955
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
956
957
}

958
static int count_wiener_bits(int wiener_win, WienerInfo *wiener_info,
959
960
                             WienerInfo *ref_wiener_info) {
  int bits = 0;
961
962
963
964
965
966
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
967
968
969
970
971
972
973
974
975
976
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
977
978
979
980
981
982
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
983
984
985
986
987
988
989
990
991
992
993
994
995
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV);
  return bits;
}

996
997
998
#define USE_WIENER_REFINEMENT_SEARCH 1
static int64_t finer_tile_search_wiener(const YV12_BUFFER_CONFIG *src,
                                        AV1_COMP *cpi, RestorationInfo *rsi,
999
1000
                                        int start_step, int plane,
                                        int wiener_win, int tile_idx,
1001
1002
                                        int partial_frame,
                                        YV12_BUFFER_CONFIG *dst_frame) {
1003
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
1004
1005
  int64_t err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                     tile_idx, 0, 0, dst_frame);
1006
  (void)start_step;
1007
1008
1009
1010
1011
1012
1013
#if USE_WIENER_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP1_MINV,
                    WIENER_FILT_TAP2_MINV };
  int tap_max[] = { WIENER_FILT_TAP0_MAXV, WIENER_FILT_TAP1_MAXV,
                    WIENER_FILT_TAP2_MAXV };
  // printf("err  pre = %"PRId64"\n", err);
1014
  for (int s = start_step; s >= 1; s >>= 1) {
1015
    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
      int skip = 0;
      do {
        if (rsi[plane].wiener_info[tile_idx].hfilter[p] - s >= tap_min[p]) {
          rsi[plane].wiener_info[tile_idx].hfilter[p] -= s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] -= s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] += 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].hfilter[p] += s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] += s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] -= 2 * s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (rsi[plane].wiener_info[tile_idx].hfilter[p] + s <= tap_max[p]) {
          rsi[plane].wiener_info[tile_idx].hfilter[p] += s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] += s;
          rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] -= 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].hfilter[p] -= s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_WIN - p - 1] -= s;
            rsi[plane].wiener_info[tile_idx].hfilter[WIENER_HALFWIN] += 2 * s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
1057
    }
1058
    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
      int skip = 0;
      do {
        if (rsi[plane].wiener_info[tile_idx].vfilter[p] - s >= tap_min[p]) {
          rsi[plane].wiener_info[tile_idx].vfilter[p] -= s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] -= s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] += 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].vfilter[p] += s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] += s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] -= 2 * s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (rsi[plane].wiener_info[tile_idx].vfilter[p] + s <= tap_max[p]) {
          rsi[plane].wiener_info[tile_idx].vfilter[p] += s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] += s;
          rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] -= 2 * s;
          err2 = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                                      tile_idx, 0, 0, dst_frame);
          if (err2 > err) {
            rsi[plane].wiener_info[tile_idx].vfilter[p] -= s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_WIN - p - 1] -= s;
            rsi[plane].wiener_info[tile_idx].vfilter[WIENER_HALFWIN] += 2 * s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
1100
1101
1102
1103
1104
1105
1106
    }
  }
// printf("err post = %"PRId64"\n", err);
#endif  // USE_WIENER_REFINEMENT_SEARCH
  return err;
}

1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
static void search_wiener_for_rtile(const struct rest_search_ctxt *ctxt,
                                    int rtile_idx, int h_start, int h_end,
                                    int v_start, int v_end, void *arg) {
  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  RestorationInfo *rsi = ctxt->cpi->rst_search;

  const int wiener_win =
      (ctxt->plane == AOM_PLANE_Y) ? WIENER_WIN : WIENER_WIN_CHROMA;

1117
1118
1119
1120
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];

1121
  WienerInfo *ref_wiener_info = (WienerInfo *)arg;
1122

1123
1124
1125
1126
1127
1128
1129
  int64_t err = sse_restoration_tile(ctxt->src, cm->frame_to_show, cm, h_start,
                                     h_end - h_start, v_start, v_end - v_start,
                                     (1 << ctxt->plane));
  // #bits when a tile is not restored
  int bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  ctxt->best_tile_cost[rtile_idx] = DBL_MAX;
1130

1131
1132
#if CONFIG_HIGHBITDEPTH
  if (cm->use_highbitdepth)
1133
1134
1135
    compute_stats_highbd(wiener_win, ctxt->dgd_buffer, ctxt->src_buffer,
                         h_start, h_end, v_start, v_end, ctxt->dgd_stride,
                         ctxt->src_stride, M, H);
1136
  else