pickrst.c 45.1 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20 21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24 25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26

27 28 29 30
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
#include "av1/encoder/quantize.h"
31

32 33 34
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *cpi, int filter_level,
                                      int partial_frame, RestorationInfo *info,
35 36
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
37

38
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 3, 3, 2 };
39 40

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
41 42
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
43 44
                                    int width, int v_start, int height,
                                    int y_only) {
45 46 47
  int64_t filt_err;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
48 49
    filt_err =
        aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
50 51 52 53 54 55 56 57 58
    if (!y_only) {
      filt_err += aom_highbd_get_u_sse_part(
          src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
          v_start >> cm->subsampling_y, height >> cm->subsampling_y);
      filt_err += aom_highbd_get_v_sse_part(
          src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
          v_start >> cm->subsampling_y, height >> cm->subsampling_y);
    }
    return filt_err;
59 60
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
61 62 63 64 65 66 67 68 69
  filt_err = aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  if (!y_only) {
    filt_err += aom_get_u_sse_part(
        src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
        v_start >> cm->subsampling_y, height >> cm->subsampling_y);
    filt_err += aom_get_u_sse_part(
        src, dst, h_start >> cm->subsampling_x, width >> cm->subsampling_x,
        v_start >> cm->subsampling_y, height >> cm->subsampling_y);
  }
70 71 72 73 74
  return filt_err;
}

static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
75
                                    int y_only, int partial_frame, int tile_idx,
76 77
                                    int subtile_idx, int subtile_bits,
                                    YV12_BUFFER_CONFIG *dst_frame) {
78 79 80 81 82 83 84 85
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
  (void)ntiles;

86
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
87
                             dst_frame);
88 89 90 91
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
                           nvtiles, tile_width, tile_height, cm->width,
                           cm->height, 0, 0, &h_start, &h_end, &v_start,
                           &v_end);
92
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
93
                                  v_start, v_end - v_start, y_only);
94 95 96 97 98

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
99
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
100
                                     int y_only, int partial_frame,
101
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
102
  AV1_COMMON *const cm = &cpi->common;
103
  int64_t filt_err;
104
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, y_only, partial_frame,
105
                             dst_frame);
Yaowu Xu's avatar
Yaowu Xu committed
106
#if CONFIG_AOM_HIGHBITDEPTH
107
  if (cm->use_highbitdepth) {
108
    filt_err = aom_highbd_get_y_sse(src, dst_frame);
109 110 111 112 113
    if (!y_only) {
      filt_err += aom_highbd_get_u_sse(src, dst_frame);
      filt_err += aom_highbd_get_v_sse(src, dst_frame);
    }
    return filt_err;
114
  }
Yaowu Xu's avatar
Yaowu Xu committed
115
#endif  // CONFIG_AOM_HIGHBITDEPTH
116 117 118 119 120
  filt_err = aom_get_y_sse(src, dst_frame);
  if (!y_only) {
    filt_err += aom_get_u_sse(src, dst_frame);
    filt_err += aom_get_v_sse(src, dst_frame);
  }
121 122 123
  return filt_err;
}

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
static int64_t get_pixel_proj_error(int64_t *src, int width, int height,
                                    int src_stride, int64_t *dgd,
                                    int dgd_stride, int64_t *flt1,
                                    int flt1_stride, int64_t *flt2,
                                    int flt2_stride, int *xqd) {
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int64_t s = (int64_t)src[i * src_stride + j];
      const int64_t u = (int64_t)dgd[i * dgd_stride + j];
      const int64_t f1 = (int64_t)flt1[i * flt1_stride + j] - u;
      const int64_t f2 = (int64_t)flt2[i * flt2_stride + j] - u;
      const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
      const int64_t e =
          ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
          ROUND_POWER_OF_TWO(s, SGRPROJ_RST_BITS);
      err += e * e;
    }
  }
  return err;
}

static void get_proj_subspace(int64_t *src, int width, int height,
                              int src_stride, int64_t *dgd, int dgd_stride,
                              int64_t *flt1, int flt1_stride, int64_t *flt2,
                              int flt2_stride, int *xq) {
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const double u = (double)dgd[i * dgd_stride + j];
      const double s = (double)src[i * src_stride + j] - u;
      const double f1 = (double)flt1[i * flt1_stride + j] - u;
      const double f2 = (double)flt2[i * flt2_stride + j] - u;
      H[0][0] += f1 * f1;
      H[1][1] += f2 * f2;
      H[0][1] += f1 * f2;
      C[0] += f1 * s;
      C[1] += f2 * s;
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
  xqd[0] = -xq[0];
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
199 200 201 202
                                          int *eps, int *xqd, void *srcbuf,
                                          void *tmpbuf) {
  int64_t *srd = (int64_t *)srcbuf;
  int64_t *dgd = (int64_t *)tmpbuf;
203
  int64_t *flt1 = dgd + RESTORATION_TILEPELS_MAX;
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
  int64_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
  uint8_t *tmpbuf2 = (uint8_t *)(flt2 + RESTORATION_TILEPELS_MAX);
  int i, j, ep, bestep = 0;
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
    if (bit_depth > 8) {
      uint16_t *src = CONVERT_TO_SHORTPTR(src8);
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
          flt1[i * width + j] = (int64_t)dat[i * dat_stride + j];
          flt2[i * width + j] = (int64_t)dat[i * dat_stride + j];
          dgd[i * width + j] = (int64_t)dat[i * dat_stride + j]
                               << SGRPROJ_RST_BITS;
          srd[i * width + j] = (int64_t)src[i * src_stride + j]
                               << SGRPROJ_RST_BITS;
        }
      }
    } else {
      uint8_t *src = src8;
      uint8_t *dat = dat8;
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
          const int k = i * width + j;
          const int l = i * dat_stride + j;
          flt1[k] = (int64_t)dat[l];
          flt2[k] = (int64_t)dat[l];
          dgd[k] = (int64_t)dat[l] << SGRPROJ_RST_BITS;
          srd[k] = (int64_t)src[i * src_stride + j] << SGRPROJ_RST_BITS;
        }
      }
    }
    av1_selfguided_restoration(flt1, width, height, width, bit_depth,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
    av1_selfguided_restoration(flt2, width, height, width, bit_depth,
                               sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
    get_proj_subspace(srd, width, height, width, dgd, width, flt1, width, flt2,
                      width, exq);
    encode_xq(exq, exqd);
    err = get_pixel_proj_error(srd, width, height, width, dgd, width, flt1,
                               width, flt2, width, exqd);
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int filter_level, int partial_frame,
261 262
                             RestorationInfo *info, double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
263 264 265 266 267 268 269 270 271
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  RestorationInfo rsi;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
272
  // Allocate for the src buffer at high precision
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
  //  Make a copy of the unfiltered / processed recon buffer
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

  rsi.frame_restoration_type = RESTORE_SGRPROJ;
  rsi.sgrproj_info =
      (SgrprojInfo *)aom_malloc(sizeof(*rsi.sgrproj_info) * ntiles);
  assert(rsi.sgrproj_info != NULL);

  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
    rsi.sgrproj_info[tile_idx].level = 0;
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
293
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
294
                               h_end - h_start, v_start, v_end - v_start, 1);
295 296 297 298 299 300 301 302 303 304 305 306 307
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
308 309
        &rsi.sgrproj_info[tile_idx].ep, rsi.sgrproj_info[tile_idx].xqd,
        cpi->highprec_srcbuf, cm->rst_internal.tmpbuf);
310
    rsi.sgrproj_info[tile_idx].level = 1;
311
    err = try_restoration_tile(src, cpi, &rsi, 1, partial_frame, tile_idx, 0, 0,
312
                               dst_frame);
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
      sgrproj_info[tile_idx].level = 0;
    } else {
      memcpy(&sgrproj_info[tile_idx], &rsi.sgrproj_info[tile_idx],
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
    rsi.sgrproj_info[tile_idx].level = 0;
  }
  // Cost for Sgrproj filtering
  bits = frame_level_restore_bits[rsi.frame_restoration_type]
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, sgrproj_info[tile_idx].level);
    memcpy(&rsi.sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
           sizeof(sgrproj_info[tile_idx]));
    if (sgrproj_info[tile_idx].level) {
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
  }
340
  err = try_restoration_frame(src, cpi, &rsi, 1, partial_frame, dst_frame);
341 342 343 344 345 346 347 348
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  aom_free(rsi.sgrproj_info);

  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
  return cost_sgrproj;
}

349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
static int64_t compute_sse(uint8_t *dgd, int width, int height, int dgd_stride,
                           uint8_t *src, int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}

#if CONFIG_AOM_HIGHBITDEPTH
static int64_t compute_sse_highbd(uint16_t *dgd, int width, int height,
                                  int dgd_stride, uint16_t *src,
                                  int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}
#endif  // CONFIG_AOM_HIGHBITDEPTH

static void search_domaintxfmrf_restoration(uint8_t *dgd8, int width,
                                            int height, int dgd_stride,
                                            uint8_t *src8, int src_stride,
                                            int bit_depth, int *sigma_r) {
  const int first_p_step = 8;
  const int second_p_range = first_p_step >> 1;
  const int second_p_step = 2;
  const int third_p_range = second_p_step >> 1;
  const int third_p_step = 1;
389
  int p, best_p0, best_p = -1;
390 391 392
  int64_t best_sse = INT64_MAX, sse;
  if (bit_depth == 8) {
    uint8_t *tmp = (uint8_t *)aom_malloc(width * height * sizeof(*tmp));
393
    int32_t *tmpbuf = (int32_t *)aom_malloc(DOMAINTXFMRF_TMPBUF_SIZE);
394 395 396 397
    uint8_t *dgd = dgd8;
    uint8_t *src = src8;
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
398
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, tmp,
399
                                   width, tmpbuf);
400 401 402 403 404 405 406 407 408 409 410
      sse = compute_sse(tmp, width, height, width, src, src_stride);
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
411
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, tmp,
412
                                   width, tmpbuf);
413 414 415 416 417 418 419 420 421 422 423
      sse = compute_sse(tmp, width, height, width, src, src_stride);
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
424
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, tmp,
425
                                   width, tmpbuf);
426 427 428 429 430 431 432
      sse = compute_sse(tmp, width, height, width, src, src_stride);
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    aom_free(tmp);
433
    aom_free(tmpbuf);
434 435 436
  } else {
#if CONFIG_AOM_HIGHBITDEPTH
    uint16_t *tmp = (uint16_t *)aom_malloc(width * height * sizeof(*tmp));
437
    int32_t *tmpbuf = (int32_t *)aom_malloc(DOMAINTXFMRF_TMPBUF_SIZE);
438 439 440 441
    uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
    uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
442
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
443
                                          bit_depth, tmp, width, tmpbuf);
444 445 446 447 448 449 450 451 452 453 454
      sse = compute_sse_highbd(tmp, width, height, width, src, src_stride);
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
455
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
456
                                          bit_depth, tmp, width, tmpbuf);
457 458 459 460 461 462 463 464 465 466 467
      sse = compute_sse_highbd(tmp, width, height, width, src, src_stride);
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
468
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
469
                                          bit_depth, tmp, width, tmpbuf);
470 471 472 473 474 475 476
      sse = compute_sse_highbd(tmp, width, height, width, src, src_stride);
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    aom_free(tmp);
477
    aom_free(tmpbuf);
478 479 480 481 482 483 484 485 486
#else
    assert(0);
#endif  // CONFIG_AOM_HIGHBITDEPTH
  }
  *sigma_r = best_p;
}

static double search_domaintxfmrf(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                  int filter_level, int partial_frame,
487 488
                                  RestorationInfo *info, double *best_tile_cost,
                                  YV12_BUFFER_CONFIG *dst_frame) {
489
  DomaintxfmrfInfo *domaintxfmrf_info = info->domaintxfmrf_info;
490 491
  double cost_norestore, cost_domaintxfmrf;
  int64_t err;
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  RestorationInfo rsi;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
  //  Make a copy of the unfiltered / processed recon buffer
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

  rsi.frame_restoration_type = RESTORE_DOMAINTXFMRF;
  rsi.domaintxfmrf_info =
      (DomaintxfmrfInfo *)aom_malloc(sizeof(*rsi.domaintxfmrf_info) * ntiles);
  assert(rsi.domaintxfmrf_info != NULL);

  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
    rsi.domaintxfmrf_info[tile_idx].level = 0;
  // Compute best Domaintxfm filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
519
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
520
                               h_end - h_start, v_start, v_end - v_start, 1);
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;

    search_domaintxfmrf_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
        &rsi.domaintxfmrf_info[tile_idx].sigma_r);

    rsi.domaintxfmrf_info[tile_idx].level = 1;
538
    err = try_restoration_tile(src, cpi, &rsi, 1, partial_frame, tile_idx, 0, 0,
539
                               dst_frame);
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
    bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 1);
    cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_domaintxfmrf >= cost_norestore) {
      domaintxfmrf_info[tile_idx].level = 0;
    } else {
      memcpy(&domaintxfmrf_info[tile_idx], &rsi.domaintxfmrf_info[tile_idx],
             sizeof(domaintxfmrf_info[tile_idx]));
      bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_DOMAINTXFMRF]) >> 4,
          err);
    }
    rsi.domaintxfmrf_info[tile_idx].level = 0;
  }
  // Cost for Domaintxfmrf filtering
  bits = frame_level_restore_bits[rsi.frame_restoration_type]
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB,
                         domaintxfmrf_info[tile_idx].level);
    memcpy(&rsi.domaintxfmrf_info[tile_idx], &domaintxfmrf_info[tile_idx],
           sizeof(domaintxfmrf_info[tile_idx]));
    if (domaintxfmrf_info[tile_idx].level) {
      bits += (DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT);
    }
  }
568
  err = try_restoration_frame(src, cpi, &rsi, 1, partial_frame, dst_frame);
569 570 571 572 573 574 575 576
  cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  aom_free(rsi.domaintxfmrf_info);

  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
  return cost_domaintxfmrf;
}

577 578
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
579 580 581
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
582 583 584
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
585 586 587
  return avg;
}

588 589 590
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
591 592
  int i, j, k, l;
  double Y[RESTORATION_WIN2];
593 594
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
595 596 597

  memset(M, 0, sizeof(*M) * RESTORATION_WIN2);
  memset(H, 0, sizeof(*H) * RESTORATION_WIN2 * RESTORATION_WIN2);
598 599
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
      for (k = -RESTORATION_HALFWIN; k <= RESTORATION_HALFWIN; k++) {
        for (l = -RESTORATION_HALFWIN; l <= RESTORATION_HALFWIN; l++) {
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
      for (k = 0; k < RESTORATION_WIN2; ++k) {
        M[k] += Y[k] * X;
        H[k * RESTORATION_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < RESTORATION_WIN2; ++l) {
          double value = Y[k] * Y[l];
          H[k * RESTORATION_WIN2 + l] += value;
          H[l * RESTORATION_WIN2 + k] += value;
        }
      }
    }
  }
}

Yaowu Xu's avatar
Yaowu Xu committed
621
#if CONFIG_AOM_HIGHBITDEPTH
622 623
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
624 625 626
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
627 628 629
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
630 631 632
  return avg;
}

633 634 635 636
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
637 638 639 640
  int i, j, k, l;
  double Y[RESTORATION_WIN2];
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
641 642
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
643 644 645

  memset(M, 0, sizeof(*M) * RESTORATION_WIN2);
  memset(H, 0, sizeof(*H) * RESTORATION_WIN2 * RESTORATION_WIN2);
646 647
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
      for (k = -RESTORATION_HALFWIN; k <= RESTORATION_HALFWIN; k++) {
        for (l = -RESTORATION_HALFWIN; l <= RESTORATION_HALFWIN; l++) {
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
      for (k = 0; k < RESTORATION_WIN2; ++k) {
        M[k] += Y[k] * X;
        H[k * RESTORATION_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < RESTORATION_WIN2; ++l) {
          double value = Y[k] * Y[l];
          H[k * RESTORATION_WIN2 + l] += value;
          H[l * RESTORATION_WIN2 + k] += value;
        }
      }
    }
  }
}
Yaowu Xu's avatar
Yaowu Xu committed
668
#endif  // CONFIG_AOM_HIGHBITDEPTH
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
691
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
692 693 694 695 696
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
697
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
698
    c = 0;
699
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
700 701 702 703 704 705 706 707 708 709 710 711 712 713
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
  return (i >= RESTORATION_HALFWIN1 ? RESTORATION_WIN - 1 - i : i);
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
  double S[RESTORATION_WIN];
  double A[RESTORATION_WIN], B[RESTORATION_WIN2];
Aamir Anis's avatar
Aamir Anis committed
714
  int w, w2;
715 716
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
717
  for (i = 0; i < RESTORATION_WIN; i++) {
718 719 720 721 722
    for (j = 0; j < RESTORATION_WIN; ++j) {
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
723 724
  for (i = 0; i < RESTORATION_WIN; i++) {
    for (j = 0; j < RESTORATION_WIN; j++) {
725 726 727 728 729 730
      int k, l;
      for (k = 0; k < RESTORATION_WIN; ++k)
        for (l = 0; l < RESTORATION_WIN; ++l) {
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
          B[ll * RESTORATION_HALFWIN1 + kk] +=
731 732
              Hc[j * RESTORATION_WIN + i][k * RESTORATION_WIN2 + l] * b[i] *
              b[j];
733 734 735
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
736 737 738 739
  // Normalization enforcement in the system of equations itself
  w = RESTORATION_WIN;
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
740 741
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
742 743 744 745 746 747 748 749 750
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
751
    }
Aamir Anis's avatar
Aamir Anis committed
752
    memcpy(a, S, w * sizeof(*a));
753 754 755 756 757 758 759 760
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
  double S[RESTORATION_WIN];
  double A[RESTORATION_WIN], B[RESTORATION_WIN2];
Aamir Anis's avatar
Aamir Anis committed
761
  int w, w2;
762 763
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
764
  for (i = 0; i < RESTORATION_WIN; i++) {
765
    const int ii = wrap_index(i);
766
    for (j = 0; j < RESTORATION_WIN; j++) A[ii] += Mc[i][j] * a[j];
767 768 769 770 771 772 773 774 775 776
  }

  for (i = 0; i < RESTORATION_WIN; i++) {
    for (j = 0; j < RESTORATION_WIN; j++) {
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
      for (k = 0; k < RESTORATION_WIN; ++k)
        for (l = 0; l < RESTORATION_WIN; ++l)
          B[jj * RESTORATION_HALFWIN1 + ii] +=
777 778
              Hc[i * RESTORATION_WIN + j][k * RESTORATION_WIN2 + l] * a[k] *
              a[l];
779 780
    }
  }
Aamir Anis's avatar
Aamir Anis committed
781 782 783 784
  // Normalization enforcement in the system of equations itself
  w = RESTORATION_WIN;
  w2 = RESTORATION_HALFWIN1;
  for (i = 0; i < w2 - 1; ++i)
785 786
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
787 788 789 790 791 792 793 794 795
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
796
    }
Aamir Anis's avatar
Aamir Anis committed
797
    memcpy(b, S, w * sizeof(*b));
798 799 800
  }
}

801 802
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
803
  static const double init_filt[RESTORATION_WIN] = {
804
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
  };
  int i, j, iter;
  double *Hc[RESTORATION_WIN2];
  double *Mc[RESTORATION_WIN];
  for (i = 0; i < RESTORATION_WIN; i++) {
    Mc[i] = M + i * RESTORATION_WIN;
    for (j = 0; j < RESTORATION_WIN; j++) {
      Hc[i * RESTORATION_WIN + j] =
          H + i * RESTORATION_WIN * RESTORATION_WIN2 + j * RESTORATION_WIN;
    }
  }
  memcpy(a, init_filt, sizeof(*a) * RESTORATION_WIN);
  memcpy(b, init_filt, sizeof(*b) * RESTORATION_WIN);

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
825
  return 1;
826 827
}

Aamir Anis's avatar
Aamir Anis committed
828 829 830
// Computes the function x'*A*x - x'*b for the learned filters, and compares
// against identity filters; Final score is defined as the difference between
// the function values
831
static double compute_score(double *M, double *H, int *vfilt, int *hfilt) {
Aamir Anis's avatar
Aamir Anis committed
832 833 834 835 836 837 838 839 840 841
  double ab[RESTORATION_WIN * RESTORATION_WIN];
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
  int w;
  double a[RESTORATION_WIN], b[RESTORATION_WIN];
  w = RESTORATION_WIN;
  a[RESTORATION_HALFWIN] = b[RESTORATION_HALFWIN] = 1.0;
  for (i = 0; i < RESTORATION_HALFWIN; ++i) {
842 843 844 845
    a[i] = a[RESTORATION_WIN - i - 1] =
        (double)vfilt[i] / RESTORATION_FILT_STEP;
    b[i] = b[RESTORATION_WIN - i - 1] =
        (double)hfilt[i] / RESTORATION_FILT_STEP;
Aamir Anis's avatar
Aamir Anis committed
846 847 848 849
    a[RESTORATION_HALFWIN] -= 2 * a[i];
    b[RESTORATION_HALFWIN] -= 2 * b[i];
  }
  for (k = 0; k < w; ++k) {
850
    for (l = 0; l < w; ++l) ab[k * w + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
851 852 853
  }
  for (k = 0; k < w * w; ++k) {
    P += ab[k] * M[k];
854
    for (l = 0; l < w * w; ++l) Q += ab[k] * H[k * w * w + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
855 856 857 858 859 860 861 862 863 864
  }
  Score = Q - 2 * P;

  iP = M[(w * w) >> 1];
  iQ = H[((w * w) >> 1) * w * w + ((w * w) >> 1)];
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

865 866 867 868 869 870 871 872 873 874 875
static void quantize_sym_filter(double *f, int *fi) {
  int i;
  for (i = 0; i < RESTORATION_HALFWIN; ++i) {
    fi[i] = RINT(f[i] * RESTORATION_FILT_STEP);
  }
  // Specialize for 7-tap filter
  fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
  fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
  fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
}

876 877
static double search_wiener(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                            int filter_level, int partial_frame,
878 879
                            RestorationInfo *info, double *best_tile_cost,
                            YV12_BUFFER_CONFIG *dst_frame) {
880
  WienerInfo *wiener_info = info->wiener_info;
Yaowu Xu's avatar
Yaowu Xu committed
881
  AV1_COMMON *const cm = &cpi->common;
882
  RestorationInfo rsi;
883 884
  int64_t err;
  int bits;
885
  double cost_wiener, cost_norestore;
886 887 888 889 890 891 892 893 894
  MACROBLOCK *x = &cpi->td.mb;
  double M[RESTORATION_WIN2];
  double H[RESTORATION_WIN2 * RESTORATION_WIN2];
  double vfilterd[RESTORATION_WIN], hfilterd[RESTORATION_WIN];
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int width = cm->width;
  const int height = cm->height;
  const int src_stride = src->y_stride;
  const int dgd_stride = dgd->y_stride;
Aamir Anis's avatar
Aamir Anis committed
895
  double score;
896
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
897
  int h_start, h_end, v_start, v_end;
898
  int i;
899

900 901
  const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
902 903 904 905 906 907
  assert(width == dgd->y_crop_width);
  assert(height == dgd->y_crop_height);
  assert(width == src->y_crop_width);
  assert(height == src->y_crop_height);

  //  Make a copy of the unfiltered / processed recon buffer
Yaowu Xu's avatar
Yaowu Xu committed
908 909 910 911
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
912

913
  rsi.frame_restoration_type = RESTORE_WIENER;
914 915 916
  rsi.wiener_info = (WienerInfo *)aom_malloc(sizeof(*rsi.wiener_info) * ntiles);
  assert(rsi.wiener_info != NULL);

917 918
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx)
    rsi.wiener_info[tile_idx].level = 0;
919 920 921

  // Compute best Wiener filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
922 923 924
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 0, 0, &h_start, &h_end,
                             &v_start, &v_end);
925
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
926
                               h_end - h_start, v_start, v_end - v_start, 1);
927 928
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
929
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
930
    best_tile_cost[tile_idx] = DBL_MAX;
931 932 933 934

    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, width, height, 1, 1, &h_start, &h_end,
                             &v_start, &v_end);
Yaowu Xu's avatar
Yaowu Xu committed
935
#if CONFIG_AOM_HIGHBITDEPTH
936 937 938 939
    if (cm->use_highbitdepth)
      compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
                           v_start, v_end, dgd_stride, src_stride, M, H);
    else
Yaowu Xu's avatar
Yaowu Xu committed
940
#endif  // CONFIG_AOM_HIGHBITDEPTH
941 942 943
      compute_stats(dgd->y_buffer, src->y_buffer, h_start, h_end, v_start,
                    v_end, dgd_stride, src_stride, M, H);

944
    wiener_info[tile_idx].level = 1;
945
    if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
946
      wiener_info[tile_idx].level = 0;
947 948
      continue;
    }
949 950
    quantize_sym_filter(vfilterd, rsi.wiener_info[tile_idx].vfilter);
    quantize_sym_filter(hfilterd, rsi.wiener_info[tile_idx].hfilter);
951 952 953 954

    // Filter score computes the value of the function x'*A*x - x'*b for the
    // learned filter and compares it against identity filer. If there is no
    // reduction in the function, the filter is reverted back to identity
955 956
    score = compute_score(M, H, rsi.wiener_info[tile_idx].vfilter,
                          rsi.wiener_info[tile_idx].hfilter);
957
    if (score > 0.0) {
958
      wiener_info[tile_idx].level = 0;
959 960
      continue;
    }
961

962
    rsi.wiener_info[tile_idx].level = 1;
963
    err = try_restoration_tile(src, cpi, &rsi, 1, partial_frame, tile_idx, 0, 0,
964
                               dst_frame);
965 966 967
    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
968
    if (cost_wiener >= cost_norestore) {
969 970 971 972 973 974 975
      wiener_info[tile_idx].level = 0;
    } else {
      wiener_info[tile_idx].level = 1;
      for (i = 0; i < RESTORATION_HALFWIN; ++i) {
        wiener_info[tile_idx].vfilter[i] = rsi.wiener_info[tile_idx].vfilter[i];
        wiener_info[tile_idx].hfilter[i] = rsi.wiener_info[tile_idx].hfilter[i];
      }
976 977 978 979 980
      bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_WIENER]) >> 4, err);
    }
981
    rsi.wiener_info[tile_idx].level = 0;
982
  }
983
  // Cost for Wiener filtering
984 985
  bits = frame_level_restore_bits[rsi.frame_restoration_type]
         << AV1_PROB_COST_SHIFT;
986
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
987 988 989
    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, wiener_info[tile_idx].level);
    rsi.wiener_info[tile_idx].level = wiener_info[tile_idx].level;
    if (wiener_info[tile_idx].level) {
990
      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
991 992 993 994 995
      for (i = 0; i < RESTORATION_HALFWIN; ++i) {
        rsi.wiener_info[tile_idx].vfilter[i] = wiener_info[tile_idx].vfilter[i];
        rsi.wiener_info[tile_idx].hfilter[i] = wiener_info[tile_idx].hfilter[i];
      }
    }
Aamir Anis's avatar
Aamir Anis committed
996
  }
997
  err = try_restoration_frame(src, cpi, &rsi, 1, partial_frame, dst_frame);
998
  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
999

1000
  aom_free(rsi.wiener_info);
1001

Yaowu Xu's avatar
Yaowu Xu committed
1002
  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
1003
  return cost_wiener;
1004 1005
}

1006 1007
static double search_norestore(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                               int filter_level, int partial_frame,
1008 1009
                               RestorationInfo *info, double *best_tile_cost,
                               YV12_BUFFER_CONFIG *dst_frame) {
1010 1011
  double err, cost_norestore;
  int bits;
1012
  MACROBLOCK *x = &cpi->td.mb;
1013 1014
  AV1_COMMON *const cm = &cpi->common;
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
1015 1016 1017
  int h_start, h_end, v_start, v_end;
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
1018
  (void)info;
1019
  (void)dst_frame;
1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030

  //  Make a copy of the unfiltered / processed recon buffer
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
1031
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
1032
                               h_end - h_start, v_start, v_end - v_start, 1);
1033
    best_tile_cost[tile_idx] =
1034 1035
        RDCOST_DBL(x->rdmult, x->rddiv,
                   (cpi->switchable_restore_cost[RESTORE_NONE] >> 4), err);
1036 1037
  }
  // RD cost associated with no restoration
1038
  err = sse_restoration_tile(src, cm->frame_to_show, cm, 0, cm->width, 0,