Commit 3202bf70 authored by Josh Holmer's avatar Josh Holmer Committed by Thomas Daede

Implement lookahead queue

Allows rav1e to read frames into the frame queue ahead of the current
one, to allow for future features to be added that depend on a lookahead
pass. Currently the lookahead distance is hardcoded to 10. In the
future, this would probably be configurable as a CLI option.
parent be245fc0
......@@ -17,6 +17,10 @@ use self::EncoderStatus::*;
use std::{cmp, fmt, io};
use std::collections::BTreeMap;
use std::sync::Arc;
use util::Fixed;
use std::collections::BTreeSet;
const LOOKAHEAD_FRAMES: u64 = 10;
// TODO: use the num crate?
#[derive(Clone, Copy, Debug)]
......@@ -291,14 +295,6 @@ impl Config {
}
pub fn new_context(&self) -> Context {
let seq = Sequence::new(&self.frame_info);
let fi = FrameInvariants::new(
self.frame_info.width,
self.frame_info.height,
self.enc,
seq,
);
#[cfg(feature = "aom")]
unsafe {
av1_rtcd();
......@@ -306,32 +302,40 @@ impl Config {
}
Context {
fi,
frame_count: 0,
frames_to_be_coded: 0,
idx: 0,
frames_processed: 0,
frame_q: BTreeMap::new(),
frame_data: BTreeMap::new(),
keyframes: BTreeSet::new(),
packet_data: Vec::new(),
segment_start_idx: 0,
segment_start_frame: 0,
frame_types: BTreeMap::new(),
keyframe_detector: SceneChangeDetector::new(&self.frame_info),
config: *self,
}
}
}
pub struct Context {
fi: FrameInvariants,
// timebase: Rational,
frame_count: u64,
frames_to_be_coded: u64,
idx: u64,
frames_processed: u64,
/// Maps frame *number* to frames
frame_q: BTreeMap<u64, Option<Arc<Frame>>>, // packet_q: VecDeque<Packet>
/// Maps frame *idx* to frame data
frame_data: BTreeMap<u64, FrameInvariants>,
/// A list of keyframe *numbers* in this encode. Needed so that we don't
/// need to keep all of the frame_data in memory for the whole life of the encode.
keyframes: BTreeSet<u64>,
packet_data: Vec<u8>,
segment_start_idx: u64,
segment_start_frame: u64,
frame_types: BTreeMap<u64, FrameType>,
keyframe_detector: SceneChangeDetector,
config: Config,
}
#[derive(Clone, Copy, Debug)]
......@@ -369,7 +373,11 @@ impl fmt::Display for Packet {
impl Context {
pub fn new_frame(&self) -> Arc<Frame> {
Arc::new(Frame::new(self.fi.padded_w, self.fi.padded_h, self.fi.sequence.chroma_sampling))
Arc::new(Frame::new(
self.config.frame_info.width.align_power_of_two(3),
self.config.frame_info.height.align_power_of_two(3),
self.config.frame_info.chroma_sampling
))
}
pub fn send_frame<F>(&mut self, frame: F) -> Result<(), EncoderStatus>
......@@ -378,7 +386,6 @@ impl Context {
{
let idx = self.frame_count;
self.frame_q.insert(idx, frame.into());
self.save_frame_type(idx);
self.frame_count = self.frame_count + 1;
Ok(())
}
......@@ -391,6 +398,10 @@ impl Context {
self.frames_to_be_coded = frames_to_be_coded;
}
pub fn needs_more_lookahead(&self) -> bool {
self.needs_more_frames(self.frame_count) && self.frames_processed + LOOKAHEAD_FRAMES > self.frame_q.keys().last().cloned().unwrap_or(0)
}
pub fn needs_more_frames(&self, frame_count: u64) -> bool {
self.frames_to_be_coded == 0 || frame_count < self.frames_to_be_coded
}
......@@ -422,14 +433,14 @@ impl Context {
Ok(buf)
}
sequence_header_inner(&self.fi.sequence).unwrap()
sequence_header_inner(&self.frame_data[&0].sequence).unwrap()
}
fn next_keyframe(&self) -> u64 {
let next_detected = self.frame_types.iter()
.find(|(&i, &ty)| ty == FrameType::KEY && i > self.segment_start_frame)
.map(|(&i, _)| i);
let next_limit = self.segment_start_frame + self.fi.config.max_key_frame_interval;
let next_detected = self.frame_data.values()
.find(|fi| fi.frame_type == FrameType::KEY && fi.number > self.segment_start_frame)
.map(|fi| fi.number);
let next_limit = self.segment_start_frame + self.config.enc.max_key_frame_interval;
if next_detected.is_none() {
return next_limit;
}
......@@ -437,12 +448,33 @@ impl Context {
}
fn set_frame_properties(&mut self, idx: u64) -> Result<(), ()> {
let props = self.get_frame_properties(idx);
let result = props.as_ref().map(|_| ()).map_err(|_| ());
match props {
Ok(props) | Err(props) => {
self.frame_data.insert(idx, props);
}
}
result
}
fn get_frame_properties(&mut self, idx: u64) -> Result<FrameInvariants, FrameInvariants> {
if idx == 0 {
// The first frame will always be a key frame
self.fi = FrameInvariants::new_key_frame(&self.fi,0);
return Ok(());
let fi = FrameInvariants::new_key_frame(
&FrameInvariants::new(
self.config.frame_info.width,
self.config.frame_info.height,
self.config.enc,
Sequence::new(&self.config.frame_info)
),
0
);
return Ok(fi);
}
let mut fi = self.frame_data[&(idx - 1)].clone();
// Initially set up the frame as an inter frame.
// We need to determine what the frame number is before we can
// look up the frame type. If reordering is enabled, the idx
......@@ -450,54 +482,53 @@ impl Context {
let idx_in_segment = idx - self.segment_start_idx;
if idx_in_segment > 0 {
let next_keyframe = self.next_keyframe();
let (fi, success) = FrameInvariants::new_inter_frame(
&self.fi,
let (fi_temp, success) = FrameInvariants::new_inter_frame(
&fi,
self.segment_start_frame,
idx_in_segment,
next_keyframe
);
self.fi = fi;
fi = fi_temp;
if !success {
if !self.fi.inter_cfg.unwrap().reorder
|| ((idx_in_segment - 1) % self.fi.inter_cfg.unwrap().group_len == 0
&& self.fi.number == (next_keyframe - 1))
if !fi.inter_cfg.unwrap().reorder
|| ((idx_in_segment - 1) % fi.inter_cfg.unwrap().group_len == 0
&& fi.number == (next_keyframe - 1))
{
self.segment_start_idx = idx;
self.segment_start_frame = next_keyframe;
self.fi.number = next_keyframe;
fi.number = next_keyframe;
} else {
return Err(());
return Err(fi);
}
}
}
// Now that we know the frame number, look up the correct frame type
let frame_type = self.frame_types.get(&self.fi.number).cloned();
if let Some(frame_type) = frame_type {
if frame_type == FrameType::KEY {
self.segment_start_idx = idx;
self.segment_start_frame = self.fi.number;
}
self.fi.frame_type = frame_type;
let frame_type = self.determine_frame_type(fi.number);
if frame_type == FrameType::KEY {
self.segment_start_idx = idx;
self.segment_start_frame = fi.number;
self.keyframes.insert(fi.number);
}
fi.frame_type = frame_type;
let idx_in_segment = idx - self.segment_start_idx;
if idx_in_segment == 0 {
self.fi = FrameInvariants::new_key_frame(&self.fi, self.segment_start_frame);
} else {
let next_keyframe = self.next_keyframe();
let (fi, success) = FrameInvariants::new_inter_frame(
&self.fi,
self.segment_start_frame,
idx_in_segment,
next_keyframe
);
self.fi = fi;
if !success {
return Err(());
}
let idx_in_segment = idx - self.segment_start_idx;
if idx_in_segment == 0 {
fi = FrameInvariants::new_key_frame(&fi, self.segment_start_frame);
} else {
let next_keyframe = self.next_keyframe();
let (fi_temp, success) = FrameInvariants::new_inter_frame(
&fi,
self.segment_start_frame,
idx_in_segment,
next_keyframe
);
fi = fi_temp;
if !success {
return Err(fi);
}
}
Ok(())
Ok(fi)
}
pub fn receive_packet(&mut self) -> Result<Packet, EncoderStatus> {
......@@ -507,56 +538,59 @@ impl Context {
idx = self.idx;
}
if !self.needs_more_frames(self.fi.number) {
if !self.needs_more_frames(self.frame_data.get(&idx).unwrap().number) {
self.idx += 1;
return Err(EncoderStatus::EnoughData)
}
if self.fi.show_existing_frame {
let fi = self.frame_data.get_mut(&idx).unwrap();
if fi.show_existing_frame {
self.idx += 1;
let mut fs = FrameState::new(&self.fi);
let mut fs = FrameState::new(fi);
let data = encode_frame(&mut self.fi, &mut fs);
let data = encode_frame(fi, &mut fs);
let rec = if self.fi.show_frame { Some(fs.rec) } else { None };
let rec = if fi.show_frame { Some(fs.rec) } else { None };
let mut psnr = None;
if self.fi.config.show_psnr {
if self.config.enc.show_psnr {
if let Some(ref rec) = rec {
psnr = Some(calculate_frame_psnr(&*fs.input, rec, self.fi.sequence.bit_depth));
psnr = Some(calculate_frame_psnr(&*fs.input, rec, fi.sequence.bit_depth));
}
}
Ok(Packet { data, rec, number: self.fi.number, frame_type: self.fi.frame_type, psnr })
self.frames_processed += 1;
Ok(Packet { data, rec, number: fi.number, frame_type: fi.frame_type, psnr })
} else {
if let Some(f) = self.frame_q.remove(&self.fi.number) {
if let Some(f) = self.frame_q.get(&fi.number) {
self.idx += 1;
if let Some(frame) = f {
let mut fs = FrameState::new_with_frame(&self.fi, frame.clone());
let mut fs = FrameState::new_with_frame(fi, frame.clone());
let data = encode_frame(&mut self.fi, &mut fs);
let data = encode_frame(fi, &mut fs);
self.packet_data.extend(data);
fs.rec.pad(self.fi.width, self.fi.height);
fs.rec.pad(fi.width, fi.height);
// TODO avoid the clone by having rec Arc.
let rec = if self.fi.show_frame { Some(fs.rec.clone()) } else { None };
let rec = if fi.show_frame { Some(fs.rec.clone()) } else { None };
update_rec_buffer(&mut self.fi, fs);
update_rec_buffer(fi, fs);
if self.fi.show_frame {
if fi.show_frame {
let data = self.packet_data.clone();
self.packet_data = Vec::new();
let mut psnr = None;
if self.fi.config.show_psnr {
if self.config.enc.show_psnr {
if let Some(ref rec) = rec {
psnr = Some(calculate_frame_psnr(&*frame, rec, self.fi.sequence.bit_depth));
psnr = Some(calculate_frame_psnr(&*frame, rec, fi.sequence.bit_depth));
}
}
Ok(Packet { data, rec, number: self.fi.number, frame_type: self.fi.frame_type, psnr })
self.frames_processed += 1;
Ok(Packet { data, rec, number: fi.number, frame_type: fi.frame_type, psnr })
} else {
Err(EncoderStatus::NeedMoreData)
}
......@@ -569,46 +603,56 @@ impl Context {
}
}
pub fn garbage_collect(&mut self, cur_frame: u64) {
if cur_frame == 0 {
return;
}
for i in 0..cur_frame {
self.frame_q.remove(&i);
}
if self.idx < 2 {
return;
}
for i in 0..(self.idx - 1) {
self.frame_data.remove(&i);
}
}
pub fn flush(&mut self) {
self.frame_q.insert(self.frame_count, None);
self.frame_count = self.frame_count + 1;
}
fn save_frame_type(&mut self, idx: u64) {
let frame_type = self.determine_frame_type(idx);
self.frame_types.insert(idx, frame_type);
}
fn determine_frame_type(&mut self, idx: u64) -> FrameType {
if idx == 0 {
fn determine_frame_type(&mut self, frame_number: u64) -> FrameType {
if frame_number == 0 {
return FrameType::KEY;
}
let prev_keyframe = *self.frame_types.iter().rfind(|(_, &ty)| ty == FrameType::KEY).unwrap().0;
let frame = self.frame_q.get(&idx).cloned().unwrap();
let prev_keyframe = self.keyframes.iter()
.rfind(|&&keyframe| keyframe < frame_number)
.cloned()
.unwrap_or(0);
let frame = match self.frame_q.get(&frame_number).cloned() {
Some(frame) => frame,
None => { return FrameType::KEY; }
};
if let Some(frame) = frame {
let distance = idx - prev_keyframe;
if distance < self.fi.config.min_key_frame_interval {
if distance + 1 == self.fi.config.min_key_frame_interval {
let distance = frame_number - prev_keyframe;
if distance < self.config.enc.min_key_frame_interval {
if distance + 1 == self.config.enc.min_key_frame_interval {
// Run the detector for the current frame, so that it will contain this frame's information
// to compare against the next frame. We can ignore the results for this frame.
self.keyframe_detector.detect_scene_change(frame, idx as usize);
self.keyframe_detector.detect_scene_change(frame, frame_number as usize);
}
return FrameType::INTER;
}
if distance >= self.fi.config.max_key_frame_interval {
if distance >= self.config.enc.max_key_frame_interval {
return FrameType::KEY;
}
if self.keyframe_detector.detect_scene_change(frame, idx as usize) {
if self.keyframe_detector.detect_scene_change(frame, frame_number as usize) {
return FrameType::KEY;
}
}
FrameType::INTER
}
}
impl fmt::Display for Context {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Frame {} - {}", self.fi.number, self.fi.frame_type)
}
}
......@@ -17,6 +17,7 @@ use std::io::prelude::*;
use std::sync::Arc;
use std::time::Instant;
use y4m;
use y4m::Colorspace;
pub struct EncoderIO {
pub input: Box<dyn Read>,
......@@ -257,6 +258,80 @@ impl fmt::Display for FrameSummary {
}
}
fn read_frame_batch(ctx: &mut Context, y4m_dec: &mut y4m::Decoder<'_, Box<dyn Read>>, y4m_details: Y4MDetails) {
loop {
if ctx.needs_more_lookahead() {
match y4m_dec.read_frame() {
Ok(y4m_frame) => {
let y4m_y = y4m_frame.get_y_plane();
let y4m_u = y4m_frame.get_u_plane();
let y4m_v = y4m_frame.get_v_plane();
let mut input = ctx.new_frame();
{
let input = Arc::get_mut(&mut input).unwrap();
input.planes[0].copy_from_raw_u8(&y4m_y, y4m_details.width * y4m_details.bytes, y4m_details.bytes);
input.planes[1].copy_from_raw_u8(
&y4m_u,
y4m_details.width * y4m_details.bytes / 2,
y4m_details.bytes
);
input.planes[2].copy_from_raw_u8(
&y4m_v,
y4m_details.width * y4m_details.bytes / 2,
y4m_details.bytes
);
}
match y4m_details.bits {
8 | 10 | 12 => {}
_ => panic!("unknown input bit depth!")
}
let _ = ctx.send_frame(Some(input));
continue;
}
_ => {
let frames_to_be_coded = ctx.get_frame_count();
ctx.set_frames_to_be_coded(frames_to_be_coded);
ctx.flush();
}
}
} else if !ctx.needs_more_frames(ctx.get_frame_count()) {
ctx.flush();
}
break;
}
}
#[derive(Debug, Clone, Copy)]
struct Y4MDetails {
width: usize,
height: usize,
bits: usize,
bytes: usize,
color_space: Colorspace,
bit_depth: usize,
}
impl Y4MDetails {
fn new(y4m_dec: &mut y4m::Decoder<'_, Box<dyn Read>>) -> Self {
let width = y4m_dec.get_width();
let height = y4m_dec.get_height();
let bits = y4m_dec.get_bit_depth();
let bytes = y4m_dec.get_bytes_per_sample();
let color_space = y4m_dec.get_colorspace();
let bit_depth = color_space.get_bit_depth();
Y4MDetails {
width,
height,
bits,
bytes,
color_space,
bit_depth,
}
}
}
// Encode and write a frame.
// Returns frame information in a `Result`.
pub fn process_frame(
......@@ -264,142 +339,94 @@ pub fn process_frame(
y4m_dec: &mut y4m::Decoder<'_, Box<dyn Read>>,
mut y4m_enc: Option<&mut y4m::Encoder<'_, Box<dyn Write>>>
) -> Result<Vec<FrameSummary>, ()> {
let width = y4m_dec.get_width();
let height = y4m_dec.get_height();
let y4m_bits = y4m_dec.get_bit_depth();
let y4m_bytes = y4m_dec.get_bytes_per_sample();
let y4m_color_space = y4m_dec.get_colorspace();
let bit_depth = y4m_color_space.get_bit_depth();
if ctx.needs_more_frames(ctx.get_frame_count()) {
match y4m_dec.read_frame() {
Ok(y4m_frame) => {
let y4m_y = y4m_frame.get_y_plane();
let y4m_u = y4m_frame.get_u_plane();
let y4m_v = y4m_frame.get_v_plane();
let mut input = ctx.new_frame();
{
let input = Arc::get_mut(&mut input).unwrap();
input.planes[0].copy_from_raw_u8(&y4m_y, width * y4m_bytes, y4m_bytes);
input.planes[1].copy_from_raw_u8(
&y4m_u,
width * y4m_bytes / 2,
y4m_bytes
);
input.planes[2].copy_from_raw_u8(
&y4m_v,
width * y4m_bytes / 2,
y4m_bytes
);
}
match y4m_bits {
8 | 10 | 12 => {}
_ => panic!("unknown input bit depth!")
}
let _ = ctx.send_frame(input);
}
_ => {
let frames_to_be_coded = ctx.get_frame_count();
ctx.set_frames_to_be_coded(frames_to_be_coded);
ctx.flush();
}
}
} else {
ctx.flush();
};
let y4m_details = Y4MDetails::new(y4m_dec);
let mut frame_summaries = Vec::new();
loop {
let pkt_wrapped = ctx.receive_packet();
match pkt_wrapped {
Ok(pkt) => {
write_ivf_frame(output_file, pkt.number as u64, pkt.data.as_ref());
if let Some(y4m_enc_uw) = y4m_enc.as_mut() {
if let Some(ref rec) = pkt.rec {
let pitch_y = if bit_depth > 8 { width * 2 } else { width };
let chroma_sampling_period = map_y4m_color_space(y4m_color_space).0.sampling_period();
let (pitch_uv, height_uv) = (
pitch_y / chroma_sampling_period.0,
height / chroma_sampling_period.1
);
let (mut rec_y, mut rec_u, mut rec_v) = (
vec![128u8; pitch_y * height],
vec![128u8; pitch_uv * height_uv],
vec![128u8; pitch_uv * height_uv]
);
let (stride_y, stride_u, stride_v) = (
rec.planes[0].cfg.stride,
rec.planes[1].cfg.stride,
rec.planes[2].cfg.stride
);
for (line, line_out) in rec.planes[0]
.data_origin()
.chunks(stride_y)
.zip(rec_y.chunks_mut(pitch_y))
{
if bit_depth > 8 {
unsafe {
line_out.copy_from_slice(slice::from_raw_parts::<u8>(
line.as_ptr() as (*const u8),
pitch_y
));
}
} else {
line_out.copy_from_slice(
&line.iter().map(|&v| v as u8).collect::<Vec<u8>>()[..pitch_y]
);
}
read_frame_batch(ctx, y4m_dec, y4m_details);
let pkt_wrapped = ctx.receive_packet();
if let Ok(pkt) = pkt_wrapped {
write_ivf_frame(output_file, pkt.number as u64, pkt.data.as_ref());
if let Some(y4m_enc_uw) = y4m_enc.as_mut() {
if let Some(ref rec) = pkt.rec {
let pitch_y = if y4m_details.bit_depth > 8 { y4m_details.width * 2 } else { y4m_details.width };
let chroma_sampling_period = map_y4m_color_space(y4m_details.color_space).0.sampling_period();
let (pitch_uv, height_uv) = (
pitch_y / chroma_sampling_period.0,
y4m_details.height / chroma_sampling_period.1
);
let (mut rec_y, mut rec_u, mut rec_v) = (
vec![128u8; pitch_y * y4m_details.height],
vec![128u8; pitch_uv * height_uv],
vec![128u8; pitch_uv * height_uv]
);
let (stride_y, stride_u, stride_v) = (
rec.planes[0].cfg.stride,
rec.planes[1].cfg.stride,
rec.planes[2].cfg.stride
);
for (line, line_out) in rec.planes[0]
.data_origin()
.chunks(stride_y)
.zip(rec_y.chunks_mut(pitch_y))
{
if y4m_details.bit_depth > 8 {
unsafe {
line_out.copy_from_slice(slice::from_raw_parts::<u8>(
line.as_ptr() as (*const u8),
pitch_y
));
}
for (line, line_out) in rec.planes[1]
.data_origin()
.chunks(stride_u)
.zip(rec_u.chunks_mut(pitch_uv))
{
if bit_depth > 8 {
unsafe {
line_out.copy_from_slice(slice::from_raw_parts::<u8>(
line.as_ptr() as (*const u8),
pitch_uv
));
}
} else {
line_out.copy_from_slice(
&line.iter().map(|&v| v as u8).collect::<Vec<u8>>()[..pitch_uv]
);
}
} else {
line_out.copy_from_slice(
&line.iter().map(|&v| v as u8).collect::<Vec<u8>>()[..pitch_y]
);
}
}
for (line, line_out) in rec.planes[1]
.data_origin()
.chunks(stride_u)
.zip(rec_u.chunks_mut(pitch_uv))
{
if y4m_details.bit_depth > 8 {
unsafe {
line_out.copy_from_slice(slice::from_raw_parts::<u8>(
line.as_ptr() as (*const u8),
pitch_uv
));
}
for (line, line_out) in rec.planes[2]
.data_origin()
.chunks(stride_v)
.zip(rec_v.chunks_mut(pitch_uv))
{
if bit_depth > 8 {
unsafe {
line_out.copy_from_slice(slice::from_raw_parts::<u8>(
line.as_ptr() as (*const u8),
pitch_uv
));
}
} else {
line_out.copy_from_slice(
&line.iter().map(|&v| v as u8).collect::<Vec<u8>>()[..pitch_uv]
);
}
} else {
line_out.copy_from_slice(
&line.iter().map(|&v| v as u8).collect::<Vec<u8>>()[..pitch_uv]
);
}