Commit 85e88293 authored by Thomas Daede's avatar Thomas Daede Committed by Thomas Daede

Add tx_dist based rate estimate.

parent 3e7c7c5d
...@@ -35,6 +35,7 @@ aom-sys = { version = "0.1.0", optional = true } ...@@ -35,6 +35,7 @@ aom-sys = { version = "0.1.0", optional = true }
scan_fmt = { version = "0.1.3", optional = true } scan_fmt = { version = "0.1.3", optional = true }
ivf = { path = "ivf/", optional = true } ivf = { path = "ivf/", optional = true }
rayon = "1.0" rayon = "1.0"
bincode = "=1.0.1"
[target.'cfg(target_arch = "x86_64")'.build-dependencies] [target.'cfg(target_arch = "x86_64")'.build-dependencies]
nasm-rs = { version = "0.1.4", optional = true } nasm-rs = { version = "0.1.4", optional = true }
......
...@@ -78,6 +78,7 @@ pub struct EncoderConfig { ...@@ -78,6 +78,7 @@ pub struct EncoderConfig {
pub pass: Option<u8>, pub pass: Option<u8>,
pub show_psnr: bool, pub show_psnr: bool,
pub stats_file: Option<PathBuf>, pub stats_file: Option<PathBuf>,
pub train_rdo: bool,
} }
impl Default for EncoderConfig { impl Default for EncoderConfig {
...@@ -112,6 +113,7 @@ impl EncoderConfig { ...@@ -112,6 +113,7 @@ impl EncoderConfig {
pass: None, pass: None,
show_psnr: false, show_psnr: false,
stats_file: None, stats_file: None,
train_rdo: false
} }
} }
} }
...@@ -123,6 +125,7 @@ pub struct SpeedSettings { ...@@ -123,6 +125,7 @@ pub struct SpeedSettings {
pub fast_deblock: bool, pub fast_deblock: bool,
pub reduced_tx_set: bool, pub reduced_tx_set: bool,
pub tx_domain_distortion: bool, pub tx_domain_distortion: bool,
pub tx_domain_rate: bool,
pub encode_bottomup: bool, pub encode_bottomup: bool,
pub rdo_tx_decision: bool, pub rdo_tx_decision: bool,
pub prediction_modes: PredictionModesSetting, pub prediction_modes: PredictionModesSetting,
...@@ -139,6 +142,7 @@ impl Default for SpeedSettings { ...@@ -139,6 +142,7 @@ impl Default for SpeedSettings {
fast_deblock: false, fast_deblock: false,
reduced_tx_set: false, reduced_tx_set: false,
tx_domain_distortion: false, tx_domain_distortion: false,
tx_domain_rate: false,
encode_bottomup: false, encode_bottomup: false,
rdo_tx_decision: false, rdo_tx_decision: false,
prediction_modes: PredictionModesSetting::Simple, prediction_modes: PredictionModesSetting::Simple,
...@@ -157,6 +161,7 @@ impl SpeedSettings { ...@@ -157,6 +161,7 @@ impl SpeedSettings {
fast_deblock: Self::fast_deblock_preset(speed), fast_deblock: Self::fast_deblock_preset(speed),
reduced_tx_set: Self::reduced_tx_set_preset(speed), reduced_tx_set: Self::reduced_tx_set_preset(speed),
tx_domain_distortion: Self::tx_domain_distortion_preset(speed), tx_domain_distortion: Self::tx_domain_distortion_preset(speed),
tx_domain_rate: Self::tx_domain_rate_preset(speed),
encode_bottomup: Self::encode_bottomup_preset(speed), encode_bottomup: Self::encode_bottomup_preset(speed),
rdo_tx_decision: Self::rdo_tx_decision_preset(speed), rdo_tx_decision: Self::rdo_tx_decision_preset(speed),
prediction_modes: Self::prediction_modes_preset(speed), prediction_modes: Self::prediction_modes_preset(speed),
...@@ -199,6 +204,10 @@ impl SpeedSettings { ...@@ -199,6 +204,10 @@ impl SpeedSettings {
true true
} }
fn tx_domain_rate_preset(_speed: usize) -> bool {
false
}
fn encode_bottomup_preset(speed: usize) -> bool { fn encode_bottomup_preset(speed: usize) -> bool {
speed == 0 speed == 0
} }
......
...@@ -245,6 +245,10 @@ pub fn parse_cli() -> CliOptions { ...@@ -245,6 +245,10 @@ pub fn parse_cli() -> CliOptions {
.long("speed-test") .long("speed-test")
.takes_value(true) .takes_value(true)
) )
.arg(
Arg::with_name("train-rdo")
.long("train-rdo")
)
.subcommand(SubCommand::with_name("advanced") .subcommand(SubCommand::with_name("advanced")
.setting(AppSettings::Hidden) .setting(AppSettings::Hidden)
.about("Advanced features") .about("Advanced features")
...@@ -313,6 +317,7 @@ fn parse_config(matches: &ArgMatches<'_>) -> EncoderConfig { ...@@ -313,6 +317,7 @@ fn parse_config(matches: &ArgMatches<'_>) -> EncoderConfig {
} }
}); });
let bitrate = maybe_bitrate.unwrap_or(0); let bitrate = maybe_bitrate.unwrap_or(0);
let train_rdo = matches.is_present("train-rdo");
if quantizer == 0 { if quantizer == 0 {
unimplemented!("Lossless encoding not yet implemented"); unimplemented!("Lossless encoding not yet implemented");
} else if quantizer > 255 { } else if quantizer > 255 {
...@@ -422,7 +427,7 @@ fn parse_config(matches: &ArgMatches<'_>) -> EncoderConfig { ...@@ -422,7 +427,7 @@ fn parse_config(matches: &ArgMatches<'_>) -> EncoderConfig {
}; };
cfg.tune = matches.value_of("TUNE").unwrap().parse().unwrap(); cfg.tune = matches.value_of("TUNE").unwrap().parse().unwrap();
cfg.low_latency = matches.value_of("LOW_LATENCY").unwrap().parse().unwrap(); cfg.low_latency = matches.value_of("LOW_LATENCY").unwrap().parse().unwrap();
cfg.train_rdo = train_rdo;
cfg cfg
} }
...@@ -455,6 +460,9 @@ fn apply_speed_test_cfg(cfg: &mut EncoderConfig, setting: &str) { ...@@ -455,6 +460,9 @@ fn apply_speed_test_cfg(cfg: &mut EncoderConfig, setting: &str) {
"tx_domain_distortion" => { "tx_domain_distortion" => {
cfg.speed_settings.tx_domain_distortion = true; cfg.speed_settings.tx_domain_distortion = true;
}, },
"tx_domain_rate" => {
cfg.speed_settings.tx_domain_rate = true;
},
"encode_bottomup" => { "encode_bottomup" => {
cfg.speed_settings.encode_bottomup = true; cfg.speed_settings.encode_bottomup = true;
}, },
......
...@@ -78,6 +78,8 @@ pub trait Writer { ...@@ -78,6 +78,8 @@ pub trait Writer {
fn checkpoint(&mut self) -> WriterCheckpoint; fn checkpoint(&mut self) -> WriterCheckpoint;
/// Restore saved position in coding/recording from a checkpoint /// Restore saved position in coding/recording from a checkpoint
fn rollback(&mut self, _: &WriterCheckpoint); fn rollback(&mut self, _: &WriterCheckpoint);
/// Add additional bits from rate estimators without coding a real symbol
fn add_bits_frac(&mut self, bits_frac: u32);
} }
/// StorageBackend is an internal trait used to tie a specific Writer /// StorageBackend is an internal trait used to tie a specific Writer
...@@ -103,6 +105,9 @@ pub struct WriterBase<S> { ...@@ -103,6 +105,9 @@ pub struct WriterBase<S> {
cnt: i16, cnt: i16,
/// Debug enable flag /// Debug enable flag
debug: bool, debug: bool,
/// Extra offset added to tell() and tell_frac() to approximate costs
/// of actually coding a symbol
fake_bits_frac: u32,
/// Use-specific storage /// Use-specific storage
s: S s: S
} }
...@@ -298,7 +303,7 @@ impl<S> WriterBase<S> { ...@@ -298,7 +303,7 @@ impl<S> WriterBase<S> {
/// Internal constructor called by the subtypes that implement the /// Internal constructor called by the subtypes that implement the
/// actual encoder and Recorder. /// actual encoder and Recorder.
fn new(storage: S) -> Self { fn new(storage: S) -> Self {
WriterBase { rng: 0x8000, cnt: -9, debug: false, s: storage } WriterBase { rng: 0x8000, cnt: -9, debug: false, fake_bits_frac: 0, s: storage }
} }
/// Compute low and range values from token cdf values and local state /// Compute low and range values from token cdf values and local state
...@@ -478,6 +483,10 @@ where ...@@ -478,6 +483,10 @@ where
fn bit(&mut self, bit: u16) { fn bit(&mut self, bit: u16) {
self.bool(bit == 1, 16384); self.bool(bit == 1, 16384);
} }
// fake add bits
fn add_bits_frac(&mut self, bits_frac: u32) {
self.fake_bits_frac += bits_frac
}
/// Encode a literal bitstring, bit by bit in MSB order, with flat /// Encode a literal bitstring, bit by bit in MSB order, with flat
/// probability. /// probability.
/// 'bits': Length of bitstring /// 'bits': Length of bitstring
...@@ -721,7 +730,7 @@ where ...@@ -721,7 +730,7 @@ where
fn tell(&mut self) -> u32 { fn tell(&mut self) -> u32 {
// The 10 here counteracts the offset of -9 baked into cnt, and adds 1 extra // The 10 here counteracts the offset of -9 baked into cnt, and adds 1 extra
// bit, which we reserve for terminating the stream. // bit, which we reserve for terminating the stream.
(((self.stream_bytes() * 8) as i32) + (self.cnt as i32) + 10) as u32 (((self.stream_bytes() * 8) as i32) + (self.cnt as i32) + 10) as u32 + (self.fake_bits_frac >> 8)
} }
/// Returns the number of bits "used" by the encoded symbols so far. /// Returns the number of bits "used" by the encoded symbols so far.
/// This same number can be computed in either the encoder or the /// This same number can be computed in either the encoder or the
...@@ -731,7 +740,7 @@ where ...@@ -731,7 +740,7 @@ where
/// This will always be slightly larger than the exact value (e.g., all /// This will always be slightly larger than the exact value (e.g., all
/// rounding error is in the positive direction). /// rounding error is in the positive direction).
fn tell_frac(&mut self) -> u32 { fn tell_frac(&mut self) -> u32 {
Self::frac_compute(self.tell(), self.rng as u32) Self::frac_compute(self.tell(), self.rng as u32) + self.fake_bits_frac
} }
/// Save current point in coding/recording to a checkpoint that can /// Save current point in coding/recording to a checkpoint that can
/// be restored later. A WriterCheckpoint can be generated for an /// be restored later. A WriterCheckpoint can be generated for an
......
...@@ -29,10 +29,13 @@ use crate::partition::PartitionType::*; ...@@ -29,10 +29,13 @@ use crate::partition::PartitionType::*;
use crate::header::*; use crate::header::*;
use bitstream_io::{BitWriter, BigEndian}; use bitstream_io::{BitWriter, BigEndian};
use bincode::{serialize, deserialize};
use std; use std;
use std::{fmt, io, mem}; use std::{fmt, io, mem};
use std::io::Write; use std::io::Write;
use std::io::Read;
use std::sync::Arc; use std::sync::Arc;
use std::fs::File;
extern { extern {
pub fn av1_rtcd(); pub fn av1_rtcd();
...@@ -396,6 +399,7 @@ pub struct FrameState<T: Pixel> { ...@@ -396,6 +399,7 @@ pub struct FrameState<T: Pixel> {
pub segmentation: SegmentationState, pub segmentation: SegmentationState,
pub restoration: RestorationState, pub restoration: RestorationState,
pub frame_mvs: Vec<Vec<MotionVector>>, pub frame_mvs: Vec<Vec<MotionVector>>,
pub t: RDOTracker,
} }
impl<T: Pixel> FrameState<T> { impl<T: Pixel> FrameState<T> {
...@@ -422,7 +426,8 @@ impl<T: Pixel> FrameState<T> { ...@@ -422,7 +426,8 @@ impl<T: Pixel> FrameState<T> {
deblock: Default::default(), deblock: Default::default(),
segmentation: Default::default(), segmentation: Default::default(),
restoration: rs, restoration: rs,
frame_mvs: vec![vec![MotionVector::default(); fi.w_in_b * fi.h_in_b]; REF_FRAMES] frame_mvs: vec![vec![MotionVector::default(); fi.w_in_b * fi.h_in_b]; REF_FRAMES],
t: RDOTracker::new()
} }
} }
} }
...@@ -538,6 +543,7 @@ pub struct FrameInvariants<T: Pixel> { ...@@ -538,6 +543,7 @@ pub struct FrameInvariants<T: Pixel> {
pub me_lambda: f64, pub me_lambda: f64,
pub me_range_scale: u8, pub me_range_scale: u8,
pub use_tx_domain_distortion: bool, pub use_tx_domain_distortion: bool,
pub use_tx_domain_rate: bool,
pub inter_cfg: Option<InterPropsConfig>, pub inter_cfg: Option<InterPropsConfig>,
pub enable_early_exit: bool, pub enable_early_exit: bool,
} }
...@@ -562,6 +568,7 @@ impl<T: Pixel> FrameInvariants<T> { ...@@ -562,6 +568,7 @@ impl<T: Pixel> FrameInvariants<T> {
let min_partition_size = config.speed_settings.min_block_size; let min_partition_size = config.speed_settings.min_block_size;
let use_reduced_tx_set = config.speed_settings.reduced_tx_set; let use_reduced_tx_set = config.speed_settings.reduced_tx_set;
let use_tx_domain_distortion = config.tune == Tune::Psnr && config.speed_settings.tx_domain_distortion; let use_tx_domain_distortion = config.tune == Tune::Psnr && config.speed_settings.tx_domain_distortion;
let use_tx_domain_rate = config.speed_settings.tx_domain_rate;
let w_in_b = 2 * config.width.align_power_of_two_and_shift(3); // MiCols, ((width+7)/8)<<3 >> MI_SIZE_LOG2 let w_in_b = 2 * config.width.align_power_of_two_and_shift(3); // MiCols, ((width+7)/8)<<3 >> MI_SIZE_LOG2
let h_in_b = 2 * config.height.align_power_of_two_and_shift(3); // MiRows, ((height+7)/8)<<3 >> MI_SIZE_LOG2 let h_in_b = 2 * config.height.align_power_of_two_and_shift(3); // MiRows, ((height+7)/8)<<3 >> MI_SIZE_LOG2
...@@ -617,6 +624,7 @@ impl<T: Pixel> FrameInvariants<T> { ...@@ -617,6 +624,7 @@ impl<T: Pixel> FrameInvariants<T> {
me_lambda: 0.0, me_lambda: 0.0,
me_range_scale: 1, me_range_scale: 1,
use_tx_domain_distortion, use_tx_domain_distortion,
use_tx_domain_rate,
inter_cfg: None, inter_cfg: None,
enable_early_exit: true, enable_early_exit: true,
config, config,
...@@ -981,9 +989,14 @@ pub fn encode_tx_block<T: Pixel>( ...@@ -981,9 +989,14 @@ pub fn encode_tx_block<T: Pixel>(
let coded_tx_size = av1_get_coded_tx_size(tx_size).area(); let coded_tx_size = av1_get_coded_tx_size(tx_size).area();
fs.qc.quantize(coeffs, qcoeffs, coded_tx_size); fs.qc.quantize(coeffs, qcoeffs, coded_tx_size);
let has_coeff = cw.write_coeffs_lv_map(w, p, bo, &qcoeffs, mode, tx_size, tx_type, plane_bsize, xdec, ydec, let tell_coeffs = w.tell_frac();
fi.use_reduced_tx_set); let has_coeff = if !for_rdo_use || rdo_type.needs_coeff_rate() {
cw.write_coeffs_lv_map(w, p, bo, &qcoeffs, mode, tx_size, tx_type, plane_bsize, xdec, ydec,
fi.use_reduced_tx_set)
} else {
true
};
let cost_coeffs = w.tell_frac() - tell_coeffs;
// Reconstruct // Reconstruct
dequantize(qidx, qcoeffs, rcoeffs, tx_size, fi.sequence.bit_depth, fi.dc_delta_q[p], fi.ac_delta_q[p]); dequantize(qidx, qcoeffs, rcoeffs, tx_size, fi.sequence.bit_depth, fi.dc_delta_q[p], fi.ac_delta_q[p]);
...@@ -1006,6 +1019,15 @@ pub fn encode_tx_block<T: Pixel>( ...@@ -1006,6 +1019,15 @@ pub fn encode_tx_block<T: Pixel>(
let tx_dist_scale_rounding_offset = 1 << (tx_dist_scale_bits - 1); let tx_dist_scale_rounding_offset = 1 << (tx_dist_scale_bits - 1);
tx_dist = (tx_dist + tx_dist_scale_rounding_offset) >> tx_dist_scale_bits; tx_dist = (tx_dist + tx_dist_scale_rounding_offset) >> tx_dist_scale_bits;
} }
if fi.config.train_rdo {
fs.t.add_rate(fi.base_q_idx, tx_size, tx_dist as u64, cost_coeffs as u64);
}
if rdo_type == RDOType::TxDistEstRate {
// look up rate and distortion in table
let estimated_rate = estimate_rate(fi.base_q_idx, tx_size, tx_dist as u64);
w.add_bits_frac(estimated_rate as u32);
}
(has_coeff, tx_dist) (has_coeff, tx_dist)
} }
...@@ -1446,7 +1468,8 @@ pub fn write_tx_tree<T: Pixel>( ...@@ -1446,7 +1468,8 @@ pub fn write_tx_tree<T: Pixel>(
pub fn encode_block_with_modes<T: Pixel>( pub fn encode_block_with_modes<T: Pixel>(
fi: &FrameInvariants<T>, fs: &mut FrameState<T>, fi: &FrameInvariants<T>, fs: &mut FrameState<T>,
cw: &mut ContextWriter, w_pre_cdef: &mut dyn Writer, w_post_cdef: &mut dyn Writer, cw: &mut ContextWriter, w_pre_cdef: &mut dyn Writer, w_post_cdef: &mut dyn Writer,
bsize: BlockSize, bo: &BlockOffset, mode_decision: &RDOPartitionOutput bsize: BlockSize, bo: &BlockOffset, mode_decision: &RDOPartitionOutput,
rdo_type: RDOType
) { ) {
let (mode_luma, mode_chroma) = let (mode_luma, mode_chroma) =
(mode_decision.pred_mode_luma, mode_decision.pred_mode_chroma); (mode_decision.pred_mode_luma, mode_decision.pred_mode_chroma);
...@@ -1469,7 +1492,7 @@ pub fn encode_block_with_modes<T: Pixel>( ...@@ -1469,7 +1492,7 @@ pub fn encode_block_with_modes<T: Pixel>(
bsize, bo, skip); bsize, bo, skip);
encode_block_b(fi, fs, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef}, encode_block_b(fi, fs, cw, if cdef_coded {w_post_cdef} else {w_pre_cdef},
mode_luma, mode_chroma, ref_frames, mvs, bsize, bo, skip, cfl, mode_luma, mode_chroma, ref_frames, mvs, bsize, bo, skip, cfl,
tx_size, tx_type, mode_context, &mv_stack, RDOType::PixelDistRealRate, false); tx_size, tx_type, mode_context, &mv_stack, rdo_type, false);
} }
fn encode_partition_bottomup<T: Pixel>( fn encode_partition_bottomup<T: Pixel>(
...@@ -1478,6 +1501,7 @@ fn encode_partition_bottomup<T: Pixel>( ...@@ -1478,6 +1501,7 @@ fn encode_partition_bottomup<T: Pixel>(
bo: &BlockOffset, pmvs: &[[Option<MotionVector>; REF_FRAMES]; 5], bo: &BlockOffset, pmvs: &[[Option<MotionVector>; REF_FRAMES]; 5],
ref_rd_cost: f64 ref_rd_cost: f64
) -> (RDOOutput) { ) -> (RDOOutput) {
let rdo_type = RDOType::PixelDistRealRate;
let mut rd_cost = std::f64::MAX; let mut rd_cost = std::f64::MAX;
let mut best_rd = std::f64::MAX; let mut best_rd = std::f64::MAX;
let mut rdo_output = RDOOutput { let mut rdo_output = RDOOutput {
...@@ -1536,7 +1560,7 @@ fn encode_partition_bottomup<T: Pixel>( ...@@ -1536,7 +1560,7 @@ fn encode_partition_bottomup<T: Pixel>(
if !can_split { if !can_split {
encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef, bsize, bo, encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef, bsize, bo,
&mode_decision); &mode_decision, rdo_type);
} }
} }
...@@ -1652,7 +1676,7 @@ fn encode_partition_bottomup<T: Pixel>( ...@@ -1652,7 +1676,7 @@ fn encode_partition_bottomup<T: Pixel>(
let offset = mode.bo.clone(); let offset = mode.bo.clone();
// FIXME: redundant block re-encode // FIXME: redundant block re-encode
encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef, encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef,
mode.bsize, &offset, &mode); mode.bsize, &offset, &mode, rdo_type);
} }
} }
} }
...@@ -1686,6 +1710,7 @@ fn encode_partition_topdown<T: Pixel>( ...@@ -1686,6 +1710,7 @@ fn encode_partition_topdown<T: Pixel>(
let bsw = bsize.width_mi(); let bsw = bsize.width_mi();
let bsh = bsize.height_mi(); let bsh = bsize.height_mi();
let is_square = bsize.is_sqr(); let is_square = bsize.is_sqr();
let rdo_type = RDOType::PixelDistRealRate;
// Always split if the current partition is too large // Always split if the current partition is too large
let must_split = (bo.x + bsw as usize > fi.w_in_b || let must_split = (bo.x + bsw as usize > fi.w_in_b ||
...@@ -1726,7 +1751,7 @@ fn encode_partition_topdown<T: Pixel>( ...@@ -1726,7 +1751,7 @@ fn encode_partition_topdown<T: Pixel>(
partition_types.push(PartitionType::PARTITION_SPLIT); partition_types.push(PartitionType::PARTITION_SPLIT);
} }
rdo_output = rdo_partition_decision(fi, fs, cw, rdo_output = rdo_partition_decision(fi, fs, cw,
w_pre_cdef, w_post_cdef, bsize, bo, &rdo_output, pmvs, &partition_types); w_pre_cdef, w_post_cdef, bsize, bo, &rdo_output, pmvs, &partition_types, rdo_type);
partition = rdo_output.part_type; partition = rdo_output.part_type;
} else { } else {
// Blocks of sizes below the supported range are encoded directly // Blocks of sizes below the supported range are encoded directly
...@@ -2059,7 +2084,6 @@ fn encode_tile<T: Pixel>(fi: &FrameInvariants<T>, fs: &mut FrameState<T>) -> Vec ...@@ -2059,7 +2084,6 @@ fn encode_tile<T: Pixel>(fi: &FrameInvariants<T>, fs: &mut FrameState<T>) -> Vec
if fs.deblock.levels[0] != 0 || fs.deblock.levels[1] != 0 { if fs.deblock.levels[0] != 0 || fs.deblock.levels[1] != 0 {
deblock_filter_frame(fs, &mut cw.bc, fi.sequence.bit_depth); deblock_filter_frame(fs, &mut cw.bc, fi.sequence.bit_depth);
} }
{
// Until the loop filters are pipelined, we'll need to keep // Until the loop filters are pipelined, we'll need to keep
// around a copy of both the pre- and post-cdef frame. // around a copy of both the pre- and post-cdef frame.
let pre_cdef_frame = fs.rec.clone(); let pre_cdef_frame = fs.rec.clone();
...@@ -2072,6 +2096,17 @@ fn encode_tile<T: Pixel>(fi: &FrameInvariants<T>, fs: &mut FrameState<T>) -> Vec ...@@ -2072,6 +2096,17 @@ fn encode_tile<T: Pixel>(fi: &FrameInvariants<T>, fs: &mut FrameState<T>) -> Vec
if fi.sequence.enable_restoration { if fi.sequence.enable_restoration {
fs.restoration.lrf_filter_frame(&mut fs.rec, &pre_cdef_frame, &fi); fs.restoration.lrf_filter_frame(&mut fs.rec, &pre_cdef_frame, &fi);
} }
if fi.config.train_rdo {
eprintln!("train rdo");
if let Ok(mut file) = File::open("rdo.dat") {
let mut data = vec![];
file.read_to_end(&mut data).unwrap();
fs.t.merge_in(&deserialize(data.as_slice()).unwrap());
}
let mut rdo_file = File::create("rdo.dat").unwrap();
rdo_file.write_all(&serialize(&fs.t).unwrap()).unwrap();
fs.t.print_code();
} }
fs.cdfs = cw.fc; fs.cdfs = cw.fc;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#[macro_use] #[macro_use]
extern crate serde_derive; extern crate serde_derive;
extern crate bincode;
#[cfg(all(test, feature="decode_test_dav1d"))] #[cfg(all(test, feature="decode_test_dav1d"))]
extern crate dav1d_sys; extern crate dav1d_sys;
...@@ -30,6 +31,7 @@ pub mod transform; ...@@ -30,6 +31,7 @@ pub mod transform;
pub mod quantize; pub mod quantize;
pub mod predict; pub mod predict;
pub mod rdo; pub mod rdo;
pub mod rdo_tables;
#[macro_use] #[macro_use]
pub mod util; pub mod util;
pub mod context; pub mod context;
......
...@@ -32,16 +32,19 @@ use crate::Tune; ...@@ -32,16 +32,19 @@ use crate::Tune;
use crate::write_tx_blocks; use crate::write_tx_blocks;
use crate::write_tx_tree; use crate::write_tx_tree;
use crate::util::{CastFromPrimitive, Pixel}; use crate::util::{CastFromPrimitive, Pixel};
use crate::rdo_tables::*;
use std; use std;
use std::cmp; use std::cmp;
use std::vec::Vec; use std::vec::Vec;
use crate::partition::PartitionType::*; use crate::partition::PartitionType::*;
#[derive(Copy,Clone)] #[derive(Copy,Clone,PartialEq)]
pub enum RDOType { pub enum RDOType {
PixelDistRealRate, PixelDistRealRate,
TxDistRealRate TxDistRealRate,
TxDistEstRate,
Train
} }
impl RDOType { impl RDOType {
...@@ -50,7 +53,18 @@ impl RDOType { ...@@ -50,7 +53,18 @@ impl RDOType {
// Pixel-domain distortion and exact ec rate // Pixel-domain distortion and exact ec rate
RDOType::PixelDistRealRate => false, RDOType::PixelDistRealRate => false,
// Tx-domain distortion and exact ec rate // Tx-domain distortion and exact ec rate
RDOType::TxDistRealRate => true RDOType::TxDistRealRate => true,
// Tx-domain distortion and txdist-based rate
RDOType::TxDistEstRate => true,
RDOType::Train => true,
}
}
pub fn needs_coeff_rate(self) -> bool {
match self {
RDOType::PixelDistRealRate => true,
RDOType::TxDistRealRate => true,
RDOType::TxDistEstRate => false,
RDOType::Train => true,
} }
} }
} }
...@@ -77,6 +91,86 @@ pub struct RDOPartitionOutput { ...@@ -77,6 +91,86 @@ pub struct RDOPartitionOutput {
pub tx_type: TxType, pub tx_type: TxType,
} }
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct RDOTracker {
rate_bins: Vec<Vec<Vec<u64>>>,
rate_counts: Vec<Vec<Vec<u64>>>,
}
impl RDOTracker {
pub fn new() -> RDOTracker {
RDOTracker {
rate_bins: vec![vec![vec![0; RDO_NUM_BINS]; TxSize::TX_SIZES_ALL]; RDO_QUANT_BINS],
rate_counts: vec![vec![vec![0; RDO_NUM_BINS]; TxSize::TX_SIZES_ALL]; RDO_QUANT_BINS],
}
}
fn merge_array(new: &mut Vec<u64>, old: &[u64]) {
for (n, o) in new.iter_mut().zip(old.iter()) {
*n += o;
}
}
fn merge_2d_array(new: &mut Vec<Vec<u64>>, old: &[Vec<u64>]) {
for (n, o) in new.iter_mut().zip(old.iter()) {
RDOTracker::merge_array(n, o);
}
}
fn merge_3d_array(new: &mut Vec<Vec<Vec<u64>>>, old: &[Vec<Vec<u64>>]) {
for (n, o) in new.iter_mut().zip(old.iter()) {
RDOTracker::merge_2d_array(n, o);
}
}
pub fn merge_in(&mut self, input: &RDOTracker) {
RDOTracker::merge_3d_array(&mut self.rate_bins, &input.rate_bins);
RDOTracker::merge_3d_array(&mut self.rate_counts, &input.rate_counts);
}
pub fn add_rate(&mut self, qindex: u8, ts: TxSize, fast_distortion: u64, rate: u64) {
if fast_distortion != 0 {
let bs_index = ts as usize;
let q_bin_idx = (qindex as usize)/RDO_QUANT_DIV;
let bin_idx_tmp = ((fast_distortion as i64 - (RATE_EST_BIN_SIZE as i64) / 2) as u64 / RATE_EST_BIN_SIZE) as usize;
let bin_idx = if bin_idx_tmp >= RDO_NUM_BINS {
RDO_NUM_BINS - 1
} else {
bin_idx_tmp
};
self.rate_counts[q_bin_idx][bs_index][bin_idx] += 1;
self.rate_bins[q_bin_idx][bs_index][bin_idx] += rate;
}
}
pub fn print_code(&self) {
println!("pub static RDO_RATE_TABLE: [[[u64; RDO_NUM_BINS]; TxSize::TX_SIZES_ALL]; RDO_QUANT_BINS] = [");
for q_bin in 0..RDO_QUANT_BINS {
print!("[");
for bs_index in 0..TxSize::TX_SIZES_ALL {
print!("[");
for (rate_total, rate_count) in self.rate_bins[q_bin][bs_index].iter().zip(self.rate_counts[q_bin][bs_index].iter()) {
if *rate_count > 100 {
print!("{},", rate_total / rate_count);
} else {
print!("99999,");
}
}
println!("],");
}
println!("],");
}
println!("];");
}
}
pub fn estimate_rate(qindex: u8, ts: TxSize, fast_distortion: u64) -> u64 {
let bs_index = ts as usize;
let q_bin_idx = (qindex as usize)/RDO_QUANT_DIV;
let bin_idx_down = ((fast_distortion) / RATE_EST_BIN_SIZE).min((RDO_NUM_BINS - 2) as u64);
let bin_idx_up = (bin_idx_down + 1).min((RDO_NUM_BINS - 1) as u64);
let x0 = (bin_idx_down * RATE_EST_BIN_SIZE) as i64;
let x1 = (bin_idx_up * RATE_EST_BIN_SIZE) as i64;
let y0 = RDO_RATE_TABLE[q_bin_idx][bs_index][bin_idx_down as usize] as i64;
let y1 = RDO_RATE_TABLE[q_bin_idx][bs_index][bin_idx_up as usize] as i64;
let slope = ((y1 - y0) << 8) / (x1 - x0);
(y0 + (((fast_distortion as i64 - x0) * slope) >> 8)).max(0) as u64
}
#[allow(unused)] #[allow(unused)]
fn cdef_dist_wxh_8x8<T: Pixel>( fn cdef_dist_wxh_8x8<T: Pixel>(
src1: &PlaneSlice<'_, T>, src2: &PlaneSlice<'_, T>, bit_depth: usize src1: &PlaneSlice<'_, T>, src2: &PlaneSlice<'_, T>, bit_depth: usize
...@@ -370,7 +464,11 @@ pub fn rdo_mode_decision<T: Pixel>( ...@@ -370,7 +464,11 @@ pub fn rdo_mode_decision<T: Pixel>(
let mut fwdref = None; let mut fwdref = None;
let mut bwdref = None; let mut bwdref = None;
let rdo_type = if fi.use_tx_domain_distortion { let rdo_type = if fi.config.train_rdo {
RDOType::Train
} else if fi.use_tx_domain_rate {
RDOType::TxDistEstRate
} else if fi.use_tx_domain_distortion {
RDOType::TxDistRealRate RDOType::TxDistRealRate
} else { } else {
RDOType::PixelDistRealRate RDOType::PixelDistRealRate
...@@ -722,8 +820,8 @@ pub fn rdo_mode_decision<T: Pixel>( ...@@ -722,8 +820,8 @@ pub fn rdo_mode_decision<T: Pixel>(
let wr: &mut dyn Writer = &mut WriterCounter::new(); let wr: &mut dyn Writer = &mut WriterCounter::new();
let tell = wr.tell_frac(); let tell = wr.tell_frac();
encode_block_a(&fi.sequence, fs, cw, wr, bsize, bo, best.skip); encode_block_a(&fi.sequence, fs, cw, wr, bsize, bo, best.skip);
encode_block_b( let _ = encode_block_b(
fi, fi,
fs, fs,
cw, cw,
...@@ -996,7 +1094,7 @@ pub fn rdo_partition_decision<T: Pixel>( ...@@ -996,7 +1094,7 @@ pub fn rdo_partition_decision<T: Pixel>(
cw: &mut ContextWriter, w_pre_cdef: &mut dyn Writer, w_post_cdef: &mut dyn Writer, cw: &mut ContextWriter, w_pre_cdef: &mut dyn Writer, w_post_cdef: &mut dyn Writer,
bsize: BlockSize, bo: &BlockOffset, bsize: BlockSize, bo: &BlockOffset,
cached_block: &RDOOutput, pmvs: &[[Option<MotionVector>; REF_FRAMES]; 5], cached_block: &RDOOutput, pmvs: &[[Option<MotionVector>; REF_FRAMES]; 5],
partition_types: &[PartitionType], partition_types: &[PartitionType], rdo_type: RDOType
) -> RDOOutput { ) -> RDOOutput {
let mut best_partition = cached_block.part_type; let mut best_partition = cached_block.part_type;
let mut best_rd = cached_block.rd_cost; let mut best_rd = cached_block.rd_cost;
...@@ -1089,7 +1187,7 @@ pub fn rdo_partition_decision<T: Pixel>( ...@@ -1089,7 +1187,7 @@ pub fn rdo_partition_decision<T: Pixel>(
cw.write_partition(w, offset, PartitionType::PARTITION_NONE, subsize); cw.write_partition(w, offset, PartitionType::PARTITION_NONE, subsize);
} }
encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef, subsize, encode_block_with_modes(fi, fs, cw, w_pre_cdef, w_post_cdef, subsize,
offset, &mode_decision); offset, &mode_decision, rdo_type);
child_modes.push(mode_decision); child_modes.push(mode_decision);
} }
} }
...@@ -1335,3 +1433,8 @@ pub fn rdo_loop_decision<T: Pixel>(sbo: &SuperBlockOffset, fi: &FrameInvariants< ...@@ -1335,3 +1433,8 @@ pub fn rdo_loop_decision<T: Pixel>(sbo: &SuperBlockOffset, fi: &FrameInvariants<
} }
} }
} }
#[test]
fn estimate_rate_test() {
assert_eq!(estimate_rate(0, TxSize::TX_4X4, 0), RDO_RATE_TABLE[0][0][0]);
}
This diff is collapsed.
#!/bin/bash
set -e
rm -f rdo.dat
cargo build --release
ls ~/sets/objective-1-fast/*.y4m | parallel target/release/rav1e --threads 1 --quantizer {2} -o /dev/null --train-rdo {1} :::: - ::: 16 48 80 112 144 176 208 240
gnuplot rdo.plt -p