From a9b44d75e42528227f9f0c82d6e0e5c60774ffb5 Mon Sep 17 00:00:00 2001 From: Kyle Siefring <kylesiefring@gmail.com> Date: Fri, 28 Sep 2018 13:53:09 -0400 Subject: [PATCH] Fix inter for 12-bit input. (#623) * Fix inter for 12-bit input. 12-bit is a special case that rounds differently. See 7.11.3.2. Rounding variable derivation process in the spec. --- src/partition.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/partition.rs b/src/partition.rs index 1243b9b5..249bcbb5 100644 --- a/src/partition.rs +++ b/src/partition.rs @@ -931,6 +931,12 @@ impl PredictionMode { let max_sample_val = ((1 << bit_depth) - 1) as i32; let y_filter_idx = if height <= 4 { 4 } else { 0 }; let x_filter_idx = if width <= 4 { 4 } else { 0 }; + let shifts = { + let shift_offset = if bit_depth == 12 { 2 } else { 0 }; + (3 + shift_offset, 11 - shift_offset) + }; + let round_shift = + |n, shift| -> i32 { (n + (1 << (shift - 1))) >> shift }; match (col_frac, row_frac) { (0, 0) => { @@ -962,7 +968,7 @@ impl PredictionMode { * SUBPEL_FILTERS[y_filter_idx][row_frac as usize][k]; } let output_index = r * stride + c; - let val = ((sum + 64) >> 7).max(0).min(max_sample_val); + let val = round_shift(sum, 7).max(0).min(max_sample_val); slice[output_index] = val as u16; } } @@ -983,7 +989,9 @@ impl PredictionMode { } let output_index = r * stride + c; let val = - ((((sum + 4) >> 3) + 8) >> 4).max(0).min(max_sample_val); + round_shift(round_shift(sum, shifts.0), shifts.1 - 7) + .max(0) + .min(max_sample_val); slice[output_index] = val as u16; } } @@ -1005,7 +1013,7 @@ impl PredictionMode { sum += s[r * ref_stride + (c + k)] as i32 * SUBPEL_FILTERS [x_filter_idx][col_frac as usize][k]; } - let val = (sum + 4) >> 3; + let val = round_shift(sum, shifts.0); intermediate[8 * r + (c - cg)] = val as i16; } } @@ -1018,7 +1026,8 @@ impl PredictionMode { * SUBPEL_FILTERS[y_filter_idx][row_frac as usize][k]; } let output_index = r * stride + c; - let val = ((sum + 1024) >> 11).max(0).min(max_sample_val); + let val = + round_shift(sum, shifts.1).max(0).min(max_sample_val); slice[output_index] = val as u16; } } -- GitLab