restoration.c 57.6 KB
Newer Older
1
/*
Yaowu Xu's avatar
Yaowu Xu committed
2 3 4 5 6 7 8 9
 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
 *
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 11 12 13 14
 *
 */

#include <math.h>

Yaowu Xu's avatar
Yaowu Xu committed
15 16
#include "./aom_config.h"
#include "./aom_dsp_rtcd.h"
17
#include "./aom_scale_rtcd.h"
18 19
#include "av1/common/onyxc_int.h"
#include "av1/common/restoration.h"
Yaowu Xu's avatar
Yaowu Xu committed
20 21
#include "aom_dsp/aom_dsp_common.h"
#include "aom_mem/aom_mem.h"
22
#include "aom_ports/mem.h"
23

24 25 26
static int domaintxfmrf_vtable[DOMAINTXFMRF_ITERS][DOMAINTXFMRF_PARAMS][256];

static const int domaintxfmrf_params[DOMAINTXFMRF_PARAMS] = {
27
  32,  40,  48,  56,  64,  68,  72,  76,  80,  82,  84,  86,  88,
28 29 30 31 32 33
  90,  92,  94,  96,  97,  98,  99,  100, 101, 102, 103, 104, 105,
  106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
  119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 130, 132, 134,
  136, 138, 140, 142, 146, 150, 154, 158, 162, 166, 170, 174
};

34 35
const sgr_params_type sgr_params[SGRPROJ_PARAMS] = {
  // r1, eps1, r2, eps2
36 37
  { 2, 25, 1, 11 }, { 2, 35, 1, 12 }, { 2, 45, 1, 13 }, { 2, 55, 1, 14 },
  { 2, 65, 1, 15 }, { 3, 50, 2, 25 }, { 3, 60, 2, 35 }, { 3, 70, 2, 45 },
38 39
};

clang-format's avatar
clang-format committed
40 41
typedef void (*restore_func_type)(uint8_t *data8, int width, int height,
                                  int stride, RestorationInternal *rst,
42
                                  uint8_t *dst8, int dst_stride);
Yaowu Xu's avatar
Yaowu Xu committed
43
#if CONFIG_AOM_HIGHBITDEPTH
clang-format's avatar
clang-format committed
44 45
typedef void (*restore_func_highbd_type)(uint8_t *data8, int width, int height,
                                         int stride, RestorationInternal *rst,
46 47
                                         int bit_depth, uint8_t *dst8,
                                         int dst_stride);
Yaowu Xu's avatar
Yaowu Xu committed
48
#endif  // CONFIG_AOM_HIGHBITDEPTH
49

50 51 52 53 54 55 56 57
int av1_alloc_restoration_struct(RestorationInfo *rst_info, int width,
                                 int height) {
  const int ntiles = av1_get_rest_ntiles(width, height, NULL, NULL, NULL, NULL);
  rst_info->restoration_type = (RestorationType *)aom_realloc(
      rst_info->restoration_type, sizeof(*rst_info->restoration_type) * ntiles);
  rst_info->wiener_info = (WienerInfo *)aom_realloc(
      rst_info->wiener_info, sizeof(*rst_info->wiener_info) * ntiles);
  assert(rst_info->wiener_info != NULL);
58
  memset(rst_info->wiener_info, 0, sizeof(*rst_info->wiener_info) * ntiles);
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
  rst_info->sgrproj_info = (SgrprojInfo *)aom_realloc(
      rst_info->sgrproj_info, sizeof(*rst_info->sgrproj_info) * ntiles);
  assert(rst_info->sgrproj_info != NULL);
  rst_info->domaintxfmrf_info = (DomaintxfmrfInfo *)aom_realloc(
      rst_info->domaintxfmrf_info,
      sizeof(*rst_info->domaintxfmrf_info) * ntiles);
  assert(rst_info->domaintxfmrf_info != NULL);
  return ntiles;
}

void av1_free_restoration_struct(RestorationInfo *rst_info) {
  aom_free(rst_info->restoration_type);
  rst_info->restoration_type = NULL;
  aom_free(rst_info->wiener_info);
  rst_info->wiener_info = NULL;
  aom_free(rst_info->sgrproj_info);
  rst_info->sgrproj_info = NULL;
  aom_free(rst_info->domaintxfmrf_info);
  rst_info->domaintxfmrf_info = NULL;
}

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
static void GenDomainTxfmRFVtable() {
  int i, j;
  const double sigma_s = sqrt(2.0);
  for (i = 0; i < DOMAINTXFMRF_ITERS; ++i) {
    const int nm = (1 << (DOMAINTXFMRF_ITERS - i - 1));
    const double A = exp(-DOMAINTXFMRF_MULT / (sigma_s * nm));
    for (j = 0; j < DOMAINTXFMRF_PARAMS; ++j) {
      const double sigma_r =
          (double)domaintxfmrf_params[j] / DOMAINTXFMRF_SIGMA_SCALE;
      const double scale = sigma_s / sigma_r;
      int k;
      for (k = 0; k < 256; ++k) {
        domaintxfmrf_vtable[i][j][k] =
            RINT(DOMAINTXFMRF_VTABLE_PREC * pow(A, 1.0 + k * scale));
      }
    }
  }
}

99
void av1_loop_restoration_precal() { GenDomainTxfmRFVtable(); }
100

101
static void loop_restoration_init(RestorationInternal *rst, int kf) {
102
  rst->keyframe = kf;
103 104
}

105
void extend_frame(uint8_t *data, int width, int height, int stride) {
106 107 108 109
  uint8_t *data_p;
  int i;
  for (i = 0; i < height; ++i) {
    data_p = data + i * stride;
110 111
    memset(data_p - WIENER_HALFWIN, data_p[0], WIENER_HALFWIN);
    memset(data_p + width, data_p[width - 1], WIENER_HALFWIN);
112
  }
113 114 115
  data_p = data - WIENER_HALFWIN;
  for (i = -WIENER_HALFWIN; i < 0; ++i) {
    memcpy(data_p + i * stride, data_p, width + 2 * WIENER_HALFWIN);
116
  }
117
  for (i = height; i < height + WIENER_HALFWIN; ++i) {
118
    memcpy(data_p + i * stride, data_p + (height - 1) * stride,
119
           width + 2 * WIENER_HALFWIN);
120 121 122
  }
}

123 124 125 126
static void loop_copy_tile(uint8_t *data, int tile_idx, int subtile_idx,
                           int subtile_bits, int width, int height, int stride,
                           RestorationInternal *rst, uint8_t *dst,
                           int dst_stride) {
127 128
  const int tile_width = rst->tile_width;
  const int tile_height = rst->tile_height;
129 130 131 132 133 134 135 136 137 138
  int i;
  int h_start, h_end, v_start, v_end;
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, rst->nhtiles,
                           rst->nvtiles, tile_width, tile_height, width, height,
                           0, 0, &h_start, &h_end, &v_start, &v_end);
  for (i = v_start; i < v_end; ++i)
    memcpy(dst + i * dst_stride + h_start, data + i * stride + h_start,
           h_end - h_start);
}

139 140
static void loop_wiener_filter_tile(uint8_t *data, int tile_idx, int width,
                                    int height, int stride,
141
                                    RestorationInternal *rst, uint8_t *dst,
142
                                    int dst_stride) {
143 144
  const int tile_width = rst->tile_width;
  const int tile_height = rst->tile_height;
145 146
  int i, j;
  int h_start, h_end, v_start, v_end;
147
  if (rst->rsi->restoration_type[tile_idx] == RESTORE_NONE) {
148 149 150 151
    loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
                   dst_stride);
    return;
  }
152
  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
153
                           tile_width, tile_height, width, height, 0, 0,
154
                           &h_start, &h_end, &v_start, &v_end);
155 156 157 158 159 160 161 162
  // Convolve the whole tile (done in blocks here to match the requirements
  // of the vectorized convolve functions, but the result is equivalent)
  for (i = v_start; i < v_end; i += MAX_SB_SIZE)
    for (j = h_start; j < h_end; j += MAX_SB_SIZE) {
      int w = AOMMIN(MAX_SB_SIZE, (h_end - j + 15) & ~15);
      int h = AOMMIN(MAX_SB_SIZE, (v_end - i + 15) & ~15);
      const uint8_t *data_p = data + i * stride + j;
      uint8_t *dst_p = dst + i * dst_stride + j;
163 164 165
      aom_convolve8_add_src(data_p, stride, dst_p, dst_stride,
                            rst->rsi->wiener_info[tile_idx].hfilter, 16,
                            rst->rsi->wiener_info[tile_idx].vfilter, 16, w, h);
166 167 168
    }
}

clang-format's avatar
clang-format committed
169
static void loop_wiener_filter(uint8_t *data, int width, int height, int stride,
170 171 172 173
                               RestorationInternal *rst, uint8_t *dst,
                               int dst_stride) {
  int tile_idx;
  extend_frame(data, width, height, stride);
174
  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
175 176
    loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst, dst,
                            dst_stride);
177
  }
178
}
179

180 181
/* Calculate windowed sums (if sqr=0) or sums of squares (if sqr=1)
   over the input. The window is of size (2r + 1)x(2r + 1), and we
182
   specialize to r = 1, 2, 3. A default function is used for r > 3.
183 184 185 186 187 188 189 190 191 192 193 194 195 196

   Each loop follows the same format: We keep a window's worth of input
   in individual variables and select data out of that as appropriate.
*/
static void boxsum1(int32_t *src, int width, int height, int src_stride,
                    int sqr, int32_t *dst, int dst_stride) {
  int i, j, a, b, c;

  // Vertical sum over 3-pixel regions, from src into dst.
  if (!sqr) {
    for (j = 0; j < width; ++j) {
      a = src[j];
      b = src[src_stride + j];
      c = src[2 * src_stride + j];
197

198 199 200 201 202 203 204 205 206 207 208 209 210 211
      dst[j] = a + b;
      for (i = 1; i < height - 2; ++i) {
        // Loop invariant: At the start of each iteration,
        // a = src[(i - 1) * src_stride + j]
        // b = src[(i    ) * src_stride + j]
        // c = src[(i + 1) * src_stride + j]
        dst[i * dst_stride + j] = a + b + c;
        a = b;
        b = c;
        c = src[(i + 2) * src_stride + j];
      }
      dst[i * dst_stride + j] = a + b + c;
      dst[(i + 1) * dst_stride + j] = b + c;
    }
212
  } else {
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
    for (j = 0; j < width; ++j) {
      a = src[j] * src[j];
      b = src[src_stride + j] * src[src_stride + j];
      c = src[2 * src_stride + j] * src[2 * src_stride + j];

      dst[j] = a + b;
      for (i = 1; i < height - 2; ++i) {
        dst[i * dst_stride + j] = a + b + c;
        a = b;
        b = c;
        c = src[(i + 2) * src_stride + j] * src[(i + 2) * src_stride + j];
      }
      dst[i * dst_stride + j] = a + b + c;
      dst[(i + 1) * dst_stride + j] = b + c;
    }
  }

  // Horizontal sum over 3-pixel regions of dst
  for (i = 0; i < height; ++i) {
    a = dst[i * dst_stride];
    b = dst[i * dst_stride + 1];
    c = dst[i * dst_stride + 2];

    dst[i * dst_stride] = a + b;
    for (j = 1; j < width - 2; ++j) {
      // Loop invariant: At the start of each iteration,
      // a = src[i * src_stride + (j - 1)]
      // b = src[i * src_stride + (j    )]
      // c = src[i * src_stride + (j + 1)]
      dst[i * dst_stride + j] = a + b + c;
      a = b;
      b = c;
      c = dst[i * dst_stride + (j + 2)];
    }
    dst[i * dst_stride + j] = a + b + c;
    dst[i * dst_stride + (j + 1)] = b + c;
  }
}

static void boxsum2(int32_t *src, int width, int height, int src_stride,
                    int sqr, int32_t *dst, int dst_stride) {
  int i, j, a, b, c, d, e;

  // Vertical sum over 5-pixel regions, from src into dst.
  if (!sqr) {
    for (j = 0; j < width; ++j) {
      a = src[j];
      b = src[src_stride + j];
      c = src[2 * src_stride + j];
      d = src[3 * src_stride + j];
      e = src[4 * src_stride + j];

      dst[j] = a + b + c;
      dst[dst_stride + j] = a + b + c + d;
      for (i = 2; i < height - 3; ++i) {
        // Loop invariant: At the start of each iteration,
        // a = src[(i - 2) * src_stride + j]
        // b = src[(i - 1) * src_stride + j]
        // c = src[(i    ) * src_stride + j]
        // d = src[(i + 1) * src_stride + j]
        // e = src[(i + 2) * src_stride + j]
        dst[i * dst_stride + j] = a + b + c + d + e;
        a = b;
        b = c;
        c = d;
        d = e;
        e = src[(i + 3) * src_stride + j];
      }
      dst[i * dst_stride + j] = a + b + c + d + e;
      dst[(i + 1) * dst_stride + j] = b + c + d + e;
      dst[(i + 2) * dst_stride + j] = c + d + e;
    }
  } else {
    for (j = 0; j < width; ++j) {
      a = src[j] * src[j];
      b = src[src_stride + j] * src[src_stride + j];
      c = src[2 * src_stride + j] * src[2 * src_stride + j];
      d = src[3 * src_stride + j] * src[3 * src_stride + j];
      e = src[4 * src_stride + j] * src[4 * src_stride + j];

      dst[j] = a + b + c;
      dst[dst_stride + j] = a + b + c + d;
      for (i = 2; i < height - 3; ++i) {
        dst[i * dst_stride + j] = a + b + c + d + e;
        a = b;
        b = c;
        c = d;
        d = e;
        e = src[(i + 3) * src_stride + j] * src[(i + 3) * src_stride + j];
      }
      dst[i * dst_stride + j] = a + b + c + d + e;
      dst[(i + 1) * dst_stride + j] = b + c + d + e;
      dst[(i + 2) * dst_stride + j] = c + d + e;
    }
  }

  // Horizontal sum over 5-pixel regions of dst
  for (i = 0; i < height; ++i) {
    a = dst[i * dst_stride];
    b = dst[i * dst_stride + 1];
    c = dst[i * dst_stride + 2];
    d = dst[i * dst_stride + 3];
    e = dst[i * dst_stride + 4];

    dst[i * dst_stride] = a + b + c;
    dst[i * dst_stride + 1] = a + b + c + d;
    for (j = 2; j < width - 3; ++j) {
      // Loop invariant: At the start of each iteration,
      // a = src[i * src_stride + (j - 2)]
      // b = src[i * src_stride + (j - 1)]
      // c = src[i * src_stride + (j    )]
      // d = src[i * src_stride + (j + 1)]
      // e = src[i * src_stride + (j + 2)]
      dst[i * dst_stride + j] = a + b + c + d + e;
      a = b;
      b = c;
      c = d;
      d = e;
      e = dst[i * dst_stride + (j + 3)];
    }
    dst[i * dst_stride + j] = a + b + c + d + e;
    dst[i * dst_stride + (j + 1)] = b + c + d + e;
    dst[i * dst_stride + (j + 2)] = c + d + e;
  }
}

339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
static void boxsum3(int32_t *src, int width, int height, int src_stride,
                    int sqr, int32_t *dst, int dst_stride) {
  int i, j, a, b, c, d, e, f, g;

  // Vertical sum over 7-pixel regions, from src into dst.
  if (!sqr) {
    for (j = 0; j < width; ++j) {
      a = src[j];
      b = src[1 * src_stride + j];
      c = src[2 * src_stride + j];
      d = src[3 * src_stride + j];
      e = src[4 * src_stride + j];
      f = src[5 * src_stride + j];
      g = src[6 * src_stride + j];

      dst[j] = a + b + c + d;
      dst[dst_stride + j] = a + b + c + d + e;
      dst[2 * dst_stride + j] = a + b + c + d + e + f;
      for (i = 3; i < height - 4; ++i) {
        dst[i * dst_stride + j] = a + b + c + d + e + f + g;
        a = b;
        b = c;
        c = d;
        d = e;
        e = f;
        f = g;
        g = src[(i + 4) * src_stride + j];
      }
      dst[i * dst_stride + j] = a + b + c + d + e + f + g;
      dst[(i + 1) * dst_stride + j] = b + c + d + e + f + g;
      dst[(i + 2) * dst_stride + j] = c + d + e + f + g;
      dst[(i + 3) * dst_stride + j] = d + e + f + g;
    }
  } else {
    for (j = 0; j < width; ++j) {
      a = src[j] * src[j];
      b = src[1 * src_stride + j] * src[1 * src_stride + j];
      c = src[2 * src_stride + j] * src[2 * src_stride + j];
      d = src[3 * src_stride + j] * src[3 * src_stride + j];
      e = src[4 * src_stride + j] * src[4 * src_stride + j];
      f = src[5 * src_stride + j] * src[5 * src_stride + j];
      g = src[6 * src_stride + j] * src[6 * src_stride + j];

      dst[j] = a + b + c + d;
      dst[dst_stride + j] = a + b + c + d + e;
      dst[2 * dst_stride + j] = a + b + c + d + e + f;
      for (i = 3; i < height - 4; ++i) {
        dst[i * dst_stride + j] = a + b + c + d + e + f + g;
        a = b;
        b = c;
        c = d;
        d = e;
        e = f;
        f = g;
        g = src[(i + 4) * src_stride + j] * src[(i + 4) * src_stride + j];
      }
      dst[i * dst_stride + j] = a + b + c + d + e + f + g;
      dst[(i + 1) * dst_stride + j] = b + c + d + e + f + g;
      dst[(i + 2) * dst_stride + j] = c + d + e + f + g;
      dst[(i + 3) * dst_stride + j] = d + e + f + g;
    }
  }

  // Horizontal sum over 7-pixel regions of dst
  for (i = 0; i < height; ++i) {
    a = dst[i * dst_stride];
    b = dst[i * dst_stride + 1];
    c = dst[i * dst_stride + 2];
    d = dst[i * dst_stride + 3];
    e = dst[i * dst_stride + 4];
    f = dst[i * dst_stride + 5];
    g = dst[i * dst_stride + 6];

    dst[i * dst_stride] = a + b + c + d;
    dst[i * dst_stride + 1] = a + b + c + d + e;
    dst[i * dst_stride + 2] = a + b + c + d + e + f;
    for (j = 3; j < width - 4; ++j) {
      dst[i * dst_stride + j] = a + b + c + d + e + f + g;
      a = b;
      b = c;
      c = d;
      d = e;
      e = f;
      f = g;
      g = dst[i * dst_stride + (j + 4)];
    }
    dst[i * dst_stride + j] = a + b + c + d + e + f + g;
    dst[i * dst_stride + (j + 1)] = b + c + d + e + f + g;
    dst[i * dst_stride + (j + 2)] = c + d + e + f + g;
    dst[i * dst_stride + (j + 3)] = d + e + f + g;
  }
}

// Generic version for any r. To be removed after experiments are done.
static void boxsumr(int32_t *src, int width, int height, int src_stride, int r,
                    int sqr, int32_t *dst, int dst_stride) {
  int32_t *tmp = aom_malloc(width * height * sizeof(*tmp));
  int tmp_stride = width;
  int i, j;
  if (sqr) {
    for (j = 0; j < width; ++j) tmp[j] = src[j] * src[j];
    for (j = 0; j < width; ++j)
      for (i = 1; i < height; ++i)
        tmp[i * tmp_stride + j] =
            tmp[(i - 1) * tmp_stride + j] +
            src[i * src_stride + j] * src[i * src_stride + j];
  } else {
    memcpy(tmp, src, sizeof(*tmp) * width);
    for (j = 0; j < width; ++j)
      for (i = 1; i < height; ++i)
        tmp[i * tmp_stride + j] =
            tmp[(i - 1) * tmp_stride + j] + src[i * src_stride + j];
  }
  for (i = 0; i <= r; ++i)
    memcpy(&dst[i * dst_stride], &tmp[(i + r) * tmp_stride],
           sizeof(*tmp) * width);
  for (i = r + 1; i < height - r; ++i)
    for (j = 0; j < width; ++j)
      dst[i * dst_stride + j] =
          tmp[(i + r) * tmp_stride + j] - tmp[(i - r - 1) * tmp_stride + j];
  for (i = height - r; i < height; ++i)
    for (j = 0; j < width; ++j)
      dst[i * dst_stride + j] = tmp[(height - 1) * tmp_stride + j] -
                                tmp[(i - r - 1) * tmp_stride + j];

  for (i = 0; i < height; ++i) tmp[i * tmp_stride] = dst[i * dst_stride];
  for (i = 0; i < height; ++i)
    for (j = 1; j < width; ++j)
      tmp[i * tmp_stride + j] =
          tmp[i * tmp_stride + j - 1] + dst[i * src_stride + j];

  for (j = 0; j <= r; ++j)
    for (i = 0; i < height; ++i)
      dst[i * dst_stride + j] = tmp[i * tmp_stride + j + r];
  for (j = r + 1; j < width - r; ++j)
    for (i = 0; i < height; ++i)
      dst[i * dst_stride + j] =
          tmp[i * tmp_stride + j + r] - tmp[i * tmp_stride + j - r - 1];
  for (j = width - r; j < width; ++j)
    for (i = 0; i < height; ++i)
      dst[i * dst_stride + j] =
          tmp[i * tmp_stride + width - 1] - tmp[i * tmp_stride + j - r - 1];
  aom_free(tmp);
}

484 485 486 487 488 489
static void boxsum(int32_t *src, int width, int height, int src_stride, int r,
                   int sqr, int32_t *dst, int dst_stride) {
  if (r == 1)
    boxsum1(src, width, height, src_stride, sqr, dst, dst_stride);
  else if (r == 2)
    boxsum2(src, width, height, src_stride, sqr, dst, dst_stride);
490 491 492 493
  else if (r == 3)
    boxsum3(src, width, height, src_stride, sqr, dst, dst_stride);
  else
    boxsumr(src, width, height, src_stride, r, sqr, dst, dst_stride);
494 495 496 497
}

static void boxnum(int width, int height, int r, int8_t *num, int num_stride) {
  int i, j;
498 499 500
  for (i = 0; i <= r; ++i) {
    for (j = 0; j <= r; ++j) {
      num[i * num_stride + j] = (r + 1 + i) * (r + 1 + j);
501 502 503 504 505 506
      num[i * num_stride + (width - 1 - j)] = num[i * num_stride + j];
      num[(height - 1 - i) * num_stride + j] = num[i * num_stride + j];
      num[(height - 1 - i) * num_stride + (width - 1 - j)] =
          num[i * num_stride + j];
    }
  }
507 508
  for (j = 0; j <= r; ++j) {
    const int val = (2 * r + 1) * (r + 1 + j);
509 510 511 512 513
    for (i = r + 1; i < height - r; ++i) {
      num[i * num_stride + j] = val;
      num[i * num_stride + (width - 1 - j)] = val;
    }
  }
514 515
  for (i = 0; i <= r; ++i) {
    const int val = (2 * r + 1) * (r + 1 + i);
516 517 518 519 520 521 522
    for (j = r + 1; j < width - r; ++j) {
      num[i * num_stride + j] = val;
      num[(height - 1 - i) * num_stride + j] = val;
    }
  }
  for (i = r + 1; i < height - r; ++i) {
    for (j = r + 1; j < width - r; ++j) {
523
      num[i * num_stride + j] = (2 * r + 1) * (2 * r + 1);
524 525 526 527 528 529 530 531 532 533
    }
  }
}

void decode_xq(int *xqd, int *xq) {
  xq[0] = -xqd[0];
  xq[1] = (1 << SGRPROJ_PRJ_BITS) - xq[0] - xqd[1];
}

#define APPROXIMATE_SGR 1
534
void av1_selfguided_restoration(int32_t *dgd, int width, int height, int stride,
535 536 537
                                int bit_depth, int r, int eps,
                                int32_t *tmpbuf) {
  int32_t *A = tmpbuf;
538
  int32_t *B = A + RESTORATION_TILEPELS_MAX;
539 540 541 542
  int8_t num[RESTORATION_TILEPELS_MAX];
  int i, j;
  eps <<= 2 * (bit_depth - 8);

543 544 545 546 547
  // Don't filter tiles with dimensions < 5 on any axis
  if ((width < 5) || (height < 5)) return;

  boxsum(dgd, width, height, stride, r, 0, B, width);
  boxsum(dgd, width, height, stride, r, 1, A, width);
548
  boxnum(width, height, r, num, width);
549 550
  // The following loop is optimized assuming r <= 2. If we allow
  // r > 2, then the loop will need modifying.
551
  assert(r <= 3);
552 553 554 555
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int k = i * width + j;
      const int n = num[k];
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
      // Assuming that we only allow up to 12-bit depth and r <= 2,
      // we calculate p = n^2 * Var(n-pixel block of original image)
      // (where n = 2 * r + 1 <= 5).
      //
      // There is an inequality which gives a bound on the variance:
      // https://en.wikipedia.org/wiki/Popoviciu's_inequality_on_variances
      // In this case, since each pixel is in the range [0, 2^12),
      // the variance is at most 1/4 * (2^12)^2 = 2^22.
      // Then p <= 25^2 * 2^22 < 2^32, and also q <= p + 25^2 * 68 < 2^32.
      //
      // The point of all this is to guarantee that q < 2^32, so that
      // platforms with a 64-bit by 32-bit divide unit (eg, x86)
      // can do the division by q more efficiently.
      const uint32_t p = (uint32_t)((uint64_t)A[k] * n - (uint64_t)B[k] * B[k]);
      const uint32_t q = (uint32_t)(p + n * n * eps);
      assert((uint64_t)A[k] * n - (uint64_t)B[k] * B[k] < (25 * 25U << 22));
      A[k] = (int32_t)(((uint64_t)p << SGRPROJ_SGR_BITS) + (q >> 1)) / q;
573 574 575 576 577 578 579 580 581 582
      B[k] = ((SGRPROJ_SGR - A[k]) * B[k] + (n >> 1)) / n;
    }
  }
#if APPROXIMATE_SGR
  i = 0;
  j = 0;
  {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
583
    const int32_t a =
584
        3 * A[k] + 2 * A[k + 1] + 2 * A[k + width] + A[k + width + 1];
585
    const int32_t b =
586
        3 * B[k] + 2 * B[k + 1] + 2 * B[k + width] + B[k + width + 1];
587
    const int32_t v =
588 589 590 591 592 593 594 595 596
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  i = 0;
  j = width - 1;
  {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
597
    const int32_t a =
598
        3 * A[k] + 2 * A[k - 1] + 2 * A[k + width] + A[k + width - 1];
599
    const int32_t b =
600
        3 * B[k] + 2 * B[k - 1] + 2 * B[k + width] + B[k + width - 1];
601
    const int32_t v =
602 603 604 605 606 607 608 609 610
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  i = height - 1;
  j = 0;
  {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
611
    const int32_t a =
612
        3 * A[k] + 2 * A[k + 1] + 2 * A[k - width] + A[k - width + 1];
613
    const int32_t b =
614
        3 * B[k] + 2 * B[k + 1] + 2 * B[k - width] + B[k - width + 1];
615
    const int32_t v =
616 617 618 619 620 621 622 623 624
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  i = height - 1;
  j = width - 1;
  {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
625
    const int32_t a =
626
        3 * A[k] + 2 * A[k - 1] + 2 * A[k - width] + A[k - width - 1];
627
    const int32_t b =
628
        3 * B[k] + 2 * B[k - 1] + 2 * B[k - width] + B[k - width - 1];
629
    const int32_t v =
630 631 632 633 634 635 636 637
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  i = 0;
  for (j = 1; j < width - 1; ++j) {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
638
    const int32_t a = A[k] + 2 * (A[k - 1] + A[k + 1]) + A[k + width] +
639
                      A[k + width - 1] + A[k + width + 1];
640
    const int32_t b = B[k] + 2 * (B[k - 1] + B[k + 1]) + B[k + width] +
641
                      B[k + width - 1] + B[k + width + 1];
642
    const int32_t v =
643 644 645 646 647 648 649 650
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  i = height - 1;
  for (j = 1; j < width - 1; ++j) {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
651
    const int32_t a = A[k] + 2 * (A[k - 1] + A[k + 1]) + A[k - width] +
652
                      A[k - width - 1] + A[k - width + 1];
653
    const int32_t b = B[k] + 2 * (B[k - 1] + B[k + 1]) + B[k - width] +
654
                      B[k - width - 1] + B[k - width + 1];
655
    const int32_t v =
656 657 658 659 660 661 662 663
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  j = 0;
  for (i = 1; i < height - 1; ++i) {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
664
    const int32_t a = A[k] + 2 * (A[k - width] + A[k + width]) + A[k + 1] +
665
                      A[k - width + 1] + A[k + width + 1];
666
    const int32_t b = B[k] + 2 * (B[k - width] + B[k + width]) + B[k + 1] +
667
                      B[k - width + 1] + B[k + width + 1];
668
    const int32_t v =
669 670 671 672 673 674 675 676
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  j = width - 1;
  for (i = 1; i < height - 1; ++i) {
    const int k = i * width + j;
    const int l = i * stride + j;
    const int nb = 3;
677
    const int32_t a = A[k] + 2 * (A[k - width] + A[k + width]) + A[k - 1] +
678
                      A[k - width - 1] + A[k + width - 1];
679
    const int32_t b = B[k] + 2 * (B[k - width] + B[k + width]) + B[k - 1] +
680
                      B[k - width - 1] + B[k + width - 1];
681
    const int32_t v =
682 683 684 685 686 687 688 689
        (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
    dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
  }
  for (i = 1; i < height - 1; ++i) {
    for (j = 1; j < width - 1; ++j) {
      const int k = i * width + j;
      const int l = i * stride + j;
      const int nb = 5;
690
      const int32_t a =
691 692 693 694
          (A[k] + A[k - 1] + A[k + 1] + A[k - width] + A[k + width]) * 4 +
          (A[k - 1 - width] + A[k - 1 + width] + A[k + 1 - width] +
           A[k + 1 + width]) *
              3;
695
      const int32_t b =
696 697 698 699
          (B[k] + B[k - 1] + B[k + 1] + B[k - width] + B[k + width]) * 4 +
          (B[k - 1 - width] + B[k - 1 + width] + B[k + 1 - width] +
           B[k + 1 + width]) *
              3;
700
      const int32_t v =
701 702 703 704 705 706
          (((a * dgd[l] + b) << SGRPROJ_RST_BITS) + (1 << nb) / 2) >> nb;
      dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
    }
  }
#else
  if (r > 1) boxnum(width, height, r = 1, num, width);
707 708
  boxsum(A, width, height, width, r, 0, A, width);
  boxsum(B, width, height, width, r, 0, B, width);
709 710 711 712 713
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int k = i * width + j;
      const int l = i * stride + j;
      const int n = num[k];
714
      const int32_t v =
715 716 717 718 719 720 721
          (((A[k] * dgd[l] + B[k]) << SGRPROJ_RST_BITS) + (n >> 1)) / n;
      dgd[l] = ROUND_POWER_OF_TWO(v, SGRPROJ_SGR_BITS);
    }
  }
#endif  // APPROXIMATE_SGR
}

722
static void apply_selfguided_restoration(uint8_t *dat, int width, int height,
723
                                         int stride, int bit_depth, int eps,
724
                                         int *xqd, uint8_t *dst, int dst_stride,
725
                                         int32_t *tmpbuf) {
726
  int xq[2];
727
  int32_t *flt1 = tmpbuf;
728
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
729
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
730
  int i, j;
731
  assert(width * height <= RESTORATION_TILEPELS_MAX);
732 733 734 735 736 737 738 739 740 741 742 743 744 745 746
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      flt1[i * width + j] = dat[i * stride + j];
      flt2[i * width + j] = dat[i * stride + j];
    }
  }
  av1_selfguided_restoration(flt1, width, height, width, bit_depth,
                             sgr_params[eps].r1, sgr_params[eps].e1, tmpbuf2);
  av1_selfguided_restoration(flt2, width, height, width, bit_depth,
                             sgr_params[eps].r2, sgr_params[eps].e2, tmpbuf2);
  decode_xq(xqd, xq);
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int k = i * width + j;
      const int l = i * stride + j;
747 748 749 750
      const int m = i * dst_stride + j;
      const int32_t u = ((int32_t)dat[l] << SGRPROJ_RST_BITS);
      const int32_t f1 = (int32_t)flt1[k] - u;
      const int32_t f2 = (int32_t)flt2[k] - u;
751 752 753
      const int64_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
      const int16_t w =
          (int16_t)ROUND_POWER_OF_TWO(v, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
754
      dst[m] = clip_pixel(w);
755 756 757 758 759 760
    }
  }
}

static void loop_sgrproj_filter_tile(uint8_t *data, int tile_idx, int width,
                                     int height, int stride,
761 762
                                     RestorationInternal *rst, uint8_t *dst,
                                     int dst_stride) {
763 764
  const int tile_width = rst->tile_width;
  const int tile_height = rst->tile_height;
765
  int h_start, h_end, v_start, v_end;
766
  uint8_t *data_p, *dst_p;
767

768
  if (rst->rsi->restoration_type[tile_idx] == RESTORE_NONE) {
769 770 771 772
    loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
                   dst_stride);
    return;
  }
773 774 775 776
  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
                           tile_width, tile_height, width, height, 0, 0,
                           &h_start, &h_end, &v_start, &v_end);
  data_p = data + h_start + v_start * stride;
777
  dst_p = dst + h_start + v_start * dst_stride;
778 779 780
  apply_selfguided_restoration(data_p, h_end - h_start, v_end - v_start, stride,
                               8, rst->rsi->sgrproj_info[tile_idx].ep,
                               rst->rsi->sgrproj_info[tile_idx].xqd, dst_p,
781
                               dst_stride, rst->tmpbuf);
782 783 784 785
}

static void loop_sgrproj_filter(uint8_t *data, int width, int height,
                                int stride, RestorationInternal *rst,
786
                                uint8_t *dst, int dst_stride) {
787 788
  int tile_idx;
  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
789 790
    loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, dst,
                             dst_stride);
791 792 793
  }
}

794 795 796 797 798 799 800
static void apply_domaintxfmrf(int iter, int param, uint8_t *diff_right,
                               uint8_t *diff_down, int width, int height,
                               int32_t *dat, int dat_stride) {
  int i, j, acc;
  // Do first row separately, to initialize the top to bottom filter
  i = 0;
  {
801
    // left to right
802 803 804 805 806 807 808 809 810 811
    acc = dat[i * dat_stride] * DOMAINTXFMRF_VTABLE_PREC;
    dat[i * dat_stride] = acc;
    for (j = 1; j < width; ++j) {
      const int in = dat[i * dat_stride + j];
      const int diff =
          diff_right[i * width + j - 1];  // Left absolute difference
      const int v = domaintxfmrf_vtable[iter][param][diff];
      acc = in * (DOMAINTXFMRF_VTABLE_PREC - v) +
            ROUND_POWER_OF_TWO(v * acc, DOMAINTXFMRF_VTABLE_PRECBITS);
      dat[i * dat_stride + j] = acc;
812 813
    }
    // right to left
814 815 816 817 818 819 820
    for (j = width - 2; j >= 0; --j) {
      const int in = dat[i * dat_stride + j];
      const int diff = diff_right[i * width + j];  // Right absolute difference
      const int v = domaintxfmrf_vtable[iter][param][diff];
      acc = ROUND_POWER_OF_TWO(in * (DOMAINTXFMRF_VTABLE_PREC - v) + acc * v,
                               DOMAINTXFMRF_VTABLE_PRECBITS);
      dat[i * dat_stride + j] = acc;
821 822 823
    }
  }

824 825 826 827 828 829 830 831 832 833 834 835
  for (i = 1; i < height; ++i) {
    // left to right
    acc = dat[i * dat_stride] * DOMAINTXFMRF_VTABLE_PREC;
    dat[i * dat_stride] = acc;
    for (j = 1; j < width; ++j) {
      const int in = dat[i * dat_stride + j];
      const int diff =
          diff_right[i * width + j - 1];  // Left absolute difference
      const int v = domaintxfmrf_vtable[iter][param][diff];
      acc = in * (DOMAINTXFMRF_VTABLE_PREC - v) +
            ROUND_POWER_OF_TWO(v * acc, DOMAINTXFMRF_VTABLE_PRECBITS);
      dat[i * dat_stride + j] = acc;
836
    }
837 838 839 840 841 842 843 844
    // right to left
    for (j = width - 2; j >= 0; --j) {
      const int in = dat[i * dat_stride + j];
      const int diff = diff_right[i * width + j];  // Right absolute difference
      const int v = domaintxfmrf_vtable[iter][param][diff];
      acc = ROUND_POWER_OF_TWO(in * (DOMAINTXFMRF_VTABLE_PREC - v) + acc * v,
                               DOMAINTXFMRF_VTABLE_PRECBITS);
      dat[i * dat_stride + j] = acc;
845
    }
846
    // top to bottom
847
    for (j = 0; j < width; ++j) {
848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872
      const int in = dat[i * dat_stride + j];
      const int in_above = dat[(i - 1) * dat_stride + j];
      const int diff =
          diff_down[(i - 1) * width + j];  // Upward absolute difference
      const int v = domaintxfmrf_vtable[iter][param][diff];
      acc =
          ROUND_POWER_OF_TWO(in * (DOMAINTXFMRF_VTABLE_PREC - v) + in_above * v,
                             DOMAINTXFMRF_VTABLE_PRECBITS);
      dat[i * dat_stride + j] = acc;
    }
  }
  for (j = 0; j < width; ++j) {
    // bottom to top + output rounding
    acc = dat[(height - 1) * dat_stride + j];
    dat[(height - 1) * dat_stride + j] =
        ROUND_POWER_OF_TWO(acc, DOMAINTXFMRF_VTABLE_PRECBITS);
    for (i = height - 2; i >= 0; --i) {
      const int in = dat[i * dat_stride + j];
      const int diff =
          diff_down[i * width + j];  // Downward absolute difference
      const int v = domaintxfmrf_vtable[iter][param][diff];
      acc = ROUND_POWER_OF_TWO(in * (DOMAINTXFMRF_VTABLE_PREC - v) + acc * v,
                               DOMAINTXFMRF_VTABLE_PRECBITS);
      dat[i * dat_stride + j] =
          ROUND_POWER_OF_TWO(acc, DOMAINTXFMRF_VTABLE_PRECBITS);
873 874 875 876 877
    }
  }
}

void av1_domaintxfmrf_restoration(uint8_t *dgd, int width, int height,
878
                                  int stride, int param, uint8_t *dst,
879 880
                                  int dst_stride, int32_t *tmpbuf) {
  int32_t *dat = tmpbuf;
881 882
  uint8_t *diff_right = (uint8_t *)(tmpbuf + RESTORATION_TILEPELS_MAX);
  uint8_t *diff_down = diff_right + RESTORATION_TILEPELS_MAX;
883
  int i, j, t;
884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900

  for (i = 0; i < height; ++i) {
    int cur_px = dgd[i * stride];
    for (j = 0; j < width - 1; ++j) {
      const int next_px = dgd[i * stride + j + 1];
      diff_right[i * width + j] = abs(cur_px - next_px);
      cur_px = next_px;
    }
  }
  for (j = 0; j < width; ++j) {
    int cur_px = dgd[j];
    for (i = 0; i < height - 1; ++i) {
      const int next_px = dgd[(i + 1) * stride + j];
      diff_down[i * width + j] = abs(cur_px - next_px);
      cur_px = next_px;
    }
  }
901 902 903 904 905
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      dat[i * width + j] = dgd[i * stride + j];
    }
  }
906

907
  for (t = 0; t < DOMAINTXFMRF_ITERS; ++t) {
908 909
    apply_domaintxfmrf(t, param, diff_right, diff_down, width, height, dat,
                       width);
910 911 912
  }
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
913
      dst[i * dst_stride + j] = clip_pixel(dat[i * width + j]);
914 915 916 917 918 919 920
    }
  }
}

static void loop_domaintxfmrf_filter_tile(uint8_t *data, int tile_idx,
                                          int width, int height, int stride,
                                          RestorationInternal *rst,
921
                                          uint8_t *dst, int dst_stride) {
922 923
  const int tile_width = rst->tile_width;
  const int tile_height = rst->tile_height;
924
  int h_start, h_end, v_start, v_end;
925
  int32_t *tmpbuf = (int32_t *)rst->tmpbuf;
926

927
  if (rst->rsi->restoration_type[tile_idx] == RESTORE_NONE) {
928 929 930 931
    loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
                   dst_stride);
    return;
  }
932 933 934
  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
                           tile_width, tile_height, width, height, 0, 0,
                           &h_start, &h_end, &v_start, &v_end);
935 936 937
  av1_domaintxfmrf_restoration(
      data + h_start + v_start * stride, h_end - h_start, v_end - v_start,
      stride, rst->rsi->domaintxfmrf_info[tile_idx].sigma_r,
938
      dst + h_start + v_start * dst_stride, dst_stride, tmpbuf);
939 940 941 942
}

static void loop_domaintxfmrf_filter(uint8_t *data, int width, int height,
                                     int stride, RestorationInternal *rst,
943
                                     uint8_t *dst, int dst_stride) {
944 945 946
  int tile_idx;
  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
    loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
947
                                  dst, dst_stride);
948 949 950
  }
}

951 952
static void loop_switchable_filter(uint8_t *data, int width, int height,
                                   int stride, RestorationInternal *rst,
953
                                   uint8_t *dst, int dst_stride) {
954 955
  int tile_idx;
  extend_frame(data, width, height, stride);
956
  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
957 958 959 960
    if (rst->rsi->restoration_type[tile_idx] == RESTORE_NONE) {
      loop_copy_tile(data, tile_idx, 0, 0, width, height, stride, rst, dst,
                     dst_stride);
    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
961 962
      loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst, dst,
                              dst_stride);
963
    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ) {
964 965
      loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst, dst,
                               dst_stride);
966 967
    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_DOMAINTXFMRF) {
      loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
968
                                    dst, dst_stride);
969
    }
970 971 972
  }
}

Yaowu Xu's avatar
Yaowu Xu committed
973
#if CONFIG_AOM_HIGHBITDEPTH
974
void extend_frame_highbd(uint16_t *data, int width, int height, int stride) {
975 976 977 978
  uint16_t *data_p;
  int i, j;
  for (i = 0; i < height; ++i) {
    data_p = data + i * stride;
979 980
    for (j = -WIENER_HALFWIN; j < 0; ++j) data_p[j] = data_p[0];
    for (j = width; j < width + WIENER_HALFWIN; ++j)
981 982
      data_p[j] = data_p[width - 1];
  }
983 984
  data_p = data - WIENER_HALFWIN;
  for (i = -WIENER_HALFWIN; i < 0; ++i) {
985
    memcpy(data_p + i * stride, data_p,
986
           (width + 2 * WIENER_HALFWIN) * sizeof(uint16_t));
987
  }
988
  for (i = height; i < height + WIENER_HALFWIN; ++i) {
989
    memcpy(data_p + i * stride, data_p + (height - 1) * stride,
990
           (width + 2 * WIENER_HALFWIN) * sizeof(uint16_t));
991 992 993
  }
}

994 995 996 997
static void loop_copy_tile_highbd(uint16_t *data, int tile_idx, int subtile_idx,
                                  int subtile_bits, int width, int height,
                                  int stride, RestorationInternal *rst,
                                  uint16_t *dst, int dst_stride) {
998 999
  const int tile_width = rst->tile_width;
  const int tile_height = rst->tile_height;
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
  int i;
  int h_start, h_end, v_start, v_end;
  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, rst->nhtiles,
                           rst->nvtiles, tile_width, tile_height, width, height,
                           0, 0, &h_start, &h_end, &v_start, &v_end);
  for (i = v_start; i < v_end; ++i)
    memcpy(dst + i * dst_stride + h_start, data + i * stride + h_start,
           (h_end - h_start) * sizeof(*dst));
}

1010 1011 1012
static void loop_wiener_filter_tile_highbd(uint16_t *data, int tile_idx,
                                           int width, int height, int stride,
                                           RestorationInternal *rst,
1013 1014
                                           int bit_depth, uint16_t *dst,
                                           int dst_stride) {
1015 1016
  const int tile_width = rst->tile_width;
  const int tile_height = rst->tile_height;
1017 1018 1019
  int h_start, h_end, v_start, v_end;
  int i, j;

1020
  if (rst->rsi->restoration_type[tile_idx] == RESTORE_NONE) {
1021 1022 1023 1024
    loop_copy_tile_highbd(data, tile_idx, 0, 0, width, height, stride, rst, dst,
                          dst_stride);
    return;
  }
1025
  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
1026
                           tile_width, tile_height, width, height, 0, 0,
1027
                           &h_start, &h_end, &v_start, &v_end);
1028 1029 1030 1031 1032 1033 1034 1035
  // Convolve the whole tile (done in blocks here to match the requirements
  // of the vectorized convolve functions, but the result is equivalent)
  for (i = v_start; i < v_end; i += MAX_SB_SIZE)
    for (j = h_start; j < h_end; j += MAX_SB_SIZE) {
      int w = AOMMIN(MAX_SB_SIZE, (h_end - j + 15) & ~15);
      int h = AOMMIN(MAX_SB_SIZE, (v_end - i + 15) & ~15);
      const uint16_t *data_p = data + i * stride + j;
      uint16_t *dst_p = dst + i * dst_stride + j;
1036 1037 1038 1039
      aom_highbd_convolve8_add_src(
          CONVERT_TO_BYTEPTR(data_p), stride, CONVERT_TO_BYTEPTR(dst_p),
          dst_stride, rst->rsi->wiener_info[tile_idx].hfilter, 16,
          rst->rsi->wiener_info[tile_idx].vfilter, 16, w, h, bit_depth);
1040 1041 1042
    }
}

1043 1044
static void loop_wiener_filter_highbd(uint8_t *data8, int width, int height,
                                      int stride, RestorationInternal *rst,
1045 1046
                                      int bit_depth, uint8_t *dst8,
                                      int dst_stride) {
1047
  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
1048
  uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
1049
  int tile_idx;
1050
  extend_frame_highbd(data, width, height, stride);
1051
  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
1052
    loop_wiener_filter_tile_highbd(data, tile_idx, width, height, stride, rst,
1053
                                   bit_depth, dst, dst_stride);
1054
  }
1055 1056
}

1057 1058 1059
static void apply_selfguided_restoration_highbd(
    uint16_t *dat, int width, int height, int stride, int bit_depth, int eps,
    int *xqd, uint16_t *dst, int dst_stride, int32_t *tmpbuf) {
1060
  int xq[2];
1061
  int32_t *flt1 = tmpbuf;
1062
  int32_t *flt2 = flt1 + RESTORATION_TILEPELS_MAX;
1063
  int32_t *tmpbuf2 = flt2 + RESTORATION_TILEPELS_MAX;
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      assert(i * width + j < RESTORATION_TILEPELS_MAX);
      flt1[i * width + j] = dat[i * stride + j];
      flt2[i * width + j] = dat[i * stride + j];
    }
  }
  av1_selfguided_restoration(flt1, width, height, width, bit_depth,
                             sgr_params[eps].r1, sgr_params[eps].e1, tmpbuf2);
  av1_selfguided_restoration(flt2, width, height, width, bit_depth,
                             sgr_params[eps].r2, sgr_params[eps].e2, tmpbuf2);
  decode_xq(xqd, xq);
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int k = i * width + j;
      const int l = i * stride + j;
      const int m = i * dst_stride + j;
      const int32_t u = ((int32_t)dat[l] << SGRPROJ_RST_BITS);
      const int32_t f1 = (int32_t)flt1[k] - u;
      const int32_t f2 = (int32_t)flt2[k] - u;
      const int64_t v = xq[</