From 1e639ba3ecfcf8df4b0081c601e833c609f26eda Mon Sep 17 00:00:00 2001
From: David Michael Barr <b@rr-dav.id.au>
Date: Fri, 24 Aug 2018 08:44:22 +0900
Subject: [PATCH] Implement Chroma-from-Luma (#492)

* Add the chroma-from-luma predictor and a test
* Add benchmarks for CfL predictor
* Implement CfL signalling
* Plumb from predict_intra to pred_cfl
* Add struct for CfL parameters
* Compute subsampled luma AC and plumb it with CfL alpha
* Plumb CfL parameters up to RDO
* Add CfL to RDO loop, no parameter search
---
 benches/bench.rs               |   8 ++-
 benches/comparative/mod.rs     |   4 +-
 benches/comparative/predict.rs |  27 ++++++++-
 benches/predict.rs             |  13 +++++
 src/context.rs                 | 103 ++++++++++++++++++++++++++++++++-
 src/encoder.rs                 |  85 ++++++++++++++++++++++-----
 src/partition.rs               |  36 +++++++++---
 src/predict.rs                 |  79 +++++++++++++++++++++++++
 src/rdo.rs                     |  10 +++-
 9 files changed, 337 insertions(+), 28 deletions(-)

diff --git a/benches/bench.rs b/benches/bench.rs
index e5bac35a..62c9af9d 100755
--- a/benches/bench.rs
+++ b/benches/bench.rs
@@ -69,6 +69,7 @@ fn write_b_bench(b: &mut Bencher, tx_size: TxSize, qindex: usize) {
 
   let sbx = 0;
   let sby = 0;
+  let ac = &[0i16; 32 * 32];
 
   b.iter(|| {
     for &mode in RAV1E_INTRA_MODES {
@@ -96,7 +97,9 @@ fn write_b_bench(b: &mut Bencher, tx_size: TxSize, qindex: usize) {
               tx_size.block_size(),
               &po,
               false,
-              8
+              8,
+              ac,
+              0
             );
           }
         }
@@ -113,7 +116,8 @@ benchmark_group!(
   predict::intra_paeth_4x4,
   predict::intra_smooth_4x4,
   predict::intra_smooth_h_4x4,
-  predict::intra_smooth_v_4x4
+  predict::intra_smooth_v_4x4,
+  predict::intra_cfl_4x4
 );
 
 #[cfg(feature = "comparative_bench")]
diff --git a/benches/comparative/mod.rs b/benches/comparative/mod.rs
index 691e4ddb..23d7ebfb 100755
--- a/benches/comparative/mod.rs
+++ b/benches/comparative/mod.rs
@@ -26,5 +26,7 @@ benchmark_group!(
   predict::intra_smooth_h_4x4_native,
   predict::intra_smooth_h_4x4_aom,
   predict::intra_smooth_v_4x4_native,
-  predict::intra_smooth_v_4x4_aom
+  predict::intra_smooth_v_4x4_aom,
+  predict::intra_cfl_4x4_native,
+  predict::intra_cfl_4x4_aom
 );
diff --git a/benches/comparative/predict.rs b/benches/comparative/predict.rs
index f2d76589..a07e03a3 100755
--- a/benches/comparative/predict.rs
+++ b/benches/comparative/predict.rs
@@ -11,7 +11,7 @@ use bencher::*;
 use comparative::libc;
 use predict as predict_native;
 use predict::*;
-use rand::{ChaChaRng, SeedableRng};
+use rand::{ChaChaRng, Rng, SeedableRng};
 
 extern {
   fn highbd_dc_predictor(
@@ -48,6 +48,11 @@ extern {
     dst: *mut u16, stride: libc::ptrdiff_t, bw: libc::c_int, bh: libc::c_int,
     above: *const u16, left: *const u16, bd: libc::c_int
   );
+
+  fn cfl_predict_hbd_c(
+    ac_buf_q3: *const i16, dst: *mut u16, stride: libc::ptrdiff_t,
+    alpha_q3: libc::c_int, bd: libc::c_int, bw: libc::c_int, bh: libc::c_int
+  );
 }
 
 fn predict_intra_4x4_aom(
@@ -122,3 +127,23 @@ pub fn intra_smooth_v_4x4_native(b: &mut Bencher) {
 pub fn intra_smooth_v_4x4_aom(b: &mut Bencher) {
   predict_intra_4x4_aom(b, highbd_smooth_v_predictor);
 }
+
+pub fn intra_cfl_4x4_native(b: &mut Bencher) {
+  predict_native::intra_cfl_4x4(b);
+}
+
+pub fn intra_cfl_4x4_aom(b: &mut Bencher) {
+  let mut rng = ChaChaRng::from_seed([0; 32]);
+  let (mut block, _above_context, _left_context) = generate_block(&mut rng);
+  let ac: Vec<i16> = (0..(32 * 32)).map(|_| rng.gen()).collect();
+  let alpha = -1 as i16;
+
+  b.iter(|| {
+    for _ in 0..MAX_ITER {
+      unsafe {
+        cfl_predict_hbd_c(ac.as_ptr(), block.as_mut_ptr(),
+          BLOCK_SIZE.width() as libc::ptrdiff_t, alpha as libc::c_int, 8, 4, 4);
+      }
+    }
+  })
+}
diff --git a/benches/predict.rs b/benches/predict.rs
index 74c7c95f..8e030911 100755
--- a/benches/predict.rs
+++ b/benches/predict.rs
@@ -100,3 +100,16 @@ pub fn intra_smooth_v_4x4(b: &mut Bencher) {
     }
   })
 }
+
+pub fn intra_cfl_4x4(b: &mut Bencher) {
+  let mut rng = ChaChaRng::from_seed([0; 32]);
+  let (mut block, _above, _left) = generate_block(&mut rng);
+  let ac: Vec<i16> = (0..(32 * 32)).map(|_| rng.gen()).collect();
+  let alpha = -1 as i16;
+
+  b.iter(|| {
+    for _ in 0..MAX_ITER {
+      Block4x4::pred_cfl(&mut block, BLOCK_SIZE.width(), &ac, alpha, 8);
+    }
+  })
+}
diff --git a/src/context.rs b/src/context.rs
index 061c8d25..b86ff5d4 100755
--- a/src/context.rs
+++ b/src/context.rs
@@ -53,6 +53,10 @@ const MAX_TX_SQUARE: usize = MAX_TX_SIZE * MAX_TX_SIZE;
 pub const INTRA_MODES: usize = 13;
 const UV_INTRA_MODES: usize = 14;
 
+const CFL_JOINT_SIGNS: usize = 8;
+const CFL_ALPHA_CONTEXTS: usize = 6;
+const CFL_ALPHABET_SIZE: usize = 16;
+
 const BLOCK_SIZE_GROUPS: usize = 4;
 const MAX_ANGLE_DELTA: usize = 3;
 const DIRECTIONAL_MODES: usize = 8;
@@ -782,7 +786,6 @@ pub struct NMVContext {
   comps: [NMVComponent; 2],
 }
 
-
 extern "C" {
   static default_partition_cdf:
     [[u16; EXT_PARTITION_TYPES + 1]; PARTITION_CONTEXTS];
@@ -790,6 +793,8 @@ extern "C" {
     [[[u16; INTRA_MODES + 1]; KF_MODE_CONTEXTS]; KF_MODE_CONTEXTS];
   static default_if_y_mode_cdf: [[u16; INTRA_MODES + 1]; BLOCK_SIZE_GROUPS];
   static default_uv_mode_cdf: [[[u16; UV_INTRA_MODES + 1]; INTRA_MODES]; 2];
+  static default_cfl_sign_cdf: [u16; CFL_JOINT_SIGNS + 1];
+  static default_cfl_alpha_cdf: [[u16; CFL_ALPHABET_SIZE + 1]; CFL_ALPHA_CONTEXTS];
   static default_newmv_cdf: [[u16; 2 + 1]; NEWMV_MODE_CONTEXTS];
   static default_zeromv_cdf: [[u16; 2 + 1]; GLOBALMV_MODE_CONTEXTS];
   static default_refmv_cdf: [[u16; 2 + 1]; REFMV_MODE_CONTEXTS];
@@ -855,6 +860,8 @@ pub struct CDFContext {
   kf_y_cdf: [[[u16; INTRA_MODES + 1]; KF_MODE_CONTEXTS]; KF_MODE_CONTEXTS],
   y_mode_cdf: [[u16; INTRA_MODES + 1]; BLOCK_SIZE_GROUPS],
   uv_mode_cdf: [[[u16; UV_INTRA_MODES + 1]; INTRA_MODES]; 2],
+  cfl_sign_cdf: [u16; CFL_JOINT_SIGNS + 1],
+  cfl_alpha_cdf: [[u16; CFL_ALPHABET_SIZE + 1]; CFL_ALPHA_CONTEXTS],
   newmv_cdf: [[u16; 2 + 1]; NEWMV_MODE_CONTEXTS],
   zeromv_cdf: [[u16; 2 + 1]; GLOBALMV_MODE_CONTEXTS],
   refmv_cdf: [[u16; 2 + 1]; REFMV_MODE_CONTEXTS],
@@ -904,6 +911,8 @@ impl CDFContext {
       kf_y_cdf: default_kf_y_mode_cdf,
       y_mode_cdf: default_if_y_mode_cdf,
       uv_mode_cdf: default_uv_mode_cdf,
+      cfl_sign_cdf: default_cfl_sign_cdf,
+      cfl_alpha_cdf: default_cfl_alpha_cdf,
       newmv_cdf: default_newmv_cdf,
       zeromv_cdf: default_zeromv_cdf,
       refmv_cdf: default_refmv_cdf,
@@ -950,6 +959,12 @@ impl CDFContext {
     let uv_mode_cdf_start =
       self.uv_mode_cdf.first().unwrap().as_ptr() as usize;
     let uv_mode_cdf_end = uv_mode_cdf_start + size_of_val(&self.uv_mode_cdf);
+    let cfl_sign_cdf_start = self.cfl_sign_cdf.as_ptr() as usize;
+    let cfl_sign_cdf_end = cfl_sign_cdf_start + size_of_val(&self.cfl_sign_cdf);
+    let cfl_alpha_cdf_start =
+      self.cfl_alpha_cdf.first().unwrap().as_ptr() as usize;
+    let cfl_alpha_cdf_end =
+      cfl_alpha_cdf_start + size_of_val(&self.cfl_alpha_cdf);
     let intra_tx_cdf_start =
       self.intra_tx_cdf.first().unwrap().as_ptr() as usize;
     let intra_tx_cdf_end =
@@ -1029,6 +1044,8 @@ impl CDFContext {
       ("kf_y_cdf", kf_y_cdf_start, kf_y_cdf_end),
       ("y_mode_cdf", y_mode_cdf_start, y_mode_cdf_end),
       ("uv_mode_cdf", uv_mode_cdf_start, uv_mode_cdf_end),
+      ("cfl_sign_cdf", cfl_sign_cdf_start, cfl_sign_cdf_end),
+      ("cfl_alpha_cdf", cfl_alpha_cdf_start, cfl_alpha_cdf_end),
       ("intra_tx_cdf", intra_tx_cdf_start, intra_tx_cdf_end),
       ("inter_tx_cdf", inter_tx_cdf_start, inter_tx_cdf_end),
       ("skip_cdfs", skip_cdfs_start, skip_cdfs_end),
@@ -1065,6 +1082,41 @@ mod test {
     let f = &cdf.partition_cdf[2];
     cdf_map.lookup(f.as_ptr() as usize);
   }
+
+  use super::CFLSign;
+  use super::CFLSign::*;
+
+  static cfl_alpha_signs: [[CFLSign; 2]; 8] = [
+    [ CFL_SIGN_ZERO, CFL_SIGN_NEG ],
+    [ CFL_SIGN_ZERO, CFL_SIGN_POS ],
+    [ CFL_SIGN_NEG, CFL_SIGN_ZERO ],
+    [ CFL_SIGN_NEG, CFL_SIGN_NEG ],
+    [ CFL_SIGN_NEG, CFL_SIGN_POS ],
+    [ CFL_SIGN_POS, CFL_SIGN_ZERO ],
+    [ CFL_SIGN_POS, CFL_SIGN_NEG ],
+    [ CFL_SIGN_POS, CFL_SIGN_POS ]
+  ];
+
+  static cfl_context: [[usize; 8]; 2] = [
+    [ 0, 0, 0, 1, 2, 3, 4, 5 ],
+    [ 0, 3, 0, 1, 4, 0, 2, 5 ]
+  ];
+
+  #[test]
+  fn cfl_joint_sign() {
+    use super::*;
+
+    let cfl = &mut CFLParams::new();
+    for (joint_sign, &signs) in cfl_alpha_signs.iter().enumerate() {
+      cfl.sign = signs;
+      assert!(cfl.joint_sign() as usize == joint_sign);
+      for uv in 0..2 {
+        if signs[uv] != CFL_SIGN_ZERO {
+          assert!(cfl.context(uv) == cfl_context[uv][joint_sign]);
+        }
+      }
+    }
+  }
 }
 
 const SUPERBLOCK_TO_PLANE_SHIFT: usize = MAX_SB_SIZE_LOG2;
@@ -1613,6 +1665,47 @@ impl BlockContext {
   }
 }
 
+#[derive(Copy, Clone, PartialEq)]
+pub enum CFLSign {
+  CFL_SIGN_ZERO = 0,
+  CFL_SIGN_NEG = 1,
+  CFL_SIGN_POS = 2
+}
+
+use context::CFLSign::*;
+const CFL_SIGNS: usize = 3;
+static cfl_sign_value: [i32; CFL_SIGNS] = [ 0, -1, 1 ];
+
+#[derive(Copy, Clone)]
+pub struct CFLParams {
+  sign: [CFLSign; 2],
+  scale: [u8; 2]
+}
+
+impl CFLParams {
+  pub fn new() -> CFLParams {
+    CFLParams {
+      sign: [CFL_SIGN_NEG, CFL_SIGN_ZERO],
+      scale: [1, 0]
+    }
+  }
+  pub fn joint_sign(&self) -> u32 {
+    assert!(self.sign[0] != CFL_SIGN_ZERO || self.sign[1] != CFL_SIGN_ZERO);
+    (self.sign[0] as u32) * (CFL_SIGNS as u32) + (self.sign[1] as u32) - 1
+  }
+  pub fn context(&self, uv: usize) -> usize {
+    assert!(self.sign[uv] != CFL_SIGN_ZERO);
+    (self.sign[uv] as usize - 1) * CFL_SIGNS + (self.sign[1 - uv] as usize)
+  }
+  pub fn index(&self, uv: usize) -> u32 {
+    assert!(self.sign[uv] != CFL_SIGN_ZERO && self.scale[uv] != 0);
+    (self.scale[uv] - 1) as u32
+  }
+  pub fn alpha(&self, uv: usize) -> i32 {
+    cfl_sign_value[self.sign[uv] as usize] * (self.scale[uv] as i32)
+  }
+}
+
 #[derive(Debug, Default)]
 struct FieldMap {
   map: Vec<(&'static str, usize, usize)>
@@ -1820,6 +1913,14 @@ impl ContextWriter {
       symbol_with_update!(self, w, uv_mode as u32, &mut cdf[..UV_INTRA_MODES]);
     }
   }
+  pub fn write_cfl_alphas(&mut self, w: &mut dyn Writer, cfl: &CFLParams) {
+    symbol_with_update!(self, w, cfl.joint_sign(), &mut self.fc.cfl_sign_cdf);
+    for uv in 0..2 {
+      if cfl.sign[uv] != CFL_SIGN_ZERO {
+        symbol_with_update!(self, w, cfl.index(uv), &mut self.fc.cfl_alpha_cdf[cfl.context(uv)]);
+      }
+    }
+  }
   pub fn write_angle_delta(&mut self, w: &mut dyn Writer, angle: i8, mode: PredictionMode) {
     symbol_with_update!(
       self,
diff --git a/src/encoder.rs b/src/encoder.rs
index f6c248a7..8c6345c8 100644
--- a/src/encoder.rs
+++ b/src/encoder.rs
@@ -1148,14 +1148,17 @@ fn diff(dst: &mut [i16], src1: &PlaneSlice<'_>, src2: &PlaneSlice<'_>, width: us
 // For a transform block,
 // predict, transform, quantize, write coefficients to a bitstream,
 // dequantize, inverse-transform.
-pub fn encode_tx_block(fi: &FrameInvariants, fs: &mut FrameState, cw: &mut ContextWriter, w: &mut dyn Writer,
-                  p: usize, bo: &BlockOffset, mode: PredictionMode, tx_size: TxSize, tx_type: TxType,
-                  plane_bsize: BlockSize, po: &PlaneOffset, skip: bool, bit_depth: usize) -> bool {
+pub fn encode_tx_block(
+  fi: &FrameInvariants, fs: &mut FrameState, cw: &mut ContextWriter,
+  w: &mut dyn Writer, p: usize, bo: &BlockOffset, mode: PredictionMode,
+  tx_size: TxSize, tx_type: TxType, plane_bsize: BlockSize, po: &PlaneOffset,
+  skip: bool, bit_depth: usize, ac: &[i16], alpha: i16
+) -> bool {
     let rec = &mut fs.rec.planes[p];
     let PlaneConfig { stride, xdec, ydec, .. } = fs.input.planes[p].cfg;
 
     if mode.is_intra() {
-      mode.predict_intra(&mut rec.mut_slice(po), tx_size, bit_depth);
+      mode.predict_intra(&mut rec.mut_slice(po), tx_size, bit_depth, &ac, alpha);
     }
 
     if skip { return false; }
@@ -1199,7 +1202,8 @@ pub fn encode_block_b(fi: &FrameInvariants, fs: &mut FrameState,
                  cw: &mut ContextWriter, w: &mut dyn Writer,
                  luma_mode: PredictionMode, chroma_mode: PredictionMode,
                  ref_frame: usize, mv: MotionVector,
-                 bsize: BlockSize, bo: &BlockOffset, skip: bool, bit_depth: usize) {
+                 bsize: BlockSize, bo: &BlockOffset, skip: bool, bit_depth: usize,
+                 cfl: &CFLParams) {
     let is_inter = !luma_mode.is_intra();
     if is_inter { assert!(luma_mode == chroma_mode); };
 
@@ -1264,6 +1268,10 @@ pub fn encode_block_b(fi: &FrameInvariants, fs: &mut FrameState,
 
     if has_chroma(bo, bsize, xdec, ydec) && !is_inter {
         cw.write_intra_uv_mode(w, chroma_mode, luma_mode, bsize);
+        if chroma_mode.is_cfl() {
+          assert!(bsize.cfl_allowed());
+          cw.write_cfl_alphas(w, cfl);
+        }
         if chroma_mode.is_directional() && bsize >= BlockSize::BLOCK_8X8 {
             cw.write_angle_delta(w, 0, chroma_mode);
         }
@@ -1344,18 +1352,52 @@ pub fn encode_block_b(fi: &FrameInvariants, fs: &mut FrameState,
       }
       write_tx_tree(fi, fs, cw, w, luma_mode, bo, bsize, tx_size, tx_type, skip, bit_depth); // i.e. var-tx if inter mode
     } else {
-      write_tx_blocks(fi, fs, cw, w, luma_mode, chroma_mode, bo, bsize, tx_size, tx_type, skip, bit_depth);
+      write_tx_blocks(fi, fs, cw, w, luma_mode, chroma_mode, bo, bsize, tx_size, tx_type, skip, bit_depth, cfl);
+    }
+}
+
+fn luma_ac(
+  ac: &mut [i16], fs: &mut FrameState, bo: &BlockOffset, bsize: BlockSize
+) {
+  let PlaneConfig { xdec, ydec, .. } = fs.input.planes[1].cfg;
+  let plane_bsize = get_plane_block_size(bsize, xdec, ydec);
+  let po = bo.plane_offset(&fs.input.planes[0].cfg);
+  let rec = &fs.rec.planes[0];
+  let luma = &rec.slice(&po);
+
+  let mut sum: i32 = 0;
+  for sub_y in 0..plane_bsize.height() {
+    for sub_x in 0..plane_bsize.width() {
+      let y = sub_y << ydec;
+      let x = sub_x << xdec;
+      let sample = ((luma.p(x, y)
+        + luma.p(x + 1, y)
+        + luma.p(x, y + 1)
+        + luma.p(x + 1, y + 1))
+        << 1) as i16;
+      ac[sub_y * 32 + sub_x] = sample;
+      sum += sample as i32;
+    }
+  }
+  let shift = plane_bsize.width_log2() + plane_bsize.height_log2();
+  let average = ((sum + (1 << (shift - 1))) >> shift) as i16;
+  for sub_y in 0..plane_bsize.height() {
+    for sub_x in 0..plane_bsize.width() {
+      ac[sub_y * 32 + sub_x] -= average;
     }
+  }
 }
 
 pub fn write_tx_blocks(fi: &FrameInvariants, fs: &mut FrameState,
                        cw: &mut ContextWriter, w: &mut dyn Writer,
                        luma_mode: PredictionMode, chroma_mode: PredictionMode, bo: &BlockOffset,
-                       bsize: BlockSize, tx_size: TxSize, tx_type: TxType, skip: bool, bit_depth: usize) {
+                       bsize: BlockSize, tx_size: TxSize, tx_type: TxType, skip: bool, bit_depth: usize,
+                       cfl: &CFLParams) {
     let bw = bsize.width_mi() / tx_size.width_mi();
     let bh = bsize.height_mi() / tx_size.height_mi();
 
     let PlaneConfig { xdec, ydec, .. } = fs.input.planes[1].cfg;
+    let ac = &mut [0i16; 32 * 32];
 
     fs.qc.update(fi.config.quantizer, tx_size, luma_mode.is_intra(), bit_depth);
 
@@ -1367,7 +1409,10 @@ pub fn write_tx_blocks(fi: &FrameInvariants, fs: &mut FrameState,
             };
 
             let po = tx_bo.plane_offset(&fs.input.planes[0].cfg);
-            encode_tx_block(fi, fs, cw, w, 0, &tx_bo, luma_mode, tx_size, tx_type, bsize, &po, skip, bit_depth);
+            encode_tx_block(
+              fi, fs, cw, w, 0, &tx_bo, luma_mode, tx_size, tx_type, bsize, &po,
+              skip, bit_depth, ac, 0,
+            );
         }
     }
 
@@ -1392,11 +1437,16 @@ pub fn write_tx_blocks(fi: &FrameInvariants, fs: &mut FrameState,
 
     let plane_bsize = get_plane_block_size(bsize, xdec, ydec);
 
+    if chroma_mode.is_cfl() {
+      luma_ac(ac, fs, bo, bsize);
+    }
+
     if bw_uv > 0 && bh_uv > 0 {
         let uv_tx_type = uv_intra_mode_to_tx_type_context(chroma_mode);
         fs.qc.update(fi.config.quantizer, uv_tx_size, true, bit_depth);
 
         for p in 1..3 {
+            let alpha = cfl.alpha(p - 1) as i16;
             for by in 0..bh_uv {
                 for bx in 0..bw_uv {
                     let tx_bo =
@@ -1412,7 +1462,7 @@ pub fn write_tx_blocks(fi: &FrameInvariants, fs: &mut FrameState,
                     po.y += by * uv_tx_size.height();
 
                     encode_tx_block(fi, fs, cw, w, p, &tx_bo, chroma_mode, uv_tx_size, uv_tx_type,
-                                    plane_bsize, &po, skip, bit_depth);
+                                    plane_bsize, &po, skip, bit_depth, ac, alpha);
                 }
             }
         }
@@ -1428,11 +1478,15 @@ pub fn write_tx_tree(fi: &FrameInvariants, fs: &mut FrameState, cw: &mut Context
     let bh = bsize.height_mi() / tx_size.height_mi();
 
     let PlaneConfig { xdec, ydec, .. } = fs.input.planes[1].cfg;
+    let ac = &[0i16; 32 * 32];
 
     fs.qc.update(fi.config.quantizer, tx_size, luma_mode.is_intra(), bit_depth);
 
     let po = bo.plane_offset(&fs.input.planes[0].cfg);
-    let has_coeff = encode_tx_block(fi, fs, cw, w, 0, &bo, luma_mode, tx_size, tx_type, bsize, &po, skip, bit_depth);
+    let has_coeff = encode_tx_block(
+      fi, fs, cw, w, 0, &bo, luma_mode, tx_size, tx_type, bsize, &po, skip,
+      bit_depth, ac, 0,
+    );
 
     // these are only valid for 4:2:0
     let uv_tx_size = match bsize {
@@ -1469,7 +1523,7 @@ pub fn write_tx_tree(fi: &FrameInvariants, fs: &mut FrameState, cw: &mut Context
             let po = bo.plane_offset(&fs.input.planes[p].cfg);
 
             encode_tx_block(fi, fs, cw, w, p, &tx_bo, luma_mode, uv_tx_size, uv_tx_type,
-                            plane_bsize, &po, skip, bit_depth);
+                            plane_bsize, &po, skip, bit_depth, ac, 0);
         }
     }
 }
@@ -1521,6 +1575,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
         }
         let mode_decision = rdo_mode_decision(seq, fi, fs, cw, bsize, bo).part_modes[0].clone();
         let (mode_luma, mode_chroma) = (mode_decision.pred_mode_luma, mode_decision.pred_mode_chroma);
+        let cfl = &CFLParams::new();
         let ref_frame = mode_decision.ref_frame;
         let mv = mode_decision.mv;
         let skip = mode_decision.skip;
@@ -1530,7 +1585,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
         cdef_coded = encode_block_a(seq, cw, if cdef_coded  {w_post_cdef} else {w_pre_cdef},
                                    bsize, bo, skip);
         encode_block_b(fi, fs, cw, if cdef_coded  {w_post_cdef} else {w_pre_cdef},
-                       mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth);
+                       mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl);
 
         best_decision = mode_decision;
     }
@@ -1575,6 +1630,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
 
             // FIXME: redundant block re-encode
             let (mode_luma, mode_chroma) = (best_decision.pred_mode_luma, best_decision.pred_mode_chroma);
+            let cfl = &CFLParams::new();
             let ref_frame = best_decision.ref_frame;
             let mv = best_decision.mv;
             let skip = best_decision.skip;
@@ -1582,7 +1638,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
             cdef_coded = encode_block_a(seq, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef},
                                        bsize, bo, skip);
             encode_block_b(fi, fs, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef},
-                          mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth);
+                          mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl);
         }
     }
 
@@ -1653,6 +1709,7 @@ fn encode_partition_topdown(seq: &Sequence, fi: &FrameInvariants, fs: &mut Frame
                 };
 
             let (mode_luma, mode_chroma) = (part_decision.pred_mode_luma, part_decision.pred_mode_chroma);
+            let cfl = &CFLParams::new();
             let skip = part_decision.skip;
             let ref_frame = part_decision.ref_frame;
             let mv = part_decision.mv;
@@ -1662,7 +1719,7 @@ fn encode_partition_topdown(seq: &Sequence, fi: &FrameInvariants, fs: &mut Frame
             cdef_coded = encode_block_a(seq, cw, if cdef_coded  {w_post_cdef} else {w_pre_cdef},
                          bsize, bo, skip);
             encode_block_b(fi, fs, cw, if cdef_coded  {w_post_cdef} else {w_pre_cdef},
-                          mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth);
+                          mode_luma, mode_chroma, ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl);
         },
         PartitionType::PARTITION_SPLIT => {
             if rdo_output.part_modes.len() >= 4 {
diff --git a/src/partition.rs b/src/partition.rs
index 27d174ce..211ec7b1 100755
--- a/src/partition.rs
+++ b/src/partition.rs
@@ -144,6 +144,10 @@ impl BlockSize {
   pub fn is_sqr(self) -> bool {
     self.width_log2() == self.height_log2()
   }
+
+  pub fn is_sub8x8(self) -> bool {
+    self.width_log2().min(self.height_log2()) < 3
+  }
 }
 
 /// Transform Size
@@ -334,6 +338,7 @@ pub enum PredictionMode {
   SMOOTH_V_PRED,
   SMOOTH_H_PRED,
   PAETH_PRED,
+  UV_CFL_PRED,
   NEARESTMV,
   NEARMV,
   GLOBALMV,
@@ -409,20 +414,30 @@ use plane::*;
 use predict::*;
 
 impl PredictionMode {
-  pub fn predict_intra<'a>(self, dst: &'a mut PlaneMutSlice<'a>, tx_size: TxSize, bit_depth: usize) {
+  pub fn predict_intra<'a>(
+    self, dst: &'a mut PlaneMutSlice<'a>, tx_size: TxSize, bit_depth: usize,
+    ac: &[i16], alpha: i16
+  ) {
     assert!(self.is_intra());
 
     match tx_size {
-      TxSize::TX_4X4 => self.predict_intra_inner::<Block4x4>(dst, bit_depth),
-      TxSize::TX_8X8 => self.predict_intra_inner::<Block8x8>(dst, bit_depth),
-      TxSize::TX_16X16 => self.predict_intra_inner::<Block16x16>(dst, bit_depth),
-      TxSize::TX_32X32 => self.predict_intra_inner::<Block32x32>(dst, bit_depth),
+      TxSize::TX_4X4 =>
+        self.predict_intra_inner::<Block4x4>(dst, bit_depth, ac, alpha),
+      TxSize::TX_8X8 =>
+        self.predict_intra_inner::<Block8x8>(dst, bit_depth, ac, alpha),
+      TxSize::TX_16X16 =>
+        self.predict_intra_inner::<Block16x16>(dst, bit_depth, ac, alpha),
+      TxSize::TX_32X32 =>
+        self.predict_intra_inner::<Block32x32>(dst, bit_depth, ac, alpha),
       _ => unimplemented!()
     }
   }
 
   #[inline(always)]
-  fn predict_intra_inner<'a, B: Intra>(self, dst: &'a mut PlaneMutSlice<'a>, bit_depth: usize) {
+  fn predict_intra_inner<'a, B: Intra>(
+    self, dst: &'a mut PlaneMutSlice<'a>, bit_depth: usize,
+    ac: &[i16], alpha: i16
+  ) {
     // above and left arrays include above-left sample
     // above array includes above-right samples
     // left array includes below-left samples
@@ -499,7 +514,7 @@ impl PredictionMode {
     let left_slice = &left[1..B::H + 1];
 
     match self {
-      PredictionMode::DC_PRED => match (x, y) {
+      PredictionMode::DC_PRED | PredictionMode::UV_CFL_PRED => match (x, y) {
         (0, 0) => B::pred_dc_128(slice, stride, bit_depth),
         (_, 0) => B::pred_dc_left(slice, stride, above_slice, left_slice, bit_depth),
         (0, _) => B::pred_dc_top(slice, stride, above_slice, left_slice, bit_depth),
@@ -525,12 +540,19 @@ impl PredictionMode {
         B::pred_smooth_v(slice, stride, above_slice, left_slice),
       _ => unimplemented!()
     }
+    if self == PredictionMode::UV_CFL_PRED {
+      B::pred_cfl(slice, stride, &ac, alpha, bit_depth);
+    }
   }
 
   pub fn is_intra(self) -> bool {
     return self < PredictionMode::NEARESTMV;
   }
 
+  pub fn is_cfl(self) -> bool {
+    self == PredictionMode::UV_CFL_PRED
+  }
+
   pub fn is_directional(self) -> bool {
     self >= PredictionMode::V_PRED && self <= PredictionMode::D63_PRED
   }
diff --git a/src/predict.rs b/src/predict.rs
index 7ae0211e..85633f8a 100755
--- a/src/predict.rs
+++ b/src/predict.rs
@@ -147,6 +147,12 @@ extern {
     dst: *mut u16, stride: libc::ptrdiff_t, bw: libc::c_int, bh: libc::c_int,
     above: *const u16, left: *const u16, bd: libc::c_int
   );
+
+  #[cfg(test)]
+  fn cfl_predict_hbd_c(
+    ac_buf_q3: *const i16, dst: *mut u16, stride: libc::ptrdiff_t,
+    alpha_q3: libc::c_int, bd: libc::c_int, bw: libc::c_int, bh: libc::c_int
+  );
 }
 
 pub trait Dim {
@@ -182,6 +188,17 @@ impl Dim for Block32x32 {
   const H: usize = 32;
 }
 
+#[inline(always)]
+fn get_scaled_luma_q0(alpha_q3: i16, ac_pred_q3: i16) -> i32 {
+  let scaled_luma_q6 = (alpha_q3 as i32) * (ac_pred_q3 as i32);
+  let abs_scaled_luma_q0 = (scaled_luma_q6.abs() + 32) >> 6;
+  if scaled_luma_q6 < 0 {
+    -abs_scaled_luma_q0
+  } else {
+    abs_scaled_luma_q0
+  }
+}
+
 pub trait Intra: Dim {
   #[cfg_attr(feature = "comparative_bench", inline(never))]
   fn pred_dc(output: &mut [u16], stride: usize, above: &[u16], left: &[u16]) {
@@ -408,6 +425,24 @@ pub trait Intra: Dim {
       }
     }
   }
+
+  #[cfg_attr(feature = "comparative_bench", inline(never))]
+  fn pred_cfl(
+    output: &mut [u16], stride: usize, ac: &[i16], alpha: i16,
+    bit_depth: usize
+  ) {
+    let sample_max = (1 << bit_depth) - 1;
+    let avg = output[0] as i32;
+
+    for (line, luma) in
+      output.chunks_mut(stride).zip(ac.chunks(32)).take(Self::H)
+    {
+      for (v, &l) in line[..Self::W].iter_mut().zip(luma[..Self::W].iter()) {
+        *v =
+          (avg + get_scaled_luma_q0(alpha, l)).max(0).min(sample_max) as u16;
+      }
+    }
+  }
 }
 
 pub trait Inter: Dim {
@@ -551,6 +586,22 @@ pub mod test {
     }
   }
 
+  pub fn pred_cfl_4x4(
+    output: &mut [u16], stride: usize, ac: &[i16], alpha: i16, bd: i32
+  ) {
+    unsafe {
+      cfl_predict_hbd_c(
+        ac.as_ptr(),
+        output.as_mut_ptr(),
+        stride as libc::ptrdiff_t,
+        alpha as libc::c_int,
+        bd,
+        4,
+        4
+      );
+    }
+  }
+
   fn do_dc_pred(ra: &mut ChaChaRng) -> (Vec<u16>, Vec<u16>) {
     let (above, left, mut o1, mut o2) = setup_pred(ra);
 
@@ -615,6 +666,31 @@ pub mod test {
     (o1, o2)
   }
 
+  fn setup_cfl_pred(
+    ra: &mut ChaChaRng
+  ) -> (Vec<u16>, Vec<u16>, Vec<i16>, i16, Vec<u16>, Vec<u16>) {
+    let o1 = vec![0u16; 32 * 32];
+    let o2 = vec![0u16; 32 * 32];
+    let above: Vec<u16> = (0..32).map(|_| ra.gen()).collect();
+    let left: Vec<u16> = (0..32).map(|_| ra.gen()).collect();
+    let ac: Vec<i16> = (0..(32 * 32)).map(|_| ra.gen()).collect();
+    let alpha = -1 as i16;
+
+    (above, left, ac, alpha, o1, o2)
+  }
+
+  fn do_cfl_pred(ra: &mut ChaChaRng) -> (Vec<u16>, Vec<u16>) {
+    let (above, left, ac, alpha, mut o1, mut o2) = setup_cfl_pred(ra);
+
+    pred_dc_4x4(&mut o1, 32, &above[..4], &left[..4]);
+    Block4x4::pred_dc(&mut o2, 32, &above[..4], &left[..4]);
+
+    pred_cfl_4x4(&mut o1, 32, &ac, alpha, 8);
+    Block4x4::pred_cfl(&mut o2, 32, &ac, alpha, 8);
+
+    (o1, o2)
+  }
+
   fn assert_same(o2: Vec<u16>) {
     for l in o2.chunks(32).take(4) {
       for v in l[..4].windows(2) {
@@ -647,6 +723,9 @@ pub mod test {
 
       let (o1, o2) = do_smooth_v_pred(&mut ra);
       assert_eq!(o1, o2);
+
+      let (o1, o2) = do_cfl_pred(&mut ra);
+      assert_eq!(o1, o2);
     }
   }
 
diff --git a/src/rdo.rs b/src/rdo.rs
index e0fbb005..c0ff4530 100755
--- a/src/rdo.rs
+++ b/src/rdo.rs
@@ -228,6 +228,10 @@ pub fn rdo_mode_decision(
       mode_set_chroma.push(PredictionMode::DC_PRED);
     }
 
+    if is_chroma_block && luma_mode.is_intra() && bsize.cfl_allowed() && !bsize.is_sub8x8() {
+      mode_set_chroma.push(PredictionMode::UV_CFL_PRED);
+    }
+
     let ref_frame = if luma_mode.is_intra() { INTRA_FRAME } else { LAST_FRAME };
     let mv = if luma_mode != PredictionMode::NEWMV {
       MotionVector { row: 0, col: 0 }
@@ -236,6 +240,7 @@ pub fn rdo_mode_decision(
     };
 
     // Find the best chroma prediction mode for the current luma prediction mode
+    let cfl = &CFLParams::new();
     for &chroma_mode in &mode_set_chroma {
       for &skip in &[false, true] {
         // Don't skip when using intra modes
@@ -246,7 +251,7 @@ pub fn rdo_mode_decision(
 
 
         encode_block_a(seq, cw, wr, bsize, bo, skip);
-        encode_block_b(fi, fs, cw, wr, luma_mode, chroma_mode, ref_frame, mv, bsize, bo, skip, seq.bit_depth);
+        encode_block_b(fi, fs, cw, wr, luma_mode, chroma_mode, ref_frame, mv, bsize, bo, skip, seq.bit_depth, cfl);
 
         let cost = wr.tell_frac() - tell;
         let rd = compute_rd_cost(
@@ -326,8 +331,9 @@ pub fn rdo_tx_type_decision(
         fi, fs, cw, wr, mode, bo, bsize, tx_size, tx_type, false, bit_depth
       );
     }  else {
+      let cfl = &CFLParams::new();
       write_tx_blocks(
-        fi, fs, cw, wr, mode, mode, bo, bsize, tx_size, tx_type, false, bit_depth
+        fi, fs, cw, wr, mode, mode, bo, bsize, tx_size, tx_type, false, bit_depth, cfl
       );
     }
 
-- 
GitLab