pickrst.c 59.4 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

Yaowu Xu's avatar
Yaowu Xu committed
19
#include "aom_dsp/aom_dsp_common.h"
20 21
#include "aom_dsp/binary_codes_writer.h"
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
22
#include "aom_mem/aom_mem.h"
23
#include "aom_ports/mem.h"
24
#include "aom_ports/system_state.h"
25

26 27
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
28
#include "av1/common/restoration.h"
29

30
#include "av1/encoder/av1_quantize.h"
31
#include "av1/encoder/encoder.h"
32
#include "av1/encoder/mathutils.h"
33 34
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
35

36
// When set to RESTORE_WIENER or RESTORE_SGRPROJ only those are allowed.
37
// When set to RESTORE_TYPES we allow switchable.
38
static const RestorationType force_restore_type = RESTORE_TYPES;
39 40

// Number of Wiener iterations
41
#define NUM_WIENER_ITERS 5
42

43
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
44
                                      AV1_COMP *cpi, int partial_frame,
45
                                      int plane, RestorationInfo *info,
46
                                      RestorationType *rest_level,
47 48
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
49

50
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
51 52

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
53 54
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
55
                                    int width, int v_start, int height,
56 57
                                    int components_pattern) {
  int64_t filt_err = 0;
58 59 60 61
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
62
#if CONFIG_HIGHBITDEPTH
63
  if (cm->use_highbitdepth) {
64 65 66 67 68
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
69 70
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
71 72
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
73 74
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
75 76
    }
    return filt_err;
77
  }
78
#endif  // CONFIG_HIGHBITDEPTH
79 80 81 82
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
83
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
84 85
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
86
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
87
  }
88 89 90
  return filt_err;
}

91 92
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
93 94 95
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
96
#if CONFIG_HIGHBITDEPTH
97 98 99 100 101 102 103 104 105 106 107 108
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
109 110
#else
  (void)cm;
111
#endif  // CONFIG_HIGHBITDEPTH
112 113 114 115 116 117 118 119 120 121 122 123
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

124 125
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
126
                                    int components_pattern, int partial_frame,
127
                                    int tile_idx,
128
                                    YV12_BUFFER_CONFIG *dst_frame) {
129 130 131
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
132 133 134 135 136 137 138 139 140 141 142 143 144
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
145 146 147
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
148 149
  (void)ntiles;

150 151
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
152
  RestorationTileLimits limits = av1_get_rest_tile_limits(
153 154 155 156 157 158
      tile_idx, nhtiles, nvtiles, tile_width, tile_height, width,
#if CONFIG_STRIPED_LOOP_RESTORATION
      height, components_pattern > 1 ? cm->subsampling_y : 0);
#else
      height);
#endif
159 160 161
  filt_err = sse_restoration_tile(
      src, dst_frame, cm, limits.h_start, limits.h_end - limits.h_start,
      limits.v_start, limits.v_end - limits.v_start, components_pattern);
162 163 164 165 166

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
167
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
168
                                     int components_pattern, int partial_frame,
169
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
170
  AV1_COMMON *const cm = &cpi->common;
171
  int64_t filt_err;
172 173
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
174
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
175 176 177
  return filt_err;
}

178 179
static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
                                    int src_stride, const uint8_t *dat8,
180
                                    int dat_stride, int use_highbitdepth,
181 182
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
183 184 185 186
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
187
  if (!use_highbitdepth) {
188 189 190 191 192 193 194 195
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
196
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
212
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
213 214 215 216 217
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
218 219 220 221 222
    }
  }
  return err;
}

223 224
#define USE_SGRPROJ_REFINEMENT_SEARCH 1
static int64_t finer_search_pixel_proj_error(
225
    const uint8_t *src8, int width, int height, int src_stride,
226
    const uint8_t *dat8, int dat_stride, int use_highbitdepth, int32_t *flt1,
227
    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
228
  int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
229 230
                                     dat_stride, use_highbitdepth, flt1,
                                     flt1_stride, flt2, flt2_stride, xqd);
231 232 233 234 235 236 237 238 239 240 241 242
  (void)start_step;
#if USE_SGRPROJ_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MIN1 };
  int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
  for (int s = start_step; s >= 1; s >>= 1) {
    for (int p = 0; p < 2; ++p) {
      int skip = 0;
      do {
        if (xqd[p] - s >= tap_min[p]) {
          xqd[p] -= s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
243 244
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
          if (err2 > err) {
            xqd[p] += s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (xqd[p] + s <= tap_max[p]) {
          xqd[p] += s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
261 262
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
          if (err2 > err) {
            xqd[p] -= s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
    }
  }
#endif  // USE_SGRPROJ_REFINEMENT_SEARCH
  return err;
}

279
static void get_proj_subspace(const uint8_t *src8, int width, int height,
280
                              int src_stride, uint8_t *dat8, int dat_stride,
281 282 283
                              int use_highbitdepth, int32_t *flt1,
                              int flt1_stride, int32_t *flt2, int flt2_stride,
                              int *xq) {
284 285 286 287 288 289 290
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

291 292
  aom_clear_system_state();

293 294 295
  // Default
  xq[0] = 0;
  xq[1] = 0;
296
  if (!use_highbitdepth) {
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
346
  xqd[0] = xq[0];
347
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
348
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
349 350 351 352
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
353
                                          int dat_stride, const uint8_t *src8,
354
                                          int src_stride, int use_highbitdepth,
355 356
                                          int bit_depth, int pu_width,
                                          int pu_height, int *eps, int *xqd,
357
                                          int32_t *rstbuf) {
358
  int32_t *flt1 = rstbuf;
359
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
360
  int ep, bestep = 0;
361 362
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
363 364
  int flt1_stride = ((width + 7) & ~7) + 8;
  int flt2_stride = ((width + 7) & ~7) + 8;
365 366 367 368
  assert(pu_width == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_width == RESTORATION_PROC_UNIT_SIZE);
  assert(pu_height == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_height == RESTORATION_PROC_UNIT_SIZE);
369 370 371
#if !CONFIG_HIGHBITDEPTH
  (void)bit_depth;
#endif
372

373 374
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
375
#if CONFIG_HIGHBITDEPTH
376
    if (use_highbitdepth) {
377
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
378 379 380 381 382 383 384
      for (int i = 0; i < height; i += pu_height)
        for (int j = 0; j < width; j += pu_width) {
          const int w = AOMMIN(pu_width, width - j);
          const int h = AOMMIN(pu_height, height - i);
          uint16_t *dat_p = dat + i * dat_stride + j;
          int32_t *flt1_p = flt1 + i * flt1_stride + j;
          int32_t *flt2_p = flt2 + i * flt2_stride + j;
385
#if USE_HIGHPASS_IN_SGRPROJ
386 387 388
          av1_highpass_filter_highbd(dat_p, w, h, dat_stride, flt1_p,
                                     flt1_stride, sgr_params[ep].corner,
                                     sgr_params[ep].edge);
389
#else
390
          av1_selfguided_restoration_highbd(
391
              dat_p, w, h, dat_stride, flt1_p, flt1_stride, bit_depth,
392
              sgr_params[ep].r1, sgr_params[ep].e1);
393
#endif  // USE_HIGHPASS_IN_SGRPROJ
394
          av1_selfguided_restoration_highbd(
395
              dat_p, w, h, dat_stride, flt2_p, flt2_stride, bit_depth,
396
              sgr_params[ep].r2, sgr_params[ep].e2);
397
        }
398
    } else {
399
#endif
400 401 402 403 404 405 406
      for (int i = 0; i < height; i += pu_height)
        for (int j = 0; j < width; j += pu_width) {
          const int w = AOMMIN(pu_width, width - j);
          const int h = AOMMIN(pu_height, height - i);
          uint8_t *dat_p = dat8 + i * dat_stride + j;
          int32_t *flt1_p = flt1 + i * flt1_stride + j;
          int32_t *flt2_p = flt2 + i * flt2_stride + j;
407
#if USE_HIGHPASS_IN_SGRPROJ
408 409
          av1_highpass_filter(dat_p, w, h, dat_stride, flt1_p, flt1_stride,
                              sgr_params[ep].corner, sgr_params[ep].edge);
410
#else
411
        av1_selfguided_restoration(dat_p, w, h, dat_stride, flt1_p, flt1_stride,
412
                                   sgr_params[ep].r1, sgr_params[ep].e1);
413
#endif  // USE_HIGHPASS_IN_SGRPROJ
414 415
          av1_selfguided_restoration(dat_p, w, h, dat_stride, flt2_p,
                                     flt2_stride, sgr_params[ep].r2,
416
                                     sgr_params[ep].e2);
417
        }
418
#if CONFIG_HIGHBITDEPTH
419
    }
420
#endif
421
    aom_clear_system_state();
422
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
423 424
                      use_highbitdepth, flt1, flt1_stride, flt2, flt2_stride,
                      exq);
425
    aom_clear_system_state();
426
    encode_xq(exq, exqd);
427 428 429
    err = finer_search_pixel_proj_error(
        src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth,
        flt1, flt1_stride, flt2, flt2_stride, 2, exqd);
430 431 432 433 434 435 436 437 438 439 440 441
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

442 443 444 445 446 447 448 449 450 451 452 453 454 455
static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                              SgrprojInfo *ref_sgrproj_info) {
  int bits = SGRPROJ_PARAMS_BITS;
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
      sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
      sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
  return bits;
}

456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
struct rest_search_ctxt {
  const YV12_BUFFER_CONFIG *src;
  AV1_COMP *cpi;
  uint8_t *dgd_buffer;
  const uint8_t *src_buffer;
  int dgd_stride;
  int src_stride;
  int partial_frame;
  RestorationInfo *info;
  RestorationType *type;
  double *best_tile_cost;
  int plane;
  int plane_width;
  int plane_height;
  int nrtiles_x;
  int nrtiles_y;
  YV12_BUFFER_CONFIG *dst_frame;
};

// Fill in ctxt. Returns the number of restoration tiles for this plane
static INLINE int init_rest_search_ctxt(
    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int partial_frame, int plane,
    RestorationInfo *info, RestorationType *type, double *best_tile_cost,
    YV12_BUFFER_CONFIG *dst_frame, struct rest_search_ctxt *ctxt) {
480
  AV1_COMMON *const cm = &cpi->common;
481 482 483 484 485 486 487 488 489
  ctxt->src = src;
  ctxt->cpi = cpi;
  ctxt->partial_frame = partial_frame;
  ctxt->info = info;
  ctxt->type = type;
  ctxt->best_tile_cost = best_tile_cost;
  ctxt->plane = plane;
  ctxt->dst_frame = dst_frame;

490
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
491
  if (plane == AOM_PLANE_Y) {
492 493 494 495 496 497 498 499 500 501
    ctxt->plane_width = src->y_crop_width;
    ctxt->plane_height = src->y_crop_height;
    ctxt->src_buffer = src->y_buffer;
    ctxt->src_stride = src->y_stride;
    ctxt->dgd_buffer = dgd->y_buffer;
    ctxt->dgd_stride = dgd->y_stride;
    assert(ctxt->plane_width == dgd->y_crop_width);
    assert(ctxt->plane_height == dgd->y_crop_height);
    assert(ctxt->plane_width == src->y_crop_width);
    assert(ctxt->plane_height == src->y_crop_height);
502
  } else {
503 504 505 506 507 508 509 510
    ctxt->plane_width = src->uv_crop_width;
    ctxt->plane_height = src->uv_crop_height;
    ctxt->src_stride = src->uv_stride;
    ctxt->dgd_stride = dgd->uv_stride;
    ctxt->src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
    ctxt->dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
    assert(ctxt->plane_width == dgd->uv_crop_width);
    assert(ctxt->plane_height == dgd->uv_crop_height);
511
  }
512

513 514 515 516
  return av1_get_rest_ntiles(ctxt->plane_width, ctxt->plane_height,
                             cm->rst_info[plane].restoration_tilesize, NULL,
                             NULL, &ctxt->nrtiles_x, &ctxt->nrtiles_y);
}
517

518
typedef void (*rtile_visitor_t)(const struct rest_search_ctxt *search_ctxt,
519 520
                                int rtile_idx,
                                const RestorationTileLimits *limits, void *arg);
521 522 523 524 525 526

static void foreach_rtile_in_tile(const struct rest_search_ctxt *ctxt,
                                  int tile_row, int tile_col,
                                  rtile_visitor_t fun, void *arg) {
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  const RestorationInfo *rsi = ctxt->cpi->rst_search;
Dominic Symes's avatar
Dominic Symes committed
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
  TileInfo tile_info;

  av1_tile_set_row(&tile_info, cm, tile_row);
  av1_tile_set_col(&tile_info, cm, tile_col);

  int tile_col_start = tile_info.mi_col_start * MI_SIZE;
  int tile_col_end = tile_info.mi_col_end * MI_SIZE;
  int tile_row_start = tile_info.mi_row_start * MI_SIZE;
  int tile_row_end = tile_info.mi_row_end * MI_SIZE;
  if (ctxt->plane > 0) {
    tile_col_start = ROUND_POWER_OF_TWO(tile_col_start, cm->subsampling_x);
    tile_col_end = ROUND_POWER_OF_TWO(tile_col_end, cm->subsampling_x);
    tile_row_start = ROUND_POWER_OF_TWO(tile_row_start, cm->subsampling_y);
    tile_row_end = ROUND_POWER_OF_TWO(tile_row_end, cm->subsampling_y);
  }
542 543

  const int rtile_size = rsi->restoration_tilesize;
Dominic Symes's avatar
Dominic Symes committed
544
  const int rtile_col0 = (tile_col_start + rtile_size - 1) / rtile_size;
545
  const int rtile_col1 =
Dominic Symes's avatar
Dominic Symes committed
546 547 548 549
      AOMMIN((tile_col_end + rtile_size - 1) / rtile_size, ctxt->nrtiles_x);
  const int rtile_row0 = (tile_row_start + rtile_size - 1) / rtile_size;
  const int rtile_row1 =
      AOMMIN((tile_row_end + rtile_size - 1) / rtile_size, ctxt->nrtiles_y);
550

Dominic Symes's avatar
Dominic Symes committed
551 552
  const int rtile_width = AOMMIN(tile_col_end - tile_col_start, rtile_size);
  const int rtile_height = AOMMIN(tile_row_end - tile_row_start, rtile_size);
553 554 555 556

  for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
    for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
      const int rtile_idx = rtile_row * ctxt->nrtiles_x + rtile_col;
557 558
      RestorationTileLimits limits = av1_get_rest_tile_limits(
          rtile_idx, ctxt->nrtiles_x, ctxt->nrtiles_y, rtile_width,
559 560 561 562 563 564
          rtile_height, ctxt->plane_width,
#if CONFIG_STRIPED_LOOP_RESTORATION
          ctxt->plane_height, ctxt->plane > 0 ? cm->subsampling_y : 0);
#else
          ctxt->plane_height);
#endif
565
      fun(ctxt, rtile_idx, &limits, arg);
566
    }
567
  }
568 569 570
}

static void search_sgrproj_for_rtile(const struct rest_search_ctxt *ctxt,
571 572 573
                                     int rtile_idx,
                                     const RestorationTileLimits *limits,
                                     void *arg) {
574 575 576 577 578 579 580
  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  RestorationInfo *rsi = ctxt->cpi->rst_search;
  SgrprojInfo *sgrproj_info = ctxt->info->sgrproj_info;

  SgrprojInfo *ref_sgrproj_info = (SgrprojInfo *)arg;

581 582 583 584
  int64_t err =
      sse_restoration_tile(ctxt->src, cm->frame_to_show, cm, limits->h_start,
                           limits->h_end - limits->h_start, limits->v_start,
                           limits->v_end - limits->v_start, (1 << ctxt->plane));
585 586 587 588 589 590 591
  // #bits when a tile is not restored
  int bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  ctxt->best_tile_cost[rtile_idx] = DBL_MAX;

  RestorationInfo *plane_rsi = &rsi[ctxt->plane];
  SgrprojInfo *rtile_sgrproj_info = &plane_rsi->sgrproj_info[rtile_idx];
592 593
  uint8_t *dgd_start =
      ctxt->dgd_buffer + limits->v_start * ctxt->dgd_stride + limits->h_start;
594
  const uint8_t *src_start =
595
      ctxt->src_buffer + limits->v_start * ctxt->src_stride + limits->h_start;
596

597
  search_selfguided_restoration(
598 599
      dgd_start, limits->h_end - limits->h_start,
      limits->v_end - limits->v_start, ctxt->dgd_stride, src_start,
600
      ctxt->src_stride,
601
#if CONFIG_HIGHBITDEPTH
602
      cm->use_highbitdepth, cm->bit_depth,
603
#else
604
      0, 8,
605
#endif  // CONFIG_HIGHBITDEPTH
606 607 608
      rsi[ctxt->plane].procunit_width, rsi[ctxt->plane].procunit_height,
      &rtile_sgrproj_info->ep, rtile_sgrproj_info->xqd,
      cm->rst_internal.tmpbuf);
609 610
  plane_rsi->restoration_type[rtile_idx] = RESTORE_SGRPROJ;
  err = try_restoration_tile(ctxt->src, ctxt->cpi, rsi, (1 << ctxt->plane),
611
                             ctxt->partial_frame, rtile_idx, ctxt->dst_frame);
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
  bits =
      count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx], ref_sgrproj_info)
      << AV1_PROB_COST_SHIFT;
  bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
  double cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  if (cost_sgrproj >= cost_norestore) {
    ctxt->type[rtile_idx] = RESTORE_NONE;
  } else {
    ctxt->type[rtile_idx] = RESTORE_SGRPROJ;
    *ref_sgrproj_info = sgrproj_info[rtile_idx] =
        plane_rsi->sgrproj_info[rtile_idx];
    ctxt->best_tile_cost[rtile_idx] = err;
  }
  plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int partial_frame, int plane,
                             RestorationInfo *info, RestorationType *type,
                             double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
  struct rest_search_ctxt ctxt;
  const int nrtiles =
      init_rest_search_ctxt(src, cpi, partial_frame, plane, info, type,
                            best_tile_cost, dst_frame, &ctxt);

  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
  plane_rsi->frame_restoration_type = RESTORE_SGRPROJ;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
  }

  // Compute best Sgrproj filters for each rtile, one (encoder/decoder)
  // tile at a time.
  const AV1_COMMON *const cm = &cpi->common;
647 648 649
#if CONFIG_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(ctxt.dgd_buffer), ctxt.plane_width,
650 651
                        ctxt.plane_height, ctxt.dgd_stride, SGRPROJ_BORDER_HORZ,
                        SGRPROJ_BORDER_VERT);
652 653 654
  else
#endif
    extend_frame(ctxt.dgd_buffer, ctxt.plane_width, ctxt.plane_height,
655
                 ctxt.dgd_stride, SGRPROJ_BORDER_HORZ, SGRPROJ_BORDER_VERT);
656

657 658 659 660 661 662
  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
      SgrprojInfo ref_sgrproj_info;
      set_default_sgrproj(&ref_sgrproj_info);
      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_sgrproj_for_rtile,
                            &ref_sgrproj_info);
663 664
    }
  }
665

666
  // Cost for Sgrproj filtering
667
  SgrprojInfo ref_sgrproj_info;
668
  set_default_sgrproj(&ref_sgrproj_info);
669 670 671 672 673 674 675 676 677 678
  SgrprojInfo *sgrproj_info = info->sgrproj_info;

  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
             << AV1_PROB_COST_SHIFT;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB,
                         type[rtile_idx] != RESTORE_NONE);
    plane_rsi->sgrproj_info[rtile_idx] = sgrproj_info[rtile_idx];
    if (type[rtile_idx] == RESTORE_SGRPROJ) {
      bits += count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx],
679 680
                                 &ref_sgrproj_info)
              << AV1_PROB_COST_SHIFT;
681
      ref_sgrproj_info = plane_rsi->sgrproj_info[rtile_idx];
682
    }
683
    plane_rsi->restoration_type[rtile_idx] = type[rtile_idx];
684
  }
685 686 687
  double err = try_restoration_frame(src, cpi, cpi->rst_search, (1 << plane),
                                     partial_frame, dst_frame);
  double cost_sgrproj = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
688 689 690
  return cost_sgrproj;
}

691 692
static double find_average(const uint8_t *src, int h_start, int h_end,
                           int v_start, int v_end, int stride) {
693 694 695
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
696
  aom_clear_system_state();
697 698 699
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
700 701 702
  return avg;
}

703 704 705 706
static void compute_stats(int wiener_win, const uint8_t *dgd,
                          const uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
707
  int i, j, k, l;
708
  double Y[WIENER_WIN2];
709 710
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
711 712
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
713

714 715
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
716 717
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
718 719
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
720 721
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
722 723 724 725
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
726
      assert(idx == wiener_win2);
727
      for (k = 0; k < wiener_win2; ++k) {
728
        M[k] += Y[k] * X;
729 730
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
731 732 733
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
734
          H[k * wiener_win2 + l] += Y[k] * Y[l];
735 736 737 738
        }
      }
    }
  }
739 740 741
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
742 743
    }
  }
744 745
}

746
#if CONFIG_HIGHBITDEPTH
747
static double find_average_highbd(const uint16_t *src, int h_start, int h_end,
748
                                  int v_start, int v_end, int stride) {
749 750 751
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
752
  aom_clear_system_state();
753 754 755
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
756 757 758
  return avg;
}

759 760 761 762
static void compute_stats_highbd(int wiener_win, const uint8_t *dgd8,
                                 const uint8_t *src8, int h_start, int h_end,
                                 int v_start, int v_end, int dgd_stride,
                                 int src_stride, double *M, double *H) {
763
  int i, j, k, l;
764
  double Y[WIENER_WIN2];
765 766
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
767 768
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
769 770
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
771

772 773
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
774 775
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
776 777
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
778 779
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
780 781 782 783
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
784
      assert(idx == wiener_win2);
785
      for (k = 0; k < wiener_win2; ++k) {
786
        M[k] += Y[k] * X;
787 788
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
789 790 791
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
792
          H[k * wiener_win2 + l] += Y[k] * Y[l];
793 794 795 796
        }
      }
    }
  }
797 798 799
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
800 801
    }
  }
802
}
803
#endif  // CONFIG_HIGHBITDEPTH
804

805 806 807
static INLINE int wrap_index(int i, int wiener_win) {
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
  return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
808 809 810
}

// Fix vector b, update vector a
811 812
static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
813
  int i, j;
814
  double S[WIENER_WIN];
815
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
816 817
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
818 819
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
820 821 822
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; ++j) {
      const int jj = wrap_index(j, wiener_win);
823 824 825
      A[jj] += Mc[i][j] * b[i];
    }
  }
826 827
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
828
      int k, l;
829 830 831 832 833 834
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l) {
          const int kk = wrap_index(k, wiener_win);
          const int ll = wrap_index(l, wiener_win);
          B[ll * wiener_halfwin1 + kk] +=
              Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
835 836 837
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
838
  // Normalization enforcement in the system of equations itself
839
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
840
    A[i] -=
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
856
    }
857
    memcpy(a, S, wiener_win * sizeof(*a));
858 859 860 861
  }
}

// Fix vector a, update vector b
862 863
static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
864
  int i, j;
865
  double S[WIENER_WIN];
866
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
867 868
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
869 870
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
871 872 873
  for (i = 0; i < wiener_win; i++) {
    const int ii = wrap_index(i, wiener_win);
    for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
874 875
  }

876 877 878 879
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
      const int ii = wrap_index(i, wiener_win);
      const int jj = wrap_index(j, wiener_win);
880
      int k, l;
881 882 883 884
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l)
          B[jj * wiener_halfwin1 + ii] +=
              Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
885 886
    }
  }
Aamir Anis's avatar
Aamir Anis committed
887
  // Normalization enforcement in the system of equations itself
888
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
889
    A[i] -=
890 891 892 893 894 895