diff --git a/dnn/torch/rdovae/README.md b/dnn/torch/rdovae/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..14359d82d0831ec22953eb40121b1784450f17c4
--- /dev/null
+++ b/dnn/torch/rdovae/README.md
@@ -0,0 +1,24 @@
+# Rate-Distortion-Optimized Variational Auto-Encoder
+
+## Setup
+The python code requires python >= 3.6 and has been tested with python 3.6 and python 3.10. To install requirements run
+```
+python -m pip install -r requirements.txt
+```
+
+## Training
+To generate training data use dump date from the main LPCNet repo
+```
+./dump_data -train 16khz_speech_input.s16 features.f32 data.s16
+```
+
+To train the model, simply run
+```
+python train_rdovae.py features.f32 output_folder
+```
+
+To train on CUDA device add `--cuda-visible-devices idx`.
+
+
+## ToDo
+- Upload checkpoints and add URLs
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..35b43704451fd9b8675dffe0bec29dbdbea58472
--- /dev/null
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -0,0 +1,256 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint')
+parser.add_argument('output_dir', type=str, help='output folder')
+parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C')
+
+args = parser.parse_args()
+
+import torch
+import numpy as np
+
+from rdovae import RDOVAE
+from wexchange.torch import dump_torch_weights
+from wexchange.c_export import CWriter, print_vector
+
+
+def dump_statistical_model(writer, qembedding):
+    w = qembedding.weight.detach()
+    levels, dim = w.shape
+    N = dim // 6
+
+    print("printing statistical model")
+    quant_scales    = torch.nn.functional.softplus(w[:, : N]).numpy()
+    dead_zone       = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy()
+    r               = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy()
+    p0              = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy()
+    p0              = 1 - r ** (0.5 + 0.5 * p0)
+
+    quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
+    dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.uint16)
+    r_q15           = np.round(r * 2**15).astype(np.uint16)
+    p0_q15          = np.round(p0 * 2**15).astype(np.uint16)
+
+    print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
+    print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
+    print_vector(writer.source, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
+    print_vector(writer.source, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
+
+    writer.header.write(
+f"""
+extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
+extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
+extern const opus_uint16 dred_r_q15[{levels * N}];
+extern const opus_uint16 dred_p0_q15[{levels * N}];
+
+"""
+    )
+
+
+def c_export(args, model):
+    
+    message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
+    
+    enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message)
+    dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message)
+    stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message)
+    constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True)
+    
+    # some custom includes
+    for writer in [enc_writer, dec_writer, stats_writer]:
+        writer.header.write(
+f"""
+#include "opus_types.h"
+
+#include "dred_rdovae_constants.h"
+
+#include "nnet.h"
+"""
+        )
+
+    # encoder
+    encoder_dense_layers = [
+        ('core_encoder.module.dense_1'       , 'enc_dense1',   'TANH'), 
+        ('core_encoder.module.dense_2'       , 'enc_dense3',   'TANH'),
+        ('core_encoder.module.dense_3'       , 'enc_dense5',   'TANH'),
+        ('core_encoder.module.dense_4'       , 'enc_dense7',   'TANH'),
+        ('core_encoder.module.dense_5'       , 'enc_dense8',   'TANH'),
+        ('core_encoder.module.state_dense_1' , 'gdense1'    ,   'TANH'),
+        ('core_encoder.module.state_dense_2' , 'gdense2'    ,   'TANH')
+    ]
+    
+    for name, export_name, activation in encoder_dense_layers:
+        layer = model.get_submodule(name)
+        dump_torch_weights(enc_writer, layer, name=export_name, activation=activation, verbose=True)
+  
+  
+    encoder_gru_layers = [    
+        ('core_encoder.module.gru_1'         , 'enc_dense2',   'TANH'),
+        ('core_encoder.module.gru_2'         , 'enc_dense4',   'TANH'),
+        ('core_encoder.module.gru_3'         , 'enc_dense6',   'TANH')
+    ]
+ 
+    enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers])
+ 
+    
+    encoder_conv_layers = [   
+        ('core_encoder.module.conv1'         , 'bits_dense' ,   'LINEAR') 
+    ]
+    
+    enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_conv_layers])    
+
+    
+    del enc_writer
+    
+    # decoder
+    decoder_dense_layers = [
+        ('core_decoder.module.gru_1_init'    , 'state1',        'TANH'),
+        ('core_decoder.module.gru_2_init'    , 'state2',        'TANH'),
+        ('core_decoder.module.gru_3_init'    , 'state3',        'TANH'),
+        ('core_decoder.module.dense_1'       , 'dec_dense1',    'TANH'),
+        ('core_decoder.module.dense_2'       , 'dec_dense3',    'TANH'),
+        ('core_decoder.module.dense_3'       , 'dec_dense5',    'TANH'),
+        ('core_decoder.module.dense_4'       , 'dec_dense7',    'TANH'),
+        ('core_decoder.module.dense_5'       , 'dec_dense8',    'TANH'),
+        ('core_decoder.module.output'        , 'dec_final',     'LINEAR')
+    ]
+
+    for name, export_name, activation in decoder_dense_layers:
+        layer = model.get_submodule(name)
+        dump_torch_weights(dec_writer, layer, name=export_name, activation=activation, verbose=True)
+        
+
+    decoder_gru_layers = [
+        ('core_decoder.module.gru_1'         , 'dec_dense2',    'TANH'),
+        ('core_decoder.module.gru_2'         , 'dec_dense4',    'TANH'),
+        ('core_decoder.module.gru_3'         , 'dec_dense6',    'TANH')
+    ]
+    
+    dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers])
+        
+    del dec_writer
+    
+    # statistical model
+    qembedding = model.statistical_model.quant_embedding
+    dump_statistical_model(stats_writer, qembedding)
+    
+    del stats_writer
+    
+    # constants
+    constants_writer.header.write(
+f"""
+#define DRED_NUM_FEATURES {model.feature_dim}
+
+#define DRED_LATENT_DIM {model.latent_dim}
+
+#define DRED_STATE_DIME {model.state_dim}
+
+#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
+
+#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
+
+#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs}
+
+#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
+
+#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs}
+
+#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units}
+
+"""
+    )
+    
+    del constants_writer
+
+
+def numpy_export(args, model):
+    
+    exchange_name_to_name = {
+        'encoder_stack_layer1_dense'    : 'core_encoder.module.dense_1',
+        'encoder_stack_layer3_dense'    : 'core_encoder.module.dense_2',
+        'encoder_stack_layer5_dense'    : 'core_encoder.module.dense_3',
+        'encoder_stack_layer7_dense'    : 'core_encoder.module.dense_4',
+        'encoder_stack_layer8_dense'    : 'core_encoder.module.dense_5',
+        'encoder_state_layer1_dense'    : 'core_encoder.module.state_dense_1',
+        'encoder_state_layer2_dense'    : 'core_encoder.module.state_dense_2',
+        'encoder_stack_layer2_gru'      : 'core_encoder.module.gru_1',
+        'encoder_stack_layer4_gru'      : 'core_encoder.module.gru_2',
+        'encoder_stack_layer6_gru'      : 'core_encoder.module.gru_3',
+        'encoder_stack_layer9_conv'     : 'core_encoder.module.conv1',
+        'statistical_model_embedding'   : 'statistical_model.quant_embedding',
+        'decoder_state1_dense'          : 'core_decoder.module.gru_1_init',
+        'decoder_state2_dense'          : 'core_decoder.module.gru_2_init',
+        'decoder_state3_dense'          : 'core_decoder.module.gru_3_init',
+        'decoder_stack_layer1_dense'    : 'core_decoder.module.dense_1',
+        'decoder_stack_layer3_dense'    : 'core_decoder.module.dense_2',
+        'decoder_stack_layer5_dense'    : 'core_decoder.module.dense_3',
+        'decoder_stack_layer7_dense'    : 'core_decoder.module.dense_4',
+        'decoder_stack_layer8_dense'    : 'core_decoder.module.dense_5',
+        'decoder_stack_layer9_dense'    : 'core_decoder.module.output',
+        'decoder_stack_layer2_gru'      : 'core_decoder.module.gru_1',
+        'decoder_stack_layer4_gru'      : 'core_decoder.module.gru_2',
+        'decoder_stack_layer6_gru'      : 'core_decoder.module.gru_3'
+    }
+    
+    name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()}
+    
+    for name, exchange_name in name_to_exchange_name.items():
+        print(f"printing layer {name}...")
+        dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name))
+
+
+if __name__ == "__main__":
+    
+    
+    os.makedirs(args.output_dir, exist_ok=True)
+    
+    
+    # load model from checkpoint
+    checkpoint = torch.load(args.checkpoint, map_location='cpu')
+    model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+    missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+    if len(missing_keys) > 0:
+        raise ValueError(f"error: missing keys in state dict")
+
+    if len(unmatched_keys) > 0:
+        print(f"warning: the following keys were unmatched {unmatched_keys}")
+    
+    if args.format == 'C':
+        c_export(args, model)
+    elif args.format == 'numpy':
+        numpy_export(args, model)
+    else:
+        raise ValueError(f'error: unknown export format {args.format}')
\ No newline at end of file
diff --git a/dnn/torch/rdovae/fec_encoder.py b/dnn/torch/rdovae/fec_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..291c0628bbb8d821e87c36a99a6fa53e893093a8
--- /dev/null
+++ b/dnn/torch/rdovae/fec_encoder.py
@@ -0,0 +1,213 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe and Jean-Marc Valin */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import subprocess
+import argparse
+
+os.environ['CUDA_VISIBLE_DEVICES'] = ""
+
+parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames')
+
+parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)')
+parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint')
+parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame')
+parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame')
+parser.add_argument('output', type=str, help='output file (will be extended with .fec)')
+
+parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
+parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)')
+parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
+parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
+parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk')
+
+args = parser.parse_args()
+
+import numpy as np
+from scipy.io import wavfile
+import torch
+
+from rdovae import RDOVAE
+from packets import write_fec_packets
+
+torch.set_num_threads(4)
+
+checkpoint = torch.load(args.checkpoint, map_location="cpu")
+model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+model.load_state_dict(checkpoint['state_dict'], strict=False)
+model.to("cpu")
+
+lpc_order = 16
+
+## prepare input signal
+# SILK frame size is 20ms and LPCNet subframes are 10ms
+subframe_size = 160
+frame_size = 2 * subframe_size
+
+# 91 samples delay to align with SILK decoded frames
+silk_delay = 91
+
+# prepend zeros to have enough history to produce the first package
+zero_history = (args.num_redundancy_frames - 1) * frame_size
+
+# dump data has a (feature) delay of 10ms
+dump_data_delay = 160
+
+total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay
+
+# load signal
+if args.input.endswith('.raw') or args.input.endswith('.pcm'):
+    signal = np.fromfile(args.input, dtype='int16')
+    
+elif args.input.endswith('.wav'):
+    fs, signal = wavfile.read(args.input)
+else:
+    raise ValueError(f'unknown input signal format: {args.input}')
+
+# fill up last frame with zeros
+padded_signal_length = len(signal) + total_delay
+tail = padded_signal_length % frame_size
+right_padding = (frame_size - tail) % frame_size
+    
+signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16)))
+
+padded_signal_file  = os.path.splitext(args.input)[0] + '_padded.raw'
+signal.tofile(padded_signal_file)
+
+# write signal and call dump_data to create features
+
+feature_file = os.path.splitext(args.input)[0] + '_features.f32'
+command = f"{args.dump_data} -test {padded_signal_file} {feature_file}"
+r = subprocess.run(command, shell=True)
+if r.returncode != 0:
+    raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}")
+
+# load features
+nb_features = model.feature_dim + lpc_order
+nb_used_features = model.feature_dim
+
+# load features
+features = np.fromfile(feature_file, dtype='float32')
+num_subframes = len(features) // nb_features
+num_subframes = 2 * (num_subframes // 2)
+num_frames = num_subframes // 2
+
+features = np.reshape(features, (1, -1, nb_features))
+features = features[:, :, :nb_used_features]
+features = features[:, :num_subframes, :]
+
+# quant_ids in reverse decoding order
+quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long()
+
+print(f"using quantization levels {quant_ids}...")
+
+# convert input to torch tensors
+features = torch.from_numpy(features)
+
+
+# run encoder
+print("running fec encoder...")
+with torch.no_grad():
+
+    # encoding
+    z, states, state_size = model.encode(features)
+
+
+    # decoder on packet chunks
+    input_length = args.num_redundancy_frames // 2
+    offset = args.num_redundancy_frames - 1
+
+    packets = []
+    packet_sizes = []
+
+    for i in range(offset, num_frames):
+        print(f"processing frame {i - offset}...")
+        # quantize / unquantize latent vectors
+        zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :])
+        zi, rates = model.quantize(zi, quant_ids)
+        zi = model.unquantize(zi, quant_ids)
+        
+        features = model.decode(zi, states[:, i : i + 1, :])
+        packets.append(features.squeeze(0).numpy())
+        packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8)
+        packet_sizes.append(packet_size)
+
+
+# write packets
+packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output
+write_fec_packets(packet_file, packets, packet_sizes)
+
+
+print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps")
+
+# assemble features according to loss file
+if args.lossfile != None:
+    num_packets = len(packets)
+    loss = np.loadtxt(args.lossfile, dtype='int16')
+    fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32')
+    foffset = -2
+    ptr = 0
+    count = 2
+    for i in range(num_packets):
+        if (loss[i] == 0) or (i == num_packets - 1):
+            
+            fec_out[ptr:ptr+count,:] = packets[i][foffset:, :]
+
+            ptr    += count
+            foffset = -2
+            count   = 2
+        else:
+            count   += 2
+            foffset -= 2
+
+    fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32)
+    fec_out_full[:, : fec_out.shape[-1]] = fec_out
+
+    fec_out_full.tofile(packet_file[:-4] + f'_fec.f32')
+    
+    
+if args.debug_output:
+    import itertools
+
+    batches = [4]
+    offsets = [0, 2 * args.num_redundancy_frames - 4]
+        
+    # sanity checks
+    # 1. concatenate features at offset 0
+    for batch, offset in itertools.product(batches, offsets):
+
+        stop = packets[0].shape[1] - offset
+        test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0)
+
+        test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32)
+        test_features_full[:, :nb_used_features] = test_features[:, :]
+
+        print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}")
+        test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')
+
diff --git a/dnn/torch/rdovae/import_rdovae_weights.py b/dnn/torch/rdovae/import_rdovae_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..eba05018cb4c1d39f4082c66cc8133f02fb213ba
--- /dev/null
+++ b/dnn/torch/rdovae/import_rdovae_weights.py
@@ -0,0 +1,143 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = ""
+
+import argparse
+
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('exchange_folder', type=str, help='exchange folder path')
+parser.add_argument('output', type=str, help='path to output model checkpoint')
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20)
+model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
+model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
+model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40)
+
+args = parser.parse_args()
+
+import torch
+from rdovae import RDOVAE
+from wexchange.torch import load_torch_weights
+
+exchange_name_to_name = {
+    'encoder_stack_layer1_dense'    : 'core_encoder.module.dense_1',
+    'encoder_stack_layer3_dense'    : 'core_encoder.module.dense_2',
+    'encoder_stack_layer5_dense'    : 'core_encoder.module.dense_3',
+    'encoder_stack_layer7_dense'    : 'core_encoder.module.dense_4',
+    'encoder_stack_layer8_dense'    : 'core_encoder.module.dense_5',
+    'encoder_state_layer1_dense'    : 'core_encoder.module.state_dense_1',
+    'encoder_state_layer2_dense'    : 'core_encoder.module.state_dense_2',
+    'encoder_stack_layer2_gru'      : 'core_encoder.module.gru_1',
+    'encoder_stack_layer4_gru'      : 'core_encoder.module.gru_2',
+    'encoder_stack_layer6_gru'      : 'core_encoder.module.gru_3',
+    'encoder_stack_layer9_conv'     : 'core_encoder.module.conv1',
+    'statistical_model_embedding'   : 'statistical_model.quant_embedding',
+    'decoder_state1_dense'          : 'core_decoder.module.gru_1_init',
+    'decoder_state2_dense'          : 'core_decoder.module.gru_2_init',
+    'decoder_state3_dense'          : 'core_decoder.module.gru_3_init',
+    'decoder_stack_layer1_dense'    : 'core_decoder.module.dense_1',
+    'decoder_stack_layer3_dense'    : 'core_decoder.module.dense_2',
+    'decoder_stack_layer5_dense'    : 'core_decoder.module.dense_3',
+    'decoder_stack_layer7_dense'    : 'core_decoder.module.dense_4',
+    'decoder_stack_layer8_dense'    : 'core_decoder.module.dense_5',
+    'decoder_stack_layer9_dense'    : 'core_decoder.module.output',
+    'decoder_stack_layer2_gru'      : 'core_decoder.module.gru_1',
+    'decoder_stack_layer4_gru'      : 'core_decoder.module.gru_2',
+    'decoder_stack_layer6_gru'      : 'core_decoder.module.gru_3'
+}
+
+if __name__ == "__main__":
+    checkpoint = dict()
+
+    # parameters
+    num_features    = args.num_features
+    latent_dim      = args.latent_dim
+    quant_levels    = args.quant_levels
+    cond_size       = args.cond_size
+    cond_size2      = args.cond_size2
+    state_dim       = args.state_dim
+    
+
+    # model
+    checkpoint['model_args']    = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
+    checkpoint['model_kwargs']  = {'state_dim': state_dim}
+    model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+    dense_layer_names = [
+        'encoder_stack_layer1_dense',
+        'encoder_stack_layer3_dense',
+        'encoder_stack_layer5_dense',
+        'encoder_stack_layer7_dense',
+        'encoder_stack_layer8_dense',
+        'encoder_state_layer1_dense',
+        'encoder_state_layer2_dense',
+        'decoder_state1_dense',      
+        'decoder_state2_dense',      
+        'decoder_state3_dense',      
+        'decoder_stack_layer1_dense',
+        'decoder_stack_layer3_dense',
+        'decoder_stack_layer5_dense',
+        'decoder_stack_layer7_dense',
+        'decoder_stack_layer8_dense',
+        'decoder_stack_layer9_dense'
+    ]
+
+    gru_layer_names = [
+        'encoder_stack_layer2_gru',
+        'encoder_stack_layer4_gru',
+        'encoder_stack_layer6_gru',
+        'decoder_stack_layer2_gru',
+        'decoder_stack_layer4_gru',
+        'decoder_stack_layer6_gru' 
+    ]
+
+    conv1d_layer_names = [
+        'encoder_stack_layer9_conv'
+    ]
+
+    embedding_layer_names = [
+        'statistical_model_embedding'
+    ]
+
+    for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names:
+        print(f"loading weights for layer {exchange_name_to_name[name]}")
+        layer = model.get_submodule(exchange_name_to_name[name])
+        load_torch_weights(os.path.join(args.exchange_folder, name), layer)
+
+    checkpoint['state_dict'] = model.state_dict()
+
+    torch.save(checkpoint, args.output)
\ No newline at end of file
diff --git a/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl
new file mode 100644
index 0000000000000000000000000000000000000000..cfeebae5bafd31e7e674e73120d45c76ab177d3a
Binary files /dev/null and b/dnn/torch/rdovae/libs/wexchange-1.0-py3-none-any.whl differ
diff --git a/dnn/torch/rdovae/packets/__init__.py b/dnn/torch/rdovae/packets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb71ab3d55691d2cd53137ce9be0708b250edda1
--- /dev/null
+++ b/dnn/torch/rdovae/packets/__init__.py
@@ -0,0 +1 @@
+from .fec_packets import write_fec_packets, read_fec_packets
\ No newline at end of file
diff --git a/dnn/torch/rdovae/packets/fec_packets.c b/dnn/torch/rdovae/packets/fec_packets.c
new file mode 100644
index 0000000000000000000000000000000000000000..376fb4f169851600ab3c0e1760f47c18e9ac5c8c
--- /dev/null
+++ b/dnn/torch/rdovae/packets/fec_packets.c
@@ -0,0 +1,142 @@
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
+#include <stdio.h>
+#include <inttypes.h>
+
+#include "fec_packets.h"
+
+int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index)
+{
+
+    int16_t version;
+    int16_t header_size;
+    int16_t num_packets;
+    int16_t packet_size;
+    int16_t subframe_size;
+    int16_t subframes_per_packet;
+    int16_t num_features;
+    long offset;
+
+    FILE *fid = fopen(filename, "rb");
+    
+    /* read header */
+    if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
+    if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
+    if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
+    if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
+    if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
+    if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
+    if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
+
+    /* check if indices are valid */
+    if (packet_index >= num_packets || subframe_index >= subframes_per_packet)
+    {
+        fprintf(stderr, "get_fec_frame: index out of bounds\n");
+        goto error;
+    }
+
+    /* calculate offset in file (+ 2 is for rate) */
+    offset = header_size + packet_index * packet_size + 2 + subframe_index * subframe_size;
+    fseek(fid, offset, SEEK_SET);
+
+    /* read features */
+    if (fread(features, sizeof(*features), num_features, fid) != num_features) goto error;
+
+    fclose(fid);
+    return 0;
+
+error:
+    fclose(fid);
+    return 1;
+}
+
+int get_fec_rate(const char * const filename, int packet_index)
+{
+    int16_t version;
+    int16_t header_size;
+    int16_t num_packets;
+    int16_t packet_size;
+    int16_t subframe_size;
+    int16_t subframes_per_packet;
+    int16_t num_features;
+    long offset;
+    int16_t rate;
+
+    FILE *fid = fopen(filename, "rb");
+    
+    /* read header */
+    if (fread(&version, sizeof(version), 1, fid) != 1) goto error;
+    if (fread(&header_size, sizeof(header_size), 1, fid) != 1) goto error;
+    if (fread(&num_packets, sizeof(num_packets), 1, fid) != 1) goto error;
+    if (fread(&packet_size, sizeof(packet_size), 1, fid) != 1) goto error;
+    if (fread(&subframe_size, sizeof(subframe_size), 1, fid) != 1) goto error;
+    if (fread(&subframes_per_packet, sizeof(subframes_per_packet), 1, fid) != 1) goto error;
+    if (fread(&num_features, sizeof(num_features), 1, fid) != 1) goto error;
+
+    /* check if indices are valid */
+    if (packet_index >= num_packets)
+    {
+        fprintf(stderr, "get_fec_rate: index out of bounds\n");
+        goto error;
+    }
+
+    /* calculate offset in file (+ 2 is for rate) */
+    offset = header_size + packet_index * packet_size;
+    fseek(fid, offset, SEEK_SET);
+
+    /* read rate */
+    if (fread(&rate, sizeof(rate), 1, fid) != 1) goto error;
+
+    fclose(fid);
+    return (int) rate;
+
+error:
+    fclose(fid);
+    return -1;
+}
+
+#if 0
+int main()
+{
+    float features[20];
+    int i;
+
+    if (get_fec_frame("../test.fec", &features[0], 0, 127))
+    {
+        return 1;
+    }
+
+    for (i = 0; i < 20; i ++)
+    {
+        printf("%d %f\n", i, features[i]);
+    }
+
+    printf("rate: %d\n", get_fec_rate("../test.fec", 0));
+
+}
+#endif
\ No newline at end of file
diff --git a/dnn/torch/rdovae/packets/fec_packets.h b/dnn/torch/rdovae/packets/fec_packets.h
new file mode 100644
index 0000000000000000000000000000000000000000..35d355428a3ddc15ece26659298473a4975c026f
--- /dev/null
+++ b/dnn/torch/rdovae/packets/fec_packets.h
@@ -0,0 +1,34 @@
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+
+#ifndef _FEC_PACKETS_H
+#define _FEC_PACKETS_H
+
+int get_fec_frame(const char * const filename, float *features, int packet_index, int subframe_index);
+int get_fec_rate(const char * const filename, int packet_index);
+
+#endif
\ No newline at end of file
diff --git a/dnn/torch/rdovae/packets/fec_packets.py b/dnn/torch/rdovae/packets/fec_packets.py
new file mode 100644
index 0000000000000000000000000000000000000000..14bed1f8cea0e7cac413b2627f43b3e42e2f7ca5
--- /dev/null
+++ b/dnn/torch/rdovae/packets/fec_packets.py
@@ -0,0 +1,108 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import numpy as np
+
+
+
+def write_fec_packets(filename, packets, rates=None):
+    """ writes packets in binary format """
+    
+    assert np.dtype(np.float32).itemsize == 4
+    assert np.dtype(np.int16).itemsize == 2
+    
+    # derive some sizes 
+    num_packets             = len(packets)
+    subframes_per_packet    = packets[0].shape[-2]
+    num_features            = packets[0].shape[-1]
+    
+    # size of float is 4
+    subframe_size           = num_features * 4
+    packet_size             = subframe_size * subframes_per_packet + 2 # two bytes for rate
+    
+    version = 1
+    # header size (version, header_size, num_packets, packet_size, subframe_size, subrames_per_packet, num_features)
+    header_size = 14
+    
+    with open(filename, 'wb') as f:
+        
+        # header
+        f.write(np.int16(version).tobytes())
+        f.write(np.int16(header_size).tobytes())
+        f.write(np.int16(num_packets).tobytes())
+        f.write(np.int16(packet_size).tobytes())
+        f.write(np.int16(subframe_size).tobytes())
+        f.write(np.int16(subframes_per_packet).tobytes())
+        f.write(np.int16(num_features).tobytes())
+        
+        # packets
+        for i, packet in enumerate(packets):
+            if type(rates) == type(None):
+                rate = 0
+            else:
+                rate = rates[i]
+            
+            f.write(np.int16(rate).tobytes())
+            
+            features = np.flip(packet, axis=-2)
+            f.write(features.astype(np.float32).tobytes())
+            
+        
+def read_fec_packets(filename):
+    """ reads packets from binary format """
+    
+    assert np.dtype(np.float32).itemsize == 4
+    assert np.dtype(np.int16).itemsize == 2
+    
+    with open(filename, 'rb') as f:
+        
+        # header
+        version                 = np.frombuffer(f.read(2), dtype=np.int16).item()
+        header_size             = np.frombuffer(f.read(2), dtype=np.int16).item()
+        num_packets             = np.frombuffer(f.read(2), dtype=np.int16).item()
+        packet_size             = np.frombuffer(f.read(2), dtype=np.int16).item()
+        subframe_size           = np.frombuffer(f.read(2), dtype=np.int16).item()
+        subframes_per_packet    = np.frombuffer(f.read(2), dtype=np.int16).item()
+        num_features            = np.frombuffer(f.read(2), dtype=np.int16).item()
+        
+        dummy_features          = np.zeros((subframes_per_packet, num_features), dtype=np.float32)
+        
+        # packets
+        rates = []
+        packets = []
+        for i in range(num_packets):
+                     
+            rate = np.frombuffer(f.read(2), dtype=np.int16).item
+            rates.append(rate)
+            
+            features = np.reshape(np.frombuffer(f.read(subframe_size * subframes_per_packet), dtype=np.float32), dummy_features.shape)
+            packet = np.flip(features, axis=-2)
+            packets.append(packet)
+            
+    return packets
\ No newline at end of file
diff --git a/dnn/torch/rdovae/rdovae/__init__.py b/dnn/torch/rdovae/rdovae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b945addec72168fa6c7b7535ee2f1c20089a6f15
--- /dev/null
+++ b/dnn/torch/rdovae/rdovae/__init__.py
@@ -0,0 +1,2 @@
+from .rdovae import RDOVAE, distortion_loss, hard_rate_estimate, soft_rate_estimate
+from .dataset import RDOVAEDataset
diff --git a/dnn/torch/rdovae/rdovae/dataset.py b/dnn/torch/rdovae/rdovae/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..99630d8b924736e59849d15f88cf8c4c45224685
--- /dev/null
+++ b/dnn/torch/rdovae/rdovae/dataset.py
@@ -0,0 +1,68 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import torch
+import numpy as np
+
+class RDOVAEDataset(torch.utils.data.Dataset):
+    def __init__(self,
+                feature_file,
+                sequence_length,
+                num_used_features=20,
+                num_features=36,
+                lambda_min=0.0002,
+                lambda_max=0.0135,
+                quant_levels=16,
+                enc_stride=2):
+        
+        self.sequence_length = sequence_length
+        self.lambda_min = lambda_min
+        self.lambda_max = lambda_max
+        self.enc_stride = enc_stride
+        self.quant_levels = quant_levels
+        self.denominator = (quant_levels - 1) / np.log(lambda_max / lambda_min)
+
+        if sequence_length % enc_stride:
+            raise ValueError(f"RDOVAEDataset.__init__: enc_stride {enc_stride} does not divide sequence length {sequence_length}")
+        
+        self.features = np.reshape(np.fromfile(feature_file, dtype=np.float32), (-1, num_features))
+        self.features = self.features[:, :num_used_features]
+        self.num_sequences = self.features.shape[0] // sequence_length
+
+    def __len__(self):
+        return self.num_sequences
+
+    def __getitem__(self, index):
+        features = self.features[index * self.sequence_length: (index + 1) * self.sequence_length, :]
+        q_ids = np.random.randint(0, self.quant_levels, (1)).astype(np.int64)
+        q_ids = np.repeat(q_ids, self.sequence_length // self.enc_stride, axis=0)
+        rate_lambda = self.lambda_min * np.exp(q_ids.astype(np.float32) / self.denominator).astype(np.float32)
+
+        return features, rate_lambda, q_ids
+
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45d2b8c3b5a329ce863160f814b44ac06023e38
--- /dev/null
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -0,0 +1,614 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+""" Pytorch implementations of rate distortion optimized variational autoencoder """
+
+import math as m
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+# Quantization and rate related utily functions
+
+def soft_pvq(x, k):
+    """ soft pyramid vector quantizer """
+
+    # L2 normalization
+    x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True))
+    
+
+    with torch.no_grad():
+        # quantization loop, no need to track gradients here
+        x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True)
+
+        # set initial scaling factor to k
+        scale_factor = k
+        x_scaled = scale_factor * x_norm1
+        x_quant = torch.round(x_scaled)
+
+        # we aim for ||x_quant||_L1 = k
+        for _ in range(10):
+            # remove signs and calculate L1 norm
+            abs_x_quant = torch.abs(x_quant)
+            abs_x_scaled = torch.abs(x_scaled)
+            l1_x_quant = torch.sum(abs_x_quant, axis=-1)
+
+            # increase, where target is too small and decrease, where target is too large
+            plus  = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
+            minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
+            factor = torch.where(l1_x_quant > k, minus, plus)
+            factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor)
+            scale_factor = scale_factor * factor.unsqueeze(-1)
+
+            # update x
+            x_scaled = scale_factor * x_norm1
+            x_quant = torch.round(x_quant)
+
+    # L2 normalization of quantized x
+    x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True))
+    quantization_error = x_quant_norm2 - x_norm2
+
+    return x_norm2 + quantization_error.detach()
+
+def cache_parameters(func):
+    cache = dict()
+    def cached_func(*args):
+        if args in cache:
+            return cache[args]
+        else:
+            cache[args] = func(*args)
+        
+        return cache[args]
+    return cached_func
+        
+@cache_parameters
+def pvq_codebook_size(n, k):
+    
+    if k == 0:
+        return 1
+    
+    if n == 0:
+        return 0
+    
+    return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1)
+
+
+def soft_rate_estimate(z, r, reduce=True):
+    """ rate approximation with dependent theta Eq. (7)"""
+
+    rate = torch.sum(
+        - torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6),
+        dim=-1
+    )
+
+    if reduce:
+        rate = torch.mean(rate)
+
+    return rate
+
+
+def hard_rate_estimate(z, r, theta, reduce=True):
+    """ hard rate approximation """
+
+    z_q = torch.round(z)
+    p0 = 1 - r ** (0.5 + 0.5 * theta)
+    alpha = torch.relu(1 - torch.abs(z_q)) ** 2
+    rate = - torch.sum(
+        (alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6) 
+        + (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)),
+        dim=-1
+    )
+
+    if reduce:
+        rate = torch.mean(rate)
+
+    return rate
+
+
+
+def soft_dead_zone(x, dead_zone):
+    """ approximates application of a dead zone to x """
+    d = dead_zone * 0.05
+    return x - d * torch.tanh(x / (0.1 + d))
+
+
+def hard_quantize(x):
+    """ round with copy gradient trick """
+    return x + (torch.round(x) - x).detach()
+
+
+def noise_quantize(x):
+    """ simulates quantization with addition of random uniform noise """
+    return x + (torch.rand_like(x) - 0.5)
+
+
+# loss functions
+
+
+def distortion_loss(y_true, y_pred, rate_lambda=None):
+    """ custom distortion loss for LPCNet features """
+    
+    if y_true.size(-1) != 20:
+        raise ValueError('distortion loss is designed to work with 20 features')
+
+    ceps_error   = y_pred[..., :18] - y_true[..., :18]
+    pitch_error  = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19])
+    corr_error   = y_pred[..., 19:] - y_true[..., 19:]
+    pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
+
+    loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1)
+
+    if type(rate_lambda) != type(None):
+        loss = loss / torch.sqrt(rate_lambda)
+
+    loss = torch.mean(loss)
+        
+    return loss
+
+
+# sampling functions
+
+import random
+
+
+def random_split(start, stop, num_splits=3, min_len=3):
+    get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)])
+    candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
+    
+    while get_min_len(candidate) < min_len: 
+        candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
+    
+    return candidate
+
+
+
+# weight initialization and clipping
+def init_weights(module):
+    
+    if isinstance(module, nn.GRU):
+        for p in module.named_parameters():
+            if p[0].startswith('weight_hh_'):
+                nn.init.orthogonal_(p[1])
+
+    
+def weight_clip_factory(max_value):
+    """ weight clipping function concerning sum of abs values of adjecent weights """
+    def clip_weight_(w):
+        stop = w.size(1)
+        # omit last column if stop is odd
+        if stop % 2:
+            stop -= 1
+        max_values = max_value * torch.ones_like(w[:, :stop])
+        factor = max_value / torch.maximum(max_values,
+                                 torch.repeat_interleave(
+                                     torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]),
+                                     2,
+                                     1))
+        with torch.no_grad():
+            w[:, :stop] *= factor
+    
+    def clip_weights(module):
+        if isinstance(module, nn.GRU) or isinstance(module, nn.Linear):
+            for name, w in module.named_parameters():
+                if name.startswith('weight'):
+                    clip_weight_(w)
+    
+    return clip_weights
+
+# RDOVAE module and submodules
+
+
+class CoreEncoder(nn.Module):
+    STATE_HIDDEN = 128
+    FRAMES_PER_STEP = 2
+    CONV_KERNEL_SIZE = 4
+    
+    def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24):
+        """ core encoder for RDOVAE
+        
+            Computes latents, initial states, and rate estimates from features and lambda parameter
+        
+        """
+
+        super(CoreEncoder, self).__init__()
+
+        # hyper parameters
+        self.feature_dim        = feature_dim
+        self.output_dim         = output_dim
+        self.cond_size          = cond_size
+        self.cond_size2         = cond_size2
+        self.state_size         = state_size
+
+        # derived parameters
+        self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
+        self.conv_input_channels =  5 * cond_size + 3 * cond_size2
+
+        # layers
+        self.dense_1 = nn.Linear(self.input_dim, self.cond_size2)
+        self.gru_1   = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
+        self.dense_2 = nn.Linear(self.cond_size, self.cond_size2)
+        self.gru_2   = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
+        self.dense_3 = nn.Linear(self.cond_size, self.cond_size2)
+        self.gru_3   = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
+        self.dense_4 = nn.Linear(self.cond_size, self.cond_size)
+        self.dense_5 = nn.Linear(self.cond_size, self.cond_size)
+        self.conv1   = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid')
+
+        self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN)
+
+        self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
+
+        # initialize weights
+        self.apply(init_weights)
+
+
+    def forward(self, features):
+
+        # reshape features
+        x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2)))
+
+        batch = x.size(0)
+        device = x.device
+
+        # run encoding layer stack
+        x1      = torch.tanh(self.dense_1(x))
+        x2, _   = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device))
+        x3      = torch.tanh(self.dense_2(x2))
+        x4, _   = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device))
+        x5      = torch.tanh(self.dense_3(x4))
+        x6, _   = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device))
+        x7      = torch.tanh(self.dense_4(x6))
+        x8      = torch.tanh(self.dense_5(x7))
+
+        # concatenation of all hidden layer outputs
+        x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+        
+        # init state for decoder
+        states = torch.tanh(self.state_dense_1(x9))
+        states = torch.tanh(self.state_dense_2(states))
+
+        # latent representation via convolution
+        x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0])
+        z = self.conv1(x9).permute(0, 2, 1)
+
+        return z, states
+
+
+
+
+class CoreDecoder(nn.Module):
+
+    FRAMES_PER_STEP = 4
+
+    def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24):
+        """ core decoder for RDOVAE
+        
+            Computes features from latents, initial state, and quantization index
+        
+        """
+
+        super(CoreDecoder, self).__init__()
+
+        # hyper parameters
+        self.input_dim  = input_dim
+        self.output_dim = output_dim
+        self.cond_size  = cond_size
+        self.cond_size2 = cond_size2
+        self.state_size = state_size
+
+        self.input_size = self.input_dim
+        
+        self.concat_size = 4 * self.cond_size + 4 * self.cond_size2
+
+        # layers
+        self.dense_1    = nn.Linear(self.input_size, cond_size2)
+        self.gru_1      = nn.GRU(cond_size2, cond_size, batch_first=True)
+        self.dense_2    = nn.Linear(cond_size, cond_size2)
+        self.gru_2      = nn.GRU(cond_size2, cond_size, batch_first=True)
+        self.dense_3    = nn.Linear(cond_size, cond_size2)
+        self.gru_3      = nn.GRU(cond_size2, cond_size, batch_first=True)
+        self.dense_4    = nn.Linear(cond_size, cond_size2)
+        self.dense_5    = nn.Linear(cond_size2, cond_size2)
+
+        self.output  = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim)
+
+
+        self.gru_1_init = nn.Linear(self.state_size, self.cond_size)
+        self.gru_2_init = nn.Linear(self.state_size, self.cond_size)
+        self.gru_3_init = nn.Linear(self.state_size, self.cond_size)
+
+        # initialize weights
+        self.apply(init_weights)
+
+    def forward(self, z, initial_state):
+        
+        gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2))
+        gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2))
+        gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2))
+
+        # run decoding layer stack
+        x1  = torch.tanh(self.dense_1(z))
+        x2, _ = self.gru_1(x1, gru_1_state)
+        x3  = torch.tanh(self.dense_2(x2))
+        x4, _ = self.gru_2(x3, gru_2_state)
+        x5  = torch.tanh(self.dense_3(x4))
+        x6, _ = self.gru_3(x5, gru_3_state)
+        x7  = torch.tanh(self.dense_4(x6))
+        x8  = torch.tanh(self.dense_5(x7))
+        x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+
+        # output layer and reshaping
+        x10 = self.output(x9)
+        features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
+
+        return features
+
+
+class StatisticalModel(nn.Module):
+    def __init__(self, quant_levels, latent_dim):
+        """ Statistical model for latent space
+        
+            Computes scaling, deadzone, r, and theta 
+        
+        """
+
+        super(StatisticalModel, self).__init__()
+
+        # copy parameters
+        self.latent_dim     = latent_dim
+        self.quant_levels   = quant_levels
+        self.embedding_dim  = 6 * latent_dim
+
+        # quantization embedding
+        self.quant_embedding    = nn.Embedding(quant_levels, self.embedding_dim)
+        
+        # initialize embedding to 0
+        with torch.no_grad():
+            self.quant_embedding.weight[:] = 0
+
+
+    def forward(self, quant_ids):
+        """ takes quant_ids and returns statistical model parameters"""
+
+        x = self.quant_embedding(quant_ids)
+
+        # CAVE: theta_soft is not used anymore. Kick it out?
+        quant_scale = F.softplus(x[..., 0 * self.latent_dim : 1 * self.latent_dim])
+        dead_zone   = F.softplus(x[..., 1 * self.latent_dim : 2 * self.latent_dim])
+        theta_soft  = torch.sigmoid(x[..., 2 * self.latent_dim : 3 * self.latent_dim])
+        r_soft      = torch.sigmoid(x[..., 3 * self.latent_dim : 4 * self.latent_dim])
+        theta_hard  = torch.sigmoid(x[..., 4 * self.latent_dim : 5 * self.latent_dim])
+        r_hard      = torch.sigmoid(x[..., 5 * self.latent_dim : 6 * self.latent_dim])
+        
+
+        return {
+            'quant_embedding'   : x,
+            'quant_scale'       : quant_scale,
+            'dead_zone'         : dead_zone,
+            'r_hard'            : r_hard,
+            'theta_hard'        : theta_hard,
+            'r_soft'            : r_soft,
+            'theta_soft'        : theta_soft
+        }
+
+
+class RDOVAE(nn.Module):
+    def __init__(self,
+                 feature_dim,
+                 latent_dim,
+                 quant_levels,
+                 cond_size,
+                 cond_size2,
+                 state_dim=24,
+                 split_mode='split',
+                 clip_weights=True,
+                 pvq_num_pulses=82,
+                 state_dropout_rate=0):
+
+        super(RDOVAE, self).__init__()
+
+        self.feature_dim    = feature_dim
+        self.latent_dim     = latent_dim
+        self.quant_levels   = quant_levels
+        self.cond_size      = cond_size
+        self.cond_size2     = cond_size2
+        self.split_mode     = split_mode
+        self.state_dim      = state_dim
+        self.pvq_num_pulses = pvq_num_pulses
+        self.state_dropout_rate = state_dropout_rate
+        
+        # submodules encoder and decoder share the statistical model
+        self.statistical_model = StatisticalModel(quant_levels, latent_dim)
+        self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
+        self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
+        
+        self.enc_stride = CoreEncoder.FRAMES_PER_STEP
+        self.dec_stride = CoreDecoder.FRAMES_PER_STEP
+       
+        if clip_weights:
+            self.weight_clip_fn = weight_clip_factory(0.496)
+        else:
+            self.weight_clip_fn = None
+        
+        if self.dec_stride % self.enc_stride != 0:
+            raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride")
+    
+    def clip_weights(self):
+        if not type(self.weight_clip_fn) == type(None):
+            self.apply(self.weight_clip_fn)
+            
+    def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
+        
+        enc_stride = self.enc_stride
+        dec_stride = self.dec_stride
+
+        stride = dec_stride // enc_stride
+        
+        chunks = []
+
+        for offset in range(stride):
+            # start is the smalles number = offset mod stride that decodes to a valid range
+            start = offset
+            while enc_stride * (start + 1) - dec_stride < 0:
+                start += stride
+
+            # check if start is a valid index
+            if start >= z_frames:
+                raise ValueError("get_decoder_chunks_generic: range too small")
+
+            # stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride
+            stop = z_frames - (z_frames % stride) + offset
+            while stop < z_frames:
+                stop += stride
+
+            # calculate split points
+            length = (stop - start)
+            if mode == 'split':
+                split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
+            elif mode == 'random_split':
+                split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
+            else:
+                raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
+
+
+            for i in range(chunks_per_offset):
+                # (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop)
+                # encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j)
+                # provided that i - j = 1 mod stride
+                chunks.append({
+                    'z_start'         : split_points[i],
+                    'z_stop'          : split_points[i + 1] - stride + 1,
+                    'z_stride'        : stride,
+                    'features_start'  : enc_stride * (split_points[i] + 1) - dec_stride,
+                    'features_stop'   : enc_stride * (split_points[i + 1] - stride + 1)
+                })
+
+        return chunks
+
+
+    def forward(self, features, q_id):
+
+        # calculate statistical model from quantization ID
+        statistical_model = self.statistical_model(q_id)
+
+        # run encoder
+        z, states = self.core_encoder(features)
+
+        # scaling, dead-zone and quantization
+        z = z * statistical_model['quant_scale']
+        z = soft_dead_zone(z, statistical_model['dead_zone'])
+
+        # quantization
+        z_q = hard_quantize(z) / statistical_model['quant_scale']
+        z_n = noise_quantize(z) / statistical_model['quant_scale']
+        states_q = soft_pvq(states, self.pvq_num_pulses)
+        
+        if self.state_dropout_rate > 0:
+            drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
+            mask = torch.ones_like(states_q)
+            mask[drop] = 0
+            states_q = states_q * mask
+
+        # decoder
+        chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode)
+
+        outputs_hq = []
+        outputs_sq = []
+        for chunk in chunks:
+            # decoder with hard quantized input
+            z_dec_reverse       = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
+            dec_initial_state   = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
+            features_reverse = self.core_decoder(z_dec_reverse,  dec_initial_state)
+            outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
+
+
+            # decoder with soft quantized input
+            z_dec_reverse       = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :],  [1])
+            features_reverse    = self.core_decoder(z_dec_reverse, dec_initial_state)
+            outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))          
+
+        return {
+            'outputs_hard_quant' : outputs_hq,
+            'outputs_soft_quant' : outputs_sq,
+            'z'                 : z,
+            'statistical_model' : statistical_model
+        }
+
+    def encode(self, features):
+        """ encoder with quantization and rate estimation """
+        
+        z, states = self.core_encoder(features)
+        
+        # quantization of initial states
+        states = soft_pvq(states, self.pvq_num_pulses)     
+        state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses))
+        
+        return z, states, state_size
+
+    def decode(self, z, initial_state):
+        """ decoder (flips sequences by itself) """
+        
+        z_reverse       = torch.flip(z, [1])
+        features_reverse = self.core_decoder(z_reverse, initial_state)
+        features = torch.flip(features_reverse, [1])
+        
+        return features
+        
+    def quantize(self, z, q_ids):
+        """ quantization of latent vectors """
+
+        stats = self.statistical_model(q_ids)
+
+        zq = z * stats['quant_scale']
+        zq = soft_dead_zone(zq, stats['dead_zone'])
+        zq = torch.round(zq)
+
+        sizes = hard_rate_estimate(zq, stats['r_hard'], stats['theta_hard'], reduce=False)
+
+        return zq, sizes
+
+    def unquantize(self, zq, q_ids):
+        """ re-scaling of latent vector """
+
+        stats = self.statistical_model(q_ids)
+
+        z = zq / stats['quant_scale']
+
+        return z
+    
+    def freeze_model(self):
+
+        # freeze all parameters
+        for p in self.parameters():
+            p.requires_grad = False
+        
+        for p in self.statistical_model.parameters():
+            p.requires_grad = True
+
diff --git a/dnn/torch/rdovae/requirements.txt b/dnn/torch/rdovae/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..668c8462855e77990cace86649bd28637a8eb652
--- /dev/null
+++ b/dnn/torch/rdovae/requirements.txt
@@ -0,0 +1,5 @@
+numpy
+scipy
+torch
+tqdm
+libs/wexchange-1.0-py3-none-any.whl
\ No newline at end of file
diff --git a/dnn/torch/rdovae/train_rdovae.py b/dnn/torch/rdovae/train_rdovae.py
new file mode 100644
index 0000000000000000000000000000000000000000..68ccf2eb0e3e01f20c9d51fdb44398db7ffaf5a9
--- /dev/null
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -0,0 +1,270 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+import os
+import argparse
+
+import torch
+import tqdm
+
+from rdovae import RDOVAE, RDOVAEDataset, distortion_loss, hard_rate_estimate, soft_rate_estimate
+
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('features', type=str, help='path to feature file in .f32 format')
+parser.add_argument('output', type=str, help='path to output folder')
+
+parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: ''", default="")
+
+
+model_group = parser.add_argument_group(title="model parameters")
+model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
+model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
+model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
+model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
+model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 16", default=16)
+model_group.add_argument('--lambda-min', type=float, help="minimal value for rate lambda, default: 0.0002", default=2e-4)
+model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
+model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
+model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
+
+training_group = parser.add_argument_group(title="training parameters")
+training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
+training_group.add_argument('--lr', type=float, help='learning rate, default: 3e-4', default=3e-4)
+training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 100', default=100)
+training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by 4, default: 256', default=256)
+training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
+training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split')
+training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
+training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
+training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
+
+args = parser.parse_args()
+
+# set visible devices
+os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
+
+# checkpoints
+checkpoint_dir = os.path.join(args.output, 'checkpoints')
+checkpoint = dict()
+os.makedirs(checkpoint_dir, exist_ok=True)
+
+# training parameters
+batch_size = args.batch_size
+lr = args.lr
+epochs = args.epochs
+sequence_length = args.sequence_length
+lr_decay_factor = args.lr_decay_factor
+split_mode = args.split_mode
+# not exposed
+adam_betas = [0.9, 0.99]
+adam_eps = 1e-8
+
+checkpoint['batch_size'] = batch_size
+checkpoint['lr'] = lr
+checkpoint['lr_decay_factor'] = lr_decay_factor 
+checkpoint['split_mode'] = split_mode
+checkpoint['epochs'] = epochs
+checkpoint['sequence_length'] = sequence_length
+checkpoint['adam_betas'] = adam_betas
+
+# logging
+log_interval = 10
+
+# device
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+# model parameters
+cond_size  = args.cond_size
+cond_size2 = args.cond_size2
+latent_dim = args.latent_dim
+quant_levels = args.quant_levels
+lambda_min = args.lambda_min
+lambda_max = args.lambda_max
+state_dim = args.state_dim
+# not expsed
+num_features = 20
+
+
+# training data
+feature_file = args.features
+
+# model
+checkpoint['model_args']    = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
+checkpoint['model_kwargs']  = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate}
+model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
+
+if type(args.initial_checkpoint) != type(None):
+    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
+    model.load_state_dict(checkpoint['state_dict'], strict=False)
+
+checkpoint['state_dict']    = model.state_dict()
+
+if args.train_decoder_only:
+    if args.initial_checkpoint is None:
+        print("warning: training decoder only without providing initial checkpoint")
+        
+    for p in model.core_encoder.module.parameters():
+        p.requires_grad = False
+        
+    for p in model.statistical_model.parameters():
+        p.requires_grad = False
+
+# dataloader
+checkpoint['dataset_args'] = (feature_file, sequence_length, num_features, 36)
+checkpoint['dataset_kwargs'] = {'lambda_min': lambda_min, 'lambda_max': lambda_max, 'enc_stride': model.enc_stride, 'quant_levels': quant_levels}
+dataset = RDOVAEDataset(*checkpoint['dataset_args'], **checkpoint['dataset_kwargs'])
+dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
+
+
+
+# optimizer
+params = [p for p in model.parameters() if p.requires_grad]
+optimizer = torch.optim.Adam(params, lr=lr, betas=adam_betas, eps=adam_eps)
+
+
+# learning rate scheduler
+scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay_factor * x))
+
+if __name__ == '__main__':
+
+    # push model to device
+    model.to(device)
+
+    # training loop
+
+    for epoch in range(1, epochs + 1):
+
+        print(f"training epoch {epoch}...")
+
+        # running stats
+        running_rate_loss       = 0
+        running_soft_dist_loss  = 0
+        running_hard_dist_loss  = 0
+        running_hard_rate_loss  = 0
+        running_soft_rate_loss  = 0
+        running_total_loss      = 0
+        running_rate_metric     = 0
+        previous_total_loss     = 0
+        running_first_frame_loss = 0
+
+        with tqdm.tqdm(dataloader, unit='batch') as tepoch:
+            for i, (features, rate_lambda, q_ids) in enumerate(tepoch):
+
+                # zero out gradients
+                optimizer.zero_grad()
+                
+                # push inputs to device
+                features    = features.to(device)
+                q_ids       = q_ids.to(device)
+                rate_lambda = rate_lambda.to(device)
+
+                
+                rate_lambda_upsamp = torch.repeat_interleave(rate_lambda, 2, 1)
+                
+                # run model
+                model_output = model(features, q_ids)
+
+                # collect outputs
+                z                   = model_output['z']
+                outputs_hard_quant  = model_output['outputs_hard_quant']
+                outputs_soft_quant  = model_output['outputs_soft_quant']
+                statistical_model   = model_output['statistical_model']
+
+                # rate loss
+                hard_rate = hard_rate_estimate(z, statistical_model['r_hard'], statistical_model['theta_hard'], reduce=False)
+                soft_rate = soft_rate_estimate(z, statistical_model['r_soft'], reduce=False)
+                soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * soft_rate)
+                hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * hard_rate)
+                rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
+                hard_rate_metric = torch.mean(hard_rate)
+
+                ## distortion losses
+
+                # hard quantized decoder input
+                distortion_loss_hard_quant = torch.zeros_like(rate_loss)
+                for dec_features, start, stop in outputs_hard_quant:
+                    distortion_loss_hard_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_hard_quant)
+
+                first_frame_loss = torch.zeros_like(rate_loss)
+                for dec_features, start, stop in outputs_hard_quant:
+                    first_frame_loss += distortion_loss(features[..., stop-4 : stop, :], dec_features[..., -4:, :], rate_lambda_upsamp[..., stop - 4 : stop]) / len(outputs_hard_quant)
+
+                # soft quantized decoder input
+                distortion_loss_soft_quant = torch.zeros_like(rate_loss)
+                for dec_features, start, stop in outputs_soft_quant:
+                    distortion_loss_soft_quant += distortion_loss(features[..., start : stop, :], dec_features, rate_lambda_upsamp[..., start : stop]) / len(outputs_soft_quant)
+
+                # total loss
+                total_loss = rate_loss + (distortion_loss_hard_quant + distortion_loss_soft_quant) / 2
+                
+                if args.enable_first_frame_loss:
+                    total_loss = total_loss + 0.5 * torch.relu(first_frame_loss - distortion_loss_hard_quant)
+                
+
+                total_loss.backward()
+                
+                optimizer.step()
+                
+                model.clip_weights()
+                
+                scheduler.step()
+
+                # collect running stats
+                running_hard_dist_loss  += float(distortion_loss_hard_quant.detach().cpu())
+                running_soft_dist_loss  += float(distortion_loss_soft_quant.detach().cpu())
+                running_rate_loss       += float(rate_loss.detach().cpu())
+                running_rate_metric     += float(hard_rate_metric.detach().cpu())
+                running_total_loss      += float(total_loss.detach().cpu())
+                running_first_frame_loss += float(first_frame_loss.detach().cpu())
+                running_soft_rate_loss += float(soft_rate_loss.detach().cpu())
+                running_hard_rate_loss += float(hard_rate_loss.detach().cpu())
+
+                if (i + 1) % log_interval == 0:
+                    current_loss = (running_total_loss - previous_total_loss) / log_interval
+                    tepoch.set_postfix(
+                        current_loss=current_loss,
+                        total_loss=running_total_loss / (i + 1),
+                        dist_hq=running_hard_dist_loss / (i + 1),
+                        dist_sq=running_soft_dist_loss / (i + 1),
+                        rate_loss=running_rate_loss / (i + 1),
+                        rate=running_rate_metric / (i + 1),
+                        ffloss=running_first_frame_loss / (i + 1),
+                        rateloss_hard=running_hard_rate_loss / (i + 1),
+                        rateloss_soft=running_soft_rate_loss / (i + 1)
+                    )
+                    previous_total_loss = running_total_loss
+
+        # save checkpoint
+        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
+        checkpoint['state_dict'] = model.state_dict()
+        checkpoint['loss'] = running_total_loss / len(dataloader)
+        checkpoint['epoch'] = epoch
+        torch.save(checkpoint, checkpoint_path)