pickrst.c 43.4 KB
Newer Older
1
/*
2
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3
 *
4 5 6 7 8 9
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12
 */

#include <assert.h>
13
#include <float.h>
14 15 16
#include <limits.h>
#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
17
#include "./aom_scale_rtcd.h"
18

Yaowu Xu's avatar
Yaowu Xu committed
19
#include "aom_dsp/aom_dsp_common.h"
20 21
#include "aom_dsp/binary_codes_writer.h"
#include "aom_dsp/psnr.h"
Yaowu Xu's avatar
Yaowu Xu committed
22
#include "aom_mem/aom_mem.h"
23
#include "aom_ports/mem.h"
24
#include "aom_ports/system_state.h"
25

26 27
#include "av1/common/onyxc_int.h"
#include "av1/common/quant_common.h"
28
#include "av1/common/restoration.h"
29

30
#include "av1/encoder/av1_quantize.h"
31
#include "av1/encoder/encoder.h"
32
#include "av1/encoder/mathutils.h"
33 34
#include "av1/encoder/picklpf.h"
#include "av1/encoder/pickrst.h"
35

36
// When set to RESTORE_WIENER or RESTORE_SGRPROJ only those are allowed.
37
// When set to RESTORE_TYPES we allow switchable.
38
static const RestorationType force_restore_type = RESTORE_TYPES;
39 40

// Number of Wiener iterations
41
#define NUM_WIENER_ITERS 5
42

43
const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
44

45 46 47 48 49 50 51
typedef int64_t (*sse_extractor_type)(const YV12_BUFFER_CONFIG *a,
                                      const YV12_BUFFER_CONFIG *b);
typedef int64_t (*sse_part_extractor_type)(const YV12_BUFFER_CONFIG *a,
                                           const YV12_BUFFER_CONFIG *b,
                                           int hstart, int width, int vstart,
                                           int height);

52
#define NUM_EXTRACTORS (3 * (1 + 1))
53 54 55

static const sse_part_extractor_type sse_part_extractors[NUM_EXTRACTORS] = {
  aom_get_y_sse_part,        aom_get_u_sse_part,
56 57
  aom_get_v_sse_part,        aom_highbd_get_y_sse_part,
  aom_highbd_get_u_sse_part, aom_highbd_get_v_sse_part,
58
};
59

60 61 62 63 64 65 66 67
static int64_t sse_restoration_tile(const RestorationTileLimits *limits,
                                    const YV12_BUFFER_CONFIG *src,
                                    const YV12_BUFFER_CONFIG *dst, int plane,
                                    int highbd) {
  return sse_part_extractors[3 * highbd + plane](
      src, dst, limits->h_start, limits->h_end - limits->h_start,
      limits->v_start, limits->v_end - limits->v_start);
}
68

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
typedef struct {
  // The best coefficients for Wiener or Sgrproj restoration
  WienerInfo wiener;
  SgrprojInfo sgrproj;

  // The sum of squared errors for this rtype.
  int64_t sse[RESTORE_SWITCHABLE_TYPES];

  // The rtype to use for this unit given a frame rtype as
  // index. Indices: WIENER, SGRPROJ, SWITCHABLE.
  RestorationType best_rtype[RESTORE_TYPES - 1];
} RestUnitSearchInfo;

typedef struct {
  const YV12_BUFFER_CONFIG *src;
  YV12_BUFFER_CONFIG *dst;

  const AV1_COMMON *cm;
  const MACROBLOCK *x;
  int plane;
  int plane_width;
  int plane_height;
  RestUnitSearchInfo *rusi;

  uint8_t *dgd_buffer;
  int dgd_stride;
  const uint8_t *src_buffer;
  int src_stride;

  // sse and bits are initialised by reset_rsc in search_rest_type
  int64_t sse;
  int64_t bits;
  int tile_y0, tile_stripe0;

  // sgrproj and wiener are initialised by rsc_on_tile when starting the first
  // tile in the frame.
  SgrprojInfo sgrproj;
  WienerInfo wiener;
} RestSearchCtxt;

static void rsc_on_tile(int tile_row, int tile_col, void *priv) {
  (void)tile_col;

  RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
  set_default_sgrproj(&rsc->sgrproj);
  set_default_wiener(&rsc->wiener);

  rsc->tile_stripe0 =
      (tile_row == 0) ? 0 : rsc->cm->rst_end_stripe[tile_row - 1];
}

static void reset_rsc(RestSearchCtxt *rsc) {
  rsc->sse = 0;
  rsc->bits = 0;
}

static void init_rsc(const YV12_BUFFER_CONFIG *src, const AV1_COMMON *cm,
                     const MACROBLOCK *x, int plane, RestUnitSearchInfo *rusi,
                     YV12_BUFFER_CONFIG *dst, RestSearchCtxt *rsc) {
  rsc->src = src;
  rsc->dst = dst;
  rsc->cm = cm;
  rsc->x = x;
  rsc->plane = plane;
  rsc->rusi = rusi;

  const YV12_BUFFER_CONFIG *dgd = cm->frame_to_show;
  const int is_uv = plane != AOM_PLANE_Y;
  rsc->plane_width = src->crop_widths[is_uv];
138
  rsc->plane_height = src->crop_heights[is_uv];
139 140 141 142 143 144 145 146 147
  rsc->src_buffer = src->buffers[plane];
  rsc->src_stride = src->strides[is_uv];
  rsc->dgd_buffer = dgd->buffers[plane];
  rsc->dgd_stride = dgd->strides[is_uv];
  assert(src->crop_widths[is_uv] == dgd->crop_widths[is_uv]);
  assert(src->crop_heights[is_uv] == dgd->crop_heights[is_uv]);
}

static int64_t try_restoration_tile(const RestSearchCtxt *rsc,
148
                                    const RestorationTileLimits *limits,
149 150 151 152
                                    const AV1PixelRect *tile_rect,
                                    const RestorationUnitInfo *rui) {
  const AV1_COMMON *const cm = rsc->cm;
  const int plane = rsc->plane;
153
  const int is_uv = plane > 0;
154
  const RestorationInfo *rsi = &cm->rst_info[plane];
155 156 157
  RestorationLineBuffers rlbs;
  const int bit_depth = cm->bit_depth;
  const int highbd = cm->use_highbitdepth;
158

159
  const YV12_BUFFER_CONFIG *fts = cm->frame_to_show;
160

161
  av1_loop_restoration_filter_unit(
162
      limits, rui, &rsi->boundaries, &rlbs, tile_rect, rsc->tile_stripe0,
163
#if CONFIG_LOOPFILTERING_ACROSS_TILES
164 165 166 167
#if CONFIG_LOOPFILTERING_ACROSS_TILES_EXT
      cm->loop_filter_across_tiles_v_enabled,
      cm->loop_filter_across_tiles_h_enabled,
#else
168
      cm->loop_filter_across_tiles_enabled,
169 170
#endif  // CONFIG_LOOPFILTERING_ACROSS_TILES_EXT
#endif  // CONFIG_LOOPFILTERING_ACROSS_TILES
171
      is_uv && cm->subsampling_x, is_uv && cm->subsampling_y, highbd, bit_depth,
172 173
      fts->buffers[plane], fts->strides[is_uv], rsc->dst->buffers[plane],
      rsc->dst->strides[is_uv], cm->rst_tmpbuf);
174

175
  return sse_restoration_tile(limits, rsc->src, rsc->dst, plane, highbd);
176 177
}

178 179
static int64_t get_pixel_proj_error(const uint8_t *src8, int width, int height,
                                    int src_stride, const uint8_t *dat8,
180
                                    int dat_stride, int use_highbitdepth,
181 182
                                    int32_t *flt1, int flt1_stride,
                                    int32_t *flt2, int flt2_stride, int *xqd) {
183 184 185 186
  int i, j;
  int64_t err = 0;
  int xq[2];
  decode_xq(xqd, xq);
187
  if (!use_highbitdepth) {
188 189 190 191 192 193 194 195
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
196
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const int32_t u =
            (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
        const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
David Barker's avatar
David Barker committed
212
        const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
213 214 215 216 217
        const int32_t e =
            ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
            src[i * src_stride + j];
        err += e * e;
      }
218 219 220 221 222
    }
  }
  return err;
}

223 224
#define USE_SGRPROJ_REFINEMENT_SEARCH 1
static int64_t finer_search_pixel_proj_error(
225
    const uint8_t *src8, int width, int height, int src_stride,
226
    const uint8_t *dat8, int dat_stride, int use_highbitdepth, int32_t *flt1,
227
    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
228
  int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
229 230
                                     dat_stride, use_highbitdepth, flt1,
                                     flt1_stride, flt2, flt2_stride, xqd);
231 232 233 234 235 236 237 238 239 240 241 242
  (void)start_step;
#if USE_SGRPROJ_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MIN1 };
  int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
  for (int s = start_step; s >= 1; s >>= 1) {
    for (int p = 0; p < 2; ++p) {
      int skip = 0;
      do {
        if (xqd[p] - s >= tap_min[p]) {
          xqd[p] -= s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
243 244
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
          if (err2 > err) {
            xqd[p] += s;
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
        if (xqd[p] + s <= tap_max[p]) {
          xqd[p] += s;
          err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
261 262
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
          if (err2 > err) {
            xqd[p] -= s;
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
    }
  }
#endif  // USE_SGRPROJ_REFINEMENT_SEARCH
  return err;
}

279
static void get_proj_subspace(const uint8_t *src8, int width, int height,
280 281 282 283
                              int src_stride, const uint8_t *dat8,
                              int dat_stride, int use_highbitdepth,
                              int32_t *flt1, int flt1_stride, int32_t *flt2,
                              int flt2_stride, int *xq) {
284 285 286 287 288 289 290
  int i, j;
  double H[2][2] = { { 0, 0 }, { 0, 0 } };
  double C[2] = { 0, 0 };
  double Det;
  double x[2];
  const int size = width * height;

291 292
  aom_clear_system_state();

293 294 295
  // Default
  xq[0] = 0;
  xq[1] = 0;
296
  if (!use_highbitdepth) {
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
    const uint8_t *src = src8;
    const uint8_t *dat = dat8;
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
    }
  } else {
    const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    const uint16_t *dat = CONVERT_TO_SHORTPTR(dat8);
    for (i = 0; i < height; ++i) {
      for (j = 0; j < width; ++j) {
        const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
        const double s =
            (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
        const double f1 = (double)flt1[i * flt1_stride + j] - u;
        const double f2 = (double)flt2[i * flt2_stride + j] - u;
        H[0][0] += f1 * f1;
        H[1][1] += f2 * f2;
        H[0][1] += f1 * f2;
        C[0] += f1 * s;
        C[1] += f2 * s;
      }
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
    }
  }
  H[0][0] /= size;
  H[0][1] /= size;
  H[1][1] /= size;
  H[1][0] = H[0][1];
  C[0] /= size;
  C[1] /= size;
  Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
  if (Det < 1e-8) return;  // ill-posed, return default values
  x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
  x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
  xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
  xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
}

void encode_xq(int *xq, int *xqd) {
346
  xqd[0] = xq[0];
347
  xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
348
  xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
349 350 351
  xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
}

352 353 354 355
static void sgr_filter_block(const sgr_params_type *params, const uint8_t *dat8,
                             int width, int height, int dat_stride,
                             int use_highbd, int bit_depth, int32_t *flt1,
                             int32_t *flt2, int flt_stride) {
356
#if CONFIG_FAST_SGR
357 358 359
  av1_selfguided_restoration_c(dat8, width, height, dat_stride, flt1, flt2,
                               flt_stride, params, bit_depth, use_highbd);
#else
360 361
  av1_selfguided_restoration(dat8, width, height, dat_stride, flt1, flt2,
                             flt_stride, params, bit_depth, use_highbd);
362
#endif  // CONFIG_FAST_SGR
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
}

// Apply the self-guided filter across an entire restoration unit.
static void apply_sgr(const sgr_params_type *params, const uint8_t *dat8,
                      int width, int height, int dat_stride, int use_highbd,
                      int bit_depth, int pu_width, int pu_height, int32_t *flt1,
                      int32_t *flt2, int flt_stride) {
  for (int i = 0; i < height; i += pu_height) {
    const int h = AOMMIN(pu_height, height - i);
    int32_t *flt1_row = flt1 + i * flt_stride;
    int32_t *flt2_row = flt2 + i * flt_stride;
    const uint8_t *dat8_row = dat8 + i * dat_stride;

    // Iterate over the stripe in blocks of width pu_width
    for (int j = 0; j < width; j += pu_width) {
      const int w = AOMMIN(pu_width, width - j);
      sgr_filter_block(params, dat8_row + j, w, h, dat_stride, use_highbd,
                       bit_depth, flt1_row + j, flt2_row + j, flt_stride);
    }
  }
}

385
static SgrprojInfo search_selfguided_restoration(
386 387 388
    const uint8_t *dat8, int width, int height, int dat_stride,
    const uint8_t *src8, int src_stride, int use_highbitdepth, int bit_depth,
    int pu_width, int pu_height, int32_t *rstbuf) {
389
  int32_t *flt1 = rstbuf;
390
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
391
  int ep, bestep = 0;
392
  int64_t besterr = -1;
393
  int exqd[2], bestxqd[2] = { 0, 0 };
394
  int flt_stride = ((width + 7) & ~7) + 8;
395 396 397 398
  assert(pu_width == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_width == RESTORATION_PROC_UNIT_SIZE);
  assert(pu_height == (RESTORATION_PROC_UNIT_SIZE >> 1) ||
         pu_height == RESTORATION_PROC_UNIT_SIZE);
399

400 401
  for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
    int exq[2];
402 403 404
    apply_sgr(&sgr_params[ep], dat8, width, height, dat_stride,
              use_highbitdepth, bit_depth, pu_width, pu_height, flt1, flt2,
              flt_stride);
405
    aom_clear_system_state();
406
    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
407
                      use_highbitdepth, flt1, flt_stride, flt2, flt_stride,
408
                      exq);
409
    aom_clear_system_state();
410
    encode_xq(exq, exqd);
411
    int64_t err = finer_search_pixel_proj_error(
412
        src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth,
413
        flt1, flt_stride, flt2, flt_stride, 2, exqd);
414 415 416 417 418 419 420
    if (besterr == -1 || err < besterr) {
      bestep = ep;
      besterr = err;
      bestxqd[0] = exqd[0];
      bestxqd[1] = exqd[1];
    }
  }
421 422 423 424 425 426

  SgrprojInfo ret;
  ret.ep = bestep;
  ret.xqd[0] = bestxqd[0];
  ret.xqd[1] = bestxqd[1];
  return ret;
427 428
}

429 430 431 432 433 434 435 436 437 438 439 440 441 442
static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                              SgrprojInfo *ref_sgrproj_info) {
  int bits = SGRPROJ_PARAMS_BITS;
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
      sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
  bits += aom_count_primitive_refsubexpfin(
      SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
      ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
      sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
  return bits;
}

443
static void search_sgrproj(const RestorationTileLimits *limits,
444 445
                           const AV1PixelRect *tile, int rest_unit_idx,
                           void *priv) {
446 447 448 449 450
  RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
  RestUnitSearchInfo *rusi = &rsc->rusi[rest_unit_idx];

  const MACROBLOCK *const x = rsc->x;
  const AV1_COMMON *const cm = rsc->cm;
451 452 453
  const int highbd = cm->use_highbitdepth;
  const int bit_depth = cm->bit_depth;

454
  uint8_t *dgd_start =
455
      rsc->dgd_buffer + limits->v_start * rsc->dgd_stride + limits->h_start;
456
  const uint8_t *src_start =
457
      rsc->src_buffer + limits->v_start * rsc->src_stride + limits->h_start;
458

459 460 461 462 463 464
  const int is_uv = rsc->plane > 0;
  const int ss_x = is_uv && cm->subsampling_x;
  const int ss_y = is_uv && cm->subsampling_y;
  const int procunit_width = RESTORATION_PROC_UNIT_SIZE >> ss_x;
  const int procunit_height = RESTORATION_PROC_UNIT_SIZE >> ss_y;

465
  rusi->sgrproj = search_selfguided_restoration(
466
      dgd_start, limits->h_end - limits->h_start,
467
      limits->v_end - limits->v_start, rsc->dgd_stride, src_start,
468 469
      rsc->src_stride, highbd, bit_depth, procunit_width, procunit_height,
      cm->rst_tmpbuf);
470 471 472 473 474

  RestorationUnitInfo rui;
  rui.restoration_type = RESTORE_SGRPROJ;
  rui.sgrproj_info = rusi->sgrproj;

475
  rusi->sse[RESTORE_SGRPROJ] = try_restoration_tile(rsc, limits, tile, &rui);
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493

  const int64_t bits_none = x->sgrproj_restore_cost[0];
  const int64_t bits_sgr = x->sgrproj_restore_cost[1] +
                           (count_sgrproj_bits(&rusi->sgrproj, &rsc->sgrproj)
                            << AV1_PROB_COST_SHIFT);

  double cost_none =
      RDCOST_DBL(x->rdmult, bits_none >> 4, rusi->sse[RESTORE_NONE]);
  double cost_sgr =
      RDCOST_DBL(x->rdmult, bits_sgr >> 4, rusi->sse[RESTORE_SGRPROJ]);

  RestorationType rtype =
      (cost_sgr < cost_none) ? RESTORE_SGRPROJ : RESTORE_NONE;
  rusi->best_rtype[RESTORE_SGRPROJ - 1] = rtype;

  rsc->sse += rusi->sse[rtype];
  rsc->bits += (cost_sgr < cost_none) ? bits_sgr : bits_none;
  if (cost_sgr < cost_none) rsc->sgrproj = rusi->sgrproj;
494 495
}

496 497
static double find_average(const uint8_t *src, int h_start, int h_end,
                           int v_start, int v_end, int stride) {
498 499 500
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
501
  aom_clear_system_state();
502 503 504
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
505 506 507
  return avg;
}

508 509 510 511
static void compute_stats(int wiener_win, const uint8_t *dgd,
                          const uint8_t *src, int h_start, int h_end,
                          int v_start, int v_end, int dgd_stride,
                          int src_stride, double *M, double *H) {
512
  int i, j, k, l;
513
  double Y[WIENER_WIN2];
514 515
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
516 517
  const double avg =
      find_average(dgd, h_start, h_end, v_start, v_end, dgd_stride);
518

519 520
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
521 522
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
523 524
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
525 526
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
527 528 529 530
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
531
      assert(idx == wiener_win2);
532
      for (k = 0; k < wiener_win2; ++k) {
533
        M[k] += Y[k] * X;
534 535
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
536 537 538
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
539
          H[k * wiener_win2 + l] += Y[k] * Y[l];
540 541 542 543
        }
      }
    }
  }
544 545 546
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
547 548
    }
  }
549 550
}

551
static double find_average_highbd(const uint16_t *src, int h_start, int h_end,
552
                                  int v_start, int v_end, int stride) {
553 554 555
  uint64_t sum = 0;
  double avg = 0;
  int i, j;
556
  aom_clear_system_state();
557 558 559
  for (i = v_start; i < v_end; i++)
    for (j = h_start; j < h_end; j++) sum += src[i * stride + j];
  avg = (double)sum / ((v_end - v_start) * (h_end - h_start));
560 561 562
  return avg;
}

563 564 565 566
static void compute_stats_highbd(int wiener_win, const uint8_t *dgd8,
                                 const uint8_t *src8, int h_start, int h_end,
                                 int v_start, int v_end, int dgd_stride,
                                 int src_stride, double *M, double *H) {
567
  int i, j, k, l;
568
  double Y[WIENER_WIN2];
569 570
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = (wiener_win >> 1);
571 572
  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
  const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8);
573 574
  const double avg =
      find_average_highbd(dgd, h_start, h_end, v_start, v_end, dgd_stride);
575

576 577
  memset(M, 0, sizeof(*M) * wiener_win2);
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
578 579
  for (i = v_start; i < v_end; i++) {
    for (j = h_start; j < h_end; j++) {
580 581
      const double X = (double)src[i * src_stride + j] - avg;
      int idx = 0;
582 583
      for (k = -wiener_halfwin; k <= wiener_halfwin; k++) {
        for (l = -wiener_halfwin; l <= wiener_halfwin; l++) {
584 585 586 587
          Y[idx] = (double)dgd[(i + l) * dgd_stride + (j + k)] - avg;
          idx++;
        }
      }
588
      assert(idx == wiener_win2);
589
      for (k = 0; k < wiener_win2; ++k) {
590
        M[k] += Y[k] * X;
591 592
        H[k * wiener_win2 + k] += Y[k] * Y[k];
        for (l = k + 1; l < wiener_win2; ++l) {
593 594 595
          // H is a symmetric matrix, so we only need to fill out the upper
          // triangle here. We can copy it down to the lower triangle outside
          // the (i, j) loops.
596
          H[k * wiener_win2 + l] += Y[k] * Y[l];
597 598 599 600
        }
      }
    }
  }
601 602 603
  for (k = 0; k < wiener_win2; ++k) {
    for (l = k + 1; l < wiener_win2; ++l) {
      H[l * wiener_win2 + k] = H[k * wiener_win2 + l];
604 605
    }
  }
606 607
}

608 609 610
static INLINE int wrap_index(int i, int wiener_win) {
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
  return (i >= wiener_halfwin1 ? wiener_win - 1 - i : i);
611 612 613
}

// Fix vector b, update vector a
614 615
static void update_a_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
616
  int i, j;
617
  double S[WIENER_WIN];
618
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
619 620
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
621 622
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
623 624 625
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; ++j) {
      const int jj = wrap_index(j, wiener_win);
626 627 628
      A[jj] += Mc[i][j] * b[i];
    }
  }
629 630
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
631
      int k, l;
632 633 634 635 636 637
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l) {
          const int kk = wrap_index(k, wiener_win);
          const int ll = wrap_index(l, wiener_win);
          B[ll * wiener_halfwin1 + kk] +=
              Hc[j * wiener_win + i][k * wiener_win2 + l] * b[i] * b[j];
638 639 640
        }
    }
  }
Aamir Anis's avatar
Aamir Anis committed
641
  // Normalization enforcement in the system of equations itself
642
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
643
    A[i] -=
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
659
    }
660
    memcpy(a, S, wiener_win * sizeof(*a));
661 662 663 664
  }
}

// Fix vector a, update vector b
665 666
static void update_b_sep_sym(int wiener_win, double **Mc, double **Hc,
                             double *a, double *b) {
667
  int i, j;
668
  double S[WIENER_WIN];
669
  double A[WIENER_HALFWIN1], B[WIENER_HALFWIN1 * WIENER_HALFWIN1];
670 671
  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin1 = (wiener_win >> 1) + 1;
672 673
  memset(A, 0, sizeof(A));
  memset(B, 0, sizeof(B));
674 675 676
  for (i = 0; i < wiener_win; i++) {
    const int ii = wrap_index(i, wiener_win);
    for (j = 0; j < wiener_win; j++) A[ii] += Mc[i][j] * a[j];
677 678
  }

679 680 681 682
  for (i = 0; i < wiener_win; i++) {
    for (j = 0; j < wiener_win; j++) {
      const int ii = wrap_index(i, wiener_win);
      const int jj = wrap_index(j, wiener_win);
683
      int k, l;
684 685 686 687
      for (k = 0; k < wiener_win; ++k)
        for (l = 0; l < wiener_win; ++l)
          B[jj * wiener_halfwin1 + ii] +=
              Hc[i * wiener_win + j][k * wiener_win2 + l] * a[k] * a[l];
688 689
    }
  }
Aamir Anis's avatar
Aamir Anis committed
690
  // Normalization enforcement in the system of equations itself
691
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
692
    A[i] -=
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
        A[wiener_halfwin1 - 1] * 2 +
        B[i * wiener_halfwin1 + wiener_halfwin1 - 1] -
        2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 + (wiener_halfwin1 - 1)];
  for (i = 0; i < wiener_halfwin1 - 1; ++i)
    for (j = 0; j < wiener_halfwin1 - 1; ++j)
      B[i * wiener_halfwin1 + j] -=
          2 * (B[i * wiener_halfwin1 + (wiener_halfwin1 - 1)] +
               B[(wiener_halfwin1 - 1) * wiener_halfwin1 + j] -
               2 * B[(wiener_halfwin1 - 1) * wiener_halfwin1 +
                     (wiener_halfwin1 - 1)]);
  if (linsolve(wiener_halfwin1 - 1, B, wiener_halfwin1, A, S)) {
    S[wiener_halfwin1 - 1] = 1.0;
    for (i = wiener_halfwin1; i < wiener_win; ++i) {
      S[i] = S[wiener_win - 1 - i];
      S[wiener_halfwin1 - 1] -= 2 * S[i];
708
    }
709
    memcpy(b, S, wiener_win * sizeof(*b));
710 711 712
  }
}

713 714
static int wiener_decompose_sep_sym(int wiener_win, double *M, double *H,
                                    double *a, double *b) {
715 716 717 718
  static const int init_filt[WIENER_WIN] = {
    WIENER_FILT_TAP0_MIDV, WIENER_FILT_TAP1_MIDV, WIENER_FILT_TAP2_MIDV,
    WIENER_FILT_TAP3_MIDV, WIENER_FILT_TAP2_MIDV, WIENER_FILT_TAP1_MIDV,
    WIENER_FILT_TAP0_MIDV,
719
  };
720 721
  double *Hc[WIENER_WIN2];
  double *Mc[WIENER_WIN];
722 723 724 725 726
  int i, j, iter;
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
  for (i = 0; i < wiener_win; i++) {
    a[i] = b[i] = (double)init_filt[i + plane_off] / WIENER_FILT_STEP;
727
  }
728 729 730 731 732 733
  for (i = 0; i < wiener_win; i++) {
    Mc[i] = M + i * wiener_win;
    for (j = 0; j < wiener_win; j++) {
      Hc[i * wiener_win + j] =
          H + i * wiener_win * wiener_win2 + j * wiener_win;
    }
734
  }
735 736

  iter = 1;
737
  while (iter < NUM_WIENER_ITERS) {
738 739
    update_a_sep_sym(wiener_win, Mc, Hc, a, b);
    update_b_sep_sym(wiener_win, Mc, Hc, a, b);
740 741
    iter++;
  }
742
  return 1;
743 744
}

745
// Computes the function x'*H*x - x'*M for the learned 2D filter x, and compares
Aamir Anis's avatar
Aamir Anis committed
746 747
// against identity filters; Final score is defined as the difference between
// the function values
748 749
static double compute_score(int wiener_win, double *M, double *H,
                            InterpKernel vfilt, InterpKernel hfilt) {
750
  double ab[WIENER_WIN * WIENER_WIN];
Aamir Anis's avatar
Aamir Anis committed
751 752 753 754
  int i, k, l;
  double P = 0, Q = 0;
  double iP = 0, iQ = 0;
  double Score, iScore;
755
  double a[WIENER_WIN], b[WIENER_WIN];
756 757
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
  const int wiener_win2 = wiener_win * wiener_win;
758 759 760

  aom_clear_system_state();

761 762 763 764 765 766
  a[WIENER_HALFWIN] = b[WIENER_HALFWIN] = 1.0;
  for (i = 0; i < WIENER_HALFWIN; ++i) {
    a[i] = a[WIENER_WIN - i - 1] = (double)vfilt[i] / WIENER_FILT_STEP;
    b[i] = b[WIENER_WIN - i - 1] = (double)hfilt[i] / WIENER_FILT_STEP;
    a[WIENER_HALFWIN] -= 2 * a[i];
    b[WIENER_HALFWIN] -= 2 * b[i];
Aamir Anis's avatar
Aamir Anis committed
767
  }
768
  memset(ab, 0, sizeof(ab));
769 770 771
  for (k = 0; k < wiener_win; ++k) {
    for (l = 0; l < wiener_win; ++l)
      ab[k * wiener_win + l] = a[l + plane_off] * b[k + plane_off];
Aamir Anis's avatar
Aamir Anis committed
772
  }
773
  for (k = 0; k < wiener_win2; ++k) {
Aamir Anis's avatar
Aamir Anis committed
774
    P += ab[k] * M[k];
775 776
    for (l = 0; l < wiener_win2; ++l)
      Q += ab[k] * H[k * wiener_win2 + l] * ab[l];
Aamir Anis's avatar
Aamir Anis committed
777 778 779
  }
  Score = Q - 2 * P;

780 781
  iP = M[wiener_win2 >> 1];
  iQ = H[(wiener_win2 >> 1) * wiener_win2 + (wiener_win2 >> 1)];
Aamir Anis's avatar
Aamir Anis committed
782 783 784 785 786
  iScore = iQ - 2 * iP;

  return Score - iScore;
}

787
static void quantize_sym_filter(int wiener_win, double *f, InterpKernel fi) {
788
  int i;
789 790
  const int wiener_halfwin = (wiener_win >> 1);
  for (i = 0; i < wiener_halfwin; ++i) {
791
    fi[i] = RINT(f[i] * WIENER_FILT_STEP);
792 793
  }
  // Specialize for 7-tap filter
794 795 796 797 798 799 800 801 802
  if (wiener_win == WIENER_WIN) {
    fi[0] = CLIP(fi[0], WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP0_MAXV);
    fi[1] = CLIP(fi[1], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[2] = CLIP(fi[2], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
  } else {
    fi[2] = CLIP(fi[1], WIENER_FILT_TAP2_MINV, WIENER_FILT_TAP2_MAXV);
    fi[1] = CLIP(fi[0], WIENER_FILT_TAP1_MINV, WIENER_FILT_TAP1_MAXV);
    fi[0] = 0;
  }
803 804 805 806
  // Satisfy filter constraints
  fi[WIENER_WIN - 1] = fi[0];
  fi[WIENER_WIN - 2] = fi[1];
  fi[WIENER_WIN - 3] = fi[2];
807 808
  // The central element has an implicit +WIENER_FILT_STEP
  fi[3] = -2 * (fi[0] + fi[1] + fi[2]);
809 810
}

811
static int count_wiener_bits(int wiener_win, WienerInfo *wiener_info,
812 813
                             WienerInfo *ref_wiener_info) {
  int bits = 0;
814 815 816 817 818 819
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->vfilter[0] - WIENER_FILT_TAP0_MINV);
820 821 822 823 824 825 826 827 828 829
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->vfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->vfilter[2] - WIENER_FILT_TAP2_MINV);
830 831 832 833 834 835
  if (wiener_win == WIENER_WIN)
    bits += aom_count_primitive_refsubexpfin(
        WIENER_FILT_TAP0_MAXV - WIENER_FILT_TAP0_MINV + 1,
        WIENER_FILT_TAP0_SUBEXP_K,
        ref_wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV,
        wiener_info->hfilter[0] - WIENER_FILT_TAP0_MINV);
836 837 838 839 840 841 842 843 844 845 846 847 848
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP1_MAXV - WIENER_FILT_TAP1_MINV + 1,
      WIENER_FILT_TAP1_SUBEXP_K,
      ref_wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV,
      wiener_info->hfilter[1] - WIENER_FILT_TAP1_MINV);
  bits += aom_count_primitive_refsubexpfin(
      WIENER_FILT_TAP2_MAXV - WIENER_FILT_TAP2_MINV + 1,
      WIENER_FILT_TAP2_SUBEXP_K,
      ref_wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV,
      wiener_info->hfilter[2] - WIENER_FILT_TAP2_MINV);
  return bits;
}

849
#define USE_WIENER_REFINEMENT_SEARCH 1
850 851 852 853 854
static int64_t finer_tile_search_wiener(const RestSearchCtxt *rsc,
                                        const RestorationTileLimits *limits,
                                        const AV1PixelRect *tile,
                                        RestorationUnitInfo *rui,
                                        int wiener_win) {
855
  const int plane_off = (WIENER_WIN - wiener_win) >> 1;
856
  int64_t err = try_restoration_tile(rsc, limits, tile, rui);
857 858 859 860 861 862
#if USE_WIENER_REFINEMENT_SEARCH
  int64_t err2;
  int tap_min[] = { WIENER_FILT_TAP0_MINV, WIENER_FILT_TAP1_MINV,
                    WIENER_FILT_TAP2_MINV };
  int tap_max[] = { WIENER_FILT_TAP0_MAXV, WIENER_FILT_TAP1_MAXV,
                    WIENER_FILT_TAP2_MAXV };
863

864
  WienerInfo *plane_wiener = &rui->wiener_info;
865

866
  // printf("err  pre = %"PRId64"\n", err);
867
  const int start_step = 4;
868
  for (int s = start_step; s >= 1; s >>= 1) {
869
    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
870 871
      int skip = 0;
      do {
872 873 874 875
        if (plane_wiener->hfilter[p] - s >= tap_min[p]) {
          plane_wiener->hfilter[p] -= s;
          plane_wiener->hfilter[WIENER_WIN - p - 1] -= s;
          plane_wiener->hfilter[WIENER_HALFWIN] += 2 * s;
876
          err2 = try_restoration_tile(rsc, limits, tile, rui);
877
          if (err2 > err) {
878 879 880
            plane_wiener->hfilter[p] += s;
            plane_wiener->hfilter[WIENER_WIN - p - 1] += s;
            plane_wiener->hfilter[WIENER_HALFWIN] -= 2 * s;
881 882 883 884 885 886 887 888 889 890 891
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
892 893 894 895
        if (plane_wiener->hfilter[p] + s <= tap_max[p]) {
          plane_wiener->hfilter[p] += s;
          plane_wiener->hfilter[WIENER_WIN - p - 1] += s;
          plane_wiener->hfilter[WIENER_HALFWIN] -= 2 * s;
896
          err2 = try_restoration_tile(rsc, limits, tile, rui);
897
          if (err2 > err) {
898 899 900
            plane_wiener->hfilter[p] -= s;
            plane_wiener->hfilter[WIENER_WIN - p - 1] -= s;
            plane_wiener->hfilter[WIENER_HALFWIN] += 2 * s;
901 902 903 904 905 906 907 908
          } else {
            err = err2;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
909
    }
910
    for (int p = plane_off; p < WIENER_HALFWIN; ++p) {
911 912
      int skip = 0;
      do {
913 914 915 916
        if (plane_wiener->vfilter[p] - s >= tap_min[p]) {
          plane_wiener->vfilter[p] -= s;
          plane_wiener->vfilter[WIENER_WIN - p - 1] -= s;
          plane_wiener->vfilter[WIENER_HALFWIN] += 2 * s;
917
          err2 = try_restoration_tile(rsc, limits, tile, rui);
918
          if (err2 > err) {
919 920 921
            plane_wiener->vfilter[p] += s;
            plane_wiener->vfilter[WIENER_WIN - p - 1] += s;
            plane_wiener->vfilter[WIENER_HALFWIN] -= 2 * s;
922 923 924 925 926 927 928 929 930 931 932
          } else {
            err = err2;
            skip = 1;
            // At the highest step size continue moving in the same direction
            if (s == start_step) continue;
          }
        }
        break;
      } while (1);
      if (skip) break;
      do {
933 934 935 936
        if (plane_wiener->vfilter[p] + s <= tap_max[p]) {
          plane_wiener->vfilter[p