From fdb04d0eef2395d46c5e5c6ea428397ea8401849 Mon Sep 17 00:00:00 2001 From: jbuethe <jbuethe@amazon.de> Date: Wed, 23 Nov 2022 11:02:29 +0000 Subject: [PATCH] added pytorch implementation of RDOVAE --- dnn/torch/rdovae/README.md | 24 + dnn/torch/rdovae/export_rdovae_weights.py | 256 ++++++++ dnn/torch/rdovae/fec_encoder.py | 213 ++++++ dnn/torch/rdovae/import_rdovae_weights.py | 143 ++++ .../libs/wexchange-1.0-py3-none-any.whl | Bin 0 -> 7153 bytes dnn/torch/rdovae/packets/__init__.py | 1 + dnn/torch/rdovae/packets/fec_packets.c | 142 ++++ dnn/torch/rdovae/packets/fec_packets.h | 34 + dnn/torch/rdovae/packets/fec_packets.py | 108 +++ dnn/torch/rdovae/rdovae/__init__.py | 2 + dnn/torch/rdovae/rdovae/dataset.py | 68 ++ dnn/torch/rdovae/rdovae/rdovae.py | 614 ++++++++++++++++++ dnn/torch/rdovae/requirements.txt | 5 + dnn/torch/rdovae/train_rdovae.py | 270 ++++++++ 14 files changed, 1880 insertions(+) create mode 100644 dnn/torch/rdovae/README.md create mode 100644 dnn/torch/rdovae/export_rdovae_weights.py create mode 100644 dnn/torch/rdovae/fec_encoder.py create mode 100644 dnn/torch/rdovae/import_rdovae_weights.py create mode 100644 dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl create mode 100644 dnn/torch/rdovae/packets/__init__.py create mode 100644 dnn/torch/rdovae/packets/fec_packets.c create mode 100644 dnn/torch/rdovae/packets/fec_packets.h create mode 100644 dnn/torch/rdovae/packets/fec_packets.py create mode 100644 dnn/torch/rdovae/rdovae/__init__.py create mode 100644 dnn/torch/rdovae/rdovae/dataset.py create mode 100644 dnn/torch/rdovae/rdovae/rdovae.py create mode 100644 dnn/torch/rdovae/requirements.txt create mode 100644 dnn/torch/rdovae/train_rdovae.py diff --git a/dnn/torch/rdovae/README.md b/dnn/torch/rdovae/README.md new file mode 100644 index 000000000..14359d82d --- /dev/null +++ b/dnn/torch/rdovae/README.md @@ -0,0 +1,24 @@ +# Rate-Distortion-Optimized Variational Auto-Encoder + +## Setup +The python code requires python >= 3.6 and has been tested with python 3.6 and python 3.10. To install requirements run +``` +python -m pip install -r requirements.txt +``` + +## Training +To generate training data use dump date from the main LPCNet repo +``` +./dump_data -train 16khz_speech_input.s16 features.f32 data.s16 +``` + +To train the model, simply run +``` +python train_rdovae.py features.f32 output_folder +``` + +To train on CUDA device add `--cuda-visible-devices idx`. + + +## ToDo +- Upload checkpoints and add URLs diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py new file mode 100644 index 000000000..35b437044 --- /dev/null +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -0,0 +1,256 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint') +parser.add_argument('output_dir', type=str, help='output folder') +parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C') + +args = parser.parse_args() + +import torch +import numpy as np + +from rdovae import RDOVAE +from wexchange.torch import dump_torch_weights +from wexchange.c_export import CWriter, print_vector + + +def dump_statistical_model(writer, qembedding): + w = qembedding.weight.detach() + levels, dim = w.shape + N = dim // 6 + + print("printing statistical model") + quant_scales = torch.nn.functional.softplus(w[:, : N]).numpy() + dead_zone = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy() + r = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy() + p0 = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy() + p0 = 1 - r ** (0.5 + 0.5 * p0) + + quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16) + dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.uint16) + r_q15 = np.round(r * 2**15).astype(np.uint16) + p0_q15 = np.round(p0 * 2**15).astype(np.uint16) + + print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False) + print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False) + print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False) + print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False) + + writer.header.write( +f""" +extern const opus_uint16 dred_quant_scales_q8[{levels * N}]; +extern const opus_uint16 dred_dead_zone_q10[{levels * N}]; +extern const opus_uint16 dred_r_q15[{levels * N}]; +extern const opus_uint16 dred_p0_q15[{levels * N}]; + +""" + ) + + +def c_export(args, model): + + message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" + + enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message) + dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message) + stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message) + constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True) + + # some custom includes + for writer in [enc_writer, dec_writer, stats_writer]: + writer.header.write( +f""" +#include "opus_types.h" + +#include "dred_rdovae_constants.h" + +#include "nnet.h" +""" + ) + + # encoder + encoder_dense_layers = [ + ('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'), + ('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'), + ('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'), + ('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'), + ('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'), + ('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'), + ('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH') + ] + + for name, export_name, activation in encoder_dense_layers: + layer = model.get_submodule(name) + dump_torch_weights(enc_writer, layer, name=export_name, activation=activation, verbose=True) + + + encoder_gru_layers = [ + ('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'), + ('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'), + ('core_encoder.module.gru_3' , 'enc_dense6', 'TANH') + ] + + enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers]) + + + encoder_conv_layers = [ + ('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR') + ] + + enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_conv_layers]) + + + del enc_writer + + # decoder + decoder_dense_layers = [ + ('core_decoder.module.gru_1_init' , 'state1', 'TANH'), + ('core_decoder.module.gru_2_init' , 'state2', 'TANH'), + ('core_decoder.module.gru_3_init' , 'state3', 'TANH'), + ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'), + ('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'), + ('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'), + ('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'), + ('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'), + ('core_decoder.module.output' , 'dec_final', 'LINEAR') + ] + + for name, export_name, activation in decoder_dense_layers: + layer = model.get_submodule(name) + dump_torch_weights(dec_writer, layer, name=export_name, activation=activation, verbose=True) + + + decoder_gru_layers = [ + ('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'), + ('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'), + ('core_decoder.module.gru_3' , 'dec_dense6', 'TANH') + ] + + dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers]) + + del dec_writer + + # statistical model + qembedding = model.statistical_model.quant_embedding + dump_statistical_model(stats_writer, qembedding) + + del stats_writer + + # constants + constants_writer.header.write( +f""" +#define DRED_NUM_FEATURES {model.feature_dim} + +#define DRED_LATENT_DIM {model.latent_dim} + +#define DRED_STATE_DIME {model.state_dim} + +#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels} + +#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)} + +#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs} + +#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs} + +#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs} + +#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units} + +""" + ) + + del constants_writer + + +def numpy_export(args, model): + + exchange_name_to_name = { + 'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1', + 'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2', + 'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3', + 'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4', + 'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5', + 'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1', + 'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2', + 'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1', + 'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2', + 'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3', + 'encoder_stack_layer9_conv' : 'core_encoder.module.conv1', + 'statistical_model_embedding' : 'statistical_model.quant_embedding', + 'decoder_state1_dense' : 'core_decoder.module.gru_1_init', + 'decoder_state2_dense' : 'core_decoder.module.gru_2_init', + 'decoder_state3_dense' : 'core_decoder.module.gru_3_init', + 'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1', + 'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2', + 'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3', + 'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4', + 'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5', + 'decoder_stack_layer9_dense' : 'core_decoder.module.output', + 'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1', + 'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2', + 'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3' + } + + name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()} + + for name, exchange_name in name_to_exchange_name.items(): + print(f"printing layer {name}...") + dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name)) + + +if __name__ == "__main__": + + + os.makedirs(args.output_dir, exist_ok=True) + + + # load model from checkpoint + checkpoint = torch.load(args.checkpoint, map_location='cpu') + model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) + missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False) + + if len(missing_keys) > 0: + raise ValueError(f"error: missing keys in state dict") + + if len(unmatched_keys) > 0: + print(f"warning: the following keys were unmatched {unmatched_keys}") + + if args.format == 'C': + c_export(args, model) + elif args.format == 'numpy': + numpy_export(args, model) + else: + raise ValueError(f'error: unknown export format {args.format}') \ No newline at end of file diff --git a/dnn/torch/rdovae/fec_encoder.py b/dnn/torch/rdovae/fec_encoder.py new file mode 100644 index 000000000..291c0628b --- /dev/null +++ b/dnn/torch/rdovae/fec_encoder.py @@ -0,0 +1,213 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe and Jean-Marc Valin */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import subprocess +import argparse + +os.environ['CUDA_VISIBLE_DEVICES'] = "" + +parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames') + +parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)') +parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint') +parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame') +parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame') +parser.add_argument('output', type=str, help='output file (will be extended with .fec)') + +parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)') +parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)') +parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)") +parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)') +parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk') + +args = parser.parse_args() + +import numpy as np +from scipy.io import wavfile +import torch + +from rdovae import RDOVAE +from packets import write_fec_packets + +torch.set_num_threads(4) + +checkpoint = torch.load(args.checkpoint, map_location="cpu") +model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) +model.load_state_dict(checkpoint['state_dict'], strict=False) +model.to("cpu") + +lpc_order = 16 + +## prepare input signal +# SILK frame size is 20ms and LPCNet subframes are 10ms +subframe_size = 160 +frame_size = 2 * subframe_size + +# 91 samples delay to align with SILK decoded frames +silk_delay = 91 + +# prepend zeros to have enough history to produce the first package +zero_history = (args.num_redundancy_frames - 1) * frame_size + +# dump data has a (feature) delay of 10ms +dump_data_delay = 160 + +total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay + +# load signal +if args.input.endswith('.raw') or args.input.endswith('.pcm'): + signal = np.fromfile(args.input, dtype='int16') + +elif args.input.endswith('.wav'): + fs, signal = wavfile.read(args.input) +else: + raise ValueError(f'unknown input signal format: {args.input}') + +# fill up last frame with zeros +padded_signal_length = len(signal) + total_delay +tail = padded_signal_length % frame_size +right_padding = (frame_size - tail) % frame_size + +signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16))) + +padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw' +signal.tofile(padded_signal_file) + +# write signal and call dump_data to create features + +feature_file = os.path.splitext(args.input)[0] + '_features.f32' +command = f"{args.dump_data} -test {padded_signal_file} {feature_file}" +r = subprocess.run(command, shell=True) +if r.returncode != 0: + raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}") + +# load features +nb_features = model.feature_dim + lpc_order +nb_used_features = model.feature_dim + +# load features +features = np.fromfile(feature_file, dtype='float32') +num_subframes = len(features) // nb_features +num_subframes = 2 * (num_subframes // 2) +num_frames = num_subframes // 2 + +features = np.reshape(features, (1, -1, nb_features)) +features = features[:, :, :nb_used_features] +features = features[:, :num_subframes, :] + +# quant_ids in reverse decoding order +quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long() + +print(f"using quantization levels {quant_ids}...") + +# convert input to torch tensors +features = torch.from_numpy(features) + + +# run encoder +print("running fec encoder...") +with torch.no_grad(): + + # encoding + z, states, state_size = model.encode(features) + + + # decoder on packet chunks + input_length = args.num_redundancy_frames // 2 + offset = args.num_redundancy_frames - 1 + + packets = [] + packet_sizes = [] + + for i in range(offset, num_frames): + print(f"processing frame {i - offset}...") + # quantize / unquantize latent vectors + zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :]) + zi, rates = model.quantize(zi, quant_ids) + zi = model.unquantize(zi, quant_ids) + + features = model.decode(zi, states[:, i : i + 1, :]) + packets.append(features.squeeze(0).numpy()) + packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8) + packet_sizes.append(packet_size) + + +# write packets +packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output +write_fec_packets(packet_file, packets, packet_sizes) + + +print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps") + +# assemble features according to loss file +if args.lossfile != None: + num_packets = len(packets) + loss = np.loadtxt(args.lossfile, dtype='int16') + fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32') + foffset = -2 + ptr = 0 + count = 2 + for i in range(num_packets): + if (loss[i] == 0) or (i == num_packets - 1): + + fec_out[ptr:ptr+count,:] = packets[i][foffset:, :] + + ptr += count + foffset = -2 + count = 2 + else: + count += 2 + foffset -= 2 + + fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32) + fec_out_full[:, : fec_out.shape[-1]] = fec_out + + fec_out_full.tofile(packet_file[:-4] + f'_fec.f32') + + +if args.debug_output: + import itertools + + batches = [4] + offsets = [0, 2 * args.num_redundancy_frames - 4] + + # sanity checks + # 1. concatenate features at offset 0 + for batch, offset in itertools.product(batches, offsets): + + stop = packets[0].shape[1] - offset + test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0) + + test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32) + test_features_full[:, :nb_used_features] = test_features[:, :] + + print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}") + test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32') + diff --git a/dnn/torch/rdovae/import_rdovae_weights.py b/dnn/torch/rdovae/import_rdovae_weights.py new file mode 100644 index 000000000..eba05018c --- /dev/null +++ b/dnn/torch/rdovae/import_rdovae_weights.py @@ -0,0 +1,143 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = "" + +import argparse + + + +parser = argparse.ArgumentParser() + +parser.add_argument('exchange_folder', type=str, help='exchange folder path') +parser.add_argument('output', type=str, help='path to output model checkpoint') + +model_group = parser.add_argument_group(title="model parameters") +model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20) +model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80) +model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) +model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256) +model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24) +model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40) + +args = parser.parse_args() + +import torch +from rdovae import RDOVAE +from wexchange.torch import load_torch_weights + +exchange_name_to_name = { + 'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1', + 'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2', + 'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3', + 'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4', + 'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5', + 'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1', + 'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2', + 'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1', + 'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2', + 'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3', + 'encoder_stack_layer9_conv' : 'core_encoder.module.conv1', + 'statistical_model_embedding' : 'statistical_model.quant_embedding', + 'decoder_state1_dense' : 'core_decoder.module.gru_1_init', + 'decoder_state2_dense' : 'core_decoder.module.gru_2_init', + 'decoder_state3_dense' : 'core_decoder.module.gru_3_init', + 'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1', + 'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2', + 'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3', + 'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4', + 'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5', + 'decoder_stack_layer9_dense' : 'core_decoder.module.output', + 'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1', + 'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2', + 'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3' +} + +if __name__ == "__main__": + checkpoint = dict() + + # parameters + num_features = args.num_features + latent_dim = args.latent_dim + quant_levels = args.quant_levels + cond_size = args.cond_size + cond_size2 = args.cond_size2 + state_dim = args.state_dim + + + # model + checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) + checkpoint['model_kwargs'] = {'state_dim': state_dim} + model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) + + dense_layer_names = [ + 'encoder_stack_layer1_dense', + 'encoder_stack_layer3_dense', + 'encoder_stack_layer5_dense', + 'encoder_stack_layer7_dense', + 'encoder_stack_layer8_dense', + 'encoder_state_layer1_dense', + 'encoder_state_layer2_dense', + 'decoder_state1_dense', + 'decoder_state2_dense', + 'decoder_state3_dense', + 'decoder_stack_layer1_dense', + 'decoder_stack_layer3_dense', + 'decoder_stack_layer5_dense', + 'decoder_stack_layer7_dense', + 'decoder_stack_layer8_dense', + 'decoder_stack_layer9_dense' + ] + + gru_layer_names = [ + 'encoder_stack_layer2_gru', + 'encoder_stack_layer4_gru', + 'encoder_stack_layer6_gru', + 'decoder_stack_layer2_gru', + 'decoder_stack_layer4_gru', + 'decoder_stack_layer6_gru' + ] + + conv1d_layer_names = [ + 'encoder_stack_layer9_conv' + ] + + embedding_layer_names = [ + 'statistical_model_embedding' + ] + + for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names: + print(f"loading weights for layer {exchange_name_to_name[name]}") + layer = model.get_submodule(exchange_name_to_name[name]) + load_torch_weights(os.path.join(args.exchange_folder, name), layer) + + checkpoint['state_dict'] = model.state_dict() + + torch.save(checkpoint, args.output) \ No newline at end of file diff --git a/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..cfeebae5bafd31e7e674e73120d45c76ab177d3a GIT binary patch literal 7153 zcmaKx1yoeu+Qx?%V(6hkN<umW0RfSe7+Qu@x^sY`8zcopx*H^<OC&@<B!(^tkuK?w zh7a%kzQ5?_y>n*IS!b=;zxD2O_WPV??^gwaf=UDc05AZA-WJe6X`cJUH&-M#?dRFm z)Xmt!(9YbHOJCp8&JwP#&k1vfu=NZ;2cLE!IJ&r4Rb``;O94Cpm47(*11+HAO>z~A zH~n318uf1+V|`OMn7t$X{{taWTwFcuavW;^3x?sMJ2ec`R-HC9oVs&21mqx^VonZo zhELKudPeV36^hAQ={c|#t2p}TB^!c@wY04q<$MTk1)}b60sVMJx8r~W0DMCIkAWEL zyE<CJO&xDWmL6AW*G`Kc?E3f<CSKIjmz(kTp~XVZTVPDp^E>ne6#%xPoF-2?2Q@2U zOEo5jZPBa7YcU%g+dCGi4EamrZM`7`xP4B@K2R+bV+qX9Ul2KOu?mBvJ_&TywLN&R zD}e2a8&9l5Tp@0KN|CsuV~B*a55gId_u<fQs7cB!G>jBLQqCpP>hvq(rpeY-ucnNd z!^E(aMKXcxW96;}r?pnpESF_x&QSBzyEUMj)pU36OM915zX0Y_y<oZ^5*bvg&+e;E zVPC1(A;ja>e>*BKEnn_p@zvp}``!xzuJvrN#YCNjSvFDIELrlx@qSDSeoS7Dh9UR| zJyHn65xT_uk~TPIVAre4G+SA6@t%j7v$=fJ#E#?Iaq|`S_W3IrO?+)ub(Thu|7PyG zaCB==Sp}zZMcmZF-la5`g;kOUt2FdfEpVZdUS{j{z6?&*Jv!f4tIXVIimY*$=3wGQ z#AvA`Xj|Uq>rU&qug<q$v&KHm0{TwB5px>@08qID0Ni|U&AxwTjj_G0t-amNqw3T0 zw4ahExz1E}NU#b@7h+h_u#$?8zXJn94AiC>`j|q$w?@AMrdm_(gTGy7@Gc4r#*Aq+ z#>=lb^xOA)_P^68sl7Ob%!!ITp^Y>V6teM}{hVpMc2>R(-%tl+DLBGWHxTYe&7iFm zo6TU=xmm8e6UVK4ofR15@2Fo~-6i8`aoq7x+C?UOoD>8SX&9nELAw6Vbyp(!)Yx?) z>Q0+VD<&C2o7*WjG}4gz1mYSIxHc5F8$#C)-+0_GixS~wzX^US%iJMgtQw0|rw`^p z^^NJ`+y5-vfI*L;wUqVt#m3AzWGcLt5qsd(48c>_G>&1o!Hd`F-Og;F7x$tHqJ)jb zy<ZZ!>Ga<TBPv{S^N+bi6Bk>P;WqZIN#fyh1nW~}QF_f_8>~pN!_qC2CR5MDR1QEG zwVPO9kxN|7^h*lR1Y7t#H9d5#+t+dItOhrp&nBIn2H}u4E9vrbs6_IYQPdS>{@BW} zBtT^P7LzzmSb1qauoF9Xv9wA6J|aBaHkFZ)5@}P%(Vbui!gI&g>zx;$AxR<SniDF> z>}>F>O+5;tY3avo=3TQg`;Po(qiA+$i_&K)*^UxEqdmYtOUwx!8P~mxvT8H^*0iZj z5gly%)yi4Dxuo{HpTyk9tml@&IO9Q6>${Qz9^Sd7Qyv2Vqy$Y8r3Z0pkGAmByC2wo zFBq?q%eNVoAk6Wj0=|Wm<}dpB)!0+8=A{ULl7WUCTh=t|Fq$bV%||kqdCt8{eG=Sj zIU3(a-!t&1v~OhU)Y%=|;`Dnq^9wZOCu@2@s$0a~m^<ViaVUP2w1I}zN)Yi;sL`HU zej9E$g}09@RQt+J^{(w%^Wn-D-*ru$CXbX5%1!rJ!C3&6f67pqJCkVVh-L3BEr@x@ zUZdCBA3OirLP7Wq`#jl{yUgQ>h=eDokC5Mwb0|Lfyw3dXVfOI+1Z&33>igIaD*Qa7 z?#xUd-@-lz_@G&x%ilXS_J{yNUP*?emOB&Ze$Xo#vx*NFPCXY#oi*b3MY@M-uUDji z?2C3Zw95jc?~xK8e!nszZ;D__vMVFSNwgt4_2*L+iy+!fBWJo0^hmCboif=FR>gVa z!}s`MJNu{7=wUaw5jqx=fE;k{@%D?SVZ>+m0P_sgqbMYzLQYI_duao9c<AWpnM5aX z82DTQuv0C6W9`{+QeYHm|0O=bD<VPlIXobR>;QR!>}uZ#FX3|DWgPM{af%68d1FRT zP0>b-L%1pCTjpp5su73RqK3{#%1ImNbAHo|nmA1qk@Kyse)<F*FN*PJu%f#;gLn!O zmcUAOP8{a84;=PiKeEn8=Eu+YreQFt-z9uc`!0qsIlDgG<QU1xhX^WPdF}4yxbKc- z16@<QBzKXAm;kNEwwuNhg?IYtz4t3g`_pMn8t~O)*F50~Hi9ncAzpFh78(yCDOyJR zI39Ei*TOE?ur5?2gnKR77hhL8m*MqFIfB9K-N`O|gokU+HpEFG*ZI0>BT;#?{P6`m ztxo%FPYx{XZ3z-nP~4-Fc~%%X7KVI!R64mFF)xyMGHg~N8Cj~F<c#ciW)bWR{cgR| zffzv}bM&L{=8Y&z^n?1$%2xYHc#9kklPQf?L=H}U7T1!vP_No3Asf{FrrKq~bncKI z_g<>UJ98sy0nU|POF87ds7ADEJU*b^ey@HNa2lH*c>>t&SaRTa4VEX2s3OL@$e^+? zAZN=TxE!UqoO^QBa$LmXef=f+xT3$o^Rr@ie)nFM*tKTQuCoj-wic0ylvB*badroR z3HcQer|z2|drx(Ng55p=g>?-jM_@p=b==K;$d+i~6Nr;Mw9+|0rVMZD&tx#6`&d>f z3wn3=*Z}34$%T=GP39HFOpaDy!1gDjqmSt-K~~3`bai)Z^=jQv<L_R-hd;eK>L|2S z%HavhmB&Q)#GIY|=+e>toUZaR!=`fv196$K!rX5~xsOi6@UpdikdW9z(8DB`M$(+9 z0Ww<)ndHWG0*O8jut|@ZFSRP^l!p2j=V!WldZY$ZZH}wQav6nmxGI*8tYUq+f{)pc zK}m}J7jwGH($)hk#l=|F$uDi&+Ro=smG!Q=>TQ~GUfyeb-ic`MHQi{c5DTBqG8;x> zu_hogM<*`|^Zuv-ng{27gdKD$;R|w2%kDh4>2#q(sGqD49ojhErK3iJ;WP_-yj|9< z$P_9aEc&E%k43@kN~=BAExLDEO+X!k*Kq#5+bS!MJ%#gFwiZSSrNoC0d9z-gVZIes z{LV?0Wc89UdFexATSxa05$bxACIWgro@E1hJ8B4Hf{}MDmQk%?EkXH$u=s}s`!YjO z09(meqI?jlJa&Aw2m2+xQ|HOO?Zb-dGe5fT9txrA`UXWf4DJs)NZSX8uS-iIuvOx% z%HXaHhsi(={`pd^h-e6Do*sYa4kD2Xyz+>=+Ouafh4%xCW$~SY6vjUEsvj})lx+Kk zhih)s<K9U|yww?+h0hb7%H<3a7WkKwonDJ}>ku(-T`MI#;5VCGnGrtAT%8@lZKwwh zRY7$LlZI0;4a%F`WG~%}8GEnS-EB*AXIoQikf6bVD1r+q24Ct%*-9~ugWd7$m^tBR z;1IzfEbq<S9T?GcD_i@vv-u+hP*8}^LP9S4B6S=#gP>#W2Ax(@iOOZM{MNG^WPTk~ zBuN4nv&qUl(=Orro9X}o%@2JwhkIx&JKnzGc?-G&B&w1mIO$>BB4@d~>^;}#IHvri zX<;MLd@vaWukeDN^VT^e8o(SnS$gZ^7r=7rvE_+Y-Rl!WlXN8t|0G&pqFMKW(WzOc zqq;jlxrZEQFcmK^>FS3WhhCnc_NPML#);trWR>l59WR?}QbkV)n9td3Ak}pwHTSF` zl&5OE&5O9jjGuf)DBpCvx9cRzYpTV1wtjifc0Y{WzU5F2>g#@<o_d*HCYt+g-KG(x z*VB5cC6{Ko@@(X#YQ=I)9UOYKkU}(dy`RIU5pOHi(D*8rxGXH@k!qP@z$byA%8VAT zU4dUd2~XDo>MbuoqI2VqSZ|Ci*>64xZg$&OJ0rwFe%#nnL(Olxkfe5{q!o3r3RrAM zzc)H$zp%O9W#TtX;*(fmRYX!Q6F2^n_l!3<)p4r!6|5_u4sU2)Ha~Xk!=}+fJie3M zk5+|kIjc9znQIb(v%9xU_J)onblPOAk{1O4(8dP<z<<UAH@k6bW15PurnvE&57^j` z#MGoGgJz@G@*6iICC7eL=}0p<ghGsD>QH($E_bMek-#B+=+e+$r0l(><UK@cqF8kr zR?xdN1&h<C1WXN~Pg`WP?h?~%z3ti$<gdI~DlsGoQ$;oR4-brF4tSKEDkXRCGwU9{ z*32i=k0Fr9?Ydx92N8NM^<qE!M15)TnXtV!H82J?H%&<Rc12EF_A(f>S1)Ct%?r^t zO!&er-X?+3rp`_4mzNqbQb;8UZv7mNfird#zr2N#l9j(#w0#P{-b@`AEGtJfyT2OK zx)W{v6ew>q_?EQB&O9YkOvtjmK0a$jSok!-)h$F12w$QHmm|nD{KiNROaq*dD~G;P zof$+KaSmeTU_Y`%A9(hOx7MBMyVFD6OI2-L2fd7K(vl|*XHRwd_N;h(KX{a-$Ds)Z z0!;ncj~QnP(8+Sc7Ov;wpmE6w$AK3}>O0jJ+|G}lHVG4@XbqYG;EI(sj%Db}e61bG z8xhO1d3R)<dmhk~qQb?$nXL^{r1Ah#Ybw59HI-IA7|$o2L}+i}Y8`I3f-bEaa}`&5 z)$Wnq{mP9VOyyV?=#CbSp5G%}aF~J&rK-v+7M2($HJqhFTB(Qyh3I~MDDZjAKLsxl zV<xPYsgr*LsXD}tT>X__qQb+6))IHV1{aUY83(U-$kEz3uf0E6SGp8dCpcYTJa@@% z5nU-6Yg39+?_w$p(fUq;R5QR{MUfpj(PU_mMq?O+@?`{@jl*h$`(wmwbu^X)D-XLe zhnCjFGs%|Cg@y)CyoXMBeIjsu`<XMcZHgzrG&}oG#btI;24f9*do~nff}UuS7!$1# zE0o>i4B^bvT8hZ$7Trh)AaW!r*gEtpV-zWRV1?C|HN4ugPv|~1@w#f3Wkg{fc1=c5 z2g|Fb1Bh%mo|No$Hq-hOKB1)#AbpK=+?Us<z4cf_)Aj}v9anj-8}kw;FMC}SgFi2} zaxu8S#O@P`Nqdd_RBq5Ujbup`BgtH(cIBNI6VC#VT1Z+}=2>HL|78%FrYuRr6|Zbg z(KN$qU(w)^*o6~o)8&01LRR?3c4^_@8!X>fpbtpDr+jf4zl{D(YH8i1Jmv2xX76Zh zaXagS5ZXVpUV6w^a|jqZ7aq>P4k%!uf}aP_GNYZJbMAaW<Q4Qxx)C!wlz}K1y*Z{} zGv<TnXsg!bW1U!`mK7gx4cM9Qmk+Bxv&-50B8)s74xkenIvtH-B0oMr{S{Ygumv;* zEr+t>CiUqE|DP27D=zrY-x~pF*6LN7;U@aJ$~H5JK$TXVyld)@?=2e#%-Oo{7?T>- z+Yz<-!vF47bKQX>bu7Ac1(})1@yWAqp4(ymdE|=^joPYYPx;M76gKc`WUI|z0CSQ} zKc(fbVJI60tW(bnhTNl4|D<=h!hpV=6+xJC?x}Zzelf!-B?g2iN<5(%$aU<ix*Mv> zfa#qcFgu4~lDTB$s^AwGxD*q}tn{(blxPzRtDELBAPA5ghz&FCUyed{rxx1rrHT%P zVK}uhix!4*ozS)p8_Hxdbd4A?R$pw=h&B=|2`aeMQJMvLjX@_Ti5b?#V;;&5R-vN4 zbdSnf9`ut#i=^eb8=&&qm_p)tbFnydauhG5X|Cp5g^(Wc)X)QmLmsO1hTIW{!5>%q zKc>gUc@|mvn@jCdE)e2k+lyNF-B(SoXt8Eu?#~yTAeoS+w6qZtGz?HTdFuInmEyOz zSOHyN2xIoaGT!-;Wj}D4b_u`E=j_X<cS){o24x=ZcB^@Pl|-RpH!r805@)k3xf6=g zKC~p(<<fGE+7(B<w-DfJW{4?B?iUC|1gg2N7*kJ~ayT)$t9fKVlG!UeC{gaikn^{m zlb3A{a=lS1>7$lpj4jN2b^pTP4k30jHbgLJNCYL21E=8)bF)R#+rSF-9(oO{J$`|2 z_T^kV{Ix4lRT~s4RIbAvDO&Nn>Px9q2)DLF5e3|M_vjXD@+O*jXk_=>-O$cEo7p^R zc6pP9K;udBVn#w!WRi%dGHmyP+f?f55(#^#W1XFg*5V~g%U;)p>la$pW)7DG7uw`I z_DVArxSTfgJ?8W|;5;HZ^6ipLc^yKqCtz=aw7-h&qS$S7*v*z0f%UBsfv}9{V@-<N zZws?ho}r$I8^6GRThhHzw8c>CGD?d>8;<ey_&f85LU1an8m}X`#$aLm^n2D0FUyB9 z4pQb$!P)h)W40N{eN$D3BySPnH3GL*RJSv;)vW635c(BQRnzK8n?Y)37&zx$))`Zw zh&W;Ia8@{RB^hRdgg4X1)!>Cy;8X=G(+kVIYPDGz?Xt`AM~!ub&zpum5C_(=9OmGc zhd)^-Z?=FYn%C$*pZNNPzC~*yOKu}rrIkK0prIs68&UR{|EM_4fyuwz_3^z)8P4tJ z+woqSQ;ke4pSQL?XqA8OPG%5&s1MbZ$#`&wkR?n1M5Xn?YQ)ZhZG+*n9S7BQ(`U>& z_BmanzI3!xQ-R4e114r9L_W*xZ^7#`m!-PpP*B=cs>4_(N>QJvgr@D|f+9w<O9!Jf z^k12erfUgR^bXW3zEPsJn?`$EkvMobxj9WNo!}gnc4qcmiqh)hQsV03(owy*0o?e~ z#|Oj*DEJPO@f4b(gC|ChS}m+CN=zv_$`i?Grk%Fswo`(v!179WbQaUw98$!U5{O+c z3i%qa+Lg`LF@6l`fSTI^`lKQI&kc=xxG~tZh0H>x^}GuCVoLT}%Z}z}PH+oUYw$p% zOE3~`K3|o#1clsPotE_%?2nou*1_>b!9MTch{V%^F2m{QH17)Eszpye|7jCD9o4m* z<|BYU7T53NT(_X`dv-IvpN;Cjj8juqT3P{XKvuDD2nWK!+K(8hR^uM!opV&~gK)^R zcXL&#$;x-LbB1xD7pV?%5AyO)@QyBldIq?Lj@ZX>m^tM8x)rL_m|0jAmqBt0^=iz8 zeB*t+1CTLAM9I<u@-5Hx^Y#aj^n$Q&!U5eh#{Uu)+#aTHW9njR!wGkTvniL#_RAt5 z%FBP^dWjYd`H%sCx*IM3ZNv7Dr>Q0_sjMdDJlq|XDh5CuJom!-0f;2UeF!i+Xfq>e z2a;RZz^choj=mF0r*z1RBnpW_<sP*1wCMLzUSPMsMZm~7zlT}uXnNFMIq-5SWv^D~ ze!{h^kgOMUw3*%+t3f)Swc!dFC<4TFFX8iD@?2Cd;72)+&AAp<@}h|*q=h#qwsLD{ zAGJT}#pUOXuwF?Tjs|8hi{LB~64iRY>6{-lA>-#_U8c=wD*fzVe-U{a?djf<y(xqa zIL@=sQc`An_ngg%n2<>Esr_UzbMZcK%gh^58tAbG&WkL~tj%7K&L9huv^#U8kQzh{ zM7Y)|+gPmBvMGJD#lzAJT4l{*c(9=Ob{m%Dict%EqU^hpb-(yP(mz8Z1-0*LasE&` zhVM$tF8Li8txhk`7&fDDEGi^5r8ki$p`e!Z+Bs?EwLJHz%aJ@((aU$D>>zb7YO2cL z?eln&$d3u_=x`})k1QJ^&p{&~-LWG?7MUC2yR_rA2hOV>eyEJ6Io<bePCUOE%Pb|c zZ1-77av4-&8RyD)%nsA6*X}9+%Sw(j_U)&aZxu=gVlAA6`&U{FW<eTxc4*2V)MpH9 zfpPE#bzRl(9%k=<a2n`uMwB$wK?BeZ>&xDRF6?2&7uB}G#De0r<2)7{70^*mVf{PW z-sJU8O&r1+F{-``7(YA-HJw8Z@P@#zPI@y6&5Jum(3U7yml%m2uDx2I`$FXcJ;+1R ztk?-s5n?_&Qrg#@1_QrN5p;Z3R)HWP6QTU)Qp3$0{@16*|L5tig@@aSzf~Fj4FCXY z{AGVb{Ih#}tB7%%^|wO7FV@iCS+~jtw^@JN$Nt6gxk<5KL;EMLe|EKRlm52S`$a0i zx=s3h-FKVvxAoaC3jd9S{EFv>@~_p}?fyTb_|ts;>Yt4FTmL_;=xxy7+~O|~&&|^D zRv^E5$J>;@y?K67*lwoRKa@X|>;K+Ef71Tc_g^%^n+n)}qWw$l|3v+%$-hv={{{7% mLjQ^UQ<wiiBAxyh<iAx~1p-9-`G4%2!|_JfEr0G60R9Ko`&IJ* literal 0 HcmV?d00001 diff --git a/dnn/torch/rdovae/packets/__init__.py b/dnn/torch/rdovae/packets/__init__.py new file mode 100644 index 000000000..fb71ab3d5 --- /dev/null +++ b/dnn/torch/rdovae/packets/__init__.py @@ -0,0 +1 @@ +from .fec_packets import write_fec_packets, read_fec_packets \ No newline at end of file diff --git a/dnn/torch/rdovae/packets/fec_packets.c b/dnn/torch/rdovae/packets/fec_packets.c new file mode 100644 index 000000000..376fb4f16 --- /dev/null +++ b/dnn/torch/rdovae/packets/fec_packets.c @@ -0,0 +1,142 @@ +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include <stdio.h> +#include <inttypes.h> + +#include "fec_packets.h" + +int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index) +{ + + int16_t version; + int16_t header_size; + int16_t num_packets; + int16_t packet_size; + int16_t subframe_size; + int16_t subframes_per_packet; + int16_t num_features; + long offset; + + FILE *fid = fopen(filename, "rb"); + + /* read header */ + if (fread(&version, sizeof(version), 1, fid) != 1) goto error; + if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error; + if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error; + if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error; + if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error; + if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error; + if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error; + + /* check if indices are valid */ + if (packet_index >= num_packets || subframe_index >= subframes_per_packet) + { + fprintf(stderr, "get_fec_frame: index out of bounds\n"); + goto error; + } + + /* calculate offset in file (+ 2 is for rate) */ + offset = header_size + packet_index * packet_size + 2 + subframe_index * subframe_size; + fseek(fid, offset, SEEK_SET); + + /* read features */ + if (fread(features, sizeof(*features), num_features, fid) != num_features) goto error; + + fclose(fid); + return 0; + +error: + fclose(fid); + return 1; +} + +int get_fec_rate(const char * const filename, int packet_index) +{ + int16_t version; + int16_t header_size; + int16_t num_packets; + int16_t packet_size; + int16_t subframe_size; + int16_t subframes_per_packet; + int16_t num_features; + long offset; + int16_t rate; + + FILE *fid = fopen(filename, "rb"); + + /* read header */ + if (fread(&version, sizeof(version), 1, fid) != 1) goto error; + if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error; + if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error; + if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error; + if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error; + if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error; + if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error; + + /* check if indices are valid */ + if (packet_index >= num_packets) + { + fprintf(stderr, "get_fec_rate: index out of bounds\n"); + goto error; + } + + /* calculate offset in file (+ 2 is for rate) */ + offset = header_size + packet_index * packet_size; + fseek(fid, offset, SEEK_SET); + + /* read rate */ + if (fread(&rate, sizeof(rate), 1, fid) != 1) goto error; + + fclose(fid); + return (int) rate; + +error: + fclose(fid); + return -1; +} + +#if 0 +int main() +{ + float features[20]; + int i; + + if (get_fec_frame("../test.fec", &features[0], 0, 127)) + { + return 1; + } + + for (i = 0; i < 20; i ++) + { + printf("%d %f\n", i, features[i]); + } + + printf("rate: %d\n", get_fec_rate("../test.fec", 0)); + +} +#endif \ No newline at end of file diff --git a/dnn/torch/rdovae/packets/fec_packets.h b/dnn/torch/rdovae/packets/fec_packets.h new file mode 100644 index 000000000..35d355428 --- /dev/null +++ b/dnn/torch/rdovae/packets/fec_packets.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef _FEC_PACKETS_H +#define _FEC_PACKETS_H + +int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index); +int get_fec_rate(const char * const filename, int packet_index); + +#endif \ No newline at end of file diff --git a/dnn/torch/rdovae/packets/fec_packets.py b/dnn/torch/rdovae/packets/fec_packets.py new file mode 100644 index 000000000..14bed1f8c --- /dev/null +++ b/dnn/torch/rdovae/packets/fec_packets.py @@ -0,0 +1,108 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import numpy as np + + + +def write_fec_packets(filename, packets, rates=None): + """ writes packets in binary format """ + + assert np.dtype(np.float32).itemsize == 4 + assert np.dtype(np.int16).itemsize == 2 + + # derive some sizes + num_packets = len(packets) + subframes_per_packet = packets[0].shape[-2] + num_features = packets[0].shape[-1] + + # size of float is 4 + subframe_size = num_features * 4 + packet_size = subframe_size * subframes_per_packet + 2 # two bytes for rate + + version = 1 + # header size (version, header_size, num_packets, packet_size, subframe_size, subrames_per_packet, num_features) + header_size = 14 + + with open(filename, 'wb') as f: + + # header + f.write(np.int16(version).tobytes()) + f.write(np.int16(header_size).tobytes()) + f.write(np.int16(num_packets).tobytes()) + f.write(np.int16(packet_size).tobytes()) + f.write(np.int16(subframe_size).tobytes()) + f.write(np.int16(subframes_per_packet).tobytes()) + f.write(np.int16(num_features).tobytes()) + + # packets + for i, packet in enumerate(packets): + if type(rates) == type(None): + rate = 0 + else: + rate = rates[i] + + f.write(np.int16(rate).tobytes()) + + features = np.flip(packet, axis=-2) + f.write(features.astype(np.float32).tobytes()) + + +def read_fec_packets(filename): + """ reads packets from binary format """ + + assert np.dtype(np.float32).itemsize == 4 + assert np.dtype(np.int16).itemsize == 2 + + with open(filename, 'rb') as f: + + # header + version = np.frombuffer(f.read(2), dtype=np.int16).item() + header_size = np.frombuffer(f.read(2), dtype=np.int16).item() + num_packets = np.frombuffer(f.read(2), dtype=np.int16).item() + packet_size = np.frombuffer(f.read(2), dtype=np.int16).item() + subframe_size = np.frombuffer(f.read(2), dtype=np.int16).item() + subframes_per_packet = np.frombuffer(f.read(2), dtype=np.int16).item() + num_features = np.frombuffer(f.read(2), dtype=np.int16).item() + + dummy_features = np.zeros((subframes_per_packet, num_features), dtype=np.float32) + + # packets + rates = [] + packets = [] + for i in range(num_packets): + + rate = np.frombuffer(f.read(2), dtype=np.int16).item + rates.append(rate) + + features = np.reshape(np.frombuffer(f.read(subframe_size * subframes_per_packet), dtype=np.float32), dummy_features.shape) + packet = np.flip(features, axis=-2) + packets.append(packet) + + return packets \ No newline at end of file diff --git a/dnn/torch/rdovae/rdovae/__init__.py b/dnn/torch/rdovae/rdovae/__init__.py new file mode 100644 index 000000000..b945addec --- /dev/null +++ b/dnn/torch/rdovae/rdovae/__init__.py @@ -0,0 +1,2 @@ +from .rdovae import RDOVAE, distortion_loss, hard_rate_estimate, soft_rate_estimate +from .dataset import RDOVAEDataset diff --git a/dnn/torch/rdovae/rdovae/dataset.py b/dnn/torch/rdovae/rdovae/dataset.py new file mode 100644 index 000000000..99630d8b9 --- /dev/null +++ b/dnn/torch/rdovae/rdovae/dataset.py @@ -0,0 +1,68 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import torch +import numpy as np + +class RDOVAEDataset(torch.utils.data.Dataset): + def __init__(self, + feature_file, + sequence_length, + num_used_features=20, + num_features=36, + lambda_min=0.0002, + lambda_max=0.0135, + quant_levels=16, + enc_stride=2): + + self.sequence_length = sequence_length + self.lambda_min = lambda_min + self.lambda_max = lambda_max + self.enc_stride = enc_stride + self.quant_levels = quant_levels + self.denominator = (quant_levels - 1) / np.log(lambda_max / lambda_min) + + if sequence_length % enc_stride: + raise ValueError(f"RDOVAEDataset.__init__: enc_stride {enc_stride} does not divide sequence length {sequence_length}") + + self.features = np.reshape(np.fromfile(feature_file, dtype=np.float32), (-1, num_features)) + self.features = self.features[:, :num_used_features] + self.num_sequences = self.features.shape[0] // sequence_length + + def __len__(self): + return self.num_sequences + + def __getitem__(self, index): + features = self.features[index * self.sequence_length: (index + 1) * self.sequence_length, :] + q_ids = np.random.randint(0, self.quant_levels, (1)).astype(np.int64) + q_ids = np.repeat(q_ids, self.sequence_length // self.enc_stride, axis=0) + rate_lambda = self.lambda_min * np.exp(q_ids.astype(np.float32) / self.denominator).astype(np.float32) + + return features, rate_lambda, q_ids + diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py new file mode 100644 index 000000000..b45d2b8c3 --- /dev/null +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -0,0 +1,614 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +""" Pytorch implementations of rate distortion optimized variational autoencoder """ + +import math as m + +import torch +from torch import nn +import torch.nn.functional as F + +# Quantization and rate related utily functions + +def soft_pvq(x, k): + """ soft pyramid vector quantizer """ + + # L2 normalization + x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True)) + + + with torch.no_grad(): + # quantization loop, no need to track gradients here + x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True) + + # set initial scaling factor to k + scale_factor = k + x_scaled = scale_factor * x_norm1 + x_quant = torch.round(x_scaled) + + # we aim for ||x_quant||_L1 = k + for _ in range(10): + # remove signs and calculate L1 norm + abs_x_quant = torch.abs(x_quant) + abs_x_scaled = torch.abs(x_scaled) + l1_x_quant = torch.sum(abs_x_quant, axis=-1) + + # increase, where target is too small and decrease, where target is too large + plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values + minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values + factor = torch.where(l1_x_quant > k, minus, plus) + factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor) + scale_factor = scale_factor * factor.unsqueeze(-1) + + # update x + x_scaled = scale_factor * x_norm1 + x_quant = torch.round(x_quant) + + # L2 normalization of quantized x + x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True)) + quantization_error = x_quant_norm2 - x_norm2 + + return x_norm2 + quantization_error.detach() + +def cache_parameters(func): + cache = dict() + def cached_func(*args): + if args in cache: + return cache[args] + else: + cache[args] = func(*args) + + return cache[args] + return cached_func + +@cache_parameters +def pvq_codebook_size(n, k): + + if k == 0: + return 1 + + if n == 0: + return 0 + + return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1) + + +def soft_rate_estimate(z, r, reduce=True): + """ rate approximation with dependent theta Eq. (7)""" + + rate = torch.sum( + - torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6), + dim=-1 + ) + + if reduce: + rate = torch.mean(rate) + + return rate + + +def hard_rate_estimate(z, r, theta, reduce=True): + """ hard rate approximation """ + + z_q = torch.round(z) + p0 = 1 - r ** (0.5 + 0.5 * theta) + alpha = torch.relu(1 - torch.abs(z_q)) ** 2 + rate = - torch.sum( + (alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6) + + (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)), + dim=-1 + ) + + if reduce: + rate = torch.mean(rate) + + return rate + + + +def soft_dead_zone(x, dead_zone): + """ approximates application of a dead zone to x """ + d = dead_zone * 0.05 + return x - d * torch.tanh(x / (0.1 + d)) + + +def hard_quantize(x): + """ round with copy gradient trick """ + return x + (torch.round(x) - x).detach() + + +def noise_quantize(x): + """ simulates quantization with addition of random uniform noise """ + return x + (torch.rand_like(x) - 0.5) + + +# loss functions + + +def distortion_loss(y_true, y_pred, rate_lambda=None): + """ custom distortion loss for LPCNet features """ + + if y_true.size(-1) != 20: + raise ValueError('distortion loss is designed to work with 20 features') + + ceps_error = y_pred[..., :18] - y_true[..., :18] + pitch_error = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19]) + corr_error = y_pred[..., 19:] - y_true[..., 19:] + pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2 + + loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1) + + if type(rate_lambda) != type(None): + loss = loss / torch.sqrt(rate_lambda) + + loss = torch.mean(loss) + + return loss + + +# sampling functions + +import random + + +def random_split(start, stop, num_splits=3, min_len=3): + get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)]) + candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] + + while get_min_len(candidate) < min_len: + candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] + + return candidate + + + +# weight initialization and clipping +def init_weights(module): + + if isinstance(module, nn.GRU): + for p in module.named_parameters(): + if p[0].startswith('weight_hh_'): + nn.init.orthogonal_(p[1]) + + +def weight_clip_factory(max_value): + """ weight clipping function concerning sum of abs values of adjecent weights """ + def clip_weight_(w): + stop = w.size(1) + # omit last column if stop is odd + if stop % 2: + stop -= 1 + max_values = max_value * torch.ones_like(w[:, :stop]) + factor = max_value / torch.maximum(max_values, + torch.repeat_interleave( + torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]), + 2, + 1)) + with torch.no_grad(): + w[:, :stop] *= factor + + def clip_weights(module): + if isinstance(module, nn.GRU) or isinstance(module, nn.Linear): + for name, w in module.named_parameters(): + if name.startswith('weight'): + clip_weight_(w) + + return clip_weights + +# RDOVAE module and submodules + + +class CoreEncoder(nn.Module): + STATE_HIDDEN = 128 + FRAMES_PER_STEP = 2 + CONV_KERNEL_SIZE = 4 + + def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24): + """ core encoder for RDOVAE + + Computes latents, initial states, and rate estimates from features and lambda parameter + + """ + + super(CoreEncoder, self).__init__() + + # hyper parameters + self.feature_dim = feature_dim + self.output_dim = output_dim + self.cond_size = cond_size + self.cond_size2 = cond_size2 + self.state_size = state_size + + # derived parameters + self.input_dim = self.FRAMES_PER_STEP * self.feature_dim + self.conv_input_channels = 5 * cond_size + 3 * cond_size2 + + # layers + self.dense_1 = nn.Linear(self.input_dim, self.cond_size2) + self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) + self.dense_2 = nn.Linear(self.cond_size, self.cond_size2) + self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) + self.dense_3 = nn.Linear(self.cond_size, self.cond_size2) + self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) + self.dense_4 = nn.Linear(self.cond_size, self.cond_size) + self.dense_5 = nn.Linear(self.cond_size, self.cond_size) + self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid') + + self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN) + + self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size) + + # initialize weights + self.apply(init_weights) + + + def forward(self, features): + + # reshape features + x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2))) + + batch = x.size(0) + device = x.device + + # run encoding layer stack + x1 = torch.tanh(self.dense_1(x)) + x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device)) + x3 = torch.tanh(self.dense_2(x2)) + x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device)) + x5 = torch.tanh(self.dense_3(x4)) + x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device)) + x7 = torch.tanh(self.dense_4(x6)) + x8 = torch.tanh(self.dense_5(x7)) + + # concatenation of all hidden layer outputs + x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1) + + # init state for decoder + states = torch.tanh(self.state_dense_1(x9)) + states = torch.tanh(self.state_dense_2(states)) + + # latent representation via convolution + x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0]) + z = self.conv1(x9).permute(0, 2, 1) + + return z, states + + + + +class CoreDecoder(nn.Module): + + FRAMES_PER_STEP = 4 + + def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24): + """ core decoder for RDOVAE + + Computes features from latents, initial state, and quantization index + + """ + + super(CoreDecoder, self).__init__() + + # hyper parameters + self.input_dim = input_dim + self.output_dim = output_dim + self.cond_size = cond_size + self.cond_size2 = cond_size2 + self.state_size = state_size + + self.input_size = self.input_dim + + self.concat_size = 4 * self.cond_size + 4 * self.cond_size2 + + # layers + self.dense_1 = nn.Linear(self.input_size, cond_size2) + self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True) + self.dense_2 = nn.Linear(cond_size, cond_size2) + self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True) + self.dense_3 = nn.Linear(cond_size, cond_size2) + self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True) + self.dense_4 = nn.Linear(cond_size, cond_size2) + self.dense_5 = nn.Linear(cond_size2, cond_size2) + + self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim) + + + self.gru_1_init = nn.Linear(self.state_size, self.cond_size) + self.gru_2_init = nn.Linear(self.state_size, self.cond_size) + self.gru_3_init = nn.Linear(self.state_size, self.cond_size) + + # initialize weights + self.apply(init_weights) + + def forward(self, z, initial_state): + + gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2)) + gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2)) + gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2)) + + # run decoding layer stack + x1 = torch.tanh(self.dense_1(z)) + x2, _ = self.gru_1(x1, gru_1_state) + x3 = torch.tanh(self.dense_2(x2)) + x4, _ = self.gru_2(x3, gru_2_state) + x5 = torch.tanh(self.dense_3(x4)) + x6, _ = self.gru_3(x5, gru_3_state) + x7 = torch.tanh(self.dense_4(x6)) + x8 = torch.tanh(self.dense_5(x7)) + x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1) + + # output layer and reshaping + x10 = self.output(x9) + features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP)) + + return features + + +class StatisticalModel(nn.Module): + def __init__(self, quant_levels, latent_dim): + """ Statistical model for latent space + + Computes scaling, deadzone, r, and theta + + """ + + super(StatisticalModel, self).__init__() + + # copy parameters + self.latent_dim = latent_dim + self.quant_levels = quant_levels + self.embedding_dim = 6 * latent_dim + + # quantization embedding + self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim) + + # initialize embedding to 0 + with torch.no_grad(): + self.quant_embedding.weight[:] = 0 + + + def forward(self, quant_ids): + """ takes quant_ids and returns statistical model parameters""" + + x = self.quant_embedding(quant_ids) + + # CAVE: theta_soft is not used anymore. Kick it out? + quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim]) + dead_zone = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim]) + theta_soft = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim]) + r_soft = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim]) + theta_hard = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim]) + r_hard = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim]) + + + return { + 'quant_embedding' : x, + 'quant_scale' : quant_scale, + 'dead_zone' : dead_zone, + 'r_hard' : r_hard, + 'theta_hard' : theta_hard, + 'r_soft' : r_soft, + 'theta_soft' : theta_soft + } + + +class RDOVAE(nn.Module): + def __init__(self, + feature_dim, + latent_dim, + quant_levels, + cond_size, + cond_size2, + state_dim=24, + split_mode='split', + clip_weights=True, + pvq_num_pulses=82, + state_dropout_rate=0): + + super(RDOVAE, self).__init__() + + self.feature_dim = feature_dim + self.latent_dim = latent_dim + self.quant_levels = quant_levels + self.cond_size = cond_size + self.cond_size2 = cond_size2 + self.split_mode = split_mode + self.state_dim = state_dim + self.pvq_num_pulses = pvq_num_pulses + self.state_dropout_rate = state_dropout_rate + + # submodules encoder and decoder share the statistical model + self.statistical_model = StatisticalModel(quant_levels, latent_dim) + self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) + self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) + + self.enc_stride = CoreEncoder.FRAMES_PER_STEP + self.dec_stride = CoreDecoder.FRAMES_PER_STEP + + if clip_weights: + self.weight_clip_fn = weight_clip_factory(0.496) + else: + self.weight_clip_fn = None + + if self.dec_stride % self.enc_stride != 0: + raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride") + + def clip_weights(self): + if not type(self.weight_clip_fn) == type(None): + self.apply(self.weight_clip_fn) + + def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4): + + enc_stride = self.enc_stride + dec_stride = self.dec_stride + + stride = dec_stride // enc_stride + + chunks = [] + + for offset in range(stride): + # start is the smalles number = offset mod stride that decodes to a valid range + start = offset + while enc_stride * (start + 1) - dec_stride < 0: + start += stride + + # check if start is a valid index + if start >= z_frames: + raise ValueError("get_decoder_chunks_generic: range too small") + + # stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride + stop = z_frames - (z_frames % stride) + offset + while stop < z_frames: + stop += stride + + # calculate split points + length = (stop - start) + if mode == 'split': + split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop] + elif mode == 'random_split': + split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)] + else: + raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}") + + + for i in range(chunks_per_offset): + # (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop) + # encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j) + # provided that i - j = 1 mod stride + chunks.append({ + 'z_start' : split_points[i], + 'z_stop' : split_points[i + 1] - stride + 1, + 'z_stride' : stride, + 'features_start' : enc_stride * (split_points[i] + 1) - dec_stride, + 'features_stop' : enc_stride * (split_points[i + 1] - stride + 1) + }) + + return chunks + + + def forward(self, features, q_id): + + # calculate statistical model from quantization ID + statistical_model = self.statistical_model(q_id) + + # run encoder + z, states = self.core_encoder(features) + + # scaling, dead-zone and quantization + z = z * statistical_model['quant_scale'] + z = soft_dead_zone(z, statistical_model['dead_zone']) + + # quantization + z_q = hard_quantize(z) / statistical_model['quant_scale'] + z_n = noise_quantize(z) / statistical_model['quant_scale'] + states_q = soft_pvq(states, self.pvq_num_pulses) + + if self.state_dropout_rate > 0: + drop = torch.rand(states_q.size(0)) < self.state_dropout_rate + mask = torch.ones_like(states_q) + mask[drop] = 0 + states_q = states_q * mask + + # decoder + chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode) + + outputs_hq = [] + outputs_sq = [] + for chunk in chunks: + # decoder with hard quantized input + z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) + dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] + features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) + outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) + + + # decoder with soft quantized input + z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) + features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) + outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) + + return { + 'outputs_hard_quant' : outputs_hq, + 'outputs_soft_quant' : outputs_sq, + 'z' : z, + 'statistical_model' : statistical_model + } + + def encode(self, features): + """ encoder with quantization and rate estimation """ + + z, states = self.core_encoder(features) + + # quantization of initial states + states = soft_pvq(states, self.pvq_num_pulses) + state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses)) + + return z, states, state_size + + def decode(self, z, initial_state): + """ decoder (flips sequences by itself) """ + + z_reverse = torch.flip(z, [1]) + features_reverse = self.core_decoder(z_reverse, initial_state) + features = torch.flip(features_reverse, [1]) + + return features + + def quantize(self, z, q_ids): + """ quantization of latent vectors """ + + stats = self.statistical_model(q_ids) + + zq = z * stats['quant_scale'] + zq = soft_dead_zone(zq, stats['dead_zone']) + zq = torch.round(zq) + + sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False) + + return zq, sizes + + def unquantize(self, zq, q_ids): + """ re-scaling of latent vector """ + + stats = self.statistical_model(q_ids) + + z = zq / stats['quant_scale'] + + return z + + def freeze_model(self): + + # freeze all parameters + for p in self.parameters(): + p.requires_grad = False + + for p in self.statistical_model.parameters(): + p.requires_grad = True + diff --git a/dnn/torch/rdovae/requirements.txt b/dnn/torch/rdovae/requirements.txt new file mode 100644 index 000000000..668c84628 --- /dev/null +++ b/dnn/torch/rdovae/requirements.txt @@ -0,0 +1,5 @@ +numpy +scipy +torch +tqdm +libs/wexchange-1.0-py3-none-any.whl \ No newline at end of file diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py new file mode 100644 index 000000000..68ccf2eb0 --- /dev/null +++ b/dnn/torch/rdovae/train_rdovae.py @@ -0,0 +1,270 @@ +""" +/* Copyright (c) 2022 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" + +import os +import argparse + +import torch +import tqdm + +from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate + + +parser = argparse.ArgumentParser() + +parser.add_argument('features', type=str, help='path to feature file in .f32 format') +parser.add_argument('output', type=str, help='path to output folder') + +parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="") + + +model_group = parser.add_argument_group(title="model parameters") +model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80) +model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256) +model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256) +model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24) +model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16) +model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4) +model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104) +model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82) +model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0) + +training_group = parser.add_argument_group(title="training parameters") +training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32) +training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4) +training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100) +training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256) +training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5) +training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split') +training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames') +training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) +training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only') + +args = parser.parse_args() + +# set visible devices +os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices + +# checkpoints +checkpoint_dir = os.path.join(args.output, 'checkpoints') +checkpoint = dict() +os.makedirs(checkpoint_dir, exist_ok=True) + +# training parameters +batch_size = args.batch_size +lr = args.lr +epochs = args.epochs +sequence_length = args.sequence_length +lr_decay_factor = args.lr_decay_factor +split_mode = args.split_mode +# not exposed +adam_betas = [0.9, 0.99] +adam_eps = 1e-8 + +checkpoint['batch_size'] = batch_size +checkpoint['lr'] = lr +checkpoint['lr_decay_factor'] = lr_decay_factor +checkpoint['split_mode'] = split_mode +checkpoint['epochs'] = epochs +checkpoint['sequence_length'] = sequence_length +checkpoint['adam_betas'] = adam_betas + +# logging +log_interval = 10 + +# device +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# model parameters +cond_size = args.cond_size +cond_size2 = args.cond_size2 +latent_dim = args.latent_dim +quant_levels = args.quant_levels +lambda_min = args.lambda_min +lambda_max = args.lambda_max +state_dim = args.state_dim +# not expsed +num_features = 20 + + +# training data +feature_file = args.features + +# model +checkpoint['model_args'] = (num_features, latent_dim, quant_levels, cond_size, cond_size2) +checkpoint['model_kwargs'] = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate} +model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) + +if type(args.initial_checkpoint) != type(None): + checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') + model.load_state_dict(checkpoint['state_dict'], strict=False) + +checkpoint['state_dict'] = model.state_dict() + +if args.train_decoder_only: + if args.initial_checkpoint is None: + print("warning: training decoder only without providing initial checkpoint") + + for p in model.core_encoder.module.parameters(): + p.requires_grad = False + + for p in model.statistical_model.parameters(): + p.requires_grad = False + +# dataloader +checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36) +checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels} +dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs']) +dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) + + + +# optimizer +params = [p for p in model.parameters() if p.requires_grad] +optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps) + + +# learning rate scheduler +scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x)) + +if __name__ == '__main__': + + # push model to device + model.to(device) + + # training loop + + for epoch in range(1, epochs + 1): + + print(f"training epoch {epoch}...") + + # running stats + running_rate_loss = 0 + running_soft_dist_loss = 0 + running_hard_dist_loss = 0 + running_hard_rate_loss = 0 + running_soft_rate_loss = 0 + running_total_loss = 0 + running_rate_metric = 0 + previous_total_loss = 0 + running_first_frame_loss = 0 + + with tqdm.tqdm(dataloader, unit='batch') as tepoch: + for i, (features, rate_lambda, q_ids) in enumerate(tepoch): + + # zero out gradients + optimizer.zero_grad() + + # push inputs to device + features = features.to(device) + q_ids = q_ids.to(device) + rate_lambda = rate_lambda.to(device) + + + rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1) + + # run model + model_output = model(features, q_ids) + + # collect outputs + z = model_output['z'] + outputs_hard_quant = model_output['outputs_hard_quant'] + outputs_soft_quant = model_output['outputs_soft_quant'] + statistical_model = model_output['statistical_model'] + + # rate loss + hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False) + soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False) + soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate) + hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate) + rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss) + hard_rate_metric = torch.mean(hard_rate) + + ## distortion losses + + # hard quantized decoder input + distortion_loss_hard_quant = torch.zeros_like(rate_loss) + for dec_features, start, stop in outputs_hard_quant: + distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant) + + first_frame_loss = torch.zeros_like(rate_loss) + for dec_features, start, stop in outputs_hard_quant: + first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant) + + # soft quantized decoder input + distortion_loss_soft_quant = torch.zeros_like(rate_loss) + for dec_features, start, stop in outputs_soft_quant: + distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant) + + # total loss + total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2 + + if args.enable_first_frame_loss: + total_loss = total_loss + 0.5 * torch.relu(first_frame_loss - distortion_loss_hard_quant) + + + total_loss.backward() + + optimizer.step() + + model.clip_weights() + + scheduler.step() + + # collect running stats + running_hard_dist_loss += float(distortion_loss_hard_quant.detach().cpu()) + running_soft_dist_loss += float(distortion_loss_soft_quant.detach().cpu()) + running_rate_loss += float(rate_loss.detach().cpu()) + running_rate_metric += float(hard_rate_metric.detach().cpu()) + running_total_loss += float(total_loss.detach().cpu()) + running_first_frame_loss += float(first_frame_loss.detach().cpu()) + running_soft_rate_loss += float(soft_rate_loss.detach().cpu()) + running_hard_rate_loss += float(hard_rate_loss.detach().cpu()) + + if (i + 1) % log_interval == 0: + current_loss = (running_total_loss - previous_total_loss) / log_interval + tepoch.set_postfix( + current_loss=current_loss, + total_loss=running_total_loss / (i + 1), + dist_hq=running_hard_dist_loss / (i + 1), + dist_sq=running_soft_dist_loss / (i + 1), + rate_loss=running_rate_loss / (i + 1), + rate=running_rate_metric / (i + 1), + ffloss=running_first_frame_loss / (i + 1), + rateloss_hard=running_hard_rate_loss / (i + 1), + rateloss_soft=running_soft_rate_loss / (i + 1) + ) + previous_total_loss = running_total_loss + + # save checkpoint + checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth') + checkpoint['state_dict'] = model.state_dict() + checkpoint['loss'] = running_total_loss / len(dataloader) + checkpoint['epoch'] = epoch + torch.save(checkpoint, checkpoint_path) -- GitLab