pickrst.c 52.7 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4
5
6
7
8
9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10
11
12
 */

#include <assert.h>
13
#include <float.h>
14
15
16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20
21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24
25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
29
30
31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
#include "av1/encoder/quantize.h"
32

33
34
35
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *cpi, int filter_level,
                                      int partial_frame, RestorationInfo *info,
36
37
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
38

39
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 3, 3, 2 };
40
41

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
42
43
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
44
                                    int width, int v_start, int height,
45
46
                                    int components_pattern) {
  int64_t filt_err = 0;
47
48
49
50
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
51
52
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
53
54
55
56
57
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
58
59
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
60
61
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
62
63
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
64
65
    }
    return filt_err;
66
67
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
68
69
70
71
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
72
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
73
74
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
75
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
76
  }
77
78
79
  return filt_err;
}

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
static int64_t sse_restoration_frame(const YV12_BUFFER_CONFIG *src,
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

110
111
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
112
113
114
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
115
                                    YV12_BUFFER_CONFIG *dst_frame) {
116
117
118
119
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
  ntiles = av1_get_rest_ntiles(width, height, &tile_width, &tile_height,
                               &nhtiles, &nvtiles);
135
136
  (void)ntiles;

137
138
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
139
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
140
141
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
142
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
143
                                  v_start, v_end - v_start, components_pattern);
144
145
146
147
148

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
149
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
150
                                     int components_pattern, int partial_frame,
151
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
152
  AV1_COMMON *const cm = &cpi->common;
153
  int64_t filt_err;
154
155
156
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
  filt_err = sse_restoration_frame(src, dst_frame, components_pattern);
157
158
159
  return filt_err;
}

160
161
162
163
164
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
165
166
167
168
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
        const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
        const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
200
201
202
203
204
    }
  }
  return err;
}

205
206
207
208
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
209
210
211
212
213
214
215
216
217
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
  xqd[0] = -xq[0];
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
277
278
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
279
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
280
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
281
282
283
  int i, j, ep, bestep = 0;
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
284

285
286
287
288
289
290
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
291
292
          flt1[i * width + j] = (int32_t)dat[i * dat_stride + j];
          flt2[i * width + j] = (int32_t)dat[i * dat_stride + j];
293
294
295
296
297
298
299
300
        }
      }
    } else {
      uint8_t *dat = dat8;
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
          const int k = i * width + j;
          const int l = i * dat_stride + j;
301
302
          flt1[k] = (int32_t)dat[l];
          flt2[k] = (int32_t)dat[l];
303
304
305
306
307
308
309
        }
      }
    }
    av1_selfguided_restoration(flt1, width, height, width, bit_depth,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
    av1_selfguided_restoration(flt2, width, height, width, bit_depth,
                               sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
310
311
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
312
    encode_xq(exq, exqd);
313
314
315
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int filter_level, int partial_frame,
330
331
                             RestorationInfo *info, double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
332
333
334
335
336
337
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
338
  RestorationInfo *rsi = &cpi->rst_search[0];
339
340
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
341
  // Allocate for the src buffer at high precision
342
343
344
345
346
347
348
349
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
  //  Make a copy of the unfiltered / processed recon buffer
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

350
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
351
352

  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
353
    rsi->sgrproj_info[tile_idx].level = 0;
354
355
356
357
358
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
359
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
360
                               h_end - h_start, v_start, v_end - v_start, 1);
361
362
363
364
365
366
367
368
369
370
371
372
373
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
374
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
375
        cm->rst_internal.tmpbuf);
376
377
    rsi->sgrproj_info[tile_idx].level = 1;
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
378
                               dst_frame);
379
380
381
382
383
384
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
      sgrproj_info[tile_idx].level = 0;
    } else {
385
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
386
387
388
389
390
391
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
392
    rsi->sgrproj_info[tile_idx].level = 0;
393
394
  }
  // Cost for Sgrproj filtering
395
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
396
397
398
399
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, sgrproj_info[tile_idx].level);
400
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
401
402
403
404
405
           sizeof(sgrproj_info[tile_idx]));
    if (sgrproj_info[tile_idx].level) {
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
  }
406
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
407
408
409
410
411
412
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
  return cost_sgrproj;
}

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
static int64_t compute_sse(uint8_t *dgd, int width, int height, int dgd_stride,
                           uint8_t *src, int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}

#if CONFIG_AOM_HIGHBITDEPTH
static int64_t compute_sse_highbd(uint16_t *dgd, int width, int height,
                                  int dgd_stride, uint16_t *src,
                                  int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}
#endif  // CONFIG_AOM_HIGHBITDEPTH

static void search_domaintxfmrf_restoration(uint8_t *dgd8, int width,
                                            int height, int dgd_stride,
                                            uint8_t *src8, int src_stride,
447
                                            int bit_depth, int *sigma_r,
448
                                            uint8_t *fltbuf, int32_t *tmpbuf) {
449
450
451
452
453
  const int first_p_step = 8;
  const int second_p_range = first_p_step >> 1;
  const int second_p_step = 2;
  const int third_p_range = second_p_step >> 1;
  const int third_p_step = 1;
454
  int p, best_p0, best_p = -1;
455
456
  int64_t best_sse = INT64_MAX, sse;
  if (bit_depth == 8) {
457
    uint8_t *flt = fltbuf;
458
459
460
461
    uint8_t *dgd = dgd8;
    uint8_t *src = src8;
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
462
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
463
                                   width, tmpbuf);
464
      sse = compute_sse(flt, width, height, width, src, src_stride);
465
466
467
468
469
470
471
472
473
474
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
475
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
476
                                   width, tmpbuf);
477
      sse = compute_sse(flt, width, height, width, src, src_stride);
478
479
480
481
482
483
484
485
486
487
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
488
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
489
                                   width, tmpbuf);
490
      sse = compute_sse(flt, width, height, width, src, src_stride);
491
492
493
494
495
496
497
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
  } else {
#if CONFIG_AOM_HIGHBITDEPTH
498
    uint16_t *flt = (uint16_t *)fltbuf;
499
500
501
502
    uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
    uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
503
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
504
505
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
506
507
508
509
510
511
512
513
514
515
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
516
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
517
518
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
519
520
521
522
523
524
525
526
527
528
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
529
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
530
531
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
532
533
534
535
536
537
538
539
540
541
542
543
544
545
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
#else
    assert(0);
#endif  // CONFIG_AOM_HIGHBITDEPTH
  }
  *sigma_r = best_p;
}

static double search_domaintxfmrf(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                  int filter_level, int partial_frame,
546
547
                                  RestorationInfo *info, double *best_tile_cost,
                                  YV12_BUFFER_CONFIG *dst_frame) {
548
  DomaintxfmrfInfo *domaintxfmrf_info = info->domaintxfmrf_info;
549
550
  double cost_norestore, cost_domaintxfmrf;
  int64_t err;
551
552
553
554
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
555
  RestorationInfo *rsi = &cpi->rst_search[0];
556
557
558
559
560
561
562
563
564
565
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
  //  Make a copy of the unfiltered / processed recon buffer
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

566
  rsi->frame_restoration_type = RESTORE_DOMAINTXFMRF;
567
568

  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
569
    rsi->domaintxfmrf_info[tile_idx].level = 0;
570
571
572
573
574
  // Compute best Domaintxfm filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
575
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
576
                               h_end - h_start, v_start, v_end - v_start, 1);
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;

    search_domaintxfmrf_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
591
592
        &rsi->domaintxfmrf_info[tile_idx].sigma_r, cpi->extra_rstbuf,
        cm->rst_internal.tmpbuf);
593

594
595
    rsi->domaintxfmrf_info[tile_idx].level = 1;
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
596
                               dst_frame);
597
598
599
600
601
602
    bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 1);
    cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_domaintxfmrf >= cost_norestore) {
      domaintxfmrf_info[tile_idx].level = 0;
    } else {
603
      memcpy(&domaintxfmrf_info[tile_idx], &rsi->domaintxfmrf_info[tile_idx],
604
605
606
607
608
609
610
             sizeof(domaintxfmrf_info[tile_idx]));
      bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_DOMAINTXFMRF]) >> 4,
          err);
    }
611
    rsi->domaintxfmrf_info[tile_idx].level = 0;
612
613
  }
  // Cost for Domaintxfmrf filtering
614
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
615
616
617
618
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB,
                         domaintxfmrf_info[tile_idx].level);
619
    memcpy(&rsi->domaintxfmrf_info[tile_idx], &domaintxfmrf_info[tile_idx],
620
621
622
623
624
           sizeof(domaintxfmrf_info[tile_idx]));
    if (domaintxfmrf_info[tile_idx].level) {
      bits += (DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT);
    }
  }
625
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
626
627
628
629
630
631
  cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
  return cost_domaintxfmrf;
}

632
633
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
634
635
636
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
637
638
639
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
640
641
642
  return avg;
}

643
644
645
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
646
  int i, j, k, l;
647
  double Y[WIENER_WIN2];
648
649
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
650

651
652
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
653
654
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
655
656
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
657
658
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
659
660
661
662
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
663
      for (k = 0; k < WIENER_WIN2; ++k) {
664
        M[k] += Y[k] * X;
665
666
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
667
668
669
670
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
671
672
673
674
        }
      }
    }
  }
675
676
677
678
679
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
680
681
}

Yaowu Xu's avatar
Yaowu Xu committed
682
#if CONFIG_AOM_HIGHBITDEPTH
683
684
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
685
686
687
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
688
689
690
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
691
692
693
  return avg;
}

694
695
696
697
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
698
  int i, j, k, l;
699
  double Y[WIENER_WIN2];
700
701
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
702
703
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
704

705
706
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
707
708
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
709
710
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
711
712
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
713
714
715
716
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
717
      for (k = 0; k < WIENER_WIN2; ++k) {
718
        M[k] += Y[k] * X;
719
720
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
721
722
723
724
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
725
726
727
728
        }
      }
    }
  }
729
730
731
732
733
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
734
}
Yaowu Xu's avatar
Yaowu Xu committed
735
#endif  // CONFIG_AOM_HIGHBITDEPTH
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
758
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
759
760
761
762
763
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
764
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
765
    c = 0;
766
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
767
768
769
770
771
772
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
773
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
774
775
776
777
778
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
779
780
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
781
  int w, w2;
782
783
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
784
785
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
786
787
788
789
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
790
791
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
792
      int k, l;
793
794
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
795
796
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
797
798
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
799
800
801
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
802
  // Normalization enforcement in the system of equations itself
803
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
804
805
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
806
807
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
808
809
810
811
812
813
814
815
816
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
817
    }
Aamir Anis's avatar
Aamir Anis committed
818
    memcpy(a, S, w * sizeof(*a));
819
820
821
822
823
824
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
825
826
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
827
  int w, w2;
828
829
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
830
  for (i = 0; i < WIENER_WIN; i++) {
831
    const int ii = wrap_index(i);
832
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
833
834
  }

835
836
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
837
838
839
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
840
841
842
843
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
844
845
    }
  }
Aamir Anis's avatar
Aamir Anis committed
846
  // Normalization enforcement in the system of equations itself
847
848
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
849
  for (i = 0; i < w2 - 1; ++i)
850
851
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
852
853
854
855
856
857
858
859
860
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
861
    }
Aamir Anis's avatar
Aamir Anis committed
862
    memcpy(b, S, w * sizeof(*b));
863
864
865
  }
}

866
867
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
868
  static const double init_filt[WIENER_WIN] = {
869
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
870
871
  };
  int i, j, iter;
872
873
874
875
876
877
878
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
879
880
    }
  }
881
882
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
883
884
885
886
887
888
889

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
890
  return 1;
891
892
}

893
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
894
895
// against identity filters; Final score is defined as the difference between
// the function values
896
static double compute_score(double *M, double *H, int *vfilt, int *hfilt) {
897
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
898
899
900
901
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
902
903
904
905
906
907
908
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
909
  }
910
911
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
912
  }
913
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
914
    P += ab[k] * M[k];
915
916
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
917
918
919
  }
  Score = Q - 2 * P;

920
921
  iP = M[WIENER_WIN2 >> 1];
  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
922
923
924
925
926
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

927
928
static void quantize_sym_filter(double *f, int *fi) {
  int i;
929
930
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
931
932
933
934
935
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
  fi[3] = WIENER_FILT_STEP - 2 * (fi[0] + fi[1] + fi[2]);
}

static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                               int filter_level, int partial_frame, int plane,
                               RestorationInfo *info,
                               YV12_BUFFER_CONFIG *dst_frame) {
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
  double cost_wiener = 0, cost_norestore = 0;
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = src->uv_crop_width;
  const int height = src->uv_crop_height;
  const int src_stride = src->uv_stride;
  const int dgd_stride = dgd->uv_stride;
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
964
  const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
965
966
967
968
969
970
                                         &tile_height, &nhtiles, &nvtiles);

  assert(width == dgd->uv_crop_width);
  assert(height == dgd->uv_crop_height);

  //  Make a copy of the unfiltered / processed recon buffer
971
972
973
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  aom_yv12_copy_u(cm->frame_to_show, &cpi->last_frame_uf);
  aom_yv12_copy_v(cm->frame_to_show, &cpi->last_frame_uf);
974
975
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        0, partial_frame);
976
977
978
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
  aom_yv12_copy_u(cm->frame_to_show, &cpi->last_frame_db);
  aom_yv12_copy_v(cm->frame_to_show, &cpi->last_frame_db);
979
980
981

  rsi[plane].frame_restoration_type = RESTORE_NONE;

982
  err = sse_restoration_frame(src, cm->frame_to_show, (1 << plane));
983
984
985
986
  bits = 0;
  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
987

988
#if CONFIG_AOM_HIGHBITDEPTH
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
  if (cm->use_highbitdepth) {
    if (plane == AOM_PLANE_U) {
      extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->u_buffer), width, height,
                          dgd_stride);
      compute_stats_highbd(dgd->u_buffer, src->u_buffer, 0, width, 0, height,
                           dgd_stride, src_stride, M, H);
    } else if (plane == AOM_PLANE_V) {
      extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->v_buffer), width, height,
                          dgd_stride);
      compute_stats_highbd(dgd->v_buffer, src->v_buffer, 0, width, 0, height,
                           dgd_stride, src_stride, M, H);
    } else {
      assert(0);
    }
1003
  } else {
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
#endif
    if (plane == AOM_PLANE_U) {
      extend_frame(dgd->u_buffer, width, height, dgd_stride);
      compute_stats(dgd->u_buffer, src->u_buffer, 0, width, 0, height,
                    dgd_stride, src_stride, M, H);
    } else if (plane == AOM_PLANE_V) {
      extend_frame(dgd->v_buffer, width, height, dgd_stride);
      compute_stats(dgd->v_buffer, src->v_buffer, 0, width, 0, height,
                    dgd_stride, src_stride, M, H);
    } else {
      assert(0);
    }
#if CONFIG_AOM_HIGHBITDEPTH
1017
  }
1018
1019
#endif

1020
1021
  if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
    info->frame_restoration_type = RESTORE_NONE;
1022
1023
1024
    aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
    aom_yv12_copy_u(&cpi->last_frame_uf, cm->frame_to_show);
    aom_yv12_copy_v(&cpi->last_frame_uf, cm->frame_to_show);
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    return cost_norestore;
  }
  quantize_sym_filter(vfilterd, rsi[plane].wiener_info[0].vfilter);
  quantize_sym_filter(hfilterd, rsi[plane].wiener_info[0].hfilter);

  // Filter score computes the value of the function x'*A*x - x'*b for the
  // learned filter and compares it against identity filer. If there is no
  // reduction in the function, the filter is reverted back to identity
  score = compute_score(M, H, rsi[plane].wiener_info[0].vfilter,
                        rsi[plane].wiener_info[0].hfilter);
  if (score > 0.0) {
    info->frame_restoration_type = RESTORE_NONE;
1037
1038
1039
    aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
    aom_yv12_copy_u(&cpi->last_frame_uf, cm->frame_to_show);
    aom_yv12_copy_v(&cpi->last_frame_uf, cm->frame_to_show);
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
    return cost_norestore;
  }

  info->frame_restoration_type = RESTORE_WIENER;
  rsi[plane].restoration_type[0] = info->restoration_type[0] = RESTORE_WIENER;
  rsi[plane].wiener_info[0].level = 1;
  memcpy(&wiener_info[0], &rsi[plane].wiener_info[0], sizeof(wiener_info[0]));
  for (tile_idx = 1; tile_idx < ntiles; ++tile_idx) {
    info->restoration_type[tile_idx] = RESTORE_WIENER;
    memcpy(&rsi[plane].wiener_info[tile_idx], &rsi[plane].wiener_info[0],
           sizeof(rsi[plane].wiener_info[0]));
    memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[0],
           sizeof(rsi[plane].wiener_info[0]));
  }
  err = try_restoration_frame(src, cpi, rsi, (1 << plane), partial_frame,
                              dst_frame);
  bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
  if (cost_wiener > cost_norestore) {
    info->frame_restoration_type = RESTORE_NONE;
1060
1061
1062
    aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
    aom_yv12_copy_u(&cpi->last_frame_uf, cm->frame_to_show);
    aom_yv12_copy_v(&cpi->last_frame_uf, cm->frame_to_show);
1063
1064
1065
    return cost_norestore;
  }

1066
1067
1068
  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
  aom_yv12_copy_u(&cpi->last_frame_uf, cm->frame_to_show);
  aom_yv12_copy_v(&cpi->last_frame_uf, cm->frame_to_show);
1069
  return cost_wiener;
1070
1071
}

1072
1073
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                            int filter_level, int partial_frame,
1074
1075
                            RestorationInfo *info, double *best_tile_cost,
                            YV12_BUFFER_CONFIG *dst_frame) {
1076
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
1077
  AV1_COMMON *const cm = &cpi->common;
1078
  RestorationInfo *rsi = cpi->rst_search;
1079
1080
  int64_t err;
  int bits;
1081
  double cost_wiener, cost_norestore;
1082
  MACROBLOCK *x = &cpi->td.mb;
1083
1084
1085
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
1086
1087
1088
1089
1090
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
1091
  double score;
1092
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
1093
  int h_start, h_end, v_start, v_end;
1094
  int i;
1095
1096
  const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
1097
1098
1099
1100
1101
1102
  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

  //  Make a copy of the unfiltered / processed recon buffer
Yaowu Xu's avatar
Yaowu Xu committed
1103
1104
1105
1106
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
1107

1108
  rsi->frame_restoration_type = RESTORE_WIENER;
1109

1110
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
1111
    rsi->wiener_info[tile_idx].level = 0;
1112

1113
1114
1115
1116
1117
1118
1119
1120
1121
// Construct a (WIENER_HALFWIN)-pixel border around the frame
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->y_buffer), width, height,
                        dgd_stride);
  else
#endif
    extend_frame(dgd->y_buffer, width, height, dgd_stride);

1122
1123
  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
1124
1125
1126
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
1127
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
1128
                               h_end - h_start,