diff --git a/dnn/torch/rdovae/README.md b/dnn/torch/rdovae/README.md new file mode 100644 index 0000000000000000000000000000000000000000..14359d82d0831ec22953eb40121b1784450f17c4 --- /dev/null +++ b/dnn/torch/rdovae/README.md @@ -0,0 +1,24 @@ +# Rate-Distortion-Optimized Variational Auto-Encoder + +## Setup +The python code requires python >= 3.6 and has been tested with python 3.6 and python 3.10. To install requirements run +``` +python -m pip install -r requirements.txt +``` + +## Training +To generate training data use dump date from the main LPCNet repo +``` +./dump_data -train 16khz_speech_input.s16 features.f32 data.s16 +``` + +To train the model, simply run +``` +python train_rdovae.py features.f32 output_folder +``` + +To train on CUDA device add `--cuda-visible-devices idx`. + + +## ToDo +- Upload checkpoints and add URLs diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..35b43704451fd9b8675dffe0bec29dbdbea58472 --- /dev/null +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -0,0 +1,256 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint') +parser.add_argument('output_dir', type=str, help='output folder') +parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C') + +args = parser.parse_args() + +import torch +import numpy as np + +from rdovae import RDOVAE +from wexchange.torch import dump_torch_weights +from wexchange.c_export import CWriter, print_vector + + +def dump_statistical_model(writer, qembedding): + w = qembedding.weight.detach() + levels, dim = w.shape + N = dim // 6 + + print("printing statistical model") + quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy() + dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy() + r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy() + p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy() + p0 = 1 - r ** (0.5 + 0.5 * p0) + + quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16) + dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16) + r_q15 = np.round(r * 2**15).astype(np.uint16) + p0_q15 = np.round(p0 * 2**15).astype(np.uint16) + + print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False) + print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False) + print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False) + print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False) + + writer.header.write( +f""" +extern const opus_uint16 dred_quant_scales_q8[{levels * N}]; +extern const opus_uint16 dred_dead_zone_q10[{levels * N}]; +extern const opus_uint16 dred_r_q15[{levels * N}]; +extern const opus_uint16 dred_p0_q15[{levels * N}]; + +""" + ) + + +def c_export(args, model): + + message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" + + enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message) + dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message) + stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message) + constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True) + + # some custom includes + for writer in [enc_writer, dec_writer, stats_writer]: + writer.header.write( +f""" +#include "opus_types.h" + +#include "dred_rdovae_constants.h" + +#include "nnet.h" +""" + ) + + # encoder + encoder_dense_layers = [ + ('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'), + ('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'), + ('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'), + ('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'), + ('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'), + ('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'), + ('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH') + ] + + for name, export_name, activation in encoder_dense_layers: + layer = model.get_submodule(name) + dump_torch_weights(enc_writer, layer, name=export_name, activation=activation, verbose=True) + + + encoder_gru_layers = [ + ('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'), + ('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'), + ('core_encoder.module.gru_3' , 'enc_dense6', 'TANH') + ] + + enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers]) + + + encoder_conv_layers = [ + ('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR') + ] + + enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_conv_layers]) + + + del enc_writer + + # decoder + decoder_dense_layers = [ + ('core_decoder.module.gru_1_init' , 'state1', 'TANH'), + ('core_decoder.module.gru_2_init' , 'state2', 'TANH'), + ('core_decoder.module.gru_3_init' , 'state3', 'TANH'), + ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'), + ('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'), + ('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'), + ('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'), + ('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'), + ('core_decoder.module.output' , 'dec_final', 'LINEAR') + ] + + for name, export_name, activation in decoder_dense_layers: + layer = model.get_submodule(name) + dump_torch_weights(dec_writer, layer, name=export_name, activation=activation, verbose=True) + + + decoder_gru_layers = [ + ('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'), + ('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'), + ('core_decoder.module.gru_3' , 'dec_dense6', 'TANH') + ] + + dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers]) + + del dec_writer + + # statistical model + qembedding = model.statistical_model.quant_embedding + dump_statistical_model(stats_writer, qembedding) + + del stats_writer + + # constants + constants_writer.header.write( +f""" +#define DRED_NUM_FEATURES {model.feature_dim} + +#define DRED_LATENT_DIM {model.latent_dim} + +#define DRED_STATE_DIME {model.state_dim} + +#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels} + +#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)} + +#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs} + +#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs} + +#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs} + +#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units} + +""" + ) + + del constants_writer + + +def numpy_export(args, model): + + exchange_name_to_name = { + 'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1', + 'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2', + 'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3', + 'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4', + 'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5', + 'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1', + 'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2', + 'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1', + 'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2', + 'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3', + 'encoder_stack_layer9_conv' : 'core_encoder.module.conv1', + 'statistical_model_embedding' : 'statistical_model.quant_embedding', + 'decoder_state1_dense' : 'core_decoder.module.gru_1_init', + 'decoder_state2_dense' : 'core_decoder.module.gru_2_init', + 'decoder_state3_dense' : 'core_decoder.module.gru_3_init', + 'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1', + 'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2', + 'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3', + 'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4', + 'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5', + 'decoder_stack_layer9_dense' : 'core_decoder.module.output', + 'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1', + 'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2', + 'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3' + } + + name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()} + + for name, exchange_name in name_to_exchange_name.items(): + print(f"printing layer {name}...") + dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name)) + + +if __name__ == "__main__": + + + os.makedirs(args.output_dir, exist_ok=True) + + + # load model from checkpoint + checkpoint = torch.load(args.checkpoint, map_location='cpu') + model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) + missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False) + + if len(missing_keys) > 0: + raise ValueError(f"error: missing keys in state dict") + + if len(unmatched_keys) > 0: + print(f"warning: the following keys were unmatched {unmatched_keys}") + + if args.format == 'C': + c_export(args, model) + elif args.format == 'numpy': + numpy_export(args, model) + else: + raise ValueError(f'error: unknown export format {args.format}') \ No newline at end of file diff --git a/dnn/torch/rdovae/fec_encoder.py b/dnn/torch/rdovae/fec_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..291c0628bbb8d821e87c36a99a6fa53e893093a8 --- /dev/null +++ b/dnn/torch/rdovae/fec_encoder.py @@ -0,0 +1,213 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe and Jean-Marc Valin */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import subprocess +import argparse + +os.environ['CUDA_VISIBLE_DEVICES'] = "" + +parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames') + +parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)') +parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint') +parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame') +parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame') +parser.add_argument('output', type=str, help='output file (will be extended with .fec)') + +parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)') +parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)') +parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)") +parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)') +parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk') + +args = parser.parse_args() + +import numpy as np +from scipy.io import wavfile +import torch + +from rdovae import RDOVAE +from packets import write_fec_packets + +torch.set_num_threads(4) + +checkpoint = torch.load(args.checkpoint, map_location="cpu") +model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) +model.load_state_dict(checkpoint['state_dict'], strict=False) +model.to("cpu") + +lpc_order = 16 + +## prepare input signal +# SILK frame size is 20ms and LPCNet subframes are 10ms +subframe_size = 160 +frame_size = 2 * subframe_size + +# 91 samples delay to align with SILK decoded frames +silk_delay = 91 + +# prepend zeros to have enough history to produce the first package +zero_history = (args.num_redundancy_frames - 1) * frame_size + +# dump data has a (feature) delay of 10ms +dump_data_delay = 160 + +total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay + +# load signal +if args.input.endswith('.raw') or args.input.endswith('.pcm'): + signal = np.fromfile(args.input, dtype='int16') + +elif args.input.endswith('.wav'): + fs, signal = wavfile.read(args.input) +else: + raise ValueError(f'unknown input signal format: {args.input}') + +# fill up last frame with zeros +padded_signal_length = len(signal) + total_delay +tail = padded_signal_length % frame_size +right_padding = (frame_size - tail) % frame_size + +signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16))) + +padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw' +signal.tofile(padded_signal_file) + +# write signal and call dump_data to create features + +feature_file = os.path.splitext(args.input)[0] + '_features.f32' +command = f"{args.dump_data} -test {padded_signal_file} {feature_file}" +r = subprocess.run(command, shell=True) +if r.returncode != 0: + raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}") + +# load features +nb_features = model.feature_dim + lpc_order +nb_used_features = model.feature_dim + +# load features +features = np.fromfile(feature_file, dtype='float32') +num_subframes = len(features) // nb_features +num_subframes = 2 * (num_subframes // 2) +num_frames = num_subframes // 2 + +features = np.reshape(features, (1, -1, nb_features)) +features = features[:, :, :nb_used_features] +features = features[:, :num_subframes, :] + +# quant_ids in reverse decoding order +quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long() + +print(f"using quantization levels {quant_ids}...") + +# convert input to torch tensors +features = torch.from_numpy(features) + + +# run encoder +print("running fec encoder...") +with torch.no_grad(): + + # encoding + z, states, state_size = model.encode(features) + + + # decoder on packet chunks + input_length = args.num_redundancy_frames // 2 + offset = args.num_redundancy_frames - 1 + + packets = [] + packet_sizes = [] + + for i in range(offset, num_frames): + print(f"processing frame {i - offset}...") + # quantize / unquantize latent vectors + zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :]) + zi, rates = model.quantize(zi, quant_ids) + zi = model.unquantize(zi, quant_ids) + + features = model.decode(zi, states[:, i : i + 1, :]) + packets.append(features.squeeze(0).numpy()) + packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8) + packet_sizes.append(packet_size) + + +# write packets +packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output +write_fec_packets(packet_file, packets, packet_sizes) + + +print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps") + +# assemble features according to loss file +if args.lossfile != None: + num_packets = len(packets) + loss = np.loadtxt(args.lossfile, dtype='int16') + fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32') + foffset = -2 + ptr = 0 + count = 2 + for i in range(num_packets): + if (loss[i] == 0) or (i == num_packets - 1): + + fec_out[ptr:ptr+count,:] = packets[i][foffset:, :] + + ptr += count + foffset = -2 + count = 2 + else: + count += 2 + foffset -= 2 + + fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32) + fec_out_full[:, : fec_out.shape[-1]] = fec_out + + fec_out_full.tofile(packet_file[:-4] + f'_fec.f32') + + +if args.debug_output: + import itertools + + batches = [4] + offsets = [0, 2 * args.num_redundancy_frames - 4] + + # sanity checks + # 1. concatenate features at offset 0 + for batch, offset in itertools.product(batches, offsets): + + stop = packets[0].shape[1] - offset + test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0) + + test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32) + test_features_full[:, :nb_used_features] = test_features[:, :] + + print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}") + test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32') + diff --git a/dnn/torch/rdovae/import_rdovae_weights.py b/dnn/torch/rdovae/import_rdovae_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..eba05018cb4c1d39f4082c66cc8133f02fb213ba --- /dev/null +++ b/dnn/torch/rdovae/import_rdovae_weights.py @@ -0,0 +1,143 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = "" + +import argparse + + + +parser = argparse.ArgumentParser() + +parser.add_argument('exchange_folder', type=str, help='exchange folder path') +parser.add_argument('output', type=str, help='path to output model checkpoint') + +model_group = parser.add_argument_group(title="model parameters") +model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20) +model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80) +model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) +model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256) +model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24) +model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40) + +args = parser.parse_args() + +import torch +from rdovae import RDOVAE +from wexchange.torch import load_torch_weights + +exchange_name_to_name = { + 'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1', + 'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2', + 'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3', + 'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4', + 'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5', + 'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1', + 'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2', + 'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1', + 'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2', + 'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3', + 'encoder_stack_layer9_conv' : 'core_encoder.module.conv1', + 'statistical_model_embedding' : 'statistical_model.quant_embedding', + 'decoder_state1_dense' : 'core_decoder.module.gru_1_init', + 'decoder_state2_dense' : 'core_decoder.module.gru_2_init', + 'decoder_state3_dense' : 'core_decoder.module.gru_3_init', + 'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1', + 'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2', + 'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3', + 'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4', + 'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5', + 'decoder_stack_layer9_dense' : 'core_decoder.module.output', + 'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1', + 'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2', + 'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3' +} + +if __name__ == "__main__": + checkpoint = dict() + + # parameters + num_features = args.num_features + latent_dim = args.latent_dim + quant_levels = args.quant_levels + cond_size = args.cond_size + cond_size2 = args.cond_size2 + state_dim = args.state_dim + + + # model + checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) + checkpoint['model_kwargs'] = {'state_dim': state_dim} + model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) + + dense_layer_names = [ + 'encoder_stack_layer1_dense', + 'encoder_stack_layer3_dense', + 'encoder_stack_layer5_dense', + 'encoder_stack_layer7_dense', + 'encoder_stack_layer8_dense', + 'encoder_state_layer1_dense', + 'encoder_state_layer2_dense', + 'decoder_state1_dense', + 'decoder_state2_dense', + 'decoder_state3_dense', + 'decoder_stack_layer1_dense', + 'decoder_stack_layer3_dense', + 'decoder_stack_layer5_dense', + 'decoder_stack_layer7_dense', + 'decoder_stack_layer8_dense', + 'decoder_stack_layer9_dense' + ] + + gru_layer_names = [ + 'encoder_stack_layer2_gru', + 'encoder_stack_layer4_gru', + 'encoder_stack_layer6_gru', + 'decoder_stack_layer2_gru', + 'decoder_stack_layer4_gru', + 'decoder_stack_layer6_gru' + ] + + conv1d_layer_names = [ + 'encoder_stack_layer9_conv' + ] + + embedding_layer_names = [ + 'statistical_model_embedding' + ] + + for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names: + print(f"loading weights for layer {exchange_name_to_name[name]}") + layer = model.get_submodule(exchange_name_to_name[name]) + load_torch_weights(os.path.join(args.exchange_folder, name), layer) + + checkpoint['state_dict'] = model.state_dict() + + torch.save(checkpoint, args.output) \ No newline at end of file diff --git a/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..cfeebae5bafd31e7e674e73120d45c76ab177d3a Binary files /dev/null and b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl differ diff --git a/dnn/torch/rdovae/packets/__init__.py b/dnn/torch/rdovae/packets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb71ab3d55691d2cd53137ce9be0708b250edda1 --- /dev/null +++ b/dnn/torch/rdovae/packets/__init__.py @@ -0,0 +1 @@ +from .fec_packets import write_fec_packets, read_fec_packets \ No newline at end of file diff --git a/dnn/torch/rdovae/packets/fec_packets.c b/dnn/torch/rdovae/packets/fec_packets.c new file mode 100644 index 0000000000000000000000000000000000000000..376fb4f169851600ab3c0e1760f47c18e9ac5c8c --- /dev/null +++ b/dnn/torch/rdovae/packets/fec_packets.c @@ -0,0 +1,142 @@ +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include <stdio.h> +#include <inttypes.h> + +#include "fec_packets.h" + +int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index) +{ + + int16_t version; + int16_t header_size; + int16_t num_packets; + int16_t packet_size; + int16_t subframe_size; + int16_t subframes_per_packet; + int16_t num_features; + long offset; + + FILE *fid = fopen(filename, "rb"); + + /* read header */ + if (fread(&version, sizeof(version), 1, fid) != 1) goto error; + if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error; + if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error; + if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error; + if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error; + if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error; + if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error; + + /* check if indices are valid */ + if (packet_index >= num_packets || subframe_index >= subframes_per_packet) + { + fprintf(stderr, "get_fec_frame: index out of bounds\n"); + goto error; + } + + /* calculate offset in file (+ 2 is for rate) */ + offset = header_size + packet_index * packet_size + 2 + subframe_index * subframe_size; + fseek(fid, offset, SEEK_SET); + + /* read features */ + if (fread(features, sizeof(*features), num_features, fid) != num_features) goto error; + + fclose(fid); + return 0; + +error: + fclose(fid); + return 1; +} + +int get_fec_rate(const char * const filename, int packet_index) +{ + int16_t version; + int16_t header_size; + int16_t num_packets; + int16_t packet_size; + int16_t subframe_size; + int16_t subframes_per_packet; + int16_t num_features; + long offset; + int16_t rate; + + FILE *fid = fopen(filename, "rb"); + + /* read header */ + if (fread(&version, sizeof(version), 1, fid) != 1) goto error; + if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error; + if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error; + if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error; + if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error; + if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error; + if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error; + + /* check if indices are valid */ + if (packet_index >= num_packets) + { + fprintf(stderr, "get_fec_rate: index out of bounds\n"); + goto error; + } + + /* calculate offset in file (+ 2 is for rate) */ + offset = header_size + packet_index * packet_size; + fseek(fid, offset, SEEK_SET); + + /* read rate */ + if (fread(&rate, sizeof(rate), 1, fid) != 1) goto error; + + fclose(fid); + return (int) rate; + +error: + fclose(fid); + return -1; +} + +#if 0 +int main() +{ + float features[20]; + int i; + + if (get_fec_frame("../test.fec", &features[0], 0, 127)) + { + return 1; + } + + for (i = 0; i < 20; i ++) + { + printf("%d %f\n", i, features[i]); + } + + printf("rate: %d\n", get_fec_rate("../test.fec", 0)); + +} +#endif \ No newline at end of file diff --git a/dnn/torch/rdovae/packets/fec_packets.h b/dnn/torch/rdovae/packets/fec_packets.h new file mode 100644 index 0000000000000000000000000000000000000000..35d355428a3ddc15ece26659298473a4975c026f --- /dev/null +++ b/dnn/torch/rdovae/packets/fec_packets.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef _FEC_PACKETS_H +#define _FEC_PACKETS_H + +int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index); +int get_fec_rate(const char * const filename, int packet_index); + +#endif \ No newline at end of file diff --git a/dnn/torch/rdovae/packets/fec_packets.py b/dnn/torch/rdovae/packets/fec_packets.py new file mode 100644 index 0000000000000000000000000000000000000000..14bed1f8cea0e7cac413b2627f43b3e42e2f7ca5 --- /dev/null +++ b/dnn/torch/rdovae/packets/fec_packets.py @@ -0,0 +1,108 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import numpy as np + + + +def write_fec_packets(filename, packets, rates=None): + """ writes packets in binary format """ + + assert np.dtype(np.float32).itemsize == 4 + assert np.dtype(np.int16).itemsize == 2 + + # derive some sizes + num_packets = len(packets) + subframes_per_packet = packets[0].shape[-2] + num_features = packets[0].shape[-1] + + # size of float is 4 + subframe_size = num_features * 4 + packet_size = subframe_size * subframes_per_packet + 2 # two bytes for rate + + version = 1 + # header size (version, header_size, num_packets, packet_size, subframe_size, subrames_per_packet, num_features) + header_size = 14 + + with open(filename, 'wb') as f: + + # header + f.write(np.int16(version).tobytes()) + f.write(np.int16(header_size).tobytes()) + f.write(np.int16(num_packets).tobytes()) + f.write(np.int16(packet_size).tobytes()) + f.write(np.int16(subframe_size).tobytes()) + f.write(np.int16(subframes_per_packet).tobytes()) + f.write(np.int16(num_features).tobytes()) + + # packets + for i, packet in enumerate(packets): + if type(rates) == type(None): + rate = 0 + else: + rate = rates[i] + + f.write(np.int16(rate).tobytes()) + + features = np.flip(packet, axis=-2) + f.write(features.astype(np.float32).tobytes()) + + +def read_fec_packets(filename): + """ reads packets from binary format """ + + assert np.dtype(np.float32).itemsize == 4 + assert np.dtype(np.int16).itemsize == 2 + + with open(filename, 'rb') as f: + + # header + version = np.frombuffer(f.read(2), dtype=np.int16).item() + header_size = np.frombuffer(f.read(2), dtype=np.int16).item() + num_packets = np.frombuffer(f.read(2), dtype=np.int16).item() + packet_size = np.frombuffer(f.read(2), dtype=np.int16).item() + subframe_size = np.frombuffer(f.read(2), dtype=np.int16).item() + subframes_per_packet = np.frombuffer(f.read(2), dtype=np.int16).item() + num_features = np.frombuffer(f.read(2), dtype=np.int16).item() + + dummy_features = np.zeros((subframes_per_packet, num_features), dtype=np.float32) + + # packets + rates = [] + packets = [] + for i in range(num_packets): + + rate = np.frombuffer(f.read(2), dtype=np.int16).item + rates.append(rate) + + features = np.reshape(np.frombuffer(f.read(subframe_size * subframes_per_packet), dtype=np.float32), dummy_features.shape) + packet = np.flip(features, axis=-2) + packets.append(packet) + + return packets \ No newline at end of file diff --git a/dnn/torch/rdovae/rdovae/__init__.py b/dnn/torch/rdovae/rdovae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b945addec72168fa6c7b7535ee2f1c20089a6f15 --- /dev/null +++ b/dnn/torch/rdovae/rdovae/__init__.py @@ -0,0 +1,2 @@ +from .rdovae import RDOVAE, distortion_loss, hard_rate_estimate, soft_rate_estimate +from .dataset import RDOVAEDataset diff --git a/dnn/torch/rdovae/rdovae/dataset.py b/dnn/torch/rdovae/rdovae/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..99630d8b924736e59849d15f88cf8c4c45224685 --- /dev/null +++ b/dnn/torch/rdovae/rdovae/dataset.py @@ -0,0 +1,68 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +import numpy as np + +class RDOVAEDataset(torch.utils.data.Dataset): + def __init__(self, + feature_file, + sequence_length, + num_used_features=20, + num_features=36, + lambda_min=0.0002, + lambda_max=0.0135, + quant_levels=16, + enc_stride=2): + + self.sequence_length = sequence_length + self.lambda_min = lambda_min + self.lambda_max = lambda_max + self.enc_stride = enc_stride + self.quant_levels = quant_levels + self.denominator = (quant_levels - 1) / np.log(lambda_max / lambda_min) + + if sequence_length % enc_stride: + raise ValueError(f"RDOVAEDataset.__init__: enc_stride {enc_stride} does not divide sequence length {sequence_length}") + + self.features = np.reshape(np.fromfile(feature_file, dtype=np.float32), (-1, num_features)) + self.features = self.features[:, :num_used_features] + self.num_sequences = self.features.shape[0] // sequence_length + + def __len__(self): + return self.num_sequences + + def __getitem__(self, index): + features = self.features[index * self.sequence_length: (index + 1) * self.sequence_length, :] + q_ids = np.random.randint(0, self.quant_levels, (1)).astype(np.int64) + q_ids = np.repeat(q_ids, self.sequence_length // self.enc_stride, axis=0) + rate_lambda = self.lambda_min * np.exp(q_ids.astype(np.float32) / self.denominator).astype(np.float32) + + return features, rate_lambda, q_ids + diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py new file mode 100644 index 0000000000000000000000000000000000000000..b45d2b8c3b5a329ce863160f814b44ac06023e38 --- /dev/null +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -0,0 +1,614 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +""" Pytorch implementations of rate distortion optimized variational autoencoder """ + +import math as m + +import torch +from torch import nn +import torch.nn.functional as F + +# Quantization and rate related utily functions + +def soft_pvq(x, k): + """ soft pyramid vector quantizer """ + + # L2 normalization + x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True)) + + + with torch.no_grad(): + # quantization loop, no need to track gradients here + x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True) + + # set initial scaling factor to k + scale_factor = k + x_scaled = scale_factor * x_norm1 + x_quant = torch.round(x_scaled) + + # we aim for ||x_quant||_L1 = k + for _ in range(10): + # remove signs and calculate L1 norm + abs_x_quant = torch.abs(x_quant) + abs_x_scaled = torch.abs(x_scaled) + l1_x_quant = torch.sum(abs_x_quant, axis=-1) + + # increase, where target is too small and decrease, where target is too large + plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values + minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values + factor = torch.where(l1_x_quant > k, minus, plus) + factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor) + scale_factor = scale_factor * factor.unsqueeze(-1) + + # update x + x_scaled = scale_factor * x_norm1 + x_quant = torch.round(x_quant) + + # L2 normalization of quantized x + x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True)) + quantization_error = x_quant_norm2 - x_norm2 + + return x_norm2 + quantization_error.detach() + +def cache_parameters(func): + cache = dict() + def cached_func(*args): + if args in cache: + return cache[args] + else: + cache[args] = func(*args) + + return cache[args] + return cached_func + +@cache_parameters +def pvq_codebook_size(n, k): + + if k == 0: + return 1 + + if n == 0: + return 0 + + return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1) + + +def soft_rate_estimate(z, r, reduce=True): + """ rate approximation with dependent theta Eq. (7)""" + + rate = torch.sum( + - torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6), + dim=-1 + ) + + if reduce: + rate = torch.mean(rate) + + return rate + + +def hard_rate_estimate(z, r, theta, reduce=True): + """ hard rate approximation """ + + z_q = torch.round(z) + p0 = 1 - r ** (0.5 + 0.5 * theta) + alpha = torch.relu(1 - torch.abs(z_q)) ** 2 + rate = - torch.sum( + (alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6) + + (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)), + dim=-1 + ) + + if reduce: + rate = torch.mean(rate) + + return rate + + + +def soft_dead_zone(x, dead_zone): + """ approximates application of a dead zone to x """ + d = dead_zone * 0.05 + return x - d * torch.tanh(x / (0.1 + d)) + + +def hard_quantize(x): + """ round with copy gradient trick """ + return x + (torch.round(x) - x).detach() + + +def noise_quantize(x): + """ simulates quantization with addition of random uniform noise """ + return x + (torch.rand_like(x) - 0.5) + + +# loss functions + + +def distortion_loss(y_true, y_pred, rate_lambda=None): + """ custom distortion loss for LPCNet features """ + + if y_true.size(-1) != 20: + raise ValueError('distortion loss is designed to work with 20 features') + + ceps_error = y_pred[..., :18] - y_true[..., :18] + pitch_error = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19]) + corr_error = y_pred[..., 19:] - y_true[..., 19:] + pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2 + + loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1) + + if type(rate_lambda) != type(None): + loss = loss / torch.sqrt(rate_lambda) + + loss = torch.mean(loss) + + return loss + + +# sampling functions + +import random + + +def random_split(start, stop, num_splits=3, min_len=3): + get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)]) + candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] + + while get_min_len(candidate) < min_len: + candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] + + return candidate + + + +# weight initialization and clipping +def init_weights(module): + + if isinstance(module, nn.GRU): + for p in module.named_parameters(): + if p[0].startswith('weight_hh_'): + nn.init.orthogonal_(p[1]) + + +def weight_clip_factory(max_value): + """ weight clipping function concerning sum of abs values of adjecent weights """ + def clip_weight_(w): + stop = w.size(1) + # omit last column if stop is odd + if stop % 2: + stop -= 1 + max_values = max_value * torch.ones_like(w[:, :stop]) + factor = max_value / torch.maximum(max_values, + torch.repeat_interleave( + torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]), + 2, + 1)) + with torch.no_grad(): + w[:, :stop] *= factor + + def clip_weights(module): + if isinstance(module, nn.GRU) or isinstance(module, nn.Linear): + for name, w in module.named_parameters(): + if name.startswith('weight'): + clip_weight_(w) + + return clip_weights + +# RDOVAE module and submodules + + +class CoreEncoder(nn.Module): + STATE_HIDDEN = 128 + FRAMES_PER_STEP = 2 + CONV_KERNEL_SIZE = 4 + + def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24): + """ core encoder for RDOVAE + + Computes latents, initial states, and rate estimates from features and lambda parameter + + """ + + super(CoreEncoder, self).__init__() + + # hyper parameters + self.feature_dim = feature_dim + self.output_dim = output_dim + self.cond_size = cond_size + self.cond_size2 = cond_size2 + self.state_size = state_size + + # derived parameters + self.input_dim = self.FRAMES_PER_STEP * self.feature_dim + self.conv_input_channels = 5 * cond_size + 3 * cond_size2 + + # layers + self.dense_1 = nn.Linear(self.input_dim, self.cond_size2) + self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) + self.dense_2 = nn.Linear(self.cond_size, self.cond_size2) + self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) + self.dense_3 = nn.Linear(self.cond_size, self.cond_size2) + self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) + self.dense_4 = nn.Linear(self.cond_size, self.cond_size) + self.dense_5 = nn.Linear(self.cond_size, self.cond_size) + self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid') + + self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN) + + self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size) + + # initialize weights + self.apply(init_weights) + + + def forward(self, features): + + # reshape features + x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2))) + + batch = x.size(0) + device = x.device + + # run encoding layer stack + x1 = torch.tanh(self.dense_1(x)) + x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device)) + x3 = torch.tanh(self.dense_2(x2)) + x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device)) + x5 = torch.tanh(self.dense_3(x4)) + x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device)) + x7 = torch.tanh(self.dense_4(x6)) + x8 = torch.tanh(self.dense_5(x7)) + + # concatenation of all hidden layer outputs + x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1) + + # init state for decoder + states = torch.tanh(self.state_dense_1(x9)) + states = torch.tanh(self.state_dense_2(states)) + + # latent representation via convolution + x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0]) + z = self.conv1(x9).permute(0, 2, 1) + + return z, states + + + + +class CoreDecoder(nn.Module): + + FRAMES_PER_STEP = 4 + + def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24): + """ core decoder for RDOVAE + + Computes features from latents, initial state, and quantization index + + """ + + super(CoreDecoder, self).__init__() + + # hyper parameters + self.input_dim = input_dim + self.output_dim = output_dim + self.cond_size = cond_size + self.cond_size2 = cond_size2 + self.state_size = state_size + + self.input_size = self.input_dim + + self.concat_size = 4 * self.cond_size + 4 * self.cond_size2 + + # layers + self.dense_1 = nn.Linear(self.input_size, cond_size2) + self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True) + self.dense_2 = nn.Linear(cond_size, cond_size2) + self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True) + self.dense_3 = nn.Linear(cond_size, cond_size2) + self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True) + self.dense_4 = nn.Linear(cond_size, cond_size2) + self.dense_5 = nn.Linear(cond_size2, cond_size2) + + self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim) + + + self.gru_1_init = nn.Linear(self.state_size, self.cond_size) + self.gru_2_init = nn.Linear(self.state_size, self.cond_size) + self.gru_3_init = nn.Linear(self.state_size, self.cond_size) + + # initialize weights + self.apply(init_weights) + + def forward(self, z, initial_state): + + gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2)) + gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2)) + gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2)) + + # run decoding layer stack + x1 = torch.tanh(self.dense_1(z)) + x2, _ = self.gru_1(x1, gru_1_state) + x3 = torch.tanh(self.dense_2(x2)) + x4, _ = self.gru_2(x3, gru_2_state) + x5 = torch.tanh(self.dense_3(x4)) + x6, _ = self.gru_3(x5, gru_3_state) + x7 = torch.tanh(self.dense_4(x6)) + x8 = torch.tanh(self.dense_5(x7)) + x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1) + + # output layer and reshaping + x10 = self.output(x9) + features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP)) + + return features + + +class StatisticalModel(nn.Module): + def __init__(self, quant_levels, latent_dim): + """ Statistical model for latent space + + Computes scaling, deadzone, r, and theta + + """ + + super(StatisticalModel, self).__init__() + + # copy parameters + self.latent_dim = latent_dim + self.quant_levels = quant_levels + self.embedding_dim = 6 * latent_dim + + # quantization embedding + self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim) + + # initialize embedding to 0 + with torch.no_grad(): + self.quant_embedding.weight[:] = 0 + + + def forward(self, quant_ids): + """ takes quant_ids and returns statistical model parameters""" + + x = self.quant_embedding(quant_ids) + + # CAVE: theta_soft is not used anymore. Kick it out? + quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim]) + dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim]) + theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim]) + r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim]) + theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim]) + r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim]) + + + return { + 'quant_embedding' : x, + 'quant_scale' : quant_scale, + 'dead_zone' : dead_zone, + 'r_hard' : r_hard, + 'theta_hard' : theta_hard, + 'r_soft' : r_soft, + 'theta_soft' : theta_soft + } + + +class RDOVAE(nn.Module): + def __init__(self, + feature_dim, + latent_dim, + quant_levels, + cond_size, + cond_size2, + state_dim=24, + split_mode='split', + clip_weights=True, + pvq_num_pulses=82, + state_dropout_rate=0): + + super(RDOVAE, self).__init__() + + self.feature_dim = feature_dim + self.latent_dim = latent_dim + self.quant_levels = quant_levels + self.cond_size = cond_size + self.cond_size2 = cond_size2 + self.split_mode = split_mode + self.state_dim = state_dim + self.pvq_num_pulses = pvq_num_pulses + self.state_dropout_rate = state_dropout_rate + + # submodules encoder and decoder share the statistical model + self.statistical_model = StatisticalModel(quant_levels, latent_dim) + self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) + self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) + + self.enc_stride = CoreEncoder.FRAMES_PER_STEP + self.dec_stride = CoreDecoder.FRAMES_PER_STEP + + if clip_weights: + self.weight_clip_fn = weight_clip_factory(0.496) + else: + self.weight_clip_fn = None + + if self.dec_stride % self.enc_stride != 0: + raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride") + + def clip_weights(self): + if not type(self.weight_clip_fn) == type(None): + self.apply(self.weight_clip_fn) + + def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4): + + enc_stride = self.enc_stride + dec_stride = self.dec_stride + + stride = dec_stride // enc_stride + + chunks = [] + + for offset in range(stride): + # start is the smalles number = offset mod stride that decodes to a valid range + start = offset + while enc_stride * (start + 1) - dec_stride < 0: + start += stride + + # check if start is a valid index + if start >= z_frames: + raise ValueError("get_decoder_chunks_generic: range too small") + + # stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride + stop = z_frames - (z_frames % stride) + offset + while stop < z_frames: + stop += stride + + # calculate split points + length = (stop - start) + if mode == 'split': + split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop] + elif mode == 'random_split': + split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)] + else: + raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}") + + + for i in range(chunks_per_offset): + # (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop) + # encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j) + # provided that i - j = 1 mod stride + chunks.append({ + 'z_start' : split_points[i], + 'z_stop' : split_points[i + 1] - stride + 1, + 'z_stride' : stride, + 'features_start' : enc_stride * (split_points[i] + 1) - dec_stride, + 'features_stop' : enc_stride * (split_points[i + 1] - stride + 1) + }) + + return chunks + + + def forward(self, features, q_id): + + # calculate statistical model from quantization ID + statistical_model = self.statistical_model(q_id) + + # run encoder + z, states = self.core_encoder(features) + + # scaling, dead-zone and quantization + z = z * statistical_model['quant_scale'] + z = soft_dead_zone(z, statistical_model['dead_zone']) + + # quantization + z_q = hard_quantize(z) / statistical_model['quant_scale'] + z_n = noise_quantize(z) / statistical_model['quant_scale'] + states_q = soft_pvq(states, self.pvq_num_pulses) + + if self.state_dropout_rate > 0: + drop = torch.rand(states_q.size(0)) < self.state_dropout_rate + mask = torch.ones_like(states_q) + mask[drop] = 0 + states_q = states_q * mask + + # decoder + chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode) + + outputs_hq = [] + outputs_sq = [] + for chunk in chunks: + # decoder with hard quantized input + z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) + dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] + features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) + outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) + + + # decoder with soft quantized input + z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) + features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) + outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) + + return { + 'outputs_hard_quant' : outputs_hq, + 'outputs_soft_quant' : outputs_sq, + 'z' : z, + 'statistical_model' : statistical_model + } + + def encode(self, features): + """ encoder with quantization and rate estimation """ + + z, states = self.core_encoder(features) + + # quantization of initial states + states = soft_pvq(states, self.pvq_num_pulses) + state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses)) + + return z, states, state_size + + def decode(self, z, initial_state): + """ decoder (flips sequences by itself) """ + + z_reverse = torch.flip(z, [1]) + features_reverse = self.core_decoder(z_reverse, initial_state) + features = torch.flip(features_reverse, [1]) + + return features + + def quantize(self, z, q_ids): + """ quantization of latent vectors """ + + stats = self.statistical_model(q_ids) + + zq = z * stats['quant_scale'] + zq = soft_dead_zone(zq, stats['dead_zone']) + zq = torch.round(zq) + + sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False) + + return zq, sizes + + def unquantize(self, zq, q_ids): + """ re-scaling of latent vector """ + + stats = self.statistical_model(q_ids) + + z = zq / stats['quant_scale'] + + return z + + def freeze_model(self): + + # freeze all parameters + for p in self.parameters(): + p.requires_grad = False + + for p in self.statistical_model.parameters(): + p.requires_grad = True + diff --git a/dnn/torch/rdovae/requirements.txt b/dnn/torch/rdovae/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..668c8462855e77990cace86649bd28637a8eb652 --- /dev/null +++ b/dnn/torch/rdovae/requirements.txt @@ -0,0 +1,5 @@ +numpy +scipy +torch +tqdm +libs/wexchange-1.0-py3-none-any.whl \ No newline at end of file diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py new file mode 100644 index 0000000000000000000000000000000000000000..68ccf2eb0e3e01f20c9d51fdb44398db7ffaf5a9 --- /dev/null +++ b/dnn/torch/rdovae/train_rdovae.py @@ -0,0 +1,270 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +import torch +import tqdm + +from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate + + +parser = argparse.ArgumentParser() + +parser.add_argument('features', type=str, help='path to feature file in .f32 format') +parser.add_argument('output', type=str, help='path to output folder') + +parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="") + + +model_group = parser.add_argument_group(title="model parameters") +model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80) +model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) +model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256) +model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24) +model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16) +model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4) +model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104) +model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82) +model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0) + +training_group = parser.add_argument_group(title="training parameters") +training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32) +training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4) +training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100) +training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256) +training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5) +training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split') +training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames') +training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) +training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only') + +args = parser.parse_args() + +# set visible devices +os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices + +# checkpoints +checkpoint_dir = os.path.join(args.output, 'checkpoints') +checkpoint = dict() +os.makedirs(checkpoint_dir, exist_ok=True) + +# training parameters +batch_size = args.batch_size +lr = args.lr +epochs = args.epochs +sequence_length = args.sequence_length +lr_decay_factor = args.lr_decay_factor +split_mode = args.split_mode +# not exposed +adam_betas = [0.9, 0.99] +adam_eps = 1e-8 + +checkpoint['batch_size'] = batch_size +checkpoint['lr'] = lr +checkpoint['lr_decay_factor'] = lr_decay_factor +checkpoint['split_mode'] = split_mode +checkpoint['epochs'] = epochs +checkpoint['sequence_length'] = sequence_length +checkpoint['adam_betas'] = adam_betas + +# logging +log_interval = 10 + +# device +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# model parameters +cond_size = args.cond_size +cond_size2 = args.cond_size2 +latent_dim = args.latent_dim +quant_levels = args.quant_levels +lambda_min = args.lambda_min +lambda_max = args.lambda_max +state_dim = args.state_dim +# not expsed +num_features = 20 + + +# training data +feature_file = args.features + +# model +checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) +checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate} +model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) + +if type(args.initial_checkpoint) != type(None): + checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') + model.load_state_dict(checkpoint['state_dict'], strict=False) + +checkpoint['state_dict'] = model.state_dict() + +if args.train_decoder_only: + if args.initial_checkpoint is None: + print("warning: training decoder only without providing initial checkpoint") + + for p in model.core_encoder.module.parameters(): + p.requires_grad = False + + for p in model.statistical_model.parameters(): + p.requires_grad = False + +# dataloader +checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36) +checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels} +dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs']) +dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) + + + +# optimizer +params = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps) + + +# learning rate scheduler +scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) + +if __name__ == '__main__': + + # push model to device + model.to(device) + + # training loop + + for epoch in range(1, epochs + 1): + + print(f"training epoch {epoch}...") + + # running stats + running_rate_loss = 0 + running_soft_dist_loss = 0 + running_hard_dist_loss = 0 + running_hard_rate_loss = 0 + running_soft_rate_loss = 0 + running_total_loss = 0 + running_rate_metric = 0 + previous_total_loss = 0 + running_first_frame_loss = 0 + + with tqdm.tqdm(dataloader, unit='batch') as tepoch: + for i, (features, rate_lambda, q_ids) in enumerate(tepoch): + + # zero out gradients + optimizer.zero_grad() + + # push inputs to device + features = features.to(device) + q_ids = q_ids.to(device) + rate_lambda = rate_lambda.to(device) + + + rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1) + + # run model + model_output = model(features, q_ids) + + # collect outputs + z = model_output['z'] + outputs_hard_quant = model_output['outputs_hard_quant'] + outputs_soft_quant = model_output['outputs_soft_quant'] + statistical_model = model_output['statistical_model'] + + # rate loss + hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False) + soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False) + soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate) + hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate) + rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss) + hard_rate_metric = torch.mean(hard_rate) + + ## distortion losses + + # hard quantized decoder input + distortion_loss_hard_quant = torch.zeros_like(rate_loss) + for dec_features, start, stop in outputs_hard_quant: + distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant) + + first_frame_loss = torch.zeros_like(rate_loss) + for dec_features, start, stop in outputs_hard_quant: + first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant) + + # soft quantized decoder input + distortion_loss_soft_quant = torch.zeros_like(rate_loss) + for dec_features, start, stop in outputs_soft_quant: + distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant) + + # total loss + total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2 + + if args.enable_first_frame_loss: + total_loss = total_loss + 0.5 * torch.relu(first_frame_loss - distortion_loss_hard_quant) + + + total_loss.backward() + + optimizer.step() + + model.clip_weights() + + scheduler.step() + + # collect running stats + running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu()) + running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu()) + running_rate_loss += float(rate_loss.detach().cpu()) + running_rate_metric += float(hard_rate_metric.detach().cpu()) + running_total_loss += float(total_loss.detach().cpu()) + running_first_frame_loss += float(first_frame_loss.detach().cpu()) + running_soft_rate_loss += float(soft_rate_loss.detach().cpu()) + running_hard_rate_loss += float(hard_rate_loss.detach().cpu()) + + if (i + 1) % log_interval == 0: + current_loss = (running_total_loss - previous_total_loss) / log_interval + tepoch.set_postfix( + current_loss=current_loss, + total_loss=running_total_loss / (i + 1), + dist_hq=running_hard_dist_loss / (i + 1), + dist_sq=running_soft_dist_loss / (i + 1), + rate_loss=running_rate_loss / (i + 1), + rate=running_rate_metric / (i + 1), + ffloss=running_first_frame_loss / (i + 1), + rateloss_hard=running_hard_rate_loss / (i + 1), + rateloss_soft=running_soft_rate_loss / (i + 1) + ) + previous_total_loss = running_total_loss + + # save checkpoint + checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth') + checkpoint['state_dict'] = model.state_dict() + checkpoint['loss'] = running_total_loss / len(dataloader) + checkpoint['epoch'] = epoch + torch.save(checkpoint, checkpoint_path)