pickrst.c 58.7 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

Yaowu Xu's avatar
Yaowu Xu committed
19
#include "aom_dsp/aom_dsp_common.h"
20 21
#include "aom_dsp/binary_codes_writer.h"
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
22
#include "aom_mem/aom_mem.h"
23
#include "aom_ports/mem.h"
24
#include "aom_ports/system_state.h"
25

26 27
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
28
#include "av1/common/restoration.h"
29

30
#include "av1/encoder/av1_quantize.h"
31
#include "av1/encoder/encoder.h"
32
#include "av1/encoder/mathutils.h"
33 34
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
35

36
// When set to RESTORE_WIENER or RESTORE_SGRPROJ only those are allowed.
37
// When set to RESTORE_TYPES we allow switchable.
38
static const RestorationType force_restore_type = RESTORE_TYPES;
39 40

// Number of Wiener iterations
41
#define NUM_WIENER_ITERS 5
42

43
typedef double (*search_restore_type)(const YV12_BUFFER_CONFIG *src,
44
                                      AV1_COMP *cpi, int partial_frame,
45
                                      int plane, RestorationInfo *info,
46
                                      RestorationType *rest_level,
47 48
                                      double *best_tile_cost,
                                      YV12_BUFFER_CONFIG *dst_frame);
49

50
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
51 52

static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
53 54
                                    const YV12_BUFFER_CONFIG *dst,
                                    const AV1_COMMON *cm, int h_start,
55
                                    int width, int v_start, int height,
56 57
                                    int components_pattern) {
  int64_t filt_err = 0;
58 59 60 61
  (void)cm;
  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);
62
#if CONFIG_HIGHBITDEPTH
63
  if (cm->use_highbitdepth) {
64 65 66 67 68
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err +=
          aom_highbd_get_y_sse_part(src, dst, h_start, width, v_start, height);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
69 70
      filt_err +=
          aom_highbd_get_u_sse_part(src, dst, h_start, width, v_start, height);
71 72
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
73 74
      filt_err +=
          aom_highbd_get_v_sse_part(src, dst, h_start, width, v_start, height);
75 76
    }
    return filt_err;
77
  }
78
#endif  // CONFIG_HIGHBITDEPTH
79 80 81 82
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err += aom_get_y_sse_part(src, dst, h_start, width, v_start, height);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
83
    filt_err += aom_get_u_sse_part(src, dst, h_start, width, v_start, height);
84 85
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
86
    filt_err += aom_get_v_sse_part(src, dst, h_start, width, v_start, height);
87
  }
88 89 90
  return filt_err;
}

91 92
static int64_t sse_restoration_frame(AV1_COMMON *const cm,
                                     const YV12_BUFFER_CONFIG *src,
93 94 95
                                     const YV12_BUFFER_CONFIG *dst,
                                     int components_pattern) {
  int64_t filt_err = 0;
96
#if CONFIG_HIGHBITDEPTH
97 98 99 100 101 102 103 104 105 106 107 108
  if (cm->use_highbitdepth) {
    if ((components_pattern >> AOM_PLANE_Y) & 1) {
      filt_err += aom_highbd_get_y_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_U) & 1) {
      filt_err += aom_highbd_get_u_sse(src, dst);
    }
    if ((components_pattern >> AOM_PLANE_V) & 1) {
      filt_err += aom_highbd_get_v_sse(src, dst);
    }
    return filt_err;
  }
109 110
#else
  (void)cm;
111
#endif  // CONFIG_HIGHBITDEPTH
112 113 114 115 116 117 118 119 120 121 122 123
  if ((components_pattern >> AOM_PLANE_Y) & 1) {
    filt_err = aom_get_y_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_U) & 1) {
    filt_err += aom_get_u_sse(src, dst);
  }
  if ((components_pattern >> AOM_PLANE_V) & 1) {
    filt_err += aom_get_v_sse(src, dst);
  }
  return filt_err;
}

124 125
static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
                                    AV1_COMP *const cpi, RestorationInfo *rsi,
126
                                    int components_pattern, int partial_frame,
127
                                    int tile_idx,
128
                                    YV12_BUFFER_CONFIG *dst_frame) {
129 130 131
  AV1_COMMON *const cm = &cpi->common;
  int64_t filt_err;
  int tile_width, tile_height, nhtiles, nvtiles;
132 133 134 135 136 137 138 139 140 141 142 143 144
  int ntiles, width, height;

  // Y and UV components cannot be mixed
  assert(components_pattern == 1 || components_pattern == 2 ||
         components_pattern == 4 || components_pattern == 6);

  if (components_pattern == 1) {  // Y only
    width = src->y_crop_width;
    height = src->y_crop_height;
  } else {  // Color
    width = src->uv_crop_width;
    height = src->uv_crop_height;
  }
145 146 147
  ntiles = av1_get_rest_ntiles(
      width, height, cm->rst_info[components_pattern > 1].restoration_tilesize,
      &tile_width, &tile_height, &nhtiles, &nvtiles);
148 149
  (void)ntiles;

150 151
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
152 153 154 155 156
  RestorationTileLimits limits = av1_get_rest_tile_limits(
      tile_idx, nhtiles, nvtiles, tile_width, tile_height, width, height);
  filt_err = sse_restoration_tile(
      src, dst_frame, cm, limits.h_start, limits.h_end - limits.h_start,
      limits.v_start, limits.v_end - limits.v_start, components_pattern);
157 158 159 160 161

  return filt_err;
}

static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
Yaowu Xu's avatar
Yaowu Xu committed
162
                                     AV1_COMP *const cpi, RestorationInfo *rsi,
163
                                     int components_pattern, int partial_frame,
164
                                     YV12_BUFFER_CONFIG *dst_frame) {
Yaowu Xu's avatar
Yaowu Xu committed
165
  AV1_COMMON *const cm = &cpi->common;
166
  int64_t filt_err;
167 168
  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, components_pattern,
                             partial_frame, dst_frame);
169
  filt_err = sse_restoration_frame(cm, src, dst_frame, components_pattern);
170 171 172
  return filt_err;
}

173 174
static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
                                    int src_stride, const uint8_t *dat8,
175
                                    int dat_stride, int use_highbitdepth,
176 177
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
178 179 180 181
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
182
  if (!use_highbitdepth) {
183 184 185 186 187 188 189 190
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
191
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
207
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
208 209 210 211 212
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
213 214 215 216 217
    }
  }
  return err;
}

218 219
#define USE_SGRPROJ_REFINEMENT_SEARCH 1
static int64_t finer_search_pixel_proj_error(
220
    const uint8_t *src8, int width, int height, int src_stride,
221
    const uint8_t *dat8, int dat_stride, int use_highbitdepth, int32_t *flt1,
222
    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
223
  int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
224 225
                                     dat_stride, use_highbitdepth, flt1,
                                     flt1_stride, flt2, flt2_stride, xqd);
226 227 228 229 230 231 232 233 234 235 236 237
  (void)start_step;
#if USE_SGRPROJ_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MIN1 };
  int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
  for (int s = start_step; s >= 1; s >>= 1) {
    for (int p = 0; p < 2; ++p) {
      int skip = 0;
      do {
        if (xqd[p] - s >= tap_min[p]) {
          xqd[p] -= s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
238 239
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
          if (err2 > err) {
            xqd[p] += s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (xqd[p] + s <= tap_max[p]) {
          xqd[p] += s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
256 257
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
          if (err2 > err) {
            xqd[p] -= s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
    }
  }
#endif  // USE_SGRPROJ_REFINEMENT_SEARCH
  return err;
}

274
static void get_proj_subspace(const uint8_t *src8, int width, int height,
275
                              int src_stride, uint8_t *dat8, int dat_stride,
276 277 278
                              int use_highbitdepth, int32_t *flt1,
                              int flt1_stride, int32_t *flt2, int flt2_stride,
                              int *xq) {
279 280 281 282 283 284 285
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

286 287
  aom_clear_system_state();

288 289 290
  // Default
  xq[0] = 0;
  xq[1] = 0;
291
  if (!use_highbitdepth) {
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
341
  xqd[0] = xq[0];
342
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
343
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
344 345 346 347
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

static void search_selfguided_restoration(uint8_t *dat8, int width, int height,
348
                                          int dat_stride, const uint8_t *src8,
349
                                          int src_stride, int use_highbitdepth,
350 351
                                          int bit_depth, int pu_width,
                                          int pu_height, int *eps, int *xqd,
352
                                          int32_t *rstbuf) {
353
  int32_t *flt1 = rstbuf;
354
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
355
  int ep, bestep = 0;
356 357
  int64_t err, besterr = -1;
  int exqd[2], bestxqd[2] = { 0, 0 };
358 359
  int flt1_stride = ((width + 7) & ~7) + 8;
  int flt2_stride = ((width + 7) & ~7) + 8;
360 361 362 363
  assert(pu_width == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_width == RESTORATION_PROC_UNIT_SIZE);
  assert(pu_height == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_height == RESTORATION_PROC_UNIT_SIZE);
364 365 366
#if !CONFIG_HIGHBITDEPTH
  (void)bit_depth;
#endif
367

368 369
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
370
#if CONFIG_HIGHBITDEPTH
371
    if (use_highbitdepth) {
372
      uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
373 374 375 376 377 378 379
      for (int i = 0; i < height; i += pu_height)
        for (int j = 0; j < width; j += pu_width) {
          const int w = AOMMIN(pu_width, width - j);
          const int h = AOMMIN(pu_height, height - i);
          uint16_t *dat_p = dat + i * dat_stride + j;
          int32_t *flt1_p = flt1 + i * flt1_stride + j;
          int32_t *flt2_p = flt2 + i * flt2_stride + j;
380
#if USE_HIGHPASS_IN_SGRPROJ
381 382 383
          av1_highpass_filter_highbd(dat_p, w, h, dat_stride, flt1_p,
                                     flt1_stride, sgr_params[ep].corner,
                                     sgr_params[ep].edge);
384
#else
385
          av1_selfguided_restoration_highbd(
386
              dat_p, w, h, dat_stride, flt1_p, flt1_stride, bit_depth,
387
              sgr_params[ep].r1, sgr_params[ep].e1);
388
#endif  // USE_HIGHPASS_IN_SGRPROJ
389
          av1_selfguided_restoration_highbd(
390
              dat_p, w, h, dat_stride, flt2_p, flt2_stride, bit_depth,
391
              sgr_params[ep].r2, sgr_params[ep].e2);
392
        }
393
    } else {
394
#endif
395 396 397 398 399 400 401
      for (int i = 0; i < height; i += pu_height)
        for (int j = 0; j < width; j += pu_width) {
          const int w = AOMMIN(pu_width, width - j);
          const int h = AOMMIN(pu_height, height - i);
          uint8_t *dat_p = dat8 + i * dat_stride + j;
          int32_t *flt1_p = flt1 + i * flt1_stride + j;
          int32_t *flt2_p = flt2 + i * flt2_stride + j;
402
#if USE_HIGHPASS_IN_SGRPROJ
403 404
          av1_highpass_filter(dat_p, w, h, dat_stride, flt1_p, flt1_stride,
                              sgr_params[ep].corner, sgr_params[ep].edge);
405
#else
406
        av1_selfguided_restoration(dat_p, w, h, dat_stride, flt1_p, flt1_stride,
407
                                   sgr_params[ep].r1, sgr_params[ep].e1);
408
#endif  // USE_HIGHPASS_IN_SGRPROJ
409 410
          av1_selfguided_restoration(dat_p, w, h, dat_stride, flt2_p,
                                     flt2_stride, sgr_params[ep].r2,
411
                                     sgr_params[ep].e2);
412
        }
413
#if CONFIG_HIGHBITDEPTH
414
    }
415
#endif
416
    aom_clear_system_state();
417
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
418 419
                      use_highbitdepth, flt1, flt1_stride, flt2, flt2_stride,
                      exq);
420
    aom_clear_system_state();
421
    encode_xq(exq, exqd);
422 423 424
    err = finer_search_pixel_proj_error(
        src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth,
        flt1, flt1_stride, flt2, flt2_stride, 2, exqd);
425 426 427 428 429 430 431 432 433 434 435 436
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
  *eps = bestep;
  xqd[0] = bestxqd[0];
  xqd[1] = bestxqd[1];
}

437 438 439 440 441 442 443 444 445 446 447 448 449 450
static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                              SgrprojInfo *ref_sgrproj_info) {
  int bits = SGRPROJ_PARAMS_BITS;
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
      sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
      sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
  return bits;
}

451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
struct rest_search_ctxt {
  const YV12_BUFFER_CONFIG *src;
  AV1_COMP *cpi;
  uint8_t *dgd_buffer;
  const uint8_t *src_buffer;
  int dgd_stride;
  int src_stride;
  int partial_frame;
  RestorationInfo *info;
  RestorationType *type;
  double *best_tile_cost;
  int plane;
  int plane_width;
  int plane_height;
  int nrtiles_x;
  int nrtiles_y;
  YV12_BUFFER_CONFIG *dst_frame;
};

// Fill in ctxt. Returns the number of restoration tiles for this plane
static INLINE int init_rest_search_ctxt(
    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int partial_frame, int plane,
    RestorationInfo *info, RestorationType *type, double *best_tile_cost,
    YV12_BUFFER_CONFIG *dst_frame, struct rest_search_ctxt *ctxt) {
475
  AV1_COMMON *const cm = &cpi->common;
476 477 478 479 480 481 482 483 484
  ctxt->src = src;
  ctxt->cpi = cpi;
  ctxt->partial_frame = partial_frame;
  ctxt->info = info;
  ctxt->type = type;
  ctxt->best_tile_cost = best_tile_cost;
  ctxt->plane = plane;
  ctxt->dst_frame = dst_frame;

485
  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
486
  if (plane == AOM_PLANE_Y) {
487 488 489 490 491 492 493 494 495 496
    ctxt->plane_width = src->y_crop_width;
    ctxt->plane_height = src->y_crop_height;
    ctxt->src_buffer = src->y_buffer;
    ctxt->src_stride = src->y_stride;
    ctxt->dgd_buffer = dgd->y_buffer;
    ctxt->dgd_stride = dgd->y_stride;
    assert(ctxt->plane_width == dgd->y_crop_width);
    assert(ctxt->plane_height == dgd->y_crop_height);
    assert(ctxt->plane_width == src->y_crop_width);
    assert(ctxt->plane_height == src->y_crop_height);
497
  } else {
498 499 500 501 502 503 504 505
    ctxt->plane_width = src->uv_crop_width;
    ctxt->plane_height = src->uv_crop_height;
    ctxt->src_stride = src->uv_stride;
    ctxt->dgd_stride = dgd->uv_stride;
    ctxt->src_buffer = plane == AOM_PLANE_U ? src->u_buffer : src->v_buffer;
    ctxt->dgd_buffer = plane == AOM_PLANE_U ? dgd->u_buffer : dgd->v_buffer;
    assert(ctxt->plane_width == dgd->uv_crop_width);
    assert(ctxt->plane_height == dgd->uv_crop_height);
506
  }
507

508 509 510 511
  return av1_get_rest_ntiles(ctxt->plane_width, ctxt->plane_height,
                             cm->rst_info[plane].restoration_tilesize, NULL,
                             NULL, &ctxt->nrtiles_x, &ctxt->nrtiles_y);
}
512

513
typedef void (*rtile_visitor_t)(const struct rest_search_ctxt *search_ctxt,
514 515
                                int rtile_idx,
                                const RestorationTileLimits *limits, void *arg);
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550

static void foreach_rtile_in_tile(const struct rest_search_ctxt *ctxt,
                                  int tile_row, int tile_col,
                                  rtile_visitor_t fun, void *arg) {
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  const RestorationInfo *rsi = ctxt->cpi->rst_search;

  const int tile_width_y = cm->tile_width * MI_SIZE;
  const int tile_height_y = cm->tile_height * MI_SIZE;

  const int tile_width =
      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_width_y, cm->subsampling_x)
                        : tile_width_y;
  const int tile_height =
      (ctxt->plane > 0) ? ROUND_POWER_OF_TWO(tile_height_y, cm->subsampling_y)
                        : tile_height_y;

  const int rtile_size = rsi->restoration_tilesize;
  const int rtiles_per_tile_x = tile_width * MI_SIZE / rtile_size;
  const int rtiles_per_tile_y = tile_height * MI_SIZE / rtile_size;

  const int rtile_row0 = rtiles_per_tile_y * tile_row;
  const int rtile_row1 =
      AOMMIN(rtile_row0 + rtiles_per_tile_y, ctxt->nrtiles_y);

  const int rtile_col0 = rtiles_per_tile_x * tile_col;
  const int rtile_col1 =
      AOMMIN(rtile_col0 + rtiles_per_tile_x, ctxt->nrtiles_x);

  const int rtile_width = AOMMIN(tile_width, rtile_size);
  const int rtile_height = AOMMIN(tile_height, rtile_size);

  for (int rtile_row = rtile_row0; rtile_row < rtile_row1; ++rtile_row) {
    for (int rtile_col = rtile_col0; rtile_col < rtile_col1; ++rtile_col) {
      const int rtile_idx = rtile_row * ctxt->nrtiles_x + rtile_col;
551 552 553 554
      RestorationTileLimits limits = av1_get_rest_tile_limits(
          rtile_idx, ctxt->nrtiles_x, ctxt->nrtiles_y, rtile_width,
          rtile_height, ctxt->plane_width, ctxt->plane_height);
      fun(ctxt, rtile_idx, &limits, arg);
555
    }
556
  }
557 558 559
}

static void search_sgrproj_for_rtile(const struct rest_search_ctxt *ctxt,
560 561 562
                                     int rtile_idx,
                                     const RestorationTileLimits *limits,
                                     void *arg) {
563 564 565 566 567 568 569
  const MACROBLOCK *const x = &ctxt->cpi->td.mb;
  const AV1_COMMON *const cm = &ctxt->cpi->common;
  RestorationInfo *rsi = ctxt->cpi->rst_search;
  SgrprojInfo *sgrproj_info = ctxt->info->sgrproj_info;

  SgrprojInfo *ref_sgrproj_info = (SgrprojInfo *)arg;

570 571 572 573
  int64_t err =
      sse_restoration_tile(ctxt->src, cm->frame_to_show, cm, limits->h_start,
                           limits->h_end - limits->h_start, limits->v_start,
                           limits->v_end - limits->v_start, (1 << ctxt->plane));
574 575 576 577 578 579 580
  // #bits when a tile is not restored
  int bits = av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 0);
  double cost_norestore = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  ctxt->best_tile_cost[rtile_idx] = DBL_MAX;

  RestorationInfo *plane_rsi = &rsi[ctxt->plane];
  SgrprojInfo *rtile_sgrproj_info = &plane_rsi->sgrproj_info[rtile_idx];
581 582
  uint8_t *dgd_start =
      ctxt->dgd_buffer + limits->v_start * ctxt->dgd_stride + limits->h_start;
583
  const uint8_t *src_start =
584
      ctxt->src_buffer + limits->v_start * ctxt->src_stride + limits->h_start;
585

586
  search_selfguided_restoration(
587 588
      dgd_start, limits->h_end - limits->h_start,
      limits->v_end - limits->v_start, ctxt->dgd_stride, src_start,
589
      ctxt->src_stride,
590
#if CONFIG_HIGHBITDEPTH
591
      cm->use_highbitdepth, cm->bit_depth,
592
#else
593
      0, 8,
594
#endif  // CONFIG_HIGHBITDEPTH
595 596 597
      rsi[ctxt->plane].procunit_width, rsi[ctxt->plane].procunit_height,
      &rtile_sgrproj_info->ep, rtile_sgrproj_info->xqd,
      cm->rst_internal.tmpbuf);
598 599
  plane_rsi->restoration_type[rtile_idx] = RESTORE_SGRPROJ;
  err = try_restoration_tile(ctxt->src, ctxt->cpi, rsi, (1 << ctxt->plane),
600
                             ctxt->partial_frame, rtile_idx, ctxt->dst_frame);
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635
  bits =
      count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx], ref_sgrproj_info)
      << AV1_PROB_COST_SHIFT;
  bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB, 1);
  double cost_sgrproj = RDCOST_DBL(x->rdmult, (bits >> 4), err);
  if (cost_sgrproj >= cost_norestore) {
    ctxt->type[rtile_idx] = RESTORE_NONE;
  } else {
    ctxt->type[rtile_idx] = RESTORE_SGRPROJ;
    *ref_sgrproj_info = sgrproj_info[rtile_idx] =
        plane_rsi->sgrproj_info[rtile_idx];
    ctxt->best_tile_cost[rtile_idx] = err;
  }
  plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
}

static double search_sgrproj(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                             int partial_frame, int plane,
                             RestorationInfo *info, RestorationType *type,
                             double *best_tile_cost,
                             YV12_BUFFER_CONFIG *dst_frame) {
  struct rest_search_ctxt ctxt;
  const int nrtiles =
      init_rest_search_ctxt(src, cpi, partial_frame, plane, info, type,
                            best_tile_cost, dst_frame, &ctxt);

  RestorationInfo *plane_rsi = &cpi->rst_search[plane];
  plane_rsi->frame_restoration_type = RESTORE_SGRPROJ;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    plane_rsi->restoration_type[rtile_idx] = RESTORE_NONE;
  }

  // Compute best Sgrproj filters for each rtile, one (encoder/decoder)
  // tile at a time.
  const AV1_COMMON *const cm = &cpi->common;
636 637 638
#if CONFIG_HIGHBITDEPTH
  if (cm->use_highbitdepth)
    extend_frame_highbd(CONVERT_TO_SHORTPTR(ctxt.dgd_buffer), ctxt.plane_width,
639 640
                        ctxt.plane_height, ctxt.dgd_stride, SGRPROJ_BORDER_HORZ,
                        SGRPROJ_BORDER_VERT);
641 642 643
  else
#endif
    extend_frame(ctxt.dgd_buffer, ctxt.plane_width, ctxt.plane_height,
644
                 ctxt.dgd_stride, SGRPROJ_BORDER_HORZ, SGRPROJ_BORDER_VERT);
645

646 647 648 649 650 651
  for (int tile_row = 0; tile_row < cm->tile_rows; ++tile_row) {
    for (int tile_col = 0; tile_col < cm->tile_cols; ++tile_col) {
      SgrprojInfo ref_sgrproj_info;
      set_default_sgrproj(&ref_sgrproj_info);
      foreach_rtile_in_tile(&ctxt, tile_row, tile_col, search_sgrproj_for_rtile,
                            &ref_sgrproj_info);
652 653
    }
  }
654

655
  // Cost for Sgrproj filtering
656
  SgrprojInfo ref_sgrproj_info;
657
  set_default_sgrproj(&ref_sgrproj_info);
658 659 660 661 662 663 664 665 666 667
  SgrprojInfo *sgrproj_info = info->sgrproj_info;

  int bits = frame_level_restore_bits[plane_rsi->frame_restoration_type]
             << AV1_PROB_COST_SHIFT;
  for (int rtile_idx = 0; rtile_idx < nrtiles; ++rtile_idx) {
    bits += av1_cost_bit(RESTORE_NONE_SGRPROJ_PROB,
                         type[rtile_idx] != RESTORE_NONE);
    plane_rsi->sgrproj_info[rtile_idx] = sgrproj_info[rtile_idx];
    if (type[rtile_idx] == RESTORE_SGRPROJ) {
      bits += count_sgrproj_bits(&plane_rsi->sgrproj_info[rtile_idx],
668 669
                                 &ref_sgrproj_info)
              << AV1_PROB_COST_SHIFT;
670
      ref_sgrproj_info = plane_rsi->sgrproj_info[rtile_idx];
671
    }
672
    plane_rsi->restoration_type[rtile_idx] = type[rtile_idx];
673
  }
674 675 676
  double err = try_restoration_frame(src, cpi, cpi->rst_search, (1 << plane),
                                     partial_frame, dst_frame);
  double cost_sgrproj = RDCOST_DBL(cpi->td.mb.rdmult, (bits >> 4), err);
677 678 679
  return cost_sgrproj;
}

680 681
static double find_average(const uint8_t *src, int h_start, int h_end,
                           int v_start, int v_end, int stride) {
682 683 684
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
685
  aom_clear_system_state();
686 687 688
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
689 690 691
  return avg;
}

692 693 694 695
static void compute_stats(int wiener_win, const uint8_t *dgd,
                          const uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
696
  int i, j, k, l;
697
  double Y[WIENER_WIN2];
698 699
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
700 701
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
702

703 704
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
705 706
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
707 708
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
709 710
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
711 712 713 714
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
715
      for (k = 0; k < wiener_win2; ++k) {
716
        M[k] += Y[k] * X;
717 718
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
719 720 721
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
722
          H[k * wiener_win2 + l] += Y[k] * Y[l];
723 724 725 726
        }
      }
    }
  }
727 728 729
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
730 731
    }
  }
732 733
}

734
#if CONFIG_HIGHBITDEPTH
735
static double find_average_highbd(const uint16_t *src, int h_start, int h_end,
736
                                  int v_start, int v_end, int stride) {
737 738 739
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
740
  aom_clear_system_state();
741 742 743
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
744 745 746
  return avg;
}

747 748 749 750
static void compute_stats_highbd(int wiener_win, const uint8_t *dgd8,
                                 const uint8_t *src8, int h_start, int h_end,
                                 int v_start, int v_end, int dgd_stride,
                                 int src_stride, double *M, double *H) {
751
  int i, j, k, l;
752
  double Y[WIENER_WIN2];
753 754
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
755 756
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
757 758
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
759

760 761
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
762 763
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
764 765
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
766 767
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
768 769 770 771
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
772
      for (k = 0; k < wiener_win2; ++k) {
773
        M[k] += Y[k] * X;
774 775
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
776 777 778
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
779
          H[k * wiener_win2 + l] += Y[k] * Y[l];
780 781 782 783
        }
      }
    }
  }
784 785 786
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
787 788
    }
  }
789
}
790
#endif  // CONFIG_HIGHBITDEPTH
791

792 793 794
static INLINE int wrap_index(int i, int wiener_win) {
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
  return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
795 796 797
}

// Fix vector b, update vector a
798 799
static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
800
  int i, j;
801
  double S[WIENER_WIN];
802
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
803 804
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
805 806
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
807 808 809
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; ++j) {
      const int jj = wrap_index(j, wiener_win);
810 811 812
      A[jj] += Mc[i][j] * b[i];
    }
  }
813 814
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
815
      int k, l;
816 817 818 819 820 821
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l) {
          const int kk = wrap_index(k, wiener_win);
          const int ll = wrap_index(l, wiener_win);
          B[ll * wiener_halfwin1 + kk] +=
              Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
822 823 824
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
825
  // Normalization enforcement in the system of equations itself
826
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
827
    A[i] -=
828 829 830 831 832 833 834 835 836 837 838 839 840 841 842
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
843
    }
844
    memcpy(a, S, wiener_win * sizeof(*a));
845 846 847 848
  }
}

// Fix vector a, update vector b
849 850
static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
851
  int i, j;
852
  double S[WIENER_WIN];
853
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
854 855
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
856 857
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
858 859 860
  for (i = 0; i < wiener_win; i++) {
    const int ii = wrap_index(i, wiener_win);
    for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
861 862
  }

863 864 865 866
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
      const int ii = wrap_index(i, wiener_win);
      const int jj = wrap_index(j, wiener_win);
867
      int k, l;
868 869 870 871
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l)
          B[jj * wiener_halfwin1 + ii] +=
              Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
872 873
    }
  }
Aamir Anis's avatar
Aamir Anis committed
874
  // Normalization enforcement in the system of equations itself
875
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
876
    A[i] -=
877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
892
    }
893
    memcpy(b, S, wiener_win * sizeof(*b));
894 895 896
  }
}

897 898
static int wiener_decompose_sep_sym(int wiener_win, double *M, double *H,
                                    double *a, double *b) {
899 900 901 902
  static const int init_filt[WIENER_WIN] = {
    WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
    WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
    WIENER_FILT_TAP0_MIDV,
903
  };
904 905
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
906 907