From 8177ecd1437d7cc0c7d6874d849c8053c31abfad Mon Sep 17 00:00:00 2001
From: fbossen <frank@bossentech.com>
Date: Fri, 31 Aug 2018 12:31:10 -0230
Subject: [PATCH] Add search for NEARESTMV mode (#518)

Move MV stack construction out of encoding function such that the
value of the nearest MV can be determined before encoding
---
 src/context.rs | 56 +++++++++++++++++++++++++++++---------------------
 src/encoder.rs | 20 ++++++++++++------
 src/predict.rs |  1 +
 src/rdo.rs     | 14 ++++++++-----
 4 files changed, 57 insertions(+), 34 deletions(-)

diff --git a/src/context.rs b/src/context.rs
index 74a09740..71db3a61 100755
--- a/src/context.rs
+++ b/src/context.rs
@@ -2099,11 +2099,14 @@ impl ContextWriter {
     }
   }
 
-  fn has_tr(&mut self, bo: &BlockOffset) -> bool {
+  fn has_tr(&mut self, bo: &BlockOffset, bsize: BlockSize, is_sec_rect: bool) -> bool {
     let sb_mi_size = BlockSize::MI_SIZE_WIDE[BLOCK_64X64 as usize]; /* Assume 64x64 for now */
     let mask_row = bo.y & LOCAL_BLOCK_MASK;
     let mask_col = bo.x & LOCAL_BLOCK_MASK;
-    let mut bs = cmp::max(self.bc.at(bo).n4_w, self.bc.at(bo).n4_h);
+    let target_n4_w = bsize.width_mi();
+    let target_n4_h = bsize.height_mi();
+
+    let mut bs = target_n4_w.max(target_n4_h);
 
     if bs > BlockSize::MI_SIZE_WIDE[BLOCK_64X64 as usize] {
       return false;
@@ -2127,20 +2130,20 @@ impl ContextWriter {
 
     /* The left hand of two vertical rectangles always has a top right (as the
      * block above will have been decoded) */
-    let blk = &self.bc.at(bo);
-    if (blk.n4_w < blk.n4_h) && !blk.is_sec_rect {
+    if (target_n4_w < target_n4_h) && !is_sec_rect {
       has_tr = true;
     }
 
     /* The bottom of two horizontal rectangles never has a top right (as the block
      * to the right won't have been decoded) */
-    if (blk.n4_w > blk.n4_h) && blk.is_sec_rect {
+    if (target_n4_w > target_n4_h) && is_sec_rect {
       has_tr = false;
     }
 
     /* The bottom left square of a Vertical A (in the old format) does
      * not have a top right as it is decoded before the right hand
      * rectangle of the partition */
+/*
     if blk.partition == PartitionType::PARTITION_VERT_A {
       if blk.n4_w == blk.n4_h {
         if (mask_row & bs) != 0 {
@@ -2148,6 +2151,7 @@ impl ContextWriter {
         }
       }
     }
+*/
 
     has_tr
   }
@@ -2197,9 +2201,9 @@ impl ContextWriter {
 
   fn scan_row_mbmi(&mut self, bo: &BlockOffset, row_offset: isize, max_row_offs: isize,
                    processed_rows: &mut isize, ref_frame: usize,
-                   mv_stack: &mut Vec<CandidateMV>, newmv_count: &mut usize) -> bool {
+                   mv_stack: &mut Vec<CandidateMV>, newmv_count: &mut usize, bsize: BlockSize) -> bool {
     let bc = &self.bc;
-    let target_n4_w = bc.at(bo).n4_w;
+    let target_n4_w = bsize.width_mi();
 
     let end_mi = cmp::min(cmp::min(target_n4_w, bc.cols - bo.x),
                           BlockSize::MI_SIZE_WIDE[BLOCK_64X64 as usize]);
@@ -2250,9 +2254,10 @@ impl ContextWriter {
 
   fn scan_col_mbmi(&mut self, bo: &BlockOffset, col_offset: isize, max_col_offs: isize,
                    processed_cols: &mut isize, ref_frame: usize,
-                   mv_stack: &mut Vec<CandidateMV>, newmv_count: &mut usize) -> bool {
+                   mv_stack: &mut Vec<CandidateMV>, newmv_count: &mut usize, bsize: BlockSize) -> bool {
     let bc = &self.bc;
-    let target_n4_h = bc.at(bo).n4_h;
+
+    let target_n4_h = bsize.height_mi();
 
     let end_mi = cmp::min(cmp::min(target_n4_h, bc.rows - bo.y),
                           BlockSize::MI_SIZE_HIGH[BLOCK_64X64 as usize]);
@@ -2317,14 +2322,18 @@ impl ContextWriter {
     }
   }
 
-  fn setup_mvref_list(&mut self, bo: &BlockOffset, ref_frame: usize, mv_stack: &mut Vec<CandidateMV>) -> usize {
+  fn setup_mvref_list(&mut self, bo: &BlockOffset, ref_frame: usize, mv_stack: &mut Vec<CandidateMV>,
+                      bsize: BlockSize, is_sec_rect: bool) -> usize {
     let (_rf, _rf_num) = self.get_mvref_ref_frames(INTRA_FRAME);
 
+    let target_n4_h = bsize.height_mi();
+    let target_n4_w = bsize.width_mi();
+
     let mut max_row_offs = 0 as isize;
-    let row_adj = (self.bc.at(bo).n4_h < BlockSize::MI_SIZE_HIGH[BLOCK_8X8 as usize]) && (bo.y & 0x01) != 0x0;
+    let row_adj = (target_n4_h < BlockSize::MI_SIZE_HIGH[BLOCK_8X8 as usize]) && (bo.y & 0x01) != 0x0;
 
     let mut max_col_offs = 0 as isize;
-    let col_adj = (self.bc.at(bo).n4_w < BlockSize::MI_SIZE_WIDE[BLOCK_8X8 as usize]) && (bo.x & 0x01) != 0x0;
+    let col_adj = (target_n4_w < BlockSize::MI_SIZE_WIDE[BLOCK_8X8 as usize]) && (bo.x & 0x01) != 0x0;
 
     let mut processed_rows = 0 as isize;
     let mut processed_cols = 0 as isize;
@@ -2336,7 +2345,7 @@ impl ContextWriter {
       max_row_offs = -2 * MVREF_ROW_COLS as isize + row_adj as isize;
 
       // limit max offset for small blocks
-      if self.bc.at(bo).n4_h < BlockSize::MI_SIZE_HIGH[BLOCK_8X8 as usize] {
+      if target_n4_h < BlockSize::MI_SIZE_HIGH[BLOCK_8X8 as usize] {
         max_row_offs = -2 * 2 + row_adj as isize;
       }
 
@@ -2348,7 +2357,7 @@ impl ContextWriter {
       max_col_offs = -2 * MVREF_ROW_COLS as isize + col_adj as isize;
 
       // limit max offset for small blocks
-      if self.bc.at(bo).n4_w < BlockSize::MI_SIZE_WIDE[BLOCK_8X8 as usize] {
+      if target_n4_w < BlockSize::MI_SIZE_WIDE[BLOCK_8X8 as usize] {
         max_col_offs = -2 * 2 + col_adj as isize;
       }
 
@@ -2361,17 +2370,16 @@ impl ContextWriter {
 
     if max_row_offs.abs() >= 1 {
       let found_match = self.scan_row_mbmi(bo, -1, max_row_offs, &mut processed_rows, ref_frame, mv_stack,
-                                           &mut newmv_count);
+                                           &mut newmv_count, bsize);
       row_match |= found_match;
     }
     if max_col_offs.abs() >= 1 {
       let found_match = self.scan_col_mbmi(bo, -1, max_col_offs, &mut processed_cols, ref_frame, mv_stack,
-                                           &mut newmv_count);
+                                           &mut newmv_count, bsize);
       col_match |= found_match;
     }
-    if self.has_tr(bo) {
-      let n4_w = self.bc.at(bo).n4_w;
-      let found_match = self.scan_blk_mbmi(&bo.with_offset(n4_w as isize, -1), ref_frame, mv_stack,
+    if self.has_tr(bo, bsize, is_sec_rect) {
+      let found_match = self.scan_blk_mbmi(&bo.with_offset(target_n4_w as isize, -1), ref_frame, mv_stack,
                                            &mut newmv_count);
       row_match |= found_match;
     }
@@ -2392,19 +2400,21 @@ impl ContextWriter {
 
       if row_offset.abs() <= max_row_offs.abs() && row_offset.abs() > processed_rows {
         let found_match = self.scan_row_mbmi(bo, row_offset, max_row_offs, &mut processed_rows, ref_frame, mv_stack,
-                                             &mut far_newmv_count);
+                                             &mut far_newmv_count, bsize);
         row_match |= found_match;
       }
 
       if col_offset.abs() <= max_col_offs.abs() && col_offset.abs() > processed_cols {
         let found_match = self.scan_col_mbmi(bo, col_offset, max_col_offs, &mut processed_cols, ref_frame, mv_stack,
-                                             &mut far_newmv_count);
+                                             &mut far_newmv_count, bsize);
         col_match |= found_match;
       }
     }
 
     let total_match = if row_match { 1 } else { 0 } + if col_match { 1 } else { 0 };
 
+    assert!(total_match >= nearest_match);
+
     let mode_context = match nearest_match {
                          0 =>  cmp::min(total_match, 1) + (total_match << REFMV_OFFSET) ,
                          1 =>  3 - cmp::min(newmv_count, 1) + ((2 + total_match) << REFMV_OFFSET) ,
@@ -2424,7 +2434,7 @@ impl ContextWriter {
   }
 
   pub fn find_mvrefs(&mut self, bo: &BlockOffset, ref_frame: usize,
-                     mv_stack: &mut Vec<CandidateMV>) -> usize {
+                     mv_stack: &mut Vec<CandidateMV>, bsize: BlockSize, is_sec_rect: bool) -> usize {
     if ref_frame < REF_FRAMES {
       if ref_frame != INTRA_FRAME {
         /* TODO: convert global mv to an mv here */
@@ -2439,7 +2449,7 @@ impl ContextWriter {
       /* TODO: Set the zeromv ref to 0 */
     }
 
-    let mode_context = self.setup_mvref_list(bo, ref_frame, mv_stack);
+    let mode_context = self.setup_mvref_list(bo, ref_frame, mv_stack, bsize, is_sec_rect);
     mode_context
   }
 
diff --git a/src/encoder.rs b/src/encoder.rs
index 2ee1b306..0a71de73 100644
--- a/src/encoder.rs
+++ b/src/encoder.rs
@@ -1291,7 +1291,8 @@ pub fn encode_block_b(seq: &Sequence, fi: &FrameInvariants, fs: &mut FrameState,
                  luma_mode: PredictionMode, chroma_mode: PredictionMode,
                  ref_frame: usize, mv: MotionVector,
                  bsize: BlockSize, bo: &BlockOffset, skip: bool, bit_depth: usize,
-                 cfl: CFLParams, tx_size: TxSize, tx_type: TxType) {
+                 cfl: CFLParams, tx_size: TxSize, tx_type: TxType,
+                 mode_context: usize, mv_stack: &Vec<CandidateMV>) {
     let is_inter = !luma_mode.is_intra();
     if is_inter { assert!(luma_mode == chroma_mode); };
     let sb_size = if seq.use_128x128_superblock {
@@ -1315,8 +1316,6 @@ pub fn encode_block_b(seq: &Sequence, fi: &FrameInvariants, fs: &mut FrameState,
             cw.bc.set_motion_vector(bo, bsize, mv);
             cw.write_ref_frames(w, bo);
 
-            let mut mv_stack = Vec::new();
-            let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack);
             //let mode_context = if bo.x == 0 && bo.y == 0 { 0 } else if bo.x ==0 || bo.y == 0 { 51 } else { 85 };
             // NOTE: Until rav1e supports other inter modes than GLOBALMV
             cw.write_inter_mode(w, luma_mode, mode_context);
@@ -1632,6 +1631,9 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
         let mut cdef_coded = cw.bc.cdef_coded;
         rd_cost = mode_decision.rd_cost + cost;
 
+        let mut mv_stack = Vec::new();
+        let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false);
+
         let (tx_size, tx_type) =
           rdo_tx_size_type(seq, fi, fs, cw, bsize, bo, mode_luma, ref_frame, mv, skip);
 
@@ -1639,7 +1641,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
                                    bsize, bo, skip);
         encode_block_b(seq, fi, fs, cw, if cdef_coded  {w_post_cdef} else {w_pre_cdef},
                        mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl,
-                       tx_size, tx_type);
+                       tx_size, tx_type, mode_context, &mv_stack);
 
         best_decision = mode_decision;
     }
@@ -1696,6 +1698,9 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
             let skip = best_decision.skip;
             let mut cdef_coded = cw.bc.cdef_coded;
 
+            let mut mv_stack = Vec::new();
+            let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false);
+
             let (tx_size, tx_type) =
                 rdo_tx_size_type(seq, fi, fs, cw, bsize, bo, mode_luma, ref_frame, mv, skip);
 
@@ -1703,7 +1708,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
                                        bsize, bo, skip);
             encode_block_b(seq, fi, fs, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef},
                           mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl,
-                          tx_size, tx_type);
+                          tx_size, tx_type, mode_context, &mv_stack);
         }
     }
 
@@ -1783,12 +1788,15 @@ fn encode_partition_topdown(seq: &Sequence, fi: &FrameInvariants, fs: &mut Frame
             let (tx_size, tx_type) =
                 rdo_tx_size_type(seq, fi, fs, cw, bsize, bo, mode_luma, ref_frame, mv, skip);
 
+            let mut mv_stack = Vec::new();
+            let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false);
+
             // FIXME: every final block that has gone through the RDO decision process is encoded twice
             cdef_coded = encode_block_a(seq, cw, if cdef_coded  {w_post_cdef} else {w_pre_cdef},
                          bsize, bo, skip);
             encode_block_b(seq, fi, fs, cw, if cdef_coded  {w_post_cdef} else {w_pre_cdef},
                           mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl,
-                          tx_size, tx_type);
+                          tx_size, tx_type, mode_context, &mv_stack);
         },
         PartitionType::PARTITION_SPLIT => {
             if rdo_output.part_modes.len() >= 4 {
diff --git a/src/predict.rs b/src/predict.rs
index 61694f78..1471271a 100755
--- a/src/predict.rs
+++ b/src/predict.rs
@@ -43,6 +43,7 @@ pub static RAV1E_INTRA_MODES_MINIMAL: &'static [PredictionMode] = &[
 
 pub static RAV1E_INTER_MODES: &'static [PredictionMode] = &[
   PredictionMode::GLOBALMV,
+  PredictionMode::NEARESTMV,
   PredictionMode::NEWMV,
 ];
 
diff --git a/src/rdo.rs b/src/rdo.rs
index 33afe77c..e109722f 100755
--- a/src/rdo.rs
+++ b/src/rdo.rs
@@ -264,6 +264,9 @@ pub fn rdo_mode_decision(
   }
   mode_set.extend_from_slice(intra_mode_set);
 
+  let mut mv_stack = Vec::new();
+  let mode_context = cw.find_mvrefs(bo, LAST_FRAME, &mut mv_stack, bsize, false);
+
   for &luma_mode in &mode_set {
     assert!(fi.frame_type == FrameType::INTER || luma_mode.is_intra());
 
@@ -278,10 +281,10 @@ pub fn rdo_mode_decision(
     }
 
     let ref_frame = if luma_mode.is_intra() { INTRA_FRAME } else { LAST_FRAME };
-    let mv = if luma_mode != PredictionMode::NEWMV {
-      MotionVector { row: 0, col: 0 }
-    } else {
-      motion_estimation(fi, fs, bsize, bo, ref_frame)
+    let mv = match luma_mode {
+      PredictionMode::NEWMV => motion_estimation(fi, fs, bsize, bo, ref_frame),
+      PredictionMode::NEARESTMV => if mv_stack.len() > 0 { mv_stack[0].this_mv } else { MotionVector { row: 0, col: 0 } },
+      _ => MotionVector { row: 0, col: 0 }
     };
 
     let (tx_size, tx_type) =
@@ -310,7 +313,7 @@ pub fn rdo_mode_decision(
 
         encode_block_a(seq, cw, wr, bsize, bo, skip);
         encode_block_b(seq, fi, fs, cw, wr, luma_mode, chroma_mode,
-          ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl, tx_size, tx_type);
+          ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl, tx_size, tx_type, mode_context, &mv_stack);
 
         let cost = wr.tell_frac() - tell;
         let rd = compute_rd_cost(
@@ -341,6 +344,7 @@ pub fn rdo_mode_decision(
   }
 
   cw.bc.set_mode(bo, bsize, best_mode_luma);
+  cw.bc.set_motion_vector(bo, bsize, best_mv);
 
   assert!(best_rd >= 0_f64);
 
-- 
GitLab