Commit bf3d4964 authored by Cheng Chen's avatar Cheng Chen

JNT_COMP: add SIMD and interface for high bit-depth

Add high bit-depth macro definitions:
highbd_jnt_sad
highbd_8(10/12)_jnt_sub_pixel_avg.

Add SIMD functions:
aom_highbd_jnt_comp_avg_pred_sse2
aom_highbd_jnt_comp_avg_upsampled_pred_sse2

This patch also solves the seg fault caused by low bit-depth and
high bit-depth paths

BUG=aomedia:967
BUG=aomedia:944

Change-Id: Iea69f114e81ca226a30d84a540ad846f1b94b8d6
parent a50f9f57
......@@ -926,6 +926,9 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
specialize "aom_highbd_sad${w}x${h}", qw/sse2/;
specialize "aom_highbd_sad${w}x${h}_avg", qw/sse2/;
}
if (aom_config("CONFIG_JNT_COMP") eq "yes") {
add_proto qw/unsigned int/, "aom_highbd_jnt_sad${w}x${h}_avg", "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred, const JNT_COMP_PARAMS* jcp_param";
}
}
specialize qw/aom_highbd_sad128x128 avx2/;
specialize qw/aom_highbd_sad128x64 avx2/;
......@@ -1166,6 +1169,11 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
specialize qw/aom_highbd_upsampled_pred sse2/;
add_proto qw/void aom_highbd_comp_avg_upsampled_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, int subsample_x_q3, int subsample_y_q3, const uint8_t *ref8, int ref_stride, int bd";
specialize qw/aom_highbd_comp_avg_upsampled_pred sse2/;
if (aom_config("CONFIG_JNT_COMP") eq "yes") {
add_proto qw/void aom_highbd_jnt_comp_avg_upsampled_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, int subsample_x_q3, int subsample_y_q3, const uint8_t *ref8, int ref_stride, int bd, const JNT_COMP_PARAMS *jcp_param";
specialize qw/aom_highbd_jnt_comp_avg_upsampled_pred sse2/;
}
}
#
......@@ -1324,6 +1332,10 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
specialize "aom_highbd_${bd}_sub_pixel_variance${w}x${h}", "sse4_1";
specialize "aom_highbd_${bd}_sub_pixel_avg_variance${w}x${h}", "sse4_1";
}
if (aom_config("CONFIG_JNT_COMP") eq "yes") {
add_proto qw/uint32_t/, "aom_highbd_${bd}_jnt_sub_pixel_avg_variance${w}x${h}", "const uint8_t *src_ptr, int source_stride, int xoffset, int yoffset, const uint8_t *ref_ptr, int ref_stride, uint32_t *sse, const uint8_t *second_pred, const JNT_COMP_PARAMS* jcp_param";
}
}
}
} # CONFIG_HIGHBITDEPTH
......@@ -1564,6 +1576,11 @@ if (aom_config("CONFIG_AV1_ENCODER") eq "yes") {
add_proto qw/void aom_highbd_comp_avg_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride";
if (aom_config("CONFIG_JNT_COMP") eq "yes") {
add_proto qw/void aom_highbd_jnt_comp_avg_pred/, "uint16_t *comp_pred, const uint8_t *pred8, int width, int height, const uint8_t *ref8, int ref_stride, const JNT_COMP_PARAMS *jcp_param";
specialize qw/aom_highbd_jnt_comp_avg_pred sse2/;
}
#
# Subpixel Variance
#
......
......@@ -247,6 +247,29 @@ static INLINE unsigned int highbd_sadb(const uint8_t *a8, int a_stride,
return sad;
}
#if CONFIG_JNT_COMP
#define highbd_sadMxN(m, n) \
unsigned int aom_highbd_sad##m##x##n##_c(const uint8_t *src, int src_stride, \
const uint8_t *ref, \
int ref_stride) { \
return highbd_sad(src, src_stride, ref, ref_stride, m, n); \
} \
unsigned int aom_highbd_sad##m##x##n##_avg_c( \
const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
const uint8_t *second_pred) { \
uint16_t comp_pred[m * n]; \
aom_highbd_comp_avg_pred(comp_pred, second_pred, m, n, ref, ref_stride); \
return highbd_sadb(src, src_stride, comp_pred, m, m, n); \
} \
unsigned int aom_highbd_jnt_sad##m##x##n##_avg_c( \
const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
const uint8_t *second_pred, const JNT_COMP_PARAMS *jcp_param) { \
uint16_t comp_pred[m * n]; \
aom_highbd_jnt_comp_avg_pred(comp_pred, second_pred, m, n, ref, \
ref_stride, jcp_param); \
return highbd_sadb(src, src_stride, comp_pred, m, m, n); \
}
#else
#define highbd_sadMxN(m, n) \
unsigned int aom_highbd_sad##m##x##n##_c(const uint8_t *src, int src_stride, \
const uint8_t *ref, \
......@@ -260,6 +283,7 @@ static INLINE unsigned int highbd_sadb(const uint8_t *a8, int a_stride,
aom_highbd_comp_avg_pred_c(comp_pred, second_pred, m, n, ref, ref_stride); \
return highbd_sadb(src, src_stride, comp_pred, m, m, n); \
}
#endif // CONFIG_JNT_COMP
#define highbd_sadMxNxK(m, n, k) \
void aom_highbd_sad##m##x##n##x##k##_c( \
......
This diff is collapsed.
......@@ -15,6 +15,8 @@
#include "./aom_config.h"
#include "./aom_dsp_rtcd.h"
#include "aom_dsp/x86/synonyms.h"
#include "aom_ports/mem.h"
#include "./av1_rtcd.h"
......@@ -707,3 +709,98 @@ void aom_highbd_comp_avg_upsampled_pred_sse2(uint16_t *comp_pred,
pred += 8;
}
}
#if CONFIG_JNT_COMP
static void highbd_compute_jnt_comp_avg(__m128i *p0, __m128i *p1,
const __m128i *w0, const __m128i *w1,
const __m128i *r, void *const result) {
__m128i mult0 = _mm_mullo_epi16(*p0, *w0);
__m128i mult1 = _mm_mullo_epi16(*p1, *w1);
__m128i sum = _mm_add_epi16(mult0, mult1);
__m128i round = _mm_add_epi16(sum, *r);
__m128i shift = _mm_srai_epi16(round, DIST_PRECISION_BITS);
xx_storeu_128(result, shift);
}
void aom_highbd_jnt_comp_avg_pred_sse2(uint16_t *comp_pred,
const uint8_t *pred8, int width,
int height, const uint8_t *ref8,
int ref_stride,
const JNT_COMP_PARAMS *jcp_param) {
int i;
const uint16_t wt0 = (uint16_t)jcp_param->fwd_offset;
const uint16_t wt1 = (uint16_t)jcp_param->bck_offset;
const __m128i w0 = _mm_set_epi16(wt0, wt0, wt0, wt0, wt0, wt0, wt0, wt0);
const __m128i w1 = _mm_set_epi16(wt1, wt1, wt1, wt1, wt1, wt1, wt1, wt1);
const uint16_t round = ((1 << DIST_PRECISION_BITS) >> 1);
const __m128i r =
_mm_set_epi16(round, round, round, round, round, round, round, round);
uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
if (width >= 8) {
// Read 8 pixels one row at a time
assert(!(width & 7));
for (i = 0; i < height; ++i) {
int j;
for (j = 0; j < width; j += 8) {
__m128i p0 = xx_loadu_128(ref);
__m128i p1 = xx_loadu_128(pred);
highbd_compute_jnt_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
comp_pred += 8;
pred += 8;
}
ref += ref_stride - width;
}
} else {
// Read 4 pixels two rows at a time
assert(!(width & 3));
for (i = 0; i < height; i += 2) {
__m128i p0_0 = xx_loadl_64(ref + 0 * ref_stride);
__m128i p0_1 = xx_loadl_64(ref + 1 * ref_stride);
__m128i p0 = _mm_unpacklo_epi64(p0_0, p0_1);
__m128i p1 = xx_loadu_128(pred);
highbd_compute_jnt_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
comp_pred += 8;
pred += 8;
ref += 2 * ref_stride;
}
}
}
void aom_highbd_jnt_comp_avg_upsampled_pred_sse2(
uint16_t *comp_pred, const uint8_t *pred8, int width, int height,
int subpel_x_q3, int subpel_y_q3, const uint8_t *ref8, int ref_stride,
int bd, const JNT_COMP_PARAMS *jcp_param) {
uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
int n;
int i;
aom_highbd_upsampled_pred(comp_pred, width, height, subpel_x_q3, subpel_y_q3,
ref8, ref_stride, bd);
assert(!(width * height & 7));
n = width * height >> 3;
const uint16_t wt0 = (uint16_t)jcp_param->fwd_offset;
const uint16_t wt1 = (uint16_t)jcp_param->bck_offset;
const __m128i w0 = _mm_set_epi16(wt0, wt0, wt0, wt0, wt0, wt0, wt0, wt0);
const __m128i w1 = _mm_set_epi16(wt1, wt1, wt1, wt1, wt1, wt1, wt1, wt1);
const uint16_t round = ((1 << DIST_PRECISION_BITS) >> 1);
const __m128i r =
_mm_set_epi16(round, round, round, round, round, round, round, round);
for (i = 0; i < n; i++) {
__m128i p0 = xx_loadu_128(comp_pred);
__m128i p1 = xx_loadu_128(pred);
highbd_compute_jnt_comp_avg(&p0, &p1, &w0, &w1, &r, comp_pred);
comp_pred += 8;
pred += 8;
}
}
#endif // CONFIG_JNT_COMP
This diff is collapsed.
......@@ -369,12 +369,19 @@ static unsigned int setup_center_error(
if (second_pred != NULL) {
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
DECLARE_ALIGNED(16, uint16_t, comp_pred16[MAX_SB_SQUARE]);
if (mask)
if (mask) {
aom_highbd_comp_mask_pred(comp_pred16, second_pred, w, h, y + offset,
y_stride, mask, mask_stride, invert_mask);
else
aom_highbd_comp_avg_pred(comp_pred16, second_pred, w, h, y + offset,
y_stride);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.use_jnt_comp_avg)
aom_highbd_jnt_comp_avg_pred(comp_pred16, second_pred, w, h,
y + offset, y_stride, &xd->jcp_param);
else
#endif // CONFIG_JNT_COMP
aom_highbd_comp_avg_pred(comp_pred16, second_pred, w, h, y + offset,
y_stride);
}
besterr =
vfp->vf(CONVERT_TO_BYTEPTR(comp_pred16), w, src, src_stride, sse1);
} else {
......@@ -685,14 +692,22 @@ static int upsampled_pref_error(const MACROBLOCKD *xd,
if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
DECLARE_ALIGNED(16, uint16_t, pred16[MAX_SB_SQUARE]);
if (second_pred != NULL) {
if (mask)
if (mask) {
aom_highbd_comp_mask_upsampled_pred(
pred16, second_pred, w, h, subpel_x_q3, subpel_y_q3, y, y_stride,
mask, mask_stride, invert_mask, xd->bd);
else
aom_highbd_comp_avg_upsampled_pred(pred16, second_pred, w, h,
subpel_x_q3, subpel_y_q3, y,
y_stride, xd->bd);
} else {
#if CONFIG_JNT_COMP
if (xd->jcp_param.use_jnt_comp_avg)
aom_highbd_jnt_comp_avg_upsampled_pred(
pred16, second_pred, w, h, subpel_x_q3, subpel_y_q3, y, y_stride,
xd->bd, &xd->jcp_param);
else
#endif // CONFIG_JNT_COMP
aom_highbd_comp_avg_upsampled_pred(pred16, second_pred, w, h,
subpel_x_q3, subpel_y_q3, y,
y_stride, xd->bd);
}
} else {
aom_highbd_upsampled_pred(pred16, w, h, subpel_x_q3, subpel_y_q3, y,
y_stride, xd->bd);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment