Commit 85e88293 authored by Thomas Daede's avatar Thomas Daede Committed by Thomas Daede

Add tx_dist based rate estimate.

parent 3e7c7c5d
......@@ -35,6 +35,7 @@ aom-sys = { version = "0.1.0", optional = true }
scan_fmt = { version = "0.1.3", optional = true }
ivf = { path = "ivf/", optional = true }
rayon = "1.0"
bincode = "=1.0.1"
[target.'cfg(target_arch = "x86_64")'.build-dependencies]
nasm-rs = { version = "0.1.4", optional = true }
......
......@@ -78,6 +78,7 @@ pub struct EncoderConfig {
pub pass: Option<u8>,
pub show_psnr: bool,
pub stats_file: Option<PathBuf>,
pub train_rdo: bool,
}
impl Default for EncoderConfig {
......@@ -112,6 +113,7 @@ impl EncoderConfig {
pass: None,
show_psnr: false,
stats_file: None,
train_rdo: false
}
}
}
......@@ -123,6 +125,7 @@ pub struct SpeedSettings {
pub fast_deblock: bool,
pub reduced_tx_set: bool,
pub tx_domain_distortion: bool,
pub tx_domain_rate: bool,
pub encode_bottomup: bool,
pub rdo_tx_decision: bool,
pub prediction_modes: PredictionModesSetting,
......@@ -139,6 +142,7 @@ impl Default for SpeedSettings {
fast_deblock: false,
reduced_tx_set: false,
tx_domain_distortion: false,
tx_domain_rate: false,
encode_bottomup: false,
rdo_tx_decision: false,
prediction_modes: PredictionModesSetting::Simple,
......@@ -157,6 +161,7 @@ impl SpeedSettings {
fast_deblock: Self::fast_deblock_preset(speed),
reduced_tx_set: Self::reduced_tx_set_preset(speed),
tx_domain_distortion: Self::tx_domain_distortion_preset(speed),
tx_domain_rate: Self::tx_domain_rate_preset(speed),
encode_bottomup: Self::encode_bottomup_preset(speed),
rdo_tx_decision: Self::rdo_tx_decision_preset(speed),
prediction_modes: Self::prediction_modes_preset(speed),
......@@ -199,6 +204,10 @@ impl SpeedSettings {
true
}
fn tx_domain_rate_preset(_speed: usize) -> bool {
false
}
fn encode_bottomup_preset(speed: usize) -> bool {
speed == 0
}
......
......@@ -245,6 +245,10 @@ pub fn parse_cli() -> CliOptions {
.long("speed-test")
.takes_value(true)
)
.arg(
Arg::with_name("train-rdo")
.long("train-rdo")
)
.subcommand(SubCommand::with_name("advanced")
.setting(AppSettings::Hidden)
.about("Advanced features")
......@@ -313,6 +317,7 @@ fn parse_config(matches: &ArgMatches<'_>) -> EncoderConfig {
}
});
let bitrate = maybe_bitrate.unwrap_or(0);
let train_rdo = matches.is_present("train-rdo");
if quantizer == 0 {
unimplemented!("Lossless encoding not yet implemented");
} else if quantizer > 255 {
......@@ -422,7 +427,7 @@ fn parse_config(matches: &ArgMatches<'_>) -> EncoderConfig {
};
cfg.tune = matches.value_of("TUNE").unwrap().parse().unwrap();
cfg.low_latency = matches.value_of("LOW_LATENCY").unwrap().parse().unwrap();
cfg.train_rdo = train_rdo;
cfg
}
......@@ -455,6 +460,9 @@ fn apply_speed_test_cfg(cfg: &mut EncoderConfig, setting: &str) {
"tx_domain_distortion" => {
cfg.speed_settings.tx_domain_distortion = true;
},
"tx_domain_rate" => {
cfg.speed_settings.tx_domain_rate = true;
},
"encode_bottomup" => {
cfg.speed_settings.encode_bottomup = true;
},
......
......@@ -78,6 +78,8 @@ pub trait Writer {
fn checkpoint(&mut self) -> WriterCheckpoint;
/// Restore saved position in coding/recording from a checkpoint
fn rollback(&mut self, _: &WriterCheckpoint);
/// Add additional bits from rate estimators without coding a real symbol
fn add_bits_frac(&mut self, bits_frac: u32);
}
/// StorageBackend is an internal trait used to tie a specific Writer
......@@ -103,6 +105,9 @@ pub struct WriterBase<S> {
cnt: i16,
/// Debug enable flag
debug: bool,
/// Extra offset added to tell() and tell_frac() to approximate costs
/// of actually coding a symbol
fake_bits_frac: u32,
/// Use-specific storage
s: S
}
......@@ -298,7 +303,7 @@ impl<S> WriterBase<S> {
/// Internal constructor called by the subtypes that implement the
/// actual encoder and Recorder.
fn new(storage: S) -> Self {
WriterBase { rng: 0x8000, cnt: -9, debug: false, s: storage }
WriterBase { rng: 0x8000, cnt: -9, debug: false, fake_bits_frac: 0, s: storage }
}
/// Compute low and range values from token cdf values and local state
......@@ -478,6 +483,10 @@ where
fn bit(&mut self, bit: u16) {
self.bool(bit == 1, 16384);
}
// fake add bits
fn add_bits_frac(&mut self, bits_frac: u32) {
self.fake_bits_frac += bits_frac
}
/// Encode a literal bitstring, bit by bit in MSB order, with flat
/// probability.
/// 'bits': Length of bitstring
......@@ -721,7 +730,7 @@ where
fn tell(&mut self) -> u32 {
// The 10 here counteracts the offset of -9 baked into cnt, and adds 1 extra
// bit, which we reserve for terminating the stream.
(((self.stream_bytes() * 8) as i32) + (self.cnt as i32) + 10) as u32
(((self.stream_bytes() * 8) as i32) + (self.cnt as i32) + 10) as u32 + (self.fake_bits_frac >> 8)
}
/// Returns the number of bits "used" by the encoded symbols so far.
/// This same number can be computed in either the encoder or the
......@@ -731,7 +740,7 @@ where
/// This will always be slightly larger than the exact value (e.g., all
/// rounding error is in the positive direction).
fn tell_frac(&mut self) -> u32 {
Self::frac_compute(self.tell(), self.rng as u32)
Self::frac_compute(self.tell(), self.rng as u32) + self.fake_bits_frac
}
/// Save current point in coding/recording to a checkpoint that can
/// be restored later. A WriterCheckpoint can be generated for an
......
......@@ -29,10 +29,13 @@ use crate::partition::PartitionType::*;
use crate::header::*;
use bitstream_io::{BitWriter, BigEndian};
use bincode::{serialize, deserialize};
use std;
use std::{fmt, io, mem};
use std::io::Write;
use std::io::Read;
use std::sync::Arc;
use std::fs::File;
extern {
pub fn av1_rtcd();
......@@ -396,6 +399,7 @@ pub struct FrameState<T: Pixel> {
pub segmentation: SegmentationState,
pub restoration: RestorationState,
pub frame_mvs: Vec<Vec<MotionVector>>,
pub t: RDOTracker,
}
impl<T: Pixel> FrameState<T> {
......@@ -422,7 +426,8 @@ impl<T: Pixel> FrameState<T> {
deblock: Default::default(),
segmentation: Default::default(),
restoration: rs,
frame_mvs: vec![vec![MotionVector::default(); fi.w_in_b * fi.h_in_b]; REF_FRAMES]
frame_mvs: vec![vec![MotionVector::default(); fi.w_in_b * fi.h_in_b]; REF_FRAMES],
t: RDOTracker::new()
}
}
}
......@@ -538,6 +543,7 @@ pub struct FrameInvariants<T: Pixel> {
pub me_lambda: f64,
pub me_range_scale: u8,
pub use_tx_domain_distortion: bool,
pub use_tx_domain_rate: bool,
pub inter_cfg: Option<InterPropsConfig>,
pub enable_early_exit: bool,
}
......@@ -562,6 +568,7 @@ impl<T: Pixel> FrameInvariants<T> {
let min_partition_size = config.speed_settings.min_block_size;
let use_reduced_tx_set = config.speed_settings.reduced_tx_set;
let use_tx_domain_distortion = config.tune == Tune::Psnr && config.speed_settings.tx_domain_distortion;
let use_tx_domain_rate = config.speed_settings.tx_domain_rate;
let w_in_b = 2 * config.width.align_power_of_two_and_shift(3); // MiCols, ((width+7)/8)<<3 >> MI_SIZE_LOG2
let h_in_b = 2 * config.height.align_power_of_two_and_shift(3); // MiRows, ((height+7)/8)<<3 >> MI_SIZE_LOG2
......@@ -617,6 +624,7 @@ impl<T: Pixel> FrameInvariants<T> {
me_lambda: 0.0,
me_range_scale: 1,
use_tx_domain_distortion,
use_tx_domain_rate,
inter_cfg: None,
enable_early_exit: true,
config,
......@@ -981,9 +989,14 @@ pub fn encode_tx_block<T: Pixel>(
let coded_tx_size = av1_get_coded_tx_size(tx_size).area();
fs.qc.quantize(coeffs, qcoeffs, coded_tx_size);
let has_coeff = cw.write_coeffs_lv_map(w, p, bo, &qcoeffs, mode, tx_size, tx_type, plane_bsize, xdec, ydec,
fi.use_reduced_tx_set);
let tell_coeffs = w.tell_frac();
let has_coeff = if !for_rdo_use || rdo_type.needs_coeff_rate() {
cw.write_coeffs_lv_map(w, p, bo, &qcoeffs, mode, tx_size, tx_type, plane_bsize, xdec, ydec,
fi.use_reduced_tx_set)
} else {
true
};
let cost_coeffs = w.tell_frac() - tell_coeffs;
// Reconstruct
dequantize(qidx, qcoeffs, rcoeffs, tx_size, fi.sequence.bit_depth, fi.dc_delta_q[p], fi.ac_delta_q[p]);
......@@ -1006,6 +1019,15 @@ pub fn encode_tx_block<T: Pixel>(
let tx_dist_scale_rounding_offset = 1 << (tx_dist_scale_bits - 1);
tx_dist = (tx_dist + tx_dist_scale_rounding_offset) >> tx_dist_scale_bits;
}
if fi.config.train_rdo {
fs.t.add_rate(fi.base_q_idx, tx_size, tx_dist as u64, cost_coeffs as u64);
}
if rdo_type == RDOType::TxDistEstRate {
// look up rate and distortion in table
let estimated_rate = estimate_rate(fi.base_q_idx, tx_size, tx_dist as u64);
w.add_bits_frac(estimated_rate as u32);
}
(has_coeff, tx_dist)
}
......@@ -1446,7 +1468,8 @@ pub fn write_tx_tree<T: Pixel>(
pub fn encode_block_with_modes<T: Pixel>(
fi: &FrameInvariants<T>, fs: &mut FrameState<T>,
cw: &mut ContextWriter, w_pre_cdef: &mut dyn Writer, w_post_cdef: &mut dyn Writer,
bsize: BlockSize, bo: &BlockOffset, mode_decision: &RDOPartitionOutput
bsize: BlockSize, bo: &BlockOffset, mode_decision: &RDOPartitionOutput,
rdo_type: RDOType
) {
let (mode_luma, mode_chroma) =
(mode_decision.pred_mode_luma, mode_decision.pred_mode_chroma);
......@@ -1469,7 +1492,7 @@ pub fn encode_block_with_modes<T: Pixel>(
bsize, bo, skip);
encode_block_b(fi, fs, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef},
mode_luma, mode_chroma, ref_frames, mvs, bsize, bo, skip, cfl,
tx_size, tx_type, mode_context, &mv_stack, RDOType::PixelDistRealRate, false);
tx_size, tx_type, mode_context, &mv_stack, rdo_type, false);
}
fn encode_partition_bottomup<T: Pixel>(
......@@ -1478,6 +1501,7 @@ fn encode_partition_bottomup<T: Pixel>(
bo: &BlockOffset, pmvs: &[[Option<MotionVector>; REF_FRAMES]; 5],
ref_rd_cost: f64
) -> (RDOOutput) {
let rdo_type = RDOType::PixelDistRealRate;
let mut rd_cost = std::f64::MAX;
let mut best_rd = std::f64::MAX;
let mut rdo_output = RDOOutput {
......@@ -1536,7 +1560,7 @@ fn encode_partition_bottomup<T: Pixel>(
if !can_split {
encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef, bsize, bo,
&mode_decision);
&mode_decision, rdo_type);
}
}
......@@ -1652,7 +1676,7 @@ fn encode_partition_bottomup<T: Pixel>(
let offset = mode.bo.clone();
// FIXME: redundant block re-encode
encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef,
mode.bsize, &offset, &mode);
mode.bsize, &offset, &mode, rdo_type);
}
}
}
......@@ -1686,6 +1710,7 @@ fn encode_partition_topdown<T: Pixel>(
let bsw = bsize.width_mi();
let bsh = bsize.height_mi();
let is_square = bsize.is_sqr();
let rdo_type = RDOType::PixelDistRealRate;
// Always split if the current partition is too large
let must_split = (bo.x + bsw as usize > fi.w_in_b ||
......@@ -1726,7 +1751,7 @@ fn encode_partition_topdown<T: Pixel>(
partition_types.push(PartitionType::PARTITION_SPLIT);
}
rdo_output = rdo_partition_decision(fi, fs, cw,
w_pre_cdef, w_post_cdef, bsize, bo, &rdo_output, pmvs, &partition_types);
w_pre_cdef, w_post_cdef, bsize, bo, &rdo_output, pmvs, &partition_types, rdo_type);
partition = rdo_output.part_type;
} else {
// Blocks of sizes below the supported range are encoded directly
......@@ -2059,7 +2084,6 @@ fn encode_tile<T: Pixel>(fi: &FrameInvariants<T>, fs: &mut FrameState<T>) -> Vec
if fs.deblock.levels[0] != 0 || fs.deblock.levels[1] != 0 {
deblock_filter_frame(fs, &mut cw.bc, fi.sequence.bit_depth);
}
{
// Until the loop filters are pipelined, we'll need to keep
// around a copy of both the pre- and post-cdef frame.
let pre_cdef_frame = fs.rec.clone();
......@@ -2072,6 +2096,17 @@ fn encode_tile<T: Pixel>(fi: &FrameInvariants<T>, fs: &mut FrameState<T>) -> Vec
if fi.sequence.enable_restoration {
fs.restoration.lrf_filter_frame(&mut fs.rec, &pre_cdef_frame, &fi);
}
if fi.config.train_rdo {
eprintln!("train rdo");
if let Ok(mut file) = File::open("rdo.dat") {
let mut data = vec![];
file.read_to_end(&mut data).unwrap();
fs.t.merge_in(&deserialize(data.as_slice()).unwrap());
}
let mut rdo_file = File::create("rdo.dat").unwrap();
rdo_file.write_all(&serialize(&fs.t).unwrap()).unwrap();
fs.t.print_code();
}
fs.cdfs = cw.fc;
......
......@@ -11,6 +11,7 @@
#[macro_use]
extern crate serde_derive;
extern crate bincode;
#[cfg(all(test, feature="decode_test_dav1d"))]
extern crate dav1d_sys;
......@@ -30,6 +31,7 @@ pub mod transform;
pub mod quantize;
pub mod predict;
pub mod rdo;
pub mod rdo_tables;
#[macro_use]
pub mod util;
pub mod context;
......
......@@ -32,16 +32,19 @@ use crate::Tune;
use crate::write_tx_blocks;
use crate::write_tx_tree;
use crate::util::{CastFromPrimitive, Pixel};
use crate::rdo_tables::*;
use std;
use std::cmp;
use std::vec::Vec;
use crate::partition::PartitionType::*;
#[derive(Copy,Clone)]
#[derive(Copy,Clone,PartialEq)]
pub enum RDOType {
PixelDistRealRate,
TxDistRealRate
TxDistRealRate,
TxDistEstRate,
Train
}
impl RDOType {
......@@ -50,7 +53,18 @@ impl RDOType {
// Pixel-domain distortion and exact ec rate
RDOType::PixelDistRealRate => false,
// Tx-domain distortion and exact ec rate
RDOType::TxDistRealRate => true
RDOType::TxDistRealRate => true,
// Tx-domain distortion and txdist-based rate
RDOType::TxDistEstRate => true,
RDOType::Train => true,
}
}
pub fn needs_coeff_rate(self) -> bool {
match self {
RDOType::PixelDistRealRate => true,
RDOType::TxDistRealRate => true,
RDOType::TxDistEstRate => false,
RDOType::Train => true,
}
}
}
......@@ -77,6 +91,86 @@ pub struct RDOPartitionOutput {
pub tx_type: TxType,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct RDOTracker {
rate_bins: Vec<Vec<Vec<u64>>>,
rate_counts: Vec<Vec<Vec<u64>>>,
}
impl RDOTracker {
pub fn new() -> RDOTracker {
RDOTracker {
rate_bins: vec![vec![vec![0; RDO_NUM_BINS]; TxSize::TX_SIZES_ALL]; RDO_QUANT_BINS],
rate_counts: vec![vec![vec![0; RDO_NUM_BINS]; TxSize::TX_SIZES_ALL]; RDO_QUANT_BINS],
}
}
fn merge_array(new: &mut Vec<u64>, old: &[u64]) {
for (n, o) in new.iter_mut().zip(old.iter()) {
*n += o;
}
}
fn merge_2d_array(new: &mut Vec<Vec<u64>>, old: &[Vec<u64>]) {
for (n, o) in new.iter_mut().zip(old.iter()) {
RDOTracker::merge_array(n, o);
}
}
fn merge_3d_array(new: &mut Vec<Vec<Vec<u64>>>, old: &[Vec<Vec<u64>>]) {
for (n, o) in new.iter_mut().zip(old.iter()) {
RDOTracker::merge_2d_array(n, o);
}
}
pub fn merge_in(&mut self, input: &RDOTracker) {
RDOTracker::merge_3d_array(&mut self.rate_bins, &input.rate_bins);
RDOTracker::merge_3d_array(&mut self.rate_counts, &input.rate_counts);
}
pub fn add_rate(&mut self, qindex: u8, ts: TxSize, fast_distortion: u64, rate: u64) {
if fast_distortion != 0 {
let bs_index = ts as usize;
let q_bin_idx = (qindex as usize)/RDO_QUANT_DIV;
let bin_idx_tmp = ((fast_distortion as i64 - (RATE_EST_BIN_SIZE as i64) / 2) as u64 / RATE_EST_BIN_SIZE) as usize;
let bin_idx = if bin_idx_tmp >= RDO_NUM_BINS {
RDO_NUM_BINS - 1
} else {
bin_idx_tmp
};
self.rate_counts[q_bin_idx][bs_index][bin_idx] += 1;
self.rate_bins[q_bin_idx][bs_index][bin_idx] += rate;
}
}
pub fn print_code(&self) {
println!("pub static RDO_RATE_TABLE: [[[u64; RDO_NUM_BINS]; TxSize::TX_SIZES_ALL]; RDO_QUANT_BINS] = [");
for q_bin in 0..RDO_QUANT_BINS {
print!("[");
for bs_index in 0..TxSize::TX_SIZES_ALL {
print!("[");
for (rate_total, rate_count) in self.rate_bins[q_bin][bs_index].iter().zip(self.rate_counts[q_bin][bs_index].iter()) {
if *rate_count > 100 {
print!("{},", rate_total / rate_count);
} else {
print!("99999,");
}
}
println!("],");
}
println!("],");
}
println!("];");
}
}
pub fn estimate_rate(qindex: u8, ts: TxSize, fast_distortion: u64) -> u64 {
let bs_index = ts as usize;
let q_bin_idx = (qindex as usize)/RDO_QUANT_DIV;
let bin_idx_down = ((fast_distortion) / RATE_EST_BIN_SIZE).min((RDO_NUM_BINS - 2) as u64);
let bin_idx_up = (bin_idx_down + 1).min((RDO_NUM_BINS - 1) as u64);
let x0 = (bin_idx_down * RATE_EST_BIN_SIZE) as i64;
let x1 = (bin_idx_up * RATE_EST_BIN_SIZE) as i64;
let y0 = RDO_RATE_TABLE[q_bin_idx][bs_index][bin_idx_down as usize] as i64;
let y1 = RDO_RATE_TABLE[q_bin_idx][bs_index][bin_idx_up as usize] as i64;
let slope = ((y1 - y0) << 8) / (x1 - x0);
(y0 + (((fast_distortion as i64 - x0) * slope) >> 8)).max(0) as u64
}
#[allow(unused)]
fn cdef_dist_wxh_8x8<T: Pixel>(
src1: &PlaneSlice<'_, T>, src2: &PlaneSlice<'_, T>, bit_depth: usize
......@@ -370,7 +464,11 @@ pub fn rdo_mode_decision<T: Pixel>(
let mut fwdref = None;
let mut bwdref = None;
let rdo_type = if fi.use_tx_domain_distortion {
let rdo_type = if fi.config.train_rdo {
RDOType::Train
} else if fi.use_tx_domain_rate {
RDOType::TxDistEstRate
} else if fi.use_tx_domain_distortion {
RDOType::TxDistRealRate
} else {
RDOType::PixelDistRealRate
......@@ -722,8 +820,8 @@ pub fn rdo_mode_decision<T: Pixel>(
let wr: &mut dyn Writer = &mut WriterCounter::new();
let tell = wr.tell_frac();
encode_block_a(&fi.sequence, fs, cw, wr, bsize, bo, best.skip);
encode_block_b(
encode_block_a(&fi.sequence, fs, cw, wr, bsize, bo, best.skip);
let _ = encode_block_b(
fi,
fs,
cw,
......@@ -996,7 +1094,7 @@ pub fn rdo_partition_decision<T: Pixel>(
cw: &mut ContextWriter, w_pre_cdef: &mut dyn Writer, w_post_cdef: &mut dyn Writer,
bsize: BlockSize, bo: &BlockOffset,
cached_block: &RDOOutput, pmvs: &[[Option<MotionVector>; REF_FRAMES]; 5],
partition_types: &[PartitionType],
partition_types: &[PartitionType], rdo_type: RDOType
) -> RDOOutput {
let mut best_partition = cached_block.part_type;
let mut best_rd = cached_block.rd_cost;
......@@ -1089,7 +1187,7 @@ pub fn rdo_partition_decision<T: Pixel>(
cw.write_partition(w, offset, PartitionType::PARTITION_NONE, subsize);
}
encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef, subsize,
offset, &mode_decision);
offset, &mode_decision, rdo_type);
child_modes.push(mode_decision);
}
}
......@@ -1335,3 +1433,8 @@ pub fn rdo_loop_decision<T: Pixel>(sbo: &SuperBlockOffset, fi: &FrameInvariants<
}
}
}
#[test]
fn estimate_rate_test() {
assert_eq!(estimate_rate(0, TxSize::TX_4X4, 0), RDO_RATE_TABLE[0][0][0]);
}
This diff is collapsed.
#!/bin/bash
set -e
rm -f rdo.dat
cargo build --release
ls ~/sets/objective-1-fast/*.y4m | parallel target/release/rav1e --threads 1 --quantizer {2} -o /dev/null --train-rdo {1} :::: - ::: 16 48 80 112 144 176 208 240
gnuplot rdo.plt -p
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment