pickrst.c 59.1 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

Yaowu Xu's avatar
Yaowu Xu committed
19
#include "aom_dsp/aom_dsp_common.h"
20
21
#include "aom_dsp/binary_codes_writer.h"
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
22
#include "aom_mem/aom_mem.h"
23
#include "aom_ports/mem.h"
24
#include "aom_ports/system_state.h"
25

26
27
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
28
#include "av1/common/restoration.h"
29

30
#include "av1/encoder/av1_quantize.h"
31
#include "av1/encoder/encoder.h"
32
#include "av1/encoder/mathutils.h"
33
34
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
35

36
// When set to RESTORE_WIENER or RESTORE_SGRPROJ only those are allowed.
37
// When set to RESTORE_TYPES we allow switchable.
38
static const RestorationType force_restore_type = RESTORE_TYPES;
39
40

// Number of Wiener iterations
41
#define NUM_WIENER_ITERS 5
42

43
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
44
                                      AV1_COMP *cpi, int partial_frame,
45
                                      int plane, RestorationInfo *info,
46
                                      RestorationType *rest_level,
47
48
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
49

50
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
51
52

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
53
54
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
55
                                    int width, int v_start, int height,
56
57
                                    int components_pattern) {
  int64_t filt_err = 0;
58
59
60
61
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
62
#if CONFIG_HIGHBITDEPTH
63
  if (cm->use_highbitdepth) {
64
65
66
67
68
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
69
70
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
71
72
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
73
74
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
75
76
    }
    return filt_err;
77
  }
78
#endif  // CONFIG_HIGHBITDEPTH
79
80
81
82
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
83
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
84
85
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
86
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
87
  }
88
89
90
  return filt_err;
}

91
92
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
93
94
95
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
96
#if CONFIG_HIGHBITDEPTH
97
98
99
100
101
102
103
104
105
106
107
108
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
109
110
#else
  (void)cm;
111
#endif  // CONFIG_HIGHBITDEPTH
112
113
114
115
116
117
118
119
120
121
122
123
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

124
125
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
126
                                    int components_pattern, int partial_frame,
127
                                    int tile_idx,
128
                                    YV12_BUFFER_CONFIG *dst_frame) {
129
130
131
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
132
133
134
135
136
137
138
139
140
141
142
143
144
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
145
146
147
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
148
149
  (void)ntiles;

150
151
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
152
  RestorationTileLimits limits = av1_get_rest_tile_limits(
153
154
155
156
157
158
      tile_idx, nhtiles, nvtiles, tile_width, tile_height, width,
#if CONFIG_STRIPED_LOOP_RESTORATION
      height, components_pattern > 1 ? cm->subsampling_y : 0);
#else
      height);
#endif
159
160
161
  filt_err = sse_restoration_tile(
      src, dst_frame, cm, limits.h_start, limits.h_end - limits.h_start,
      limits.v_start, limits.v_end - limits.v_start, components_pattern);
162
163
164
165
166

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
167
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
168
                                     int components_pattern, int partial_frame,
169
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
170
  AV1_COMMON *const cm = &cpi->common;
171
  int64_t filt_err;
172
173
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
174
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
175
176
177
  return filt_err;
}

178
179
static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
                                    int src_stride, const uint8_t *dat8,
180
                                    int dat_stride, int use_highbitdepth,
181
182
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
183
184
185
186
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
187
  if (!use_highbitdepth) {
188
189
190
191
192
193
194
195
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
196
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
212
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
213
214
215
216
217
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
218
219
220
221
222
    }
  }
  return err;
}

223
224
#define USE_SGRPROJ_REFINEMENT_SEARCH 1
static int64_t finer_search_pixel_proj_error(
225
    const uint8_t *src8, int width, int height, int src_stride,
226
    const uint8_t *dat8, int dat_stride, int use_highbitdepth, int32_t *flt1,
227
    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
228
  int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
229
230
                                     dat_stride, use_highbitdepth, flt1,
                                     flt1_stride, flt2, flt2_stride, xqd);
231
232
233
234
235
236
237
238
239
240
241
242
  (void)start_step;
#if USE_SGRPROJ_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MIN1 };
  int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
  for (int s = start_step; s >= 1; s >>= 1) {
    for (int p = 0; p < 2; ++p) {
      int skip = 0;
      do {
        if (xqd[p] - s >= tap_min[p]) {
          xqd[p] -= s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
243
244
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
          if (err2 > err) {
            xqd[p] += s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (xqd[p] + s <= tap_max[p]) {
          xqd[p] += s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
261
262
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
          if (err2 > err) {
            xqd[p] -= s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
    }
  }
#endif  // USE_SGRPROJ_REFINEMENT_SEARCH
  return err;
}

279
static void get_proj_subspace(const uint8_t *src8, int width, int height,
280
                              int src_stride, uint8_t *dat8, int dat_stride,
281
282
283
                              int use_highbitdepth, int32_t *flt1,
                              int flt1_stride, int32_t *flt2, int flt2_stride,
                              int *xq) {
284
285
286
287
288
289
290
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

291
292
  aom_clear_system_state();

293
294
295
  // Default
  xq[0] = 0;
  xq[1] = 0;
296
  if (!use_highbitdepth) {
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
346
  xqd[0] = xq[0];
347
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
348
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
349
350
351
352
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
353
                                          int dat_stride, const uint8_t *src8,
354
                                          int src_stride, int use_highbitdepth,
355
356
                                          int bit_depth, int pu_width,
                                          int pu_height, int *eps, int *xqd,
357
                                          int32_t *rstbuf) {
358
  int32_t *flt1 = rstbuf;
359
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
360
  int ep, bestep = 0;
361
362
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
363
364
  int flt1_stride = ((width + 7) & ~7) + 8;
  int flt2_stride = ((width + 7) & ~7) + 8;
365
366
367
368
  assert(pu_width == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_width == RESTORATION_PROC_UNIT_SIZE);
  assert(pu_height == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_height == RESTORATION_PROC_UNIT_SIZE);
369
370
371
#if !CONFIG_HIGHBITDEPTH
  (void)bit_depth;
#endif
372

373
374
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
375
#if CONFIG_HIGHBITDEPTH
376
    if (use_highbitdepth) {
377
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
378
379
380
381
382
383
384
      for (int i = 0; i < height; i += pu_height)
        for (int j = 0; j < width; j += pu_width) {
          const int w = AOMMIN(pu_width, width - j);
          const int h = AOMMIN(pu_height, height - i);
          uint16_t *dat_p = dat + i * dat_stride + j;
          int32_t *flt1_p = flt1 + i * flt1_stride + j;
          int32_t *flt2_p = flt2 + i * flt2_stride + j;
385
#if USE_HIGHPASS_IN_SGRPROJ
386
387
388
          av1_highpass_filter_highbd(dat_p, w, h, dat_stride, flt1_p,
                                     flt1_stride, sgr_params[ep].corner,
                                     sgr_params[ep].edge);
389
#else
390
          av1_selfguided_restoration_highbd(
391
              dat_p, w, h, dat_stride, flt1_p, flt1_stride, bit_depth,
392
              sgr_params[ep].r1, sgr_params[ep].e1);
393
#endif  // USE_HIGHPASS_IN_SGRPROJ
394
          av1_selfguided_restoration_highbd(
395
              dat_p, w, h, dat_stride, flt2_p, flt2_stride, bit_depth,
396
              sgr_params[ep].r2, sgr_params[ep].e2);
397
        }
398
    } else {
399
#endif
400
401
402
403
404
405
406
      for (int i = 0; i < height; i += pu_height)
        for (int j = 0; j < width; j += pu_width) {
          const int w = AOMMIN(pu_width, width - j);
          const int h = AOMMIN(pu_height, height - i);
          uint8_t *dat_p = dat8 + i * dat_stride + j;
          int32_t *flt1_p = flt1 + i * flt1_stride + j;
          int32_t *flt2_p = flt2 + i * flt2_stride + j;
407
#if USE_HIGHPASS_IN_SGRPROJ
408
409
          av1_highpass_filter(dat_p, w, h, dat_stride, flt1_p, flt1_stride,
                              sgr_params[ep].corner, sgr_params[ep].edge);
410
#else
411
        av1_selfguided_restoration(dat_p, w, h, dat_stride, flt1_p, flt1_stride,
412
                                   sgr_params[ep].r1, sgr_params[ep].e1);
413
#endif  // USE_HIGHPASS_IN_SGRPROJ
414
415
          av1_selfguided_restoration(dat_p, w, h, dat_stride, flt2_p,
                                     flt2_stride, sgr_params[ep].r2,
416
                                     sgr_params[ep].e2);
417
        }
418
#if CONFIG_HIGHBITDEPTH
419
    }
420
#endif
421
    aom_clear_system_state();
422
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
423
424
                      use_highbitdepth, flt1, flt1_stride, flt2, flt2_stride,
                      exq);
425
    aom_clear_system_state();
426
    encode_xq(exq, exqd);
427
428
429
    err = finer_search_pixel_proj_error(
        src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth,
        flt1, flt1_stride, flt2, flt2_stride, 2, exqd);
430
431
432
433
434
435
436
437
438
439
440
441
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

442
443
444
445
446
447
448
449
450
451
452
453
454
455
static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                              SgrprojInfo *ref_sgrproj_info) {
  int bits = SGRPROJ_PARAMS_BITS;
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
      sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
      sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
  return bits;
}

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
struct rest_search_ctxt {
  const YV12_BUFFER_CONFIG *src;
  AV1_COMP *cpi;
  uint8_t *dgd_buffer;
  const uint8_t *src_buffer;
  int dgd_stride;
  int src_stride;
  int partial_frame;
  RestorationInfo *info;
  RestorationType *type;
  double *best_tile_cost;
  int plane;
  int plane_width;
  int plane_height;
  int nrtiles_x;
  int nrtiles_y;
  YV12_BUFFER_CONFIG *dst_frame;
};

// Fill in ctxt. Returns the number of restoration tiles for this plane
static INLINE int init_rest_search_ctxt(
    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int partial_frame, int plane,
    RestorationInfo *info, RestorationType *type, double *best_tile_cost,
    YV12_BUFFER_CONFIG *dst_frame, struct rest_search_ctxt *ctxt) {
480
  AV1_COMMON *const cm = &cpi->common;
481
482
483
484
485
486
487
488
489
  ctxt->src = src;
  ctxt->cpi = cpi;
  ctxt->partial_frame = partial_frame;
  ctxt->info = info;
  ctxt->type = type;
  ctxt->best_tile_cost = best_tile_cost;
  ctxt->plane = plane;
  ctxt->dst_frame = dst_frame;

490
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
491
  if (plane == AOM_PLANE_Y) {
492
493
494
495
496
497
498
499
500
501
    ctxt->plane_width = src->y_crop_width;
    ctxt->plane_height = src->y_crop_height;
    ctxt->src_buffer = src->y_buffer;
    ctxt->src_stride = src->y_stride;
    ctxt->dgd_buffer = dgd->y_buffer;
    ctxt->dgd_stride = dgd->y_stride;
    assert(ctxt->plane_width == dgd->y_crop_width);
    assert(ctxt->plane_height == dgd->y_crop_height);
    assert(ctxt->plane_width == src->y_crop_width);
    assert(ctxt->plane_height == src->y_crop_height);
502
  } else {
503
504
505
506
507
508
509
510
    ctxt->plane_width = src->uv_crop_width;
    ctxt->plane_height = src->uv_crop_height;
    ctxt->src_stride = src->uv_stride;
    ctxt->dgd_stride = dgd->uv_stride;
    ctxt->src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
    ctxt->dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
    assert(ctxt->plane_width == dgd->uv_crop_width);
    assert(ctxt->plane_height == dgd->uv_crop_height);
511
  }
512

513
514
515
516
  return av1_get_rest_ntiles(ctxt->plane_width, ctxt->plane_height,
                             cm->rst_info[plane].restoration_tilesize, NULL,
                             NULL, &ctxt->nrtiles_x, &ctxt->nrtiles_y);
}
517

518
typedef void (*rtile_visitor_t)(const struct rest_search_ctxt *search_ctxt,
519
520
                                int rtile_idx,
                                const RestorationTileLimits *limits, void *arg);
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555

static void foreach_rtile_in_tile(const struct rest_search_ctxt *ctxt,
                                  int tile_row, int tile_col,
                                  rtile_visitor_t fun, void *arg) {
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  const RestorationInfo *rsi = ctxt->cpi->rst_search;

  const int tile_width_y = cm->tile_width * MI_SIZE;
  const int tile_height_y = cm->tile_height * MI_SIZE;

  const int tile_width =
      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_width_y, cm->subsampling_x)
                        : tile_width_y;
  const int tile_height =
      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_height_y, cm->subsampling_y)
                        : tile_height_y;

  const int rtile_size = rsi->restoration_tilesize;
  const int rtiles_per_tile_x = tile_width * MI_SIZE / rtile_size;
  const int rtiles_per_tile_y = tile_height * MI_SIZE / rtile_size;

  const int rtile_row0 = rtiles_per_tile_y * tile_row;
  const int rtile_row1 =
      AOMMIN(rtile_row0 + rtiles_per_tile_y, ctxt->nrtiles_y);

  const int rtile_col0 = rtiles_per_tile_x * tile_col;
  const int rtile_col1 =
      AOMMIN(rtile_col0 + rtiles_per_tile_x, ctxt->nrtiles_x);

  const int rtile_width = AOMMIN(tile_width, rtile_size);
  const int rtile_height = AOMMIN(tile_height, rtile_size);

  for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
    for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
      const int rtile_idx = rtile_row * ctxt->nrtiles_x + rtile_col;
556
557
      RestorationTileLimits limits = av1_get_rest_tile_limits(
          rtile_idx, ctxt->nrtiles_x, ctxt->nrtiles_y, rtile_width,
558
559
560
561
562
563
          rtile_height, ctxt->plane_width,
#if CONFIG_STRIPED_LOOP_RESTORATION
          ctxt->plane_height, ctxt->plane > 0 ? cm->subsampling_y : 0);
#else
          ctxt->plane_height);
#endif
564
      fun(ctxt, rtile_idx, &limits, arg);
565
    }
566
  }
567
568
569
}

static void search_sgrproj_for_rtile(const struct rest_search_ctxt *ctxt,
570
571
572
                                     int rtile_idx,
                                     const RestorationTileLimits *limits,
                                     void *arg) {
573
574
575
576
577
578
579
  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  RestorationInfo *rsi = ctxt->cpi->rst_search;
  SgrprojInfo *sgrproj_info = ctxt->info->sgrproj_info;

  SgrprojInfo *ref_sgrproj_info = (SgrprojInfo *)arg;

580
581
582
583
  int64_t err =
      sse_restoration_tile(ctxt->src, cm->frame_to_show, cm, limits->h_start,
                           limits->h_end - limits->h_start, limits->v_start,
                           limits->v_end - limits->v_start, (1 << ctxt->plane));
584
585
586
587
588
589
590
  // #bits when a tile is not restored
  int bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  ctxt->best_tile_cost[rtile_idx] = DBL_MAX;

  RestorationInfo *plane_rsi = &rsi[ctxt->plane];
  SgrprojInfo *rtile_sgrproj_info = &plane_rsi->sgrproj_info[rtile_idx];
591
592
  uint8_t *dgd_start =
      ctxt->dgd_buffer + limits->v_start * ctxt->dgd_stride + limits->h_start;
593
  const uint8_t *src_start =
594
      ctxt->src_buffer + limits->v_start * ctxt->src_stride + limits->h_start;
595

596
  search_selfguided_restoration(
597
598
      dgd_start, limits->h_end - limits->h_start,
      limits->v_end - limits->v_start, ctxt->dgd_stride, src_start,
599
      ctxt->src_stride,
600
#if CONFIG_HIGHBITDEPTH
601
      cm->use_highbitdepth, cm->bit_depth,
602
#else
603
      0, 8,
604
#endif  // CONFIG_HIGHBITDEPTH
605
606
607
      rsi[ctxt->plane].procunit_width, rsi[ctxt->plane].procunit_height,
      &rtile_sgrproj_info->ep, rtile_sgrproj_info->xqd,
      cm->rst_internal.tmpbuf);
608
609
  plane_rsi->restoration_type[rtile_idx] = RESTORE_SGRPROJ;
  err = try_restoration_tile(ctxt->src, ctxt->cpi, rsi, (1 << ctxt->plane),
610
                             ctxt->partial_frame, rtile_idx, ctxt->dst_frame);
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
  bits =
      count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx], ref_sgrproj_info)
      << AV1_PROB_COST_SHIFT;
  bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
  double cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  if (cost_sgrproj >= cost_norestore) {
    ctxt->type[rtile_idx] = RESTORE_NONE;
  } else {
    ctxt->type[rtile_idx] = RESTORE_SGRPROJ;
    *ref_sgrproj_info = sgrproj_info[rtile_idx] =
        plane_rsi->sgrproj_info[rtile_idx];
    ctxt->best_tile_cost[rtile_idx] = err;
  }
  plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int partial_frame, int plane,
                             RestorationInfo *info, RestorationType *type,
                             double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
  struct rest_search_ctxt ctxt;
  const int nrtiles =
      init_rest_search_ctxt(src, cpi, partial_frame, plane, info, type,
                            best_tile_cost, dst_frame, &ctxt);

  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
  plane_rsi->frame_restoration_type = RESTORE_SGRPROJ;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
  }

  // Compute best Sgrproj filters for each rtile, one (encoder/decoder)
  // tile at a time.
  const AV1_COMMON *const cm = &cpi->common;
646
647
648
#if CONFIG_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(ctxt.dgd_buffer), ctxt.plane_width,
649
650
                        ctxt.plane_height, ctxt.dgd_stride, SGRPROJ_BORDER_HORZ,
                        SGRPROJ_BORDER_VERT);
651
652
653
  else
#endif
    extend_frame(ctxt.dgd_buffer, ctxt.plane_width, ctxt.plane_height,
654
                 ctxt.dgd_stride, SGRPROJ_BORDER_HORZ, SGRPROJ_BORDER_VERT);
655

656
657
658
659
660
661
  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
      SgrprojInfo ref_sgrproj_info;
      set_default_sgrproj(&ref_sgrproj_info);
      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_sgrproj_for_rtile,
                            &ref_sgrproj_info);
662
663
    }
  }
664

665
  // Cost for Sgrproj filtering
666
  SgrprojInfo ref_sgrproj_info;
667
  set_default_sgrproj(&ref_sgrproj_info);
668
669
670
671
672
673
674
675
676
677
  SgrprojInfo *sgrproj_info = info->sgrproj_info;

  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
             << AV1_PROB_COST_SHIFT;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB,
                         type[rtile_idx] != RESTORE_NONE);
    plane_rsi->sgrproj_info[rtile_idx] = sgrproj_info[rtile_idx];
    if (type[rtile_idx] == RESTORE_SGRPROJ) {
      bits += count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx],
678
679
                                 &ref_sgrproj_info)
              << AV1_PROB_COST_SHIFT;
680
      ref_sgrproj_info = plane_rsi->sgrproj_info[rtile_idx];
681
    }
682
    plane_rsi->restoration_type[rtile_idx] = type[rtile_idx];
683
  }
684
685
686
  double err = try_restoration_frame(src, cpi, cpi->rst_search, (1 << plane),
                                     partial_frame, dst_frame);
  double cost_sgrproj = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
687
688
689
  return cost_sgrproj;
}

690
691
static double find_average(const uint8_t *src, int h_start, int h_end,
                           int v_start, int v_end, int stride) {
692
693
694
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
695
  aom_clear_system_state();
696
697
698
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
699
700
701
  return avg;
}

702
703
704
705
static void compute_stats(int wiener_win, const uint8_t *dgd,
                          const uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
706
  int i, j, k, l;
707
  double Y[WIENER_WIN2];
708
709
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
710
711
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
712

713
714
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
715
716
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
717
718
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
719
720
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
721
722
723
724
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
725
      for (k = 0; k < wiener_win2; ++k) {
726
        M[k] += Y[k] * X;
727
728
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
729
730
731
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
732
          H[k * wiener_win2 + l] += Y[k] * Y[l];
733
734
735
736
        }
      }
    }
  }
737
738
739
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
740
741
    }
  }
742
743
}

744
#if CONFIG_HIGHBITDEPTH
745
static double find_average_highbd(const uint16_t *src, int h_start, int h_end,
746
                                  int v_start, int v_end, int stride) {
747
748
749
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
750
  aom_clear_system_state();
751
752
753
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
754
755
756
  return avg;
}

757
758
759
760
static void compute_stats_highbd(int wiener_win, const uint8_t *dgd8,
                                 const uint8_t *src8, int h_start, int h_end,
                                 int v_start, int v_end, int dgd_stride,
                                 int src_stride, double *M, double *H) {
761
  int i, j, k, l;
762
  double Y[WIENER_WIN2];
763
764
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
765
766
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
767
768
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
769

770
771
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
772
773
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
774
775
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
776
777
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
778
779
780
781
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
782
      for (k = 0; k < wiener_win2; ++k) {
783
        M[k] += Y[k] * X;
784
785
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
786
787
788
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
789
          H[k * wiener_win2 + l] += Y[k] * Y[l];
790
791
792
793
        }
      }
    }
  }
794
795
796
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
797
798
    }
  }
799
}
800
#endif  // CONFIG_HIGHBITDEPTH
801

802
803
804
static INLINE int wrap_index(int i, int wiener_win) {
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
  return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
805
806
807
}

// Fix vector b, update vector a
808
809
static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
810
  int i, j;
811
  double S[WIENER_WIN];
812
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
813
814
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
815
816
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
817
818
819
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; ++j) {
      const int jj = wrap_index(j, wiener_win);
820
821
822
      A[jj] += Mc[i][j] * b[i];
    }
  }
823
824
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
825
      int k, l;
826
827
828
829
830
831
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l) {
          const int kk = wrap_index(k, wiener_win);
          const int ll = wrap_index(l, wiener_win);
          B[ll * wiener_halfwin1 + kk] +=
              Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
832
833
834
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
835
  // Normalization enforcement in the system of equations itself
836
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
837
    A[i] -=
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
853
    }
854
    memcpy(a, S, wiener_win * sizeof(*a));
855
856
857
858
  }
}

// Fix vector a, update vector b
859
860
static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
861
  int i, j;
862
  double S[WIENER_WIN];
863
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
864
865
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
866
867
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
868
869
870
  for (i = 0; i < wiener_win; i++) {
    const int ii = wrap_index(i, wiener_win);
    for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
871
872
  }

873
874
875
876
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
      const int ii = wrap_index(i, wiener_win);
      const int jj = wrap_index(j, wiener_win);
877
      int k, l;
878
879
880
881
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l)
          B[jj * wiener_halfwin1 + ii] +=
              Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
882
883
    }
  }
Aamir Anis's avatar
Aamir Anis committed
884
  // Normalization enforcement in the system of equations itself
885
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
886
    A[i] -=
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
902
    }
903
    memcpy(b, S, wiener_win * sizeof(*b));
904
905
906
  }
}

907
908
static int wiener_decompose_sep_sym(int wiener_win, double *M, double *H,
                                    double *a, double *b) {
909
910
911
912
  static const int init_filt[WIENER_WIN] = {
    WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
    WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
    WIENER_FILT_TAP0_MIDV,
913
  };
914
915
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
916
917
918
919
920
  int i, j, iter;
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
  for (i = 0; i < wiener_win; i++) {
    a[i] = b[i] = (double)init_filt[i + plane_off] / WIENER_FILT_STEP;
921
  }
922
923
924
925
926
927
  for (i = 0; i < wiener_win; i++) {
    Mc[i] = M + i * wiener_win;
    for (j = 0; j < wiener_win; j++) {
      Hc[i * wiener_win + j] =
          H + i * wiener_win * wiener_win2 + j * wiener_win;
    }
928
  }
929
930

  iter = 1;
931
  while (iter < NUM_WIENER_ITERS) {
932
933
    update_a_sep_sym(wiener_win, Mc, Hc, a, b);
    update_b_sep_sym(wiener_win, Mc, Hc, a, b);
934
935
    iter++;
  }
936
  return 1;
937
938
}

939
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
940
941
// against identity filters; Final score is defined as the difference between
// the function values
942
943
static double compute_score(int wiener_win, double *M, double *H,
                            InterpKernel vfilt, InterpKernel hfilt) {
944
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
945
946
947
948
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
949
  double a[WIENER_WIN], b[WIENER_WIN];
950
951
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
952
953
954

  aom_clear_system_state();

955
956
957
958
959
960
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
961
  }
962
963
964
  for (k = 0; k < wiener_win; ++k) {
    for (l = 0; l < wiener_win; ++l)
      ab[k * wiener_win + l] = a[l + plane_off] * b[k + plane_off];
Aamir Anis's avatar
Aamir Anis committed
965
  }
966
  for (k = 0; k < wiener_win2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
967
    P += ab[k] * M[k];
968
969
    for (l = 0; l < wiener_win2; ++l)
      Q += ab[k] * H[k * wiener_win2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
970
971
972
  }
  Score = Q - 2 * P;

973
974
  iP = M[wiener_win2 >> 1];
  iQ = H[(wiener_win2 >> 1) * wiener_win2 + (wiener_win2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
975
976
977
978
979
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

980
static void quantize_sym_filter(int wiener_win, double *f, InterpKernel fi) {
981
  int i;
982
983
  const int wiener_halfwin = (wiener_win >> 1);
  for (i = 0; i < wiener_halfwin; ++i) {
984
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
985
986
  }
  // Specialize for 7-tap filter
987
988
989
990
991
992
993
994
995
  if (wiener_win == WIENER_WIN) {
    fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
    fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
  } else {
    fi[2] = CLIP(fi[1], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
    fi[1] = CLIP(fi[0], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[0] = 0;
  }
996
997
998
999
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
1000
1001
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
1002
1003
}

1004
static int count_wiener_bits(int wiener_win, WienerInfo *wiener_info,
1005
1006
                             WienerInfo *ref_wiener_info) {
  int bits = 0;
1007
1008
1009
1010
1011
1012
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
1023
1024
1025
1026
1027
1028
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV);
  return bits;
}

1042
1043
1044
#define USE_WIENER_REFINEMENT_SEARCH 1
static int64_t finer_tile_search_wiener(const YV12_BUFFER_CONFIG *src,
                                        AV1_COMP *cpi, RestorationInfo *rsi,
1045
1046
                                        int start_step, int plane,
                                        int wiener_win, int tile_idx,
1047
1048
                                        int partial_frame,
                                        YV12_BUFFER_CONFIG *dst_frame) {
1049
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
1050
  int64_t err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
1051
                                     tile_idx, dst_frame);
1052