Commit 4d226d8d authored by fbossen's avatar fbossen Committed by Yushin Cho

Add rate of partition type in RDO decision (#799)

parent 68cb35b0
......@@ -483,6 +483,10 @@ pub fn rdo_mode_decision(
if skip { tx_type = TxType::DCT_DCT; };
if bsize >= BlockSize::BLOCK_8X8 && bsize.is_sqr() {
cw.write_partition(wr, bo, PartitionType::PARTITION_NONE, bsize);
}
encode_block_a(seq, fs, cw, wr, bsize, bo, skip);
let tx_dist =
encode_block_b(
......@@ -834,6 +838,8 @@ pub fn rdo_partition_decision(
let mut best_pred_modes = cached_block.part_modes.clone();
let cw_checkpoint = cw.checkpoint();
let w_pre_checkpoint = w_pre_cdef.checkpoint();
let w_post_checkpoint = w_post_cdef.checkpoint();
for &partition in RAV1E_PARTITION_TYPES {
// Do not re-encode results we already have
......@@ -841,6 +847,8 @@ pub fn rdo_partition_decision(
continue;
}
let mut cost: f64 = 0.0;
let mut rd: f64;
let mut child_modes = std::vec::Vec::new();
......@@ -867,6 +875,14 @@ pub fn rdo_partition_decision(
if subsize == BlockSize::BLOCK_INVALID {
continue;
}
if bsize >= BlockSize::BLOCK_8X8 {
let w: &mut dyn Writer = if cw.bc.cdef_coded {w_post_cdef} else {w_pre_cdef};
let tell = w.tell_frac();
cw.write_partition(w, bo, partition, bsize);
cost = (w.tell_frac() - tell) as f64 * get_lambda(fi, seq.bit_depth)/ ((1 << OD_BITRES) as f64);
}
//pmv = best_pred_modes[0].mvs[0];
assert!(best_pred_modes.len() <= 4);
......@@ -918,6 +934,11 @@ pub fn rdo_partition_decision(
let is_compound = ref_frames[1] != NONE_FRAME;
let mode_context = cw.find_mvrefs(bo, ref_frames, &mut mv_stack, subsize, false, fi, is_compound);
if subsize >= BlockSize::BLOCK_8X8 && subsize.is_sqr() {
let w: &mut dyn Writer = if cw.bc.cdef_coded {w_post_cdef} else {w_pre_cdef};
cw.write_partition(w, bo, PartitionType::PARTITION_NONE, subsize);
}
cdef_coded = encode_block_a(seq, fs, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef},
subsize, bo, skip);
encode_block_b(seq, fi, fs, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef},
......@@ -936,7 +957,12 @@ pub fn rdo_partition_decision(
}
}
rd = child_modes.iter().map(|m| m.rd_cost).sum::<f64>();
cw.rollback(&cw_checkpoint);
w_pre_cdef.rollback(&w_pre_checkpoint);
w_post_cdef.rollback(&w_post_checkpoint);
rd = cost + child_modes.iter().map(|m| m.rd_cost).sum::<f64>();
if rd < best_rd {
best_rd = rd;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment