pickrst.c 42.5 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20 21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24 25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
#include "av1/encoder/av1_quantize.h"
29 30 31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
32

33
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
34 35
                                      AV1_COMP *cpi, int partial_frame,
                                      RestorationInfo *info,
36
                                      RestorationType *rest_level,
37 38
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
39

40
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
41 42

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
43 44
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
45
                                    int width, int v_start, int height,
46 47
                                    int components_pattern) {
  int64_t filt_err = 0;
48 49 50 51
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
52 53
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
54 55 56 57 58
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
59 60
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
61 62
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
63 64
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
65 66
    }
    return filt_err;
67 68
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
69 70 71 72
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
73
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
74 75
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
76
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
77
  }
78 79 80
  return filt_err;
}

81 82
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
99 100
#else
  (void)cm;
101 102 103 104 105 106 107 108 109 110 111 112 113
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

114 115
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
116 117 118
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
119
                                    YV12_BUFFER_CONFIG *dst_frame) {
120 121 122 123
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
124 125 126 127 128 129 130 131 132 133 134 135 136
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
137 138 139
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
140 141
  (void)ntiles;

142 143
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
144
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
145 146
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
147
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
148
                                  v_start, v_end - v_start, components_pattern);
149 150 151 152 153

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
154
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
155
                                     int components_pattern, int partial_frame,
156
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
157
  AV1_COMMON *const cm = &cpi->common;
158
  int64_t filt_err;
159 160
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
161
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
162 163 164
  return filt_err;
}

165 166 167 168 169
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
170 171 172 173
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
174 175 176 177 178 179 180 181 182
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
183
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
199
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
200 201 202 203 204
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
205 206 207 208 209
    }
  }
  return err;
}

210 211 212 213
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
214 215 216 217 218 219 220 221 222
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
  xqd[0] = -xq[0];
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
282 283
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
284
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
285
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
286
  int ep, bestep = 0;
287 288
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
289

290 291
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
292
#if CONFIG_AOM_HIGHBITDEPTH
293 294
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
295 296 297 298 299 300
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt1,
                                        width, bit_depth, sgr_params[ep].r1,
                                        sgr_params[ep].e1, tmpbuf2);
      av1_selfguided_restoration_highbd(dat, width, height, dat_stride, flt2,
                                        width, bit_depth, sgr_params[ep].r2,
                                        sgr_params[ep].e2, tmpbuf2);
301
    } else {
302 303 304 305 306 307 308 309
#endif
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, width,
                                 bit_depth, sgr_params[ep].r1,
                                 sgr_params[ep].e1, tmpbuf2);
      av1_selfguided_restoration(dat8, width, height, dat_stride, flt2, width,
                                 bit_depth, sgr_params[ep].r2,
                                 sgr_params[ep].e2, tmpbuf2);
#if CONFIG_AOM_HIGHBITDEPTH
310
    }
311
#endif
312 313
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
314
    encode_xq(exq, exqd);
315 316 317
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
318 319 320 321 322 323 324 325 326 327 328 329 330
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
331 332
                             int partial_frame, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
333
                             YV12_BUFFER_CONFIG *dst_frame) {
334 335 336 337 338 339
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
340
  RestorationInfo *rsi = &cpi->rst_search[0];
341 342
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
343
  // Allocate for the src buffer at high precision
344 345 346
  const int ntiles = av1_get_rest_ntiles(
      cm->width, cm->height, cm->rst_info[0].restoration_tilesize, &tile_width,
      &tile_height, &nhtiles, &nvtiles);
347
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
348

349 350 351
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
352 353 354 355 356
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
357
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
358
                               h_end - h_start, v_start, v_end - v_start, 1);
359 360 361 362 363 364 365 366 367 368 369 370 371
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
372
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
373
        cm->rst_internal.tmpbuf);
374
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
375
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
376
                               dst_frame);
377 378 379 380
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
381
      type[tile_idx] = RESTORE_NONE;
382
    } else {
383
      type[tile_idx] = RESTORE_SGRPROJ;
384
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
385 386 387 388 389 390
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
391
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
392 393
  }
  // Cost for Sgrproj filtering
394
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
395 396 397
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
398
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
399
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
400
           sizeof(sgrproj_info[tile_idx]));
401
    if (type[tile_idx] == RESTORE_SGRPROJ) {
402 403
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
404
    rsi->restoration_type[tile_idx] = type[tile_idx];
405
  }
406
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
407 408 409 410 411
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_sgrproj;
}

412 413
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
414 415 416
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
417 418 419
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
420 421 422
  return avg;
}

423 424 425
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
426
  int i, j, k, l;
427
  double Y[WIENER_WIN2];
428 429
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
430

431 432
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
433 434
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
435 436
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
437 438
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
439 440 441 442
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
443
      for (k = 0; k < WIENER_WIN2; ++k) {
444
        M[k] += Y[k] * X;
445 446
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
447 448 449 450
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
451 452 453 454
        }
      }
    }
  }
455 456 457 458 459
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
460 461
}

Yaowu Xu's avatar
Yaowu Xu committed
462
#if CONFIG_AOM_HIGHBITDEPTH
463 464
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
465 466 467
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
468 469 470
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
471 472 473
  return avg;
}

474 475 476 477
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
478
  int i, j, k, l;
479
  double Y[WIENER_WIN2];
480 481
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
482 483
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
484

485 486
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
487 488
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
489 490
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
491 492
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
493 494 495 496
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
497
      for (k = 0; k < WIENER_WIN2; ++k) {
498
        M[k] += Y[k] * X;
499 500
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
501 502 503 504
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
505 506 507 508
        }
      }
    }
  }
509 510 511 512 513
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
514
}
Yaowu Xu's avatar
Yaowu Xu committed
515
#endif  // CONFIG_AOM_HIGHBITDEPTH
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
538
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
539 540 541 542 543
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
544
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
545
    c = 0;
546
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
547 548 549 550 551 552
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
553
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
554 555 556 557 558
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
559 560
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
561
  int w, w2;
562 563
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
564 565
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
566 567 568 569
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
570 571
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
572
      int k, l;
573 574
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
575 576
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
577 578
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
579 580 581
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
582
  // Normalization enforcement in the system of equations itself
583
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
584 585
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
586 587
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
588 589 590 591 592 593 594 595 596
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
597
    }
Aamir Anis's avatar
Aamir Anis committed
598
    memcpy(a, S, w * sizeof(*a));
599 600 601 602 603 604
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
605 606
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
607
  int w, w2;
608 609
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
610
  for (i = 0; i < WIENER_WIN; i++) {
611
    const int ii = wrap_index(i);
612
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
613 614
  }

615 616
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
617 618 619
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
620 621 622 623
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
624 625
    }
  }
Aamir Anis's avatar
Aamir Anis committed
626
  // Normalization enforcement in the system of equations itself
627 628
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
629
  for (i = 0; i < w2 - 1; ++i)
630 631
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
632 633 634 635 636 637 638 639 640
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
641
    }
Aamir Anis's avatar
Aamir Anis committed
642
    memcpy(b, S, w * sizeof(*b));
643 644 645
  }
}

646 647
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
648
  static const double init_filt[WIENER_WIN] = {
649
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
650 651
  };
  int i, j, iter;
652 653 654 655 656 657 658
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
659 660
    }
  }
661 662
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
663 664 665 666 667 668 669

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
670
  return 1;
671 672
}

673
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
674 675
// against identity filters; Final score is defined as the difference between
// the function values
676 677
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
678
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
679 680 681 682
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
683 684 685 686 687 688 689
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
690
  }
691 692
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
693
  }
694
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
695
    P += ab[k] * M[k];
696 697
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k * WIENER_WIN2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
698 699 700
  }
  Score = Q - 2 * P;

701 702
  iP = M[WIENER_WIN2 >> 1];
  iQ = H[(WIENER_WIN2 >> 1) * WIENER_WIN2 + (WIENER_WIN2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
703 704 705 706 707
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

708
static void quantize_sym_filter(double *f, InterpKernel fi) {
709
  int i;
710 711
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
712 713 714 715 716
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
717 718 719 720
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
721 722
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
723 724 725
}

static double search_wiener_uv(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
726
                               int partial_frame, int plane,
727
                               RestorationInfo *info, RestorationType *type,
728 729 730 731 732 733
                               YV12_BUFFER_CONFIG *dst_frame) {
  WienerInfo *wiener_info = info->wiener_info;
  AV1_COMMON *const cm = &cpi->common;
  RestorationInfo *rsi = cpi->rst_search;
  int64_t err;
  int bits;
734
  double cost_wiener, cost_norestore, cost_wiener_frame, cost_norestore_frame;
735 736 737 738 739 740 741 742 743 744 745
  MACROBLOCK *x = &cpi->td.mb;
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = src->uv_crop_width;
  const int height = src->uv_crop_height;
  const int src_stride = src->uv_stride;
  const int dgd_stride = dgd->uv_stride;
  double score;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
746
  int h_start, h_end, v_start, v_end;
747 748 749
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[1].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
750 751 752 753
  assert(width == dgd->uv_crop_width);
  assert(height == dgd->uv_crop_height);

  rsi[plane].frame_restoration_type = RESTORE_NONE;
754
  err = sse_restoration_frame(cm, src, cm->frame_to_show, (1 << plane));
755
  bits = 0;
756
  cost_norestore_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
757 758

  rsi[plane].frame_restoration_type = RESTORE_WIENER;
759

760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
                               h_end - h_start, v_start, v_end - v_start,
                               1 << plane);
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    // best_tile_cost[tile_idx] = DBL_MAX;

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, WIENER_HALFWIN,
                             WIENER_HALFWIN, &h_start, &h_end, &v_start,
                             &v_end);
781
    if (plane == AOM_PLANE_U) {
782 783 784 785 786 787 788 789
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->u_buffer, src->u_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->u_buffer, src->u_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
790
    } else if (plane == AOM_PLANE_V) {
791 792 793 794 795 796 797 798
#if CONFIG_AOM_HIGHBITDEPTH
      if (cm->use_highbitdepth)
        compute_stats_highbd(dgd->v_buffer, src->v_buffer, h_start, h_end,
                             v_start, v_end, dgd_stride, src_stride, M, H);
      else
#endif  // CONFIG_AOM_HIGHBITDEPTH
        compute_stats(dgd->v_buffer, src->v_buffer, h_start, h_end, v_start,
                      v_end, dgd_stride, src_stride, M, H);
799 800 801
    } else {
      assert(0);
    }
802 803 804 805 806 807

    type[tile_idx] = RESTORE_WIENER;

    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
      type[tile_idx] = RESTORE_NONE;
      continue;
808
    }
809 810
    quantize_sym_filter(vfilterd, rsi[plane].wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi[plane].wiener_info[tile_idx].hfilter);
811

812 813 814 815 816 817 818 819 820
    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
    score = compute_score(M, H, rsi[plane].wiener_info[tile_idx].vfilter,
                          rsi[plane].wiener_info[tile_idx].hfilter);
    if (score > 0.0) {
      type[tile_idx] = RESTORE_NONE;
      continue;
    }
821

822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847
    rsi[plane].restoration_type[tile_idx] = RESTORE_WIENER;
    err = try_restoration_tile(src, cpi, rsi, 1 << plane, partial_frame,
                               tile_idx, 0, 0, dst_frame);
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_wiener >= cost_norestore) {
      type[tile_idx] = RESTORE_NONE;
    } else {
      type[tile_idx] = RESTORE_WIENER;
      memcpy(&wiener_info[tile_idx], &rsi[plane].wiener_info[tile_idx],
             sizeof(wiener_info[tile_idx]));
    }
    rsi[plane].restoration_type[tile_idx] = RESTORE_NONE;
  }
  // Cost for Wiener filtering
  bits = 0;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_WIENER_PROB, type[tile_idx] != RESTORE_NONE);
    memcpy(&rsi[plane].wiener_info[tile_idx], &wiener_info[tile_idx],
           sizeof(wiener_info[tile_idx]));
    if (type[tile_idx] == RESTORE_WIENER) {
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
    }
    rsi[plane].restoration_type[tile_idx] = type[tile_idx];
848
  }
849
  err = try_restoration_frame(src, cpi, rsi, 1 << plane, partial_frame,
850
                              dst_frame);
851 852 853 854 855
  cost_wiener_frame = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  if (cost_wiener_frame < cost_norestore_frame) {
    info->frame_restoration_type = RESTORE_WIENER;
  } else {
856 857 858
    info->frame_restoration_type = RESTORE_NONE;
  }

859 860
  return info->frame_restoration_type == RESTORE_WIENER ? cost_wiener_frame
                                                        : cost_norestore_frame;
861 862
}

863
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
864 865
                            int partial_frame, RestorationInfo *info,
                            RestorationType *type, double *best_tile_cost,
866
                            YV12_BUFFER_CONFIG *dst_frame) {
867
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
868
  AV1_COMMON *const cm = &cpi->common;
869
  RestorationInfo *rsi = cpi->rst_search;
870 871
  int64_t err;
  int bits;
872
  double cost_wiener, cost_norestore;
873
  MACROBLOCK *x = &cpi->td.mb;
874 875 876
  double M[WIENER_WIN2];
  double H[WIENER_WIN2 * WIENER_WIN2];
  double vfilterd[WIENER_WIN], hfilterd[WIENER_WIN];
877 878 879 880 881
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
882
  double score;
883
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
884
  int h_start, h_end, v_start, v_end;
885 886 887
  const int ntiles =
      av1_get_rest_ntiles(width, height, cm->rst_info[0].restoration_tilesize,
                          &tile_width, &tile_height, &nhtiles, &nvtiles);
888 889 890 891 892
  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

893
  rsi->frame_restoration_type = RESTORE_WIENER;
894

895 896 897
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
898

899 900 901 902 903 904 905 906 907
// Construct a (WIENER_HALFWIN)-pixel border around the frame
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(dgd->y_buffer), width, height,
                        dgd_stride);
  else
#endif
    extend_frame(dgd->y_buffer, width, height, dgd_stride);

908 909
  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
910 911 912
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
913
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,