pickrst.c 55.4 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20 21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24 25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
#include "av1/encoder/av1_quantize.h"
29 30 31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
32

33 34 35
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *cpi, int filter_level,
                                      int partial_frame, RestorationInfo *info,
36
                                      RestorationType *rest_level,
37 38
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
39

40
#if USE_DOMAINTXFMRF
41
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 3, 3, 2 };
42 43 44
#else
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
#endif  // USE_DOMAINTXFMRF
45 46

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
47 48
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
49
                                    int width, int v_start, int height,
50 51
                                    int components_pattern) {
  int64_t filt_err = 0;
52 53 54 55
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
56 57
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
58 59 60 61 62
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
63 64
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
65 66
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
67 68
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
69 70
    }
    return filt_err;
71 72
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
73 74 75 76
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
77
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
78 79
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
80
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
81
  }
82 83 84
  return filt_err;
}

85 86
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
103 104
#else
  (void)cm;
105 106 107 108 109 110 111 112 113 114 115 116 117
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

118 119
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
120 121 122
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
123
                                    YV12_BUFFER_CONFIG *dst_frame) {
124 125 126 127
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
  ntiles = av1_get_rest_ntiles(width, height, &tile_width, &tile_height,
                               &nhtiles, &nvtiles);
143 144
  (void)ntiles;

145 146
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
147
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
148 149
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
150
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
151
                                  v_start, v_end - v_start, components_pattern);
152 153 154 155 156

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
157
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
158
                                     int components_pattern, int partial_frame,
159
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
160
  AV1_COMMON *const cm = &cpi->common;
161
  int64_t filt_err;
162 163
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
164
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
165 166 167
  return filt_err;
}

168 169 170 171 172
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
173 174 175 176
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
        const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
        const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
208 209 210 211 212
    }
  }
  return err;
}

213 214 215 216
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
217 218 219 220 221 222 223 224 225
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
  xqd[0] = -xq[0];
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
285 286
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
287
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
288
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
289 290 291
  int i, j, ep, bestep = 0;
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
292

293 294 295 296 297 298
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
299 300
          flt1[i * width + j] = (int32_t)dat[i * dat_stride + j];
          flt2[i * width + j] = (int32_t)dat[i * dat_stride + j];
301 302 303 304 305 306 307 308
        }
      }
    } else {
      uint8_t *dat = dat8;
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
          const int k = i * width + j;
          const int l = i * dat_stride + j;
309 310
          flt1[k] = (int32_t)dat[l];
          flt2[k] = (int32_t)dat[l];
311 312 313 314 315 316 317
        }
      }
    }
    av1_selfguided_restoration(flt1, width, height, width, bit_depth,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
    av1_selfguided_restoration(flt2, width, height, width, bit_depth,
                               sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
318 319
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
320
    encode_xq(exq, exqd);
321 322 323
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
324 325 326 327 328 329 330 331 332 333 334 335 336 337
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int filter_level, int partial_frame,
338 339
                             RestorationInfo *info, RestorationType *type,
                             double *best_tile_cost,
340
                             YV12_BUFFER_CONFIG *dst_frame) {
341 342 343 344 345 346
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
347
  RestorationInfo *rsi = &cpi->rst_search[0];
348 349
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
350
  // Allocate for the src buffer at high precision
351 352 353 354 355 356 357 358
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
  //  Make a copy of the unfiltered / processed recon buffer
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

359
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
360

361 362 363
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
364 365 366 367 368
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
369
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
370
                               h_end - h_start, v_start, v_end - v_start, 1);
371 372 373 374 375 376 377 378 379 380 381 382 383
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
384
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
385
        cm->rst_internal.tmpbuf);
386
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
387
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
388
                               dst_frame);
389 390 391 392
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
393
      type[tile_idx] = RESTORE_NONE;
394
    } else {
395
      type[tile_idx] = RESTORE_SGRPROJ;
396
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
397 398 399 400 401 402
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
403
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
404 405
  }
  // Cost for Sgrproj filtering
406
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
407 408 409
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
410
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
411
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
412
           sizeof(sgrproj_info[tile_idx]));
413
    if (type[tile_idx] == RESTORE_SGRPROJ) {
414 415
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
416
    rsi->restoration_type[tile_idx] = type[tile_idx];
417
  }
418
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
419 420 421 422 423 424
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
  return cost_sgrproj;
}

425
#if USE_DOMAINTXFMRF
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
static int64_t compute_sse(uint8_t *dgd, int width, int height, int dgd_stride,
                           uint8_t *src, int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}

#if CONFIG_AOM_HIGHBITDEPTH
static int64_t compute_sse_highbd(uint16_t *dgd, int width, int height,
                                  int dgd_stride, uint16_t *src,
                                  int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}
#endif  // CONFIG_AOM_HIGHBITDEPTH

static void search_domaintxfmrf_restoration(uint8_t *dgd8, int width,
                                            int height, int dgd_stride,
                                            uint8_t *src8, int src_stride,
460
                                            int bit_depth, int *sigma_r,
461
                                            uint8_t *fltbuf, int32_t *tmpbuf) {
462 463 464 465 466
  const int first_p_step = 8;
  const int second_p_range = first_p_step >> 1;
  const int second_p_step = 2;
  const int third_p_range = second_p_step >> 1;
  const int third_p_step = 1;
467
  int p, best_p0, best_p = -1;
468 469
  int64_t best_sse = INT64_MAX, sse;
  if (bit_depth == 8) {
470
    uint8_t *flt = fltbuf;
471 472 473 474
    uint8_t *dgd = dgd8;
    uint8_t *src = src8;
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
475
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
476
                                   width, tmpbuf);
477
      sse = compute_sse(flt, width, height, width, src, src_stride);
478 479 480 481 482 483 484 485 486 487
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
488
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
489
                                   width, tmpbuf);
490
      sse = compute_sse(flt, width, height, width, src, src_stride);
491 492 493 494 495 496 497 498 499 500
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
501
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
502
                                   width, tmpbuf);
503
      sse = compute_sse(flt, width, height, width, src, src_stride);
504 505 506 507 508 509 510
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
  } else {
#if CONFIG_AOM_HIGHBITDEPTH
511
    uint16_t *flt = (uint16_t *)fltbuf;
512 513 514 515
    uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
    uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
516
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
517 518
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
519 520 521 522 523 524 525 526 527 528
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
529
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
530 531
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
532 533 534 535 536 537 538 539 540 541
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
542
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
543 544
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
545 546 547 548 549 550 551 552 553 554 555 556 557 558
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
#else
    assert(0);
#endif  // CONFIG_AOM_HIGHBITDEPTH
  }
  *sigma_r = best_p;
}

static double search_domaintxfmrf(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                  int filter_level, int partial_frame,
559 560
                                  RestorationInfo *info, RestorationType *type,
                                  double *best_tile_cost,
561
                                  YV12_BUFFER_CONFIG *dst_frame) {
562
  DomaintxfmrfInfo *domaintxfmrf_info = info->domaintxfmrf_info;
563 564
  double cost_norestore, cost_domaintxfmrf;
  int64_t err;
565 566 567 568
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
569
  RestorationInfo *rsi = &cpi->rst_search[0];
570 571 572 573 574 575 576 577 578 579
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
  //  Make a copy of the unfiltered / processed recon buffer
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                        1, partial_frame);
  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);

580
  rsi->frame_restoration_type = RESTORE_DOMAINTXFMRF;
581

582 583 584
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
585 586 587 588 589
  // Compute best Domaintxfm filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
590
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
591
                               h_end - h_start, v_start, v_end - v_start, 1);
592 593 594 595 596 597 598 599 600 601 602 603 604 605
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;

    search_domaintxfmrf_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
606 607
        &rsi->domaintxfmrf_info[tile_idx].sigma_r, cpi->extra_rstbuf,
        cm->rst_internal.tmpbuf);
608

609
    rsi->restoration_type[tile_idx] = RESTORE_DOMAINTXFMRF;
610
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
611
                               dst_frame);
612 613 614 615
    bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 1);
    cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_domaintxfmrf >= cost_norestore) {
616
      type[tile_idx] = RESTORE_NONE;
617
    } else {
618
      type[tile_idx] = RESTORE_DOMAINTXFMRF;
619
      memcpy(&domaintxfmrf_info[tile_idx], &rsi->domaintxfmrf_info[tile_idx],
620 621 622 623 624 625 626
             sizeof(domaintxfmrf_info[tile_idx]));
      bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_DOMAINTXFMRF]) >> 4,
          err);
    }
627
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
628 629
  }
  // Cost for Domaintxfmrf filtering
630
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
631 632 633
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB,
634
                         type[tile_idx] != RESTORE_NONE);
635
    memcpy(&rsi->domaintxfmrf_info[tile_idx], &domaintxfmrf_info[tile_idx],
636
           sizeof(domaintxfmrf_info[tile_idx]));
637
    if (type[tile_idx] == RESTORE_DOMAINTXFMRF) {
638 639
      bits += (DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT);
    }
640
    rsi->restoration_type[tile_idx] = type[tile_idx];
641
  }
642
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
643 644 645 646 647
  cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
  return cost_domaintxfmrf;
}
648
#endif  // USE_DOMAINTXFMRF
649

650 651
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
652 653 654
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
655 656 657
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
658 659 660
  return avg;
}

661 662 663
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
664
  int i, j, k, l;
665
  double Y[WIENER_WIN2];
666 667
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
668

669 670
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
671 672
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
673 674
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
675 676
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
677 678 679 680
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
681
      for (k = 0; k < WIENER_WIN2; ++k) {
682
        M[k] += Y[k] * X;
683 684
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
685 686 687 688
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
689 690 691 692
        }
      }
    }
  }
693 694 695 696 697
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
698 699
}

Yaowu Xu's avatar
Yaowu Xu committed
700
#if CONFIG_AOM_HIGHBITDEPTH
701 702
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
703 704 705
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
706 707 708
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
709 710 711
  return avg;
}

712 713 714 715
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
716
  int i, j, k, l;
717
  double Y[WIENER_WIN2];
718 719
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
720 721
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
722

723 724
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
725 726
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
727 728
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
729 730
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
731 732 733 734
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
735
      for (k = 0; k < WIENER_WIN2; ++k) {
736
        M[k] += Y[k] * X;
737 738
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
739 740 741 742
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
743 744 745 746
        }
      }
    }
  }
747 748 749 750 751
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
752
}
Yaowu Xu's avatar
Yaowu Xu committed
753
#endif  // CONFIG_AOM_HIGHBITDEPTH
754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
776
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
777 778 779 780 781
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
782
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
783
    c = 0;
784
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
785 786 787 788 789 790
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
791
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
792 793 794 795 796
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
797 798
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
799
  int w, w2;
800 801
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
802 803
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
804 805 806 807
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
808 809
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
810
      int k, l;
811 812
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
813 814
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
815 816
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
817 818 819
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
820
  // Normalization enforcement in the system of equations itself
821
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
822 823
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
824 825
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
826 827 828 829 830 831 832 833 834
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
835
    }
Aamir Anis's avatar
Aamir Anis committed
836
    memcpy(a, S, w * sizeof(*a));
837 838 839 840 841 842
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
843 844
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
845
  int w, w2;
846 847
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
848
  for (i = 0; i < WIENER_WIN; i++) {
849
    const int ii = wrap_index(i);
850
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
851 852
  }

853 854
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
855 856 857
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
858 859 860 861
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
862 863
    }
  }
Aamir Anis's avatar
Aamir Anis committed
864
  // Normalization enforcement in the system of equations itself
865 866
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
867
  for (i = 0; i < w2 - 1; ++i)
868 869
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
870 871 872 873 874 875 876 877 878
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
879
    }
Aamir Anis's avatar
Aamir Anis committed
880
    memcpy(b, S, w * sizeof(*b));
881 882 883
  }
}

884 885
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
886
  static const double init_filt[WIENER_WIN] = {
887
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
888 889
  };
  int i, j, iter;
890 891 892 893 894 895 896
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
897 898
    }
  }
899 900
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
901 902 903 904 905 906 907

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
908
  return 1;
909 910
}

911
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
912 913
// against identity filters; Final score is defined as the difference between
// the function values
914 915
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
916
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
917 918 919 920
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
921 922 923 924 925 926 927
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
928
  }
929 930
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed