pickrst.c 51 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

19
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
20 21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24 25
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
26
#include "av1/common/restoration.h"
27

28
#include "av1/encoder/av1_quantize.h"
29 30 31
#include "av1/encoder/encoder.h"
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
32

33
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
34 35
                                      AV1_COMP *cpi, int partial_frame,
                                      RestorationInfo *info,
36
                                      RestorationType *rest_level,
37 38
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
39

40
#if USE_DOMAINTXFMRF
41
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 3, 3, 2 };
42 43 44
#else
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
#endif  // USE_DOMAINTXFMRF
45 46

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
47 48
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
49
                                    int width, int v_start, int height,
50 51
                                    int components_pattern) {
  int64_t filt_err = 0;
52 53 54 55
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
56 57
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
58 59 60 61 62
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
63 64
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
65 66
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
67 68
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
69 70
    }
    return filt_err;
71 72
  }
#endif  // CONFIG_AOM_HIGHBITDEPTH
73 74 75 76
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
77
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
78 79
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
80
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
81
  }
82 83 84
  return filt_err;
}

85 86
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
#if CONFIG_AOM_HIGHBITDEPTH
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
103 104
#else
  (void)cm;
105 106 107 108 109 110 111 112 113 114 115 116 117
#endif  // CONFIG_AOM_HIGHBITDEPTH
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

118 119
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
120 121 122
                                    int components_pattern, int partial_frame,
                                    int tile_idx, int subtile_idx,
                                    int subtile_bits,
123
                                    YV12_BUFFER_CONFIG *dst_frame) {
124 125 126 127
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
  ntiles = av1_get_rest_ntiles(width, height, &tile_width, &tile_height,
                               &nhtiles, &nvtiles);
143 144
  (void)ntiles;

145 146
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
147
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
148 149
                           nvtiles, tile_width, tile_height, width, height, 0,
                           0, &h_start, &h_end, &v_start, &v_end);
150
  filt_err = sse_restoration_tile(src, dst_frame, cm, h_start, h_end - h_start,
151
                                  v_start, v_end - v_start, components_pattern);
152 153 154 155 156

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
157
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
158
                                     int components_pattern, int partial_frame,
159
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
160
  AV1_COMMON *const cm = &cpi->common;
161
  int64_t filt_err;
162 163
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
164
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
165 166 167
  return filt_err;
}

168 169 170 171 172
static int64_t get_pixel_proj_error(uint8_t *src8, int width, int height,
                                    int src_stride, uint8_t *dat8,
                                    int dat_stride, int bit_depth,
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
173 174 175 176
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
        const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
        const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
208 209 210 211 212
    }
  }
  return err;
}

213 214 215 216
static void get_proj_subspace(uint8_t *src8, int width, int height,
                              int src_stride, uint8_t *dat8, int dat_stride,
                              int bit_depth, int32_t *flt1, int flt1_stride,
                              int32_t *flt2, int flt2_stride, int *xq) {
217 218 219 220 221 222 223 224 225
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

  xq[0] = -(1 << SGRPROJ_PRJ_BITS) / 4;
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0];
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
  if (bit_depth == 8) {
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
  xqd[0] = -xq[0];
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) + xqd[0] - xq[1];
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
                                          int dat_stride, uint8_t *src8,
                                          int src_stride, int bit_depth,
285 286
                                          int *eps, int *xqd, int32_t *rstbuf) {
  int32_t *flt1 = rstbuf;
287
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
288
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
289 290 291
  int i, j, ep, bestep = 0;
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
292

293 294 295 296 297 298
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
    if (bit_depth > 8) {
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
299 300
          flt1[i * width + j] = (int32_t)dat[i * dat_stride + j];
          flt2[i * width + j] = (int32_t)dat[i * dat_stride + j];
301 302 303 304 305 306 307 308
        }
      }
    } else {
      uint8_t *dat = dat8;
      for (i = 0; i < height; ++i) {
        for (j = 0; j < width; ++j) {
          const int k = i * width + j;
          const int l = i * dat_stride + j;
309 310
          flt1[k] = (int32_t)dat[l];
          flt2[k] = (int32_t)dat[l];
311 312 313 314 315 316 317
        }
      }
    }
    av1_selfguided_restoration(flt1, width, height, width, bit_depth,
                               sgr_params[ep].r1, sgr_params[ep].e1, tmpbuf2);
    av1_selfguided_restoration(flt2, width, height, width, bit_depth,
                               sgr_params[ep].r2, sgr_params[ep].e2, tmpbuf2);
318 319
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                      bit_depth, flt1, width, flt2, width, exq);
320
    encode_xq(exq, exqd);
321 322 323
    err =
        get_pixel_proj_error(src8, width, height, src_stride, dat8, dat_stride,
                             bit_depth, flt1, width, flt2, width, exqd);
324 325 326 327 328 329 330 331 332 333 334 335 336
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
337 338
                             int partial_frame, RestorationInfo *info,
                             RestorationType *type, double *best_tile_cost,
339
                             YV12_BUFFER_CONFIG *dst_frame) {
340 341 342 343 344 345
  SgrprojInfo *sgrproj_info = info->sgrproj_info;
  double err, cost_norestore, cost_sgrproj;
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
346
  RestorationInfo *rsi = &cpi->rst_search[0];
347 348
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
349
  // Allocate for the src buffer at high precision
350 351
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);
352
  rsi->frame_restoration_type = RESTORE_SGRPROJ;
353

354 355 356
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
357 358 359 360 361
  // Compute best Sgrproj filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
362
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
363
                               h_end - h_start, v_start, v_end - v_start, 1);
364 365 366 367 368 369 370 371 372 373 374 375 376
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;
    search_selfguided_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
377
        &rsi->sgrproj_info[tile_idx].ep, rsi->sgrproj_info[tile_idx].xqd,
378
        cm->rst_internal.tmpbuf);
379
    rsi->restoration_type[tile_idx] = RESTORE_SGRPROJ;
380
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
381
                               dst_frame);
382 383 384 385
    bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
    cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_sgrproj >= cost_norestore) {
386
      type[tile_idx] = RESTORE_NONE;
387
    } else {
388
      type[tile_idx] = RESTORE_SGRPROJ;
389
      memcpy(&sgrproj_info[tile_idx], &rsi->sgrproj_info[tile_idx],
390 391 392 393 394 395
             sizeof(sgrproj_info[tile_idx]));
      bits = SGRPROJ_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_SGRPROJ]) >> 4, err);
    }
396
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
397 398
  }
  // Cost for Sgrproj filtering
399
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
400 401 402
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits +=
403
        av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, type[tile_idx] != RESTORE_NONE);
404
    memcpy(&rsi->sgrproj_info[tile_idx], &sgrproj_info[tile_idx],
405
           sizeof(sgrproj_info[tile_idx]));
406
    if (type[tile_idx] == RESTORE_SGRPROJ) {
407 408
      bits += (SGRPROJ_BITS << AV1_PROB_COST_SHIFT);
    }
409
    rsi->restoration_type[tile_idx] = type[tile_idx];
410
  }
411
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
412 413 414 415 416
  cost_sgrproj = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_sgrproj;
}

417
#if USE_DOMAINTXFMRF
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
static int64_t compute_sse(uint8_t *dgd, int width, int height, int dgd_stride,
                           uint8_t *src, int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}

#if CONFIG_AOM_HIGHBITDEPTH
static int64_t compute_sse_highbd(uint16_t *dgd, int width, int height,
                                  int dgd_stride, uint16_t *src,
                                  int src_stride) {
  int64_t sse = 0;
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int diff =
          (int)dgd[i * dgd_stride + j] - (int)src[i * src_stride + j];
      sse += diff * diff;
    }
  }
  return sse;
}
#endif  // CONFIG_AOM_HIGHBITDEPTH

static void search_domaintxfmrf_restoration(uint8_t *dgd8, int width,
                                            int height, int dgd_stride,
                                            uint8_t *src8, int src_stride,
452
                                            int bit_depth, int *sigma_r,
453
                                            uint8_t *fltbuf, int32_t *tmpbuf) {
454 455 456 457 458
  const int first_p_step = 8;
  const int second_p_range = first_p_step >> 1;
  const int second_p_step = 2;
  const int third_p_range = second_p_step >> 1;
  const int third_p_step = 1;
459
  int p, best_p0, best_p = -1;
460 461
  int64_t best_sse = INT64_MAX, sse;
  if (bit_depth == 8) {
462
    uint8_t *flt = fltbuf;
463 464 465 466
    uint8_t *dgd = dgd8;
    uint8_t *src = src8;
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
467
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
468
                                   width, tmpbuf);
469
      sse = compute_sse(flt, width, height, width, src, src_stride);
470 471 472 473 474 475 476 477 478 479
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
480
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
481
                                   width, tmpbuf);
482
      sse = compute_sse(flt, width, height, width, src, src_stride);
483 484 485 486 487 488 489 490 491 492
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
493
      av1_domaintxfmrf_restoration(dgd, width, height, dgd_stride, p, flt,
494
                                   width, tmpbuf);
495
      sse = compute_sse(flt, width, height, width, src, src_stride);
496 497 498 499 500 501 502
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
  } else {
#if CONFIG_AOM_HIGHBITDEPTH
503
    uint16_t *flt = (uint16_t *)fltbuf;
504 505 506 507
    uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
    uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    // First phase
    for (p = first_p_step / 2; p < DOMAINTXFMRF_PARAMS; p += first_p_step) {
508
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
509 510
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
511 512 513 514 515 516 517 518 519 520
      if (sse < best_sse || best_p == -1) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Second Phase
    best_p0 = best_p;
    for (p = best_p0 - second_p_range; p <= best_p0 + second_p_range;
         p += second_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
521
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
522 523
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
524 525 526 527 528 529 530 531 532 533
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
    // Third Phase
    best_p0 = best_p;
    for (p = best_p0 - third_p_range; p <= best_p0 + third_p_range;
         p += third_p_step) {
      if (p < 0 || p == best_p || p >= DOMAINTXFMRF_PARAMS) continue;
534
      av1_domaintxfmrf_restoration_highbd(dgd, width, height, dgd_stride, p,
535 536
                                          bit_depth, flt, width, tmpbuf);
      sse = compute_sse_highbd(flt, width, height, width, src, src_stride);
537 538 539 540 541 542 543 544 545 546 547 548 549
      if (sse < best_sse) {
        best_p = p;
        best_sse = sse;
      }
    }
#else
    assert(0);
#endif  // CONFIG_AOM_HIGHBITDEPTH
  }
  *sigma_r = best_p;
}

static double search_domaintxfmrf(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
550 551
                                  int partial_frame, RestorationInfo *info,
                                  RestorationType *type, double *best_tile_cost,
552
                                  YV12_BUFFER_CONFIG *dst_frame) {
553
  DomaintxfmrfInfo *domaintxfmrf_info = info->domaintxfmrf_info;
554 555
  double cost_norestore, cost_domaintxfmrf;
  int64_t err;
556 557 558 559
  int bits;
  MACROBLOCK *x = &cpi->td.mb;
  AV1_COMMON *const cm = &cpi->common;
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
560
  RestorationInfo *rsi = &cpi->rst_search[0];
561 562 563 564 565
  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
  int h_start, h_end, v_start, v_end;
  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
                                         &tile_height, &nhtiles, &nvtiles);

566
  rsi->frame_restoration_type = RESTORE_DOMAINTXFMRF;
567

568 569 570
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
  }
571 572 573 574 575
  // Compute best Domaintxfm filters for each tile
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
                             tile_height, cm->width, cm->height, 0, 0, &h_start,
                             &h_end, &v_start, &v_end);
576
    err = sse_restoration_tile(src, cm->frame_to_show, cm, h_start,
577
                               h_end - h_start, v_start, v_end - v_start, 1);
578 579 580 581 582 583 584 585 586 587 588 589 590 591
    // #bits when a tile is not restored
    bits = av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 0);
    cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    best_tile_cost[tile_idx] = DBL_MAX;

    search_domaintxfmrf_restoration(
        dgd->y_buffer + v_start * dgd->y_stride + h_start, h_end - h_start,
        v_end - v_start, dgd->y_stride,
        src->y_buffer + v_start * src->y_stride + h_start, src->y_stride,
#if CONFIG_AOM_HIGHBITDEPTH
        cm->bit_depth,
#else
        8,
#endif  // CONFIG_AOM_HIGHBITDEPTH
592 593
        &rsi->domaintxfmrf_info[tile_idx].sigma_r, cpi->extra_rstbuf,
        cm->rst_internal.tmpbuf);
594

595
    rsi->restoration_type[tile_idx] = RESTORE_DOMAINTXFMRF;
596
    err = try_restoration_tile(src, cpi, rsi, 1, partial_frame, tile_idx, 0, 0,
597
                               dst_frame);
598 599 600 601
    bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB, 1);
    cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
    if (cost_domaintxfmrf >= cost_norestore) {
602
      type[tile_idx] = RESTORE_NONE;
603
    } else {
604
      type[tile_idx] = RESTORE_DOMAINTXFMRF;
605
      memcpy(&domaintxfmrf_info[tile_idx], &rsi->domaintxfmrf_info[tile_idx],
606 607 608 609 610 611 612
             sizeof(domaintxfmrf_info[tile_idx]));
      bits = DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT;
      best_tile_cost[tile_idx] = RDCOST_DBL(
          x->rdmult, x->rddiv,
          (bits + cpi->switchable_restore_cost[RESTORE_DOMAINTXFMRF]) >> 4,
          err);
    }
613
    rsi->restoration_type[tile_idx] = RESTORE_NONE;
614 615
  }
  // Cost for Domaintxfmrf filtering
616
  bits = frame_level_restore_bits[rsi->frame_restoration_type]
617 618 619
         << AV1_PROB_COST_SHIFT;
  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_DOMAINTXFMRF_PROB,
620
                         type[tile_idx] != RESTORE_NONE);
621
    memcpy(&rsi->domaintxfmrf_info[tile_idx], &domaintxfmrf_info[tile_idx],
622
           sizeof(domaintxfmrf_info[tile_idx]));
623
    if (type[tile_idx] == RESTORE_DOMAINTXFMRF) {
624 625
      bits += (DOMAINTXFMRF_PARAMS_BITS << AV1_PROB_COST_SHIFT);
    }
626
    rsi->restoration_type[tile_idx] = type[tile_idx];
627
  }
628
  err = try_restoration_frame(src, cpi, rsi, 1, partial_frame, dst_frame);
629 630 631 632
  cost_domaintxfmrf = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);

  return cost_domaintxfmrf;
}
633
#endif  // USE_DOMAINTXFMRF
634

635 636
static double find_average(uint8_t *src, int h_start, int h_end, int v_start,
                           int v_end, int stride) {
637 638 639
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
640 641 642
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
643 644 645
  return avg;
}

646 647 648
static void compute_stats(uint8_t *dgd, uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
649
  int i, j, k, l;
650
  double Y[WIENER_WIN2];
651 652
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
653

654 655
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
656 657
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
658 659
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
660 661
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
662 663 664 665
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
666
      for (k = 0; k < WIENER_WIN2; ++k) {
667
        M[k] += Y[k] * X;
668 669
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
670 671 672 673
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
674 675 676 677
        }
      }
    }
  }
678 679 680 681 682
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
683 684
}

Yaowu Xu's avatar
Yaowu Xu committed
685
#if CONFIG_AOM_HIGHBITDEPTH
686 687
static double find_average_highbd(uint16_t *src, int h_start, int h_end,
                                  int v_start, int v_end, int stride) {
688 689 690
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
691 692 693
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
694 695 696
  return avg;
}

697 698 699 700
static void compute_stats_highbd(uint8_t *dgd8, uint8_t *src8, int h_start,
                                 int h_end, int v_start, int v_end,
                                 int dgd_stride, int src_stride, double *M,
                                 double *H) {
701
  int i, j, k, l;
702
  double Y[WIENER_WIN2];
703 704
  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
705 706
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
707

708 709
  memset(M, 0, sizeof(*M) * WIENER_WIN2);
  memset(H, 0, sizeof(*H) * WIENER_WIN2 * WIENER_WIN2);
710 711
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
712 713
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
714 715
      for (k = -WIENER_HALFWIN; k <= WIENER_HALFWIN; k++) {
        for (l = -WIENER_HALFWIN; l <= WIENER_HALFWIN; l++) {
716 717 718 719
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
720
      for (k = 0; k < WIENER_WIN2; ++k) {
721
        M[k] += Y[k] * X;
722 723
        H[k * WIENER_WIN2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < WIENER_WIN2; ++l) {
724 725 726 727
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
          H[k * WIENER_WIN2 + l] += Y[k] * Y[l];
728 729 730 731
        }
      }
    }
  }
732 733 734 735 736
  for (k = 0; k < WIENER_WIN2; ++k) {
    for (l = k + 1; l < WIENER_WIN2; ++l) {
      H[l * WIENER_WIN2 + k] = H[k * WIENER_WIN2 + l];
    }
  }
737
}
Yaowu Xu's avatar
Yaowu Xu committed
738
#endif  // CONFIG_AOM_HIGHBITDEPTH
739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760

// Solves Ax = b, where x and b are column vectors
static int linsolve(int n, double *A, int stride, double *b, double *x) {
  int i, j, k;
  double c;
  // Partial pivoting
  for (i = n - 1; i > 0; i--) {
    if (A[(i - 1) * stride] < A[i * stride]) {
      for (j = 0; j < n; j++) {
        c = A[i * stride + j];
        A[i * stride + j] = A[(i - 1) * stride + j];
        A[(i - 1) * stride + j] = c;
      }
      c = b[i];
      b[i] = b[i - 1];
      b[i - 1] = c;
    }
  }
  // Forward elimination
  for (k = 0; k < n - 1; k++) {
    for (i = k; i < n - 1; i++) {
      c = A[(i + 1) * stride + k] / A[k * stride + k];
761
      for (j = 0; j < n; j++) A[(i + 1) * stride + j] -= c * A[k * stride + j];
762 763 764 765 766
      b[i + 1] -= c * b[k];
    }
  }
  // Backward substitution
  for (i = n - 1; i >= 0; i--) {
767
    if (fabs(A[i * stride + i]) < 1e-10) return 0;
768
    c = 0;
769
    for (j = i + 1; j <= n - 1; j++) c += A[i * stride + j] * x[j];
770 771 772 773 774 775
    x[i] = (b[i] - c) / A[i * stride + i];
  }
  return 1;
}

static INLINE int wrap_index(int i) {
776
  return (i >= WIENER_HALFWIN1 ? WIENER_WIN - 1 - i : i);
777 778 779 780 781
}

// Fix vector b, update vector a
static void update_a_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
782 783
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
784
  int w, w2;
785 786
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
787 788
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; ++j) {
789 790 791 792
      const int jj = wrap_index(j);
      A[jj] += Mc[i][j] * b[i];
    }
  }
793 794
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
795
      int k, l;
796 797
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l) {
798 799
          const int kk = wrap_index(k);
          const int ll = wrap_index(l);
800 801
          B[ll * WIENER_HALFWIN1 + kk] +=
              Hc[j * WIENER_WIN + i][k * WIENER_WIN2 + l] * b[i] * b[j];
802 803 804
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
805
  // Normalization enforcement in the system of equations itself
806
  w = WIENER_WIN;
Aamir Anis's avatar
Aamir Anis committed
807 808
  w2 = (w >> 1) + 1;
  for (i = 0; i < w2 - 1; ++i)
809 810
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
811 812 813 814 815 816 817 818 819
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
820
    }
Aamir Anis's avatar
Aamir Anis committed
821
    memcpy(a, S, w * sizeof(*a));
822 823 824 825 826 827
  }
}

// Fix vector a, update vector b
static void update_b_sep_sym(double **Mc, double **Hc, double *a, double *b) {
  int i, j;
828 829
  double S[WIENER_WIN];
  double A[WIENER_WIN], B[WIENER_WIN2];
Aamir Anis's avatar
Aamir Anis committed
830
  int w, w2;
831 832
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
833
  for (i = 0; i < WIENER_WIN; i++) {
834
    const int ii = wrap_index(i);
835
    for (j = 0; j < WIENER_WIN; j++) A[ii] += Mc[i][j] * a[j];
836 837
  }

838 839
  for (i = 0; i < WIENER_WIN; i++) {
    for (j = 0; j < WIENER_WIN; j++) {
840 841 842
      const int ii = wrap_index(i);
      const int jj = wrap_index(j);
      int k, l;
843 844 845 846
      for (k = 0; k < WIENER_WIN; ++k)
        for (l = 0; l < WIENER_WIN; ++l)
          B[jj * WIENER_HALFWIN1 + ii] +=
              Hc[i * WIENER_WIN + j][k * WIENER_WIN2 + l] * a[k] * a[l];
847 848
    }
  }
Aamir Anis's avatar
Aamir Anis committed
849
  // Normalization enforcement in the system of equations itself
850 851
  w = WIENER_WIN;
  w2 = WIENER_HALFWIN1;
Aamir Anis's avatar
Aamir Anis committed
852
  for (i = 0; i < w2 - 1; ++i)
853 854
    A[i] -=
        A[w2 - 1] * 2 + B[i * w2 + w2 - 1] - 2 * B[(w2 - 1) * w2 + (w2 - 1)];
Aamir Anis's avatar
Aamir Anis committed
855 856 857 858 859 860 861 862 863
  for (i = 0; i < w2 - 1; ++i)
    for (j = 0; j < w2 - 1; ++j)
      B[i * w2 + j] -= 2 * (B[i * w2 + (w2 - 1)] + B[(w2 - 1) * w2 + j] -
                            2 * B[(w2 - 1) * w2 + (w2 - 1)]);
  if (linsolve(w2 - 1, B, w2, A, S)) {
    S[w2 - 1] = 1.0;
    for (i = w2; i < w; ++i) {
      S[i] = S[w - 1 - i];
      S[w2 - 1] -= 2 * S[i];
864
    }
Aamir Anis's avatar
Aamir Anis committed
865
    memcpy(b, S, w * sizeof(*b));
866 867 868
  }
}

869 870
static int wiener_decompose_sep_sym(double *M, double *H, double *a,
                                    double *b) {
871
  static const double init_filt[WIENER_WIN] = {
872
    0.035623, -0.127154, 0.211436, 0.760190, 0.211436, -0.127154, 0.035623,
873 874
  };
  int i, j, iter;
875 876 877 878 879 880 881
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
  for (i = 0; i < WIENER_WIN; i++) {
    Mc[i] = M + i * WIENER_WIN;
    for (j = 0; j < WIENER_WIN; j++) {
      Hc[i * WIENER_WIN + j] =
          H + i * WIENER_WIN * WIENER_WIN2 + j * WIENER_WIN;
882 883
    }
  }
884 885
  memcpy(a, init_filt, sizeof(*a) * WIENER_WIN);
  memcpy(b, init_filt, sizeof(*b) * WIENER_WIN);
886 887 888 889 890 891 892

  iter = 1;
  while (iter < 10) {
    update_a_sep_sym(Mc, Hc, a, b);
    update_b_sep_sym(Mc, Hc, a, b);
    iter++;
  }
893
  return 1;
894 895
}

896
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
897 898
// against identity filters; Final score is defined as the difference between
// the function values
899 900
static double compute_score(double *M, double *H, InterpKernel vfilt,
                            InterpKernel hfilt) {
901
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
902 903 904 905
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
906 907 908 909 910 911 912
  double a[WIENER_WIN], b[WIENER_WIN];
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
913
  }
914 915
  for (k = 0; k < WIENER_WIN; ++k) {
    for (l = 0; l < WIENER_WIN; ++l) ab[k * WIENER_WIN + l] = a[l] * b[k];
Aamir Anis's avatar
Aamir Anis committed
916
  }
917
  for (k = 0; k < WIENER_WIN2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
918
    P += ab[k] * M[k];
919 920
    for (l = 0; l < WIENER_WIN2; ++l)
      Q += ab[k] * H[k