pickrst.c 43.7 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20 21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24 25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
#include "av1/encoder/av1_quantize.h"
29 30 31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
32

33 34 35 36 37 38
// When set to 1, only RESTORE_WIENER or RESTORE_NONE are allowed
#define FORCE_RESTORE_WIENER_ONLY 0

// Number of Wiener iterations
#define NUM_WIENER_ITERS 10

39
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
40 41
                                      AV1_COMP *cpi, int partial_frame,
                                      RestorationInfo *info,
42
                                      RestorationType *rest_level,
43 44
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
45

46
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
47 48

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
49 50
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
51
                                    int width, int v_start, int height,
52 53
                                    int components_pattern) {
  int64_t filt_err = 0;
54 55 56 57
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
58 59
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
60 61 62 63 64
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
65 66
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
67 68
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
69 70
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
71 72
    }
    return filt_err;
73 74
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
75 76 77 78
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
79
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
80 81
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
82
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
83
  }
84 85 86
  return filt_err;
}

87 88
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
105 106
#else
  (void)cm;
107 108 109 110 111 112 113 114 115 116 117 118 119
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

120 121
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
122 123 124
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
125
                                    YV12_BUFFER_CONFIG *dst_frame) {
126 127 128 129
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
130 131 132 133 134 135 136 137 138 139 140 141 142
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
143 144 145
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
146 147
  (void)ntiles;

148 149
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
150
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
151 152
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
153
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
154
                                  v_start, v_end - v_start, components_pattern);
155 156 157 158 159

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
160
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
161
                                     int components_pattern, int partial_frame,
162
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
163
  AV1_COMMON *const cm = &cpi->common;
164
  int64_t filt_err;
165 166
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
167
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
168 169 170
  return filt_err;
}

171 172 173 174 175
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
176 177 178 179
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
180 181 182 183 184 185 186 187 188
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
189
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
205
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
206 207 208 209 210
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
211 212 213 214 215
    }
  }
  return err;
}

216 217 218 219
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
220 221 222 223 224 225 226
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

227 228 229
  // Default
  xq[0] = 0;
  xq[1] = 0;
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
280
  xqd[0] = xq[0];
281
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
282
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
283 284 285 286 287 288
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
289 290
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
291
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
292
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
293
  int ep, bestep = 0;
294 295
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
296

297 298
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
299
#if CONFIG_AOM_HIGHBITDEPTH
300 301
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
302 303
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter_highbd(dat, width, height, dat_stride, flt1, width,
304
                                 sgr_params[ep].corner, sgr_params[ep].edge);
305
#else
306 307 308
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
309
#endif  // USE_HIGHPASS_IN_SGRPROJ
310 311 312
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
313
    } else {
314
#endif
315 316
#if USE_HIGHPASS_IN_SGRPROJ
      av1_highpass_filter(dat8, width, height, dat_stride, flt1, width,
317
                          sgr_params[ep].corner, sgr_params[ep].edge);
318 319 320 321
#else
    av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
#endif  // USE_HIGHPASS_IN_SGRPROJ
322
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
323
                                 sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
324
#if CONFIG_AOM_HIGHBITDEPTH
325
    }
326
#endif
327 328
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
329
    encode_xq(exq, exqd);
330 331 332
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
333 334 335 336 337 338 339 340 341 342 343 344 345
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
346 347
                             int partial_frame, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
348
                             YV12_BUFFER_CONFIG *dst_frame) {
349 350 351 352 353 354
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
355
  RestorationInfo *rsi = &cpi->rst_search[0];
356 357
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
358
  // Allocate for the src buffer at high precision
359 360 361
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
362
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
363

364 365 366
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
367 368 369 370 371
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
372
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
373
                               h_end - h_start, v_start, v_end - v_start, 1);
374 375 376 377 378 379 380 381 382 383 384 385 386
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
387
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
388
        cm->rst_internal.tmpbuf);
389
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
390
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
391
                               dst_frame);
392 393 394 395
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
396
      type[tile_idx] = RESTORE_NONE;
397
    } else {
398
      type[tile_idx] = RESTORE_SGRPROJ;
399
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
400 401 402 403 404 405
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
406
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
407 408
  }
  // Cost for Sgrproj filtering
409
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
410 411 412
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
413
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
414
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
415
           sizeof(sgrproj_info[tile_idx]));
416
    if (type[tile_idx] == RESTORE_SGRPROJ) {
417 418
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
419
    rsi->restoration_type[tile_idx] = type[tile_idx];
420
  }
421
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
422 423 424 425 426
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_sgrproj;
}

427 428
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
429 430 431
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
432 433 434
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
435 436 437
  return avg;
}

438 439 440
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
441
  int i, j, k, l;
442
  double Y[WIENER_WIN2];
443 444
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
445

446 447
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
448 449
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
450 451
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
452 453
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
454 455 456 457
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
458
      for (k = 0; k < WIENER_WIN2; ++k) {
459
        M[k] += Y[k] * X;
460 461
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
462 463 464 465
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
466 467 468 469
        }
      }
    }
  }
470 471 472 473 474
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
475 476
}

Yaowu Xu's avatar
Yaowu Xu committed
477
#if CONFIG_AOM_HIGHBITDEPTH
478 479
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
480 481 482
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
483 484 485
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
486 487 488
  return avg;
}

489 490 491 492
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
493
  int i, j, k, l;
494
  double Y[WIENER_WIN2];
495 496
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
497 498
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
499

500 501
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
502 503
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
504 505
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
506 507
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
508 509 510 511
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
512
      for (k = 0; k < WIENER_WIN2; ++k) {
513
        M[k] += Y[k] * X;
514 515
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
516 517 518 519
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
520 521 522 523
        }
      }
    }
  }
524 525 526 527 528
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
529
}
Yaowu Xu's avatar
Yaowu Xu committed
530
#endif  // CONFIG_AOM_HIGHBITDEPTH
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
553
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
554 555 556 557 558
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
559
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
560
    c = 0;
561
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
562 563 564 565 566 567
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
568
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
569 570 571 572 573
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
574 575
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
576
  int w, w2;
577 578
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
579 580
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
581 582 583 584
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
585 586
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
587
      int k, l;
588 589
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
590 591
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
592 593
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
594 595 596
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
597
  // Normalization enforcement in the system of equations itself
598
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
599 600
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
601 602
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
603 604 605 606 607 608 609 610 611
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
612
    }
Aamir Anis's avatar
Aamir Anis committed
613
    memcpy(a, S, w * sizeof(*a));
614 615 616 617 618 619
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
620 621
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
622
  int w, w2;
623 624
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
625
  for (i = 0; i < WIENER_WIN; i++) {
626
    const int ii = wrap_index(i);
627
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
628 629
  }

630 631
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
632 633 634
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
635 636 637 638
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
639 640
    }
  }
Aamir Anis's avatar
Aamir Anis committed
641
  // Normalization enforcement in the system of equations itself
642 643
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
644
  for (i = 0; i < w2 - 1; ++i)
645 646
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
647 648 649 650 651 652 653 654 655
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
656
    }
Aamir Anis's avatar
Aamir Anis committed
657
    memcpy(b, S, w * sizeof(*b));
658 659 660
  }
}

661 662
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
663
  static const double init_filt[WIENER_WIN] = {
664
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
665 666
  };
  int i, j, iter;
667 668 669 670 671 672 673
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
674 675
    }
  }
676 677
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
678 679

  iter = 1;
680
  while (iter < NUM_WIENER_ITERS) {
681 682 683 684
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
685
  return 1;
686 687
}

688
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
689 690
// against identity filters; Final score is defined as the difference between
// the function values
691 692
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
693
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
694 695 696 697
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
698 699 700 701 702 703 704
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
705
  }
706 707
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
708
  }
709
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
710
    P += ab[k] * M[k];
711 712
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
713 714 715
  }
  Score = Q - 2 * P;

716 717
  iP = M[WIENER_WIN2 >> 1];
  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
718 719 720 721 722
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

723
static void quantize_sym_filter(double *f, InterpKernel fi) {
724
  int i;
725 726
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
727 728 729 730 731
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
732 733 734 735
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
736 737
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
738 739 740
}

static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
741
                               int partial_frame, int plane,
742
                               RestorationInfo *info, RestorationType *type,
743 744 745 746 747 748
                               YV12_BUFFER_CONFIG *dst_frame) {
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
749
  double cost_wiener, cost_norestore, cost_wiener_frame, cost_norestore_frame;
750 751 752 753 754 755 756 757 758 759 760
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = src->uv_crop_width;
  const int height = src->uv_crop_height;
  const int src_stride = src->uv_stride;
  const int dgd_stride = dgd->uv_stride;
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
761
  int h_start, h_end, v_start, v_end;
762 763 764
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[1].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
765 766 767 768
  assert(width == dgd->uv_crop_width);
  assert(height == dgd->uv_crop_height);

  rsi[plane].frame_restoration_type = RESTORE_NONE;
769
  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
770
  bits = 0;
771
  cost_norestore_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
772 773

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
774

775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
                               h_end - h_start, v_start, v_end - v_start,
                               1 << plane);
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    // best_tile_cost[tile_idx] = DBL_MAX;

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, WIENER_HALFWIN,
                             WIENER_HALFWIN, &h_start, &h_end, &v_start,
                             &v_end);
796
    if (plane == AOM_PLANE_U) {
797 798 799 800 801 802 803 804
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
805
    } else if (plane == AOM_PLANE_V) {
806 807 808 809 810 811 812 813
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
814 815 816
    } else {
      assert(0);
    }
817 818 819 820 821 822

    type[tile_idx] = RESTORE_WIENER;

    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
      type[tile_idx] = RESTORE_NONE;
      continue;
823
    }
824 825
    quantize_sym_filter(vfilterd, rsi[plane].wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi[plane].wiener_info[tile_idx].hfilter);
826

827 828 829 830 831 832 833 834 835
    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
    score = compute_score(M, H, rsi[plane].wiener_info[tile_idx].vfilter,
                          rsi[plane].wiener_info[tile_idx].hfilter);
    if (score > 0.0) {
      type[tile_idx] = RESTORE_NONE;
      continue;
    }
836

837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
    rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
    err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                               tile_idx, 0, 0, dst_frame);
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_wiener >= cost_norestore) {
      type[tile_idx] = RESTORE_NONE;
    } else {
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
    }
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Cost for Wiener filtering
  bits = 0;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
    }
    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
863
  }
864
  err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
865
                              dst_frame);
866 867 868 869 870
  cost_wiener_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  if (cost_wiener_frame < cost_norestore_frame) {
    info->frame_restoration_type = RESTORE_WIENER;
  } else {
871 872 873
    info->frame_restoration_type = RESTORE_NONE;
  }

874 875
  return info->frame_restoration_type == RESTORE_WIENER ? cost_wiener_frame
                                                        : cost_norestore_frame;
876 877
}

878
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
879 880
                            int partial_frame, RestorationInfo *info,
                            RestorationType *type, double *best_tile_cost,
881
                            YV12_BUFFER_CONFIG *dst_frame) {
882
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
883
  AV1_COMMON *const cm = &cpi->common;
884
  RestorationInfo *rsi = cpi->rst_search;
885 886
  int64_t err;
  int bits;
887
  double cost_wiener, cost_norestore;
888
  MACROBLOCK *x = &cpi->td.mb;
889 890 891
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
892 893 894 895 896
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
897
  double score;