Unverified Commit 58e3084b authored by fbossen's avatar fbossen Committed by GitHub

Frame reordering (#629)

* Add code to reorder frames during encoding

* Use BTreeMap instead of VecDequeue to hold input frames

* Implement sign bias in MV prediction

* Disable reuse of learned prob and change ref frame semantics

Use LAST2 instead of ALTREF when the second reference frame is also
in the past
parent fce17a24
......@@ -11,7 +11,7 @@ use context::CDFContext;
use encoder::*;
use partition::*;
use std::collections::VecDeque;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::Arc;
......@@ -85,7 +85,7 @@ impl Config {
aom_dsp_rtcd();
}
Context { fi, seq, frame_q: VecDeque::new() }
Context { fi, seq, frame_count: 0, idx: 0, frame_q: BTreeMap::new() }
}
}
......@@ -93,7 +93,9 @@ pub struct Context {
fi: FrameInvariants,
seq: Sequence,
// timebase: Ratio,
frame_q: VecDeque<Option<Arc<Frame>>> // packet_q: VecDeque<Packet>
frame_count: u64,
idx: u64,
frame_q: BTreeMap<u64, Option<Arc<Frame>>> // packet_q: VecDeque<Packet>
}
#[derive(Clone, Copy, Debug)]
......@@ -110,8 +112,8 @@ pub enum EncoderStatus {
pub struct Packet {
pub data: Vec<u8>,
pub rec: Frame,
pub number: usize,
pub rec: Option<Frame>,
pub number: u64,
pub frame_type: FrameType
}
......@@ -136,7 +138,8 @@ impl Context {
where
F: Into<Option<Arc<Frame>>>
{
self.frame_q.push_back(frame.into());
self.frame_q.insert(self.frame_count, frame.into());
self.frame_count = self.frame_count + 1;
Ok(())
}
......@@ -170,100 +173,161 @@ impl Context {
sequence_header_inner(&self.seq).unwrap()
}
pub fn receive_packet(&mut self) -> Result<Packet, EncoderStatus> {
let f = self.frame_q.pop_front().ok_or(EncoderStatus::NeedMoreData)?;
if let Some(frame) = f {
let mut fs = FrameState {
input: frame,
rec: Frame::new(self.fi.padded_w, self.fi.padded_h),
qc: Default::default(),
cdfs: CDFContext::new(0),
deblock: Default::default()
};
let frame_number_in_segment = self.fi.number % 30;
pub fn frame_properties(&mut self, idx: u64) -> bool {
let key_frame_interval: u64 = 30;
let reorder = false;
let multiref = reorder || self.fi.config.speed <= 2;
let pyramid_depth = if reorder { 1 } else { 0 };
let group_src_len = 1 << pyramid_depth;
let group_len = group_src_len + if reorder { pyramid_depth } else { 0 };
let segment_len = 1 + (key_frame_interval - 1 + group_src_len - 1) / group_src_len * group_len;
let idx_in_segment = idx % segment_len;
let segment_idx = idx / segment_len;
if idx_in_segment == 0 {
self.fi.frame_type = FrameType::KEY;
self.fi.intra_only = true;
self.fi.order_hint = 0;
self.fi.refresh_frame_flags = ALL_REF_FRAMES_MASK;
self.fi.show_frame = true;
self.fi.show_existing_frame = false;
self.fi.frame_to_show_map_idx = 0;
let q_boost = 15;
self.fi.base_q_idx = (self.fi.config.quantizer.max(1 + q_boost).min(255 + q_boost) - q_boost) as u8;
self.fi.primary_ref_frame = PRIMARY_REF_NONE;
self.fi.number = segment_idx * key_frame_interval;
for i in 0..INTER_REFS_PER_FRAME {
self.fi.ref_frames[i] = 0;
}
} else {
let idx_in_group = (idx_in_segment - 1) % group_len;
let group_idx = (idx_in_segment - 1) / group_len;
self.fi.order_hint = frame_number_in_segment as u32;
self.fi.frame_type = FrameType::INTER;
self.fi.intra_only = false;
self.fi.frame_type = if frame_number_in_segment == 0 {
FrameType::KEY
} else {
FrameType::INTER
};
let slot_idx = frame_number_in_segment % REF_FRAMES as u64;
self.fi.order_hint = (group_src_len * group_idx +
if reorder && idx_in_group < pyramid_depth {
group_src_len >> idx_in_group
} else {
idx_in_group - pyramid_depth + 1
}) as u32;
if self.fi.order_hint >= key_frame_interval as u32 {
return false;
}
self.fi.refresh_frame_flags = if self.fi.frame_type == FrameType::KEY {
ALL_REF_FRAMES_MASK
let slot_idx = self.fi.order_hint % REF_FRAMES as u32;
self.fi.show_frame = !reorder || idx_in_group >= pyramid_depth;
self.fi.show_existing_frame = self.fi.show_frame && reorder &&
(idx_in_group - pyramid_depth + 1).count_ones() == 1 &&
idx_in_group != pyramid_depth;
self.fi.frame_to_show_map_idx = slot_idx;
self.fi.refresh_frame_flags = if self.fi.show_existing_frame {
0
} else {
1 << slot_idx
};
self.fi.intra_only = self.fi.frame_type == FrameType::KEY
|| self.fi.frame_type == FrameType::INTRA_ONLY;
// self.fi.use_prev_frame_mvs =
// !(self.fi.intra_only || self.fi.error_resilient);
let use_multiple_ref_frames = self.fi.config.speed <= 2;
let log_boost_frequency = if use_multiple_ref_frames {
2 // Higher quality frame every 4 frames
let lvl = if !reorder {
0
} else if idx_in_group < pyramid_depth {
idx_in_group
} else {
0 // No boosting with single reference frame
pyramid_depth - (idx_in_group - pyramid_depth + 1).trailing_zeros() as u64
};
assert!(log_boost_frequency >= 0 && log_boost_frequency <= 2);
let boost_frequency = 1 << log_boost_frequency;
self.fi.base_q_idx = if self.fi.frame_type == FrameType::KEY {
let q_boost = 15;
self.fi.config.quantizer.max(1 + q_boost).min(255 + q_boost) - q_boost
} else if slot_idx & (boost_frequency - 1) == 0 {
self.fi.config.quantizer.max(1).min(255)
} else {
let q_drop = 15;
self.fi.config.quantizer.min(255 - q_drop) + q_drop
} as u8;
let q_drop = 15 * lvl as usize;
self.fi.base_q_idx = (self.fi.config.quantizer.min(255 - q_drop) + q_drop) as u8;
let first_ref_frame = LAST_FRAME;
let second_ref_frame =
if use_multiple_ref_frames { ALTREF_FRAME } else { NONE_FRAME };
let second_ref_frame = if !multiref {
NONE_FRAME
} else if !reorder || idx_in_group == 0 {
LAST2_FRAME
} else {
ALTREF_FRAME
};
let ref_in_previous_group = LAST3_FRAME;
self.fi.primary_ref_frame =
if self.fi.intra_only || self.fi.error_resilient {
PRIMARY_REF_NONE
} else {
(first_ref_frame - LAST_FRAME) as u32
};
self.fi.primary_ref_frame = (ref_in_previous_group - LAST_FRAME) as u32;
for i in 0..INTER_REFS_PER_FRAME {
self.fi.ref_frames[i] = if i == second_ref_frame - LAST_FRAME {
(REF_FRAMES + slot_idx as usize - 2) & boost_frequency as usize
(slot_idx as u64 + if lvl == 0 { 6 * group_src_len } else { group_src_len >> lvl }) & 7
} else if i == ref_in_previous_group - LAST_FRAME {
(slot_idx as u64 - group_src_len) & 7
} else {
(REF_FRAMES + slot_idx as usize - 1) & (REF_FRAMES - 1)
};
(slot_idx as u64 - (group_src_len >> lvl)) & 7
} as usize;
}
let data = encode_frame(&mut self.seq, &mut self.fi, &mut fs);
self.fi.number = segment_idx * key_frame_interval + self.fi.order_hint as u64;
}
let number = self.fi.number as usize;
true
}
self.fi.number += 1;
pub fn receive_packet(&mut self) -> Result<Packet, EncoderStatus> {
let mut idx = self.idx;
while !self.frame_properties(idx) {
self.idx = self.idx + 1;
idx = self.idx;
}
fs.rec.pad();
if self.fi.show_existing_frame {
self.idx = self.idx + 1;
// TODO avoid the clone by having rec Arc.
let rec = fs.rec.clone();
let mut fs = FrameState {
input: Arc::new(Frame::new(self.fi.padded_w, self.fi.padded_h)), // dummy
rec: Frame::new(self.fi.padded_w, self.fi.padded_h),
qc: Default::default(),
cdfs: CDFContext::new(0),
deblock: Default::default(),
};
let data = encode_frame(&mut self.seq, &mut self.fi, &mut fs);
update_rec_buffer(&mut self.fi, fs);
// TODO avoid the clone by having rec Arc.
let rec = if self.fi.show_frame { Some(fs.rec.clone()) } else { None };
Ok(Packet { data, rec, number, frame_type: self.fi.frame_type })
Ok(Packet { data, rec, number: self.fi.number, frame_type: self.fi.frame_type })
} else {
unimplemented!("Flushing not implemented")
if let Some(f) = self.frame_q.remove(&self.fi.number) {
self.idx = self.idx + 1;
if let Some(frame) = f {
let mut fs = FrameState {
input: frame,
rec: Frame::new(self.fi.padded_w, self.fi.padded_h),
qc: Default::default(),
cdfs: CDFContext::new(0),
deblock: Default::default(),
};
let data = encode_frame(&mut self.seq, &mut self.fi, &mut fs);
fs.rec.pad();
// TODO avoid the clone by having rec Arc.
let rec = if self.fi.show_frame { Some(fs.rec.clone()) } else { None };
update_rec_buffer(&mut self.fi, fs);
Ok(Packet { data, rec, number: self.fi.number, frame_type: self.fi.frame_type })
} else {
Err(EncoderStatus::NeedMoreData)
}
} else {
Err(EncoderStatus::NeedMoreData)
}
}
}
pub fn flush(&mut self) {
self.frame_q.push_back(None);
self.frame_q.insert(self.frame_count, None);
self.frame_count = self.frame_count + 1;
}
}
......
......@@ -92,6 +92,7 @@ pub fn parse_cli() -> (EncoderIO, EncoderConfig, usize) {
}
/// Encode and write a frame.
/// returns wheter done with sequence
pub fn process_frame(
ctx: &mut Context, output_file: &mut dyn Write,
y4m_dec: &mut y4m::Decoder<'_, Box<dyn Read>>,
......@@ -107,6 +108,7 @@ pub fn process_frame(
let y4m_bytes = y4m_dec.get_bytes_per_sample();
let bit_depth = y4m_dec.get_colorspace().get_bit_depth();
let read_frame =
match y4m_dec.read_frame() {
Ok(y4m_frame) => {
let y4m_y = y4m_frame.get_y_plane();
......@@ -134,10 +136,24 @@ pub fn process_frame(
}
let _ = ctx.send_frame(input);
let pkt = ctx.receive_packet().unwrap();
true
}
_ => {
ctx.flush();
false
}
};
let y4m_enc_uw = y4m_enc.unwrap();
let mut has_data = true;
while has_data {
let pkt_wrapped = ctx.receive_packet();
match pkt_wrapped {
Ok(pkt) => {
eprintln!("{}", pkt);
write_ivf_frame(output_file, pkt.number as u64, pkt.data.as_ref());
if let Some(mut y4m_enc) = y4m_enc {
if let Some(rec) = pkt.rec {
let pitch_y = if bit_depth > 8 { width * 2 } else { width };
let pitch_uv = pitch_y / 2;
......@@ -148,12 +164,12 @@ pub fn process_frame(
);
let (stride_y, stride_u, stride_v) = (
pkt.rec.planes[0].cfg.stride,
pkt.rec.planes[1].cfg.stride,
pkt.rec.planes[2].cfg.stride
rec.planes[0].cfg.stride,
rec.planes[1].cfg.stride,
rec.planes[2].cfg.stride
);
for (line, line_out) in pkt.rec.planes[0]
for (line, line_out) in rec.planes[0]
.data_origin()
.chunks(stride_y)
.zip(rec_y.chunks_mut(pitch_y))
......@@ -171,7 +187,7 @@ pub fn process_frame(
);
}
}
for (line, line_out) in pkt.rec.planes[1]
for (line, line_out) in rec.planes[1]
.data_origin()
.chunks(stride_u)
.zip(rec_u.chunks_mut(pitch_uv))
......@@ -189,7 +205,7 @@ pub fn process_frame(
);
}
}
for (line, line_out) in pkt.rec.planes[2]
for (line, line_out) in rec.planes[2]
.data_origin()
.chunks(stride_v)
.zip(rec_v.chunks_mut(pitch_uv))
......@@ -209,11 +225,11 @@ pub fn process_frame(
}
let rec_frame = y4m::Frame::new([&rec_y, &rec_u, &rec_v], None);
y4m_enc.write_frame(&rec_frame).unwrap();
y4m_enc_uw.write_frame(&rec_frame).unwrap();
}
true
},
_ => { has_data = false; }
}
_ => false
}
read_frame
}
......@@ -28,6 +28,7 @@ use util::msb;
use std::*;
use entropymode::*;
use token_cdfs::*;
use encoder::FrameInvariants;
use self::REF_CONTEXTS;
use self::SINGLE_REFS;
......@@ -1928,9 +1929,9 @@ impl ContextWriter {
cmp::max(col_offset, -(mi_col as isize))
}
fn find_matching_mv(&self, blk: &Block, mv_stack: &mut Vec<CandidateMV>) -> bool {
fn find_matching_mv(&self, mv: &MotionVector, mv_stack: &mut Vec<CandidateMV>) -> bool {
for mv_cand in mv_stack {
if blk.mv[0].row == mv_cand.this_mv.row && blk.mv[0].col == mv_cand.this_mv.col {
if mv.row == mv_cand.this_mv.row && mv.col == mv_cand.this_mv.col {
return true;
}
}
......@@ -1976,13 +1977,26 @@ impl ContextWriter {
}
}
fn add_extra_mv_candidate(&self, blk: &Block, mv_stack: &mut Vec<CandidateMV>) {
fn add_extra_mv_candidate(
&self,
blk: &Block,
ref_frame: usize,
mv_stack: &mut Vec<CandidateMV>,
fi: &FrameInvariants
) {
for cand_list in 0..2 {
if blk.ref_frames[cand_list] > INTRA_FRAME {
if !self.find_matching_mv(blk, mv_stack) {
let mut mv = blk.mv[0];
if fi.ref_frame_sign_bias[blk.ref_frames[cand_list] - LAST_FRAME] !=
fi.ref_frame_sign_bias[ref_frame - LAST_FRAME] {
mv.row = -mv.row;
mv.col = -mv.col;
}
if !self.find_matching_mv(&mv, mv_stack) {
let mv_cand = CandidateMV {
this_mv: blk.mv[0],
comp_mv: blk.mv[1],
this_mv: mv,
comp_mv: mv,
weight: 2
};
mv_stack.push(mv_cand);
......@@ -2115,7 +2129,7 @@ impl ContextWriter {
}
fn setup_mvref_list(&mut self, bo: &BlockOffset, ref_frame: usize, mv_stack: &mut Vec<CandidateMV>,
bsize: BlockSize, is_sec_rect: bool) -> usize {
bsize: BlockSize, is_sec_rect: bool, fi: &FrameInvariants) -> usize {
let (_rf, _rf_num) = self.get_mvref_ref_frames(INTRA_FRAME);
let target_n4_h = bsize.height_mi();
......@@ -2239,7 +2253,7 @@ impl ContextWriter {
};
let blk = &self.bc.at(&rbo);
self.add_extra_mv_candidate(blk, mv_stack);
self.add_extra_mv_candidate(blk, ref_frame, mv_stack, fi);
idx += if pass == 0 {
blk.n4_w
......@@ -2270,7 +2284,8 @@ impl ContextWriter {
}
pub fn find_mvrefs(&mut self, bo: &BlockOffset, ref_frame: usize,
mv_stack: &mut Vec<CandidateMV>, bsize: BlockSize, is_sec_rect: bool) -> usize {
mv_stack: &mut Vec<CandidateMV>, bsize: BlockSize, is_sec_rect: bool,
fi: &FrameInvariants) -> usize {
if ref_frame < REF_FRAMES {
if ref_frame != INTRA_FRAME {
/* TODO: convert global mv to an mv here */
......@@ -2289,7 +2304,7 @@ impl ContextWriter {
return 0;
}
let mode_context = self.setup_mvref_list(bo, ref_frame, mv_stack, bsize, is_sec_rect);
let mode_context = self.setup_mvref_list(bo, ref_frame, mv_stack, bsize, is_sec_rect, fi);
mode_context
}
......
......@@ -246,6 +246,12 @@ impl Sequence {
separate_uv_delta_q: false,
}
}
pub fn get_relative_dist(&self, a: u32, b: u32) -> i32 {
let diff = a as i32 - b as i32;
let m = 1 << self.order_hint_bits_minus_1;
(diff & (m - 1)) - (diff & m)
}
}
use std::sync::Arc;
......@@ -331,6 +337,7 @@ pub struct FrameInvariants {
pub allow_high_precision_mv: bool,
pub frame_type: FrameType,
pub show_existing_frame: bool,
pub frame_to_show_map_idx: u32,
pub use_reduced_tx_set: bool,
pub reference_mode: ReferenceMode,
pub use_prev_frame_mvs: bool,
......@@ -358,6 +365,7 @@ pub struct FrameInvariants {
pub delta_q_present: bool,
pub config: EncoderConfig,
pub ref_frames: [usize; INTER_REFS_PER_FRAME],
pub ref_frame_sign_bias: [bool; INTER_REFS_PER_FRAME],
pub rec_buffer: ReferenceFramesSet,
pub base_q_idx: u8,
}
......@@ -399,6 +407,7 @@ impl FrameInvariants {
allow_high_precision_mv: false,
frame_type: FrameType::KEY,
show_existing_frame: false,
frame_to_show_map_idx: 0,
use_reduced_tx_set,
reference_mode: ReferenceMode::SINGLE,
use_prev_frame_mvs: false,
......@@ -424,6 +433,7 @@ impl FrameInvariants {
delta_q_present: false,
config,
ref_frames: [0; INTER_REFS_PER_FRAME],
ref_frame_sign_bias: [false; INTER_REFS_PER_FRAME],
rec_buffer: ReferenceFramesSet::new(),
base_q_idx: config.quantizer as u8,
}
......@@ -438,6 +448,7 @@ impl FrameInvariants {
deblock: Default::default(),
}
}
}
impl fmt::Display for FrameInvariants{
......@@ -694,7 +705,7 @@ impl<W: io::Write> UncompressedHeader for BitWriter<W, BigEndian> {
} else {
if fi.show_existing_frame {
self.write_bit(true)?; // show_existing_frame=1
self.write(3, 0)?; // show last frame
self.write(3, fi.frame_to_show_map_idx)?;
//TODO:
/* temporal_point_info();
......@@ -706,6 +717,7 @@ impl<W: io::Write> UncompressedHeader for BitWriter<W, BigEndian> {
// write display_frame_id;
}*/
self.write_bit(true)?; // trailing bit
self.byte_align()?;
return Ok((()));
}
......@@ -1692,7 +1704,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
rd_cost = mode_decision.rd_cost + cost;
let mut mv_stack = Vec::new();
let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false);
let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false, fi);
let (tx_size, tx_type) =
rdo_tx_size_type(seq, fi, fs, cw, bsize, bo, mode_luma, ref_frame, mv, skip);
......@@ -1770,7 +1782,7 @@ fn encode_partition_bottomup(seq: &Sequence, fi: &FrameInvariants, fs: &mut Fram
let mut cdef_coded = cw.bc.cdef_coded;
let mut mv_stack = Vec::new();
let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false);
let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false, fi);
let (tx_size, tx_type) =
rdo_tx_size_type(seq, fi, fs, cw, bsize, bo, mode_luma, ref_frame, mv, skip);
......@@ -1863,7 +1875,7 @@ fn encode_partition_topdown(seq: &Sequence, fi: &FrameInvariants, fs: &mut Frame
rdo_tx_size_type(seq, fi, fs, cw, bsize, bo, mode_luma, ref_frame, mv, skip);
let mut mv_stack = Vec::new();
let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false);
let mode_context = cw.find_mvrefs(bo, ref_frame, &mut mv_stack, bsize, false, fi);
if !mode_luma.is_intra() && mode_luma != PredictionMode::GLOBALMV {
mode_luma = PredictionMode::NEWMV;
......@@ -2020,13 +2032,27 @@ pub fn encode_frame(sequence: &mut Sequence, fi: &mut FrameInvariants, fs: &mut
if fi.show_existing_frame {
//write_uncompressed_header(&mut packet, sequence, fi).unwrap();
write_obus(&mut packet, sequence, fi, fs).unwrap();
match fi.rec_buffer.frames[0] {
match fi.rec_buffer.frames[fi.frame_to_show_map_idx as usize] {
Some(ref rec) => for p in 0..3 {
fs.rec.planes[p].data.copy_from_slice(rec.frame.planes[p].data.as_slice());
},
None => (),
}
} else {
if !fi.intra_only {
for i in 0..INTER_REFS_PER_FRAME {
fi.ref_frame_sign_bias[i] =
if !sequence.enable_order_hint {
false
} else if let Some(ref rec) = fi.rec_buffer.frames[fi.ref_frames[i]] {
let hint = rec.order_hint;
sequence.get_relative_dist(hint, fi.order_hint) > 0
} else {
false
};
}
}
let bit_depth = sequence.bit_depth;
let tile = encode_tile(sequence, fi, fs, bit_depth); // actually tile group
......
......@@ -307,6 +307,8 @@ pub fn rdo_mode_decision(
if fi.frame_type == FrameType::INTER {
for i in LAST_FRAME..NONE_FRAME {
// Don't search LAST3 since it's used only for probs
if i == LAST3_FRAME { continue; }
if !ref_slot_set.contains(&fi.ref_frames[i - LAST_FRAME]) {
ref_frame_set.push(i);
ref_slot_set.push(fi.ref_frames[i - LAST_FRAME]);
......@@ -321,7 +323,7 @@ pub fn rdo_mode_decision(
for (i, &ref_frame) in ref_frame_set.iter().enumerate() {
let mut mvs: Vec<CandidateMV> = Vec::new();
mode_contexts.push(cw.find_mvrefs(bo, ref_frame, &mut mvs, bsize, false));
mode_contexts.push(cw.find_mvrefs(bo, ref_frame, &mut mvs, bsize, false, fi));
if fi.frame_type == FrameType::INTER {
for &x in RAV1E_INTER_MODES_MINIMAL {
......
......@@ -254,71 +254,80 @@ fn encode_decode(
fill_frame(&mut ra, Arc::get_mut(&mut input).unwrap());
let _ = ctx.send_frame(input);
let pkt = ctx.receive_packet().unwrap();
println!("Encoded packet {}", pkt.number);
rec_fifo.push_back(pkt.rec.clone());
let packet = pkt.data;
let mut done = false;
let mut corrupted_count = 0;
unsafe {
println!("Decoding frame {}", pkt.number);
let ret = aom_codec_decode(
&mut dec.dec,
packet.as_ptr(),
packet.len(),
ptr::null_mut()
);
println!("Decoded. -> {}", ret);
if ret != 0 {
use std::ffi::CStr;
let error_msg = aom_codec_error(&mut dec.dec);
println!(
" Decode codec_decode failed: {}",
CStr::from_ptr(error_msg).to_string_lossy()
);
let detail = aom_codec_error_detail(&mut dec.dec);
if !detail.is_null() {
println!(
" Decode codec_decode failed {}",
CStr::from_ptr(detail).to_string_lossy()
);
while !done {
let res = ctx.receive_packet();
if let Ok(pkt) = res {
println!("Encoded packet {}", pkt.number);
if let Some(pkt_rec) = pkt.rec {
rec_fifo.push_back(pkt_rec.clone());
}
corrupted_count += 1;
}
let packet = pkt.data;
if ret == 0 {