From a9b44d75e42528227f9f0c82d6e0e5c60774ffb5 Mon Sep 17 00:00:00 2001
From: Kyle Siefring <kylesiefring@gmail.com>
Date: Fri, 28 Sep 2018 13:53:09 -0400
Subject: [PATCH] Fix inter for 12-bit input. (#623)

* Fix inter for 12-bit input.

12-bit is a special case that rounds differently.

See 7.11.3.2. Rounding variable derivation process in the spec.
---
 src/partition.rs | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/src/partition.rs b/src/partition.rs
index 1243b9b5..249bcbb5 100644
--- a/src/partition.rs
+++ b/src/partition.rs
@@ -931,6 +931,12 @@ impl PredictionMode {
         let max_sample_val = ((1 << bit_depth) - 1) as i32;
         let y_filter_idx = if height <= 4 { 4 } else { 0 };
         let x_filter_idx = if width <= 4 { 4 } else { 0 };
+        let shifts = {
+          let shift_offset = if bit_depth == 12 { 2 } else { 0 };
+          (3 + shift_offset, 11 - shift_offset)
+        };
+        let round_shift =
+          |n, shift| -> i32 { (n + (1 << (shift - 1))) >> shift };
 
         match (col_frac, row_frac) {
           (0, 0) => {
@@ -962,7 +968,7 @@ impl PredictionMode {
                     * SUBPEL_FILTERS[y_filter_idx][row_frac as usize][k];
                 }
                 let output_index = r * stride + c;
-                let val = ((sum + 64) >> 7).max(0).min(max_sample_val);
+                let val = round_shift(sum, 7).max(0).min(max_sample_val);
                 slice[output_index] = val as u16;
               }
             }
@@ -983,7 +989,9 @@ impl PredictionMode {
                 }
                 let output_index = r * stride + c;
                 let val =
-                  ((((sum + 4) >> 3) + 8) >> 4).max(0).min(max_sample_val);
+                  round_shift(round_shift(sum, shifts.0), shifts.1 - 7)
+                    .max(0)
+                    .min(max_sample_val);
                 slice[output_index] = val as u16;
               }
             }
@@ -1005,7 +1013,7 @@ impl PredictionMode {
                     sum += s[r * ref_stride + (c + k)] as i32 * SUBPEL_FILTERS
                       [x_filter_idx][col_frac as usize][k];
                   }
-                  let val = (sum + 4) >> 3;
+                  let val = round_shift(sum, shifts.0);
                   intermediate[8 * r + (c - cg)] = val as i16;
                 }
               }
@@ -1018,7 +1026,8 @@ impl PredictionMode {
                       * SUBPEL_FILTERS[y_filter_idx][row_frac as usize][k];
                   }
                   let output_index = r * stride + c;
-                  let val = ((sum + 1024) >> 11).max(0).min(max_sample_val);
+                  let val =
+                    round_shift(sum, shifts.1).max(0).min(max_sample_val);
                   slice[output_index] = val as u16;
                 }
               }
-- 
GitLab