From 27663d364188711e5f662304357cb532e689bfe2 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Thu, 21 Sep 2023 12:20:11 -0400
Subject: [PATCH] Using a DenseNet for DRED

---
 dnn/dred_rdovae_dec.c                     | 100 +++++++++------
 dnn/dred_rdovae_dec.h                     |  14 ++-
 dnn/dred_rdovae_enc.c                     |  90 ++++++++------
 dnn/dred_rdovae_enc.h                     |  15 ++-
 dnn/nnet.c                                |  19 +++
 dnn/nnet.h                                |   1 +
 dnn/torch/rdovae/export_rdovae_weights.py |  54 ++++----
 dnn/torch/rdovae/rdovae/rdovae.py         | 143 +++++++++++++---------
 8 files changed, 268 insertions(+), 168 deletions(-)

diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c
index c79723b35..1fef375ad 100644
--- a/dnn/dred_rdovae_dec.c
+++ b/dnn/dred_rdovae_dec.c
@@ -33,16 +33,36 @@
 #include "dred_rdovae_constants.h"
 #include "os_support.h"
 
+static void conv1_cond_init(float *mem, int len, int dilation, int *init)
+{
+    if (!*init) {
+        int i;
+        for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len);
+    }
+    *init = 1;
+}
+
 void dred_rdovae_dec_init_states(
     RDOVAEDecState *h,            /* io: state buffer handle */
     const RDOVAEDec *model,
     const float *initial_state  /* i: initial state */
     )
 {
-    /* initialize GRU states from initial state */
-    compute_generic_dense(&model->state1, h->dense2_state, initial_state, ACTIVATION_TANH);
-    compute_generic_dense(&model->state2, h->dense4_state, initial_state, ACTIVATION_TANH);
-    compute_generic_dense(&model->state3, h->dense6_state, initial_state, ACTIVATION_TANH);
+    float hidden[DEC_HIDDEN_INIT_OUT_SIZE];
+    float state_init[DEC_GRU1_STATE_SIZE+DEC_GRU2_STATE_SIZE+DEC_GRU3_STATE_SIZE+DEC_GRU4_STATE_SIZE+DEC_GRU5_STATE_SIZE];
+    int counter=0;
+    compute_generic_dense(&model->dec_hidden_init, hidden, initial_state, ACTIVATION_TANH);
+    compute_generic_dense(&model->dec_gru_init, state_init, hidden, ACTIVATION_TANH);
+    OPUS_COPY(h->gru1_state, state_init, DEC_GRU1_STATE_SIZE);
+    counter += DEC_GRU1_STATE_SIZE;
+    OPUS_COPY(h->gru2_state, &state_init[counter], DEC_GRU2_STATE_SIZE);
+    counter += DEC_GRU2_STATE_SIZE;
+    OPUS_COPY(h->gru3_state, &state_init[counter], DEC_GRU3_STATE_SIZE);
+    counter += DEC_GRU3_STATE_SIZE;
+    OPUS_COPY(h->gru4_state, &state_init[counter], DEC_GRU4_STATE_SIZE);
+    counter += DEC_GRU4_STATE_SIZE;
+    OPUS_COPY(h->gru5_state, &state_init[counter], DEC_GRU5_STATE_SIZE);
+    h->initialized = 0;
 }
 
 
@@ -53,44 +73,48 @@ void dred_rdovae_decode_qframe(
     const float *input          /* i: latent vector */
     )
 {
-    float buffer[DEC_DENSE1_OUT_SIZE + DEC_DENSE2_OUT_SIZE + DEC_DENSE3_OUT_SIZE + DEC_DENSE4_OUT_SIZE + DEC_DENSE5_OUT_SIZE + DEC_DENSE6_OUT_SIZE + DEC_DENSE7_OUT_SIZE + DEC_DENSE8_OUT_SIZE];
+    float buffer[DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE
+                 + DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE];
     int output_index = 0;
-    int input_index = 0;
 
     /* run encoder stack and concatenate output in buffer*/
     compute_generic_dense(&model->dec_dense1, &buffer[output_index], input, ACTIVATION_TANH);
-    input_index = output_index;
     output_index += DEC_DENSE1_OUT_SIZE;
 
-    compute_generic_gru(&model->dec_dense2_input, &model->dec_dense2_recurrent, dec_state->dense2_state, &buffer[input_index]);
-    OPUS_COPY(&buffer[output_index], dec_state->dense2_state, DEC_DENSE2_OUT_SIZE);
-    input_index = output_index;
-    output_index += DEC_DENSE2_OUT_SIZE;
-
-    compute_generic_dense(&model->dec_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    input_index = output_index;
-    output_index += DEC_DENSE3_OUT_SIZE;
-
-    compute_generic_gru(&model->dec_dense4_input, &model->dec_dense4_recurrent, dec_state->dense4_state, &buffer[input_index]);
-    OPUS_COPY(&buffer[output_index], dec_state->dense4_state, DEC_DENSE4_OUT_SIZE);
-    input_index = output_index;
-    output_index += DEC_DENSE4_OUT_SIZE;
-
-    compute_generic_dense(&model->dec_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    input_index = output_index;
-    output_index += DEC_DENSE5_OUT_SIZE;
-
-    compute_generic_gru(&model->dec_dense6_input, &model->dec_dense6_recurrent, dec_state->dense6_state, &buffer[input_index]);
-    OPUS_COPY(&buffer[output_index], dec_state->dense6_state, DEC_DENSE6_OUT_SIZE);
-    input_index = output_index;
-    output_index += DEC_DENSE6_OUT_SIZE;
-
-    compute_generic_dense(&model->dec_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    input_index = output_index;
-    output_index += DEC_DENSE7_OUT_SIZE;
-
-    compute_generic_dense(&model->dec_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    output_index += DEC_DENSE8_OUT_SIZE;
-
-    compute_generic_dense(&model->dec_final, qframe, buffer, ACTIVATION_LINEAR);
+    compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer);
+    OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE);
+    output_index += DEC_GRU1_OUT_SIZE;
+    conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
+    compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
+    output_index += DEC_CONV1_OUT_SIZE;
+
+    compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer);
+    OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE);
+    output_index += DEC_GRU2_OUT_SIZE;
+    conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
+    compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH);
+    output_index += DEC_CONV2_OUT_SIZE;
+
+    compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer);
+    OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE);
+    output_index += DEC_GRU3_OUT_SIZE;
+    conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
+    compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH);
+    output_index += DEC_CONV3_OUT_SIZE;
+
+    compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer);
+    OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE);
+    output_index += DEC_GRU4_OUT_SIZE;
+    conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
+    compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH);
+    output_index += DEC_CONV4_OUT_SIZE;
+
+    compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer);
+    OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE);
+    output_index += DEC_GRU5_OUT_SIZE;
+    conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
+    compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH);
+    output_index += DEC_CONV5_OUT_SIZE;
+
+    compute_generic_dense(&model->dec_output, qframe, buffer, ACTIVATION_LINEAR);
 }
diff --git a/dnn/dred_rdovae_dec.h b/dnn/dred_rdovae_dec.h
index 008551b5d..4e039cf27 100644
--- a/dnn/dred_rdovae_dec.h
+++ b/dnn/dred_rdovae_dec.h
@@ -33,9 +33,17 @@
 #include "dred_rdovae_stats_data.h"
 
 struct RDOVAEDecStruct {
-    float dense2_state[DEC_DENSE2_STATE_SIZE];
-    float dense4_state[DEC_DENSE2_STATE_SIZE];
-    float dense6_state[DEC_DENSE2_STATE_SIZE];
+  int initialized;
+  float gru1_state[DEC_GRU1_STATE_SIZE];
+  float gru2_state[DEC_GRU2_STATE_SIZE];
+  float gru3_state[DEC_GRU3_STATE_SIZE];
+  float gru4_state[DEC_GRU4_STATE_SIZE];
+  float gru5_state[DEC_GRU5_STATE_SIZE];
+  float conv1_state[DEC_CONV1_STATE_SIZE];
+  float conv2_state[DEC_CONV2_STATE_SIZE];
+  float conv3_state[DEC_CONV3_STATE_SIZE];
+  float conv4_state[DEC_CONV4_STATE_SIZE];
+  float conv5_state[DEC_CONV5_STATE_SIZE];
 };
 
 void dred_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, const float * initial_state);
diff --git a/dnn/dred_rdovae_enc.c b/dnn/dred_rdovae_enc.c
index 9361af17b..98ffba8c8 100644
--- a/dnn/dred_rdovae_enc.c
+++ b/dnn/dred_rdovae_enc.c
@@ -35,6 +35,15 @@
 #include "dred_rdovae_enc.h"
 #include "os_support.h"
 
+static void conv1_cond_init(float *mem, int len, int dilation, int *init)
+{
+    if (!*init) {
+        int i;
+        for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len);
+    }
+    *init = 1;
+}
+
 void dred_rdovae_encode_dframe(
     RDOVAEEncState *enc_state,           /* io: encoder state */
     const RDOVAEEnc *model,
@@ -43,52 +52,53 @@ void dred_rdovae_encode_dframe(
     const float *input              /* i: double feature frame (concatenated) */
     )
 {
-    float buffer[ENC_DENSE1_OUT_SIZE + ENC_DENSE2_OUT_SIZE + ENC_DENSE3_OUT_SIZE + ENC_DENSE4_OUT_SIZE + ENC_DENSE5_OUT_SIZE + ENC_DENSE6_OUT_SIZE + ENC_DENSE7_OUT_SIZE + ENC_DENSE8_OUT_SIZE + GDENSE1_OUT_SIZE];
+    float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE
+               + ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE];
+    float state_hidden[GDENSE1_OUT_SIZE];
     int output_index = 0;
-    int input_index = 0;
 
     /* run encoder stack and concatenate output in buffer*/
     compute_generic_dense(&model->enc_dense1, &buffer[output_index], input, ACTIVATION_TANH);
-    input_index = output_index;
     output_index += ENC_DENSE1_OUT_SIZE;
 
-    compute_generic_gru(&model->enc_dense2_input, &model->enc_dense2_recurrent, enc_state->dense2_state, &buffer[input_index]);
-    OPUS_COPY(&buffer[output_index], enc_state->dense2_state, ENC_DENSE2_OUT_SIZE);
-    input_index = output_index;
-    output_index += ENC_DENSE2_OUT_SIZE;
-
-    compute_generic_dense(&model->enc_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    input_index = output_index;
-    output_index += ENC_DENSE3_OUT_SIZE;
-
-    compute_generic_gru(&model->enc_dense4_input, &model->enc_dense4_recurrent, enc_state->dense4_state, &buffer[input_index]);
-    OPUS_COPY(&buffer[output_index], enc_state->dense4_state, ENC_DENSE4_OUT_SIZE);
-    input_index = output_index;
-    output_index += ENC_DENSE4_OUT_SIZE;
-
-    compute_generic_dense(&model->enc_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    input_index = output_index;
-    output_index += ENC_DENSE5_OUT_SIZE;
-
-    compute_generic_gru(&model->enc_dense6_input, &model->enc_dense6_recurrent, enc_state->dense6_state, &buffer[input_index]);
-    OPUS_COPY(&buffer[output_index], enc_state->dense6_state, ENC_DENSE6_OUT_SIZE);
-    input_index = output_index;
-    output_index += ENC_DENSE6_OUT_SIZE;
-
-    compute_generic_dense(&model->enc_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    input_index = output_index;
-    output_index += ENC_DENSE7_OUT_SIZE;
-
-    compute_generic_dense(&model->enc_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH);
-    output_index += ENC_DENSE8_OUT_SIZE;
-
-    /* compute latents from concatenated input buffer */
-    compute_generic_conv1d(&model->bits_dense, latents, enc_state->bits_dense_state, buffer, BITS_DENSE_IN_SIZE, ACTIVATION_LINEAR);
-
+    compute_generic_gru(&model->enc_gru1_input, &model->enc_gru1_recurrent, enc_state->gru1_state, buffer);
+    OPUS_COPY(&buffer[output_index], enc_state->gru1_state, ENC_GRU1_OUT_SIZE);
+    output_index += ENC_GRU1_OUT_SIZE;
+    conv1_cond_init(enc_state->conv1_state, output_index, 1, &enc_state->initialized);
+    compute_generic_conv1d(&model->enc_conv1, &buffer[output_index], enc_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
+    output_index += ENC_CONV1_OUT_SIZE;
+
+    compute_generic_gru(&model->enc_gru2_input, &model->enc_gru2_recurrent, enc_state->gru2_state, buffer);
+    OPUS_COPY(&buffer[output_index], enc_state->gru2_state, ENC_GRU2_OUT_SIZE);
+    output_index += ENC_GRU2_OUT_SIZE;
+    conv1_cond_init(enc_state->conv2_state, output_index, 2, &enc_state->initialized);
+    compute_generic_conv1d_dilation(&model->enc_conv2, &buffer[output_index], enc_state->conv2_state, buffer, output_index, 2, ACTIVATION_TANH);
+    output_index += ENC_CONV2_OUT_SIZE;
+
+    compute_generic_gru(&model->enc_gru3_input, &model->enc_gru3_recurrent, enc_state->gru3_state, buffer);
+    OPUS_COPY(&buffer[output_index], enc_state->gru3_state, ENC_GRU3_OUT_SIZE);
+    output_index += ENC_GRU3_OUT_SIZE;
+    conv1_cond_init(enc_state->conv3_state, output_index, 2, &enc_state->initialized);
+    compute_generic_conv1d_dilation(&model->enc_conv3, &buffer[output_index], enc_state->conv3_state, buffer, output_index, 2, ACTIVATION_TANH);
+    output_index += ENC_CONV3_OUT_SIZE;
+
+    compute_generic_gru(&model->enc_gru4_input, &model->enc_gru4_recurrent, enc_state->gru4_state, buffer);
+    OPUS_COPY(&buffer[output_index], enc_state->gru4_state, ENC_GRU4_OUT_SIZE);
+    output_index += ENC_GRU4_OUT_SIZE;
+    conv1_cond_init(enc_state->conv4_state, output_index, 2, &enc_state->initialized);
+    compute_generic_conv1d_dilation(&model->enc_conv4, &buffer[output_index], enc_state->conv4_state, buffer, output_index, 2, ACTIVATION_TANH);
+    output_index += ENC_CONV4_OUT_SIZE;
+
+    compute_generic_gru(&model->enc_gru5_input, &model->enc_gru5_recurrent, enc_state->gru5_state, buffer);
+    OPUS_COPY(&buffer[output_index], enc_state->gru5_state, ENC_GRU5_OUT_SIZE);
+    output_index += ENC_GRU5_OUT_SIZE;
+    conv1_cond_init(enc_state->conv5_state, output_index, 2, &enc_state->initialized);
+    compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH);
+    output_index += ENC_CONV5_OUT_SIZE;
+
+    compute_generic_dense(&model->enc_zdense, latents, buffer, ACTIVATION_LINEAR);
 
     /* next, calculate initial state */
-    compute_generic_dense(&model->gdense1, &buffer[output_index], buffer, ACTIVATION_TANH);
-    input_index = output_index;
-    compute_generic_dense(&model->gdense2, initial_state, &buffer[input_index], ACTIVATION_TANH);
-
+    compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH);
+    compute_generic_dense(&model->gdense2, initial_state, state_hidden, ACTIVATION_LINEAR);
 }
diff --git a/dnn/dred_rdovae_enc.h b/dnn/dred_rdovae_enc.h
index 70ff6adca..832bd7379 100644
--- a/dnn/dred_rdovae_enc.h
+++ b/dnn/dred_rdovae_enc.h
@@ -33,10 +33,17 @@
 #include "dred_rdovae_enc_data.h"
 
 struct RDOVAEEncStruct {
-    float dense2_state[3 * ENC_DENSE2_STATE_SIZE];
-    float dense4_state[3 * ENC_DENSE4_STATE_SIZE];
-    float dense6_state[3 * ENC_DENSE6_STATE_SIZE];
-    float bits_dense_state[BITS_DENSE_STATE_SIZE];
+    int initialized;
+    float gru1_state[ENC_GRU1_STATE_SIZE];
+    float gru2_state[ENC_GRU2_STATE_SIZE];
+    float gru3_state[ENC_GRU3_STATE_SIZE];
+    float gru4_state[ENC_GRU4_STATE_SIZE];
+    float gru5_state[ENC_GRU5_STATE_SIZE];
+    float conv1_state[ENC_CONV1_STATE_SIZE];
+    float conv2_state[2*ENC_CONV2_STATE_SIZE];
+    float conv3_state[2*ENC_CONV3_STATE_SIZE];
+    float conv4_state[2*ENC_CONV4_STATE_SIZE];
+    float conv5_state[2*ENC_CONV5_STATE_SIZE];
 };
 
 void dred_rdovae_encode_dframe(RDOVAEEncState *enc_state, const RDOVAEEnc *model, float *latents, float *initial_state, const float *input);
diff --git a/dnn/nnet.c b/dnn/nnet.c
index 3661ba77f..d5ef904ec 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -366,6 +366,25 @@ void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem,
    OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size);
 }
 
+void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation)
+{
+   float tmp[MAX_CONV_INPUTS_ALL];
+   int ksize = layer->nb_inputs/input_size;
+   int i;
+   celt_assert(input != output);
+   celt_assert(layer->nb_inputs <= MAX_CONV_INPUTS_ALL);
+   if (dilation==1) OPUS_COPY(tmp, mem, layer->nb_inputs-input_size);
+   else for (i=0;i<ksize-1;i++) OPUS_COPY(&tmp[i*input_size], &mem[i*input_size*dilation], input_size);
+   OPUS_COPY(&tmp[layer->nb_inputs-input_size], input, input_size);
+   compute_linear(layer, output, tmp);
+   compute_activation(output, output, layer->nb_outputs, activation);
+   if (dilation==1) OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size);
+   else {
+     OPUS_COPY(mem, &mem[input_size], input_size*dilation*(ksize-1)-input_size);
+     OPUS_COPY(&mem[input_size*dilation*(ksize-1)-input_size], input, input_size);
+   }
+}
+
 void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input)
 {
    LinearLayer matrix;
diff --git a/dnn/nnet.h b/dnn/nnet.h
index 16ce82bab..9ed20b028 100644
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -145,6 +145,7 @@ void compute_linear(const LinearLayer *linear, float *out, const float *in);
 void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation);
 void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in);
 void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation);
+void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation);
 void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation);
 
 void compute_activation(float *output, const float *input, int N, int activation);
diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py
index f9c1db815..c2cc61bde 100644
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -116,10 +116,7 @@ f"""
     # encoder
     encoder_dense_layers = [
         ('core_encoder.module.dense_1'       , 'enc_dense1',   'TANH'),
-        ('core_encoder.module.dense_2'       , 'enc_dense3',   'TANH'),
-        ('core_encoder.module.dense_3'       , 'enc_dense5',   'TANH'),
-        ('core_encoder.module.dense_4'       , 'enc_dense7',   'TANH'),
-        ('core_encoder.module.dense_5'       , 'enc_dense8',   'TANH'),
+        ('core_encoder.module.z_dense'       , 'enc_zdense',   'LINEAR'),
         ('core_encoder.module.state_dense_1' , 'gdense1'    ,   'TANH'),
         ('core_encoder.module.state_dense_2' , 'gdense2'    ,   'TANH')
     ]
@@ -130,9 +127,11 @@ f"""
 
 
     encoder_gru_layers = [
-        ('core_encoder.module.gru_1'         , 'enc_dense2',   'TANH'),
-        ('core_encoder.module.gru_2'         , 'enc_dense4',   'TANH'),
-        ('core_encoder.module.gru_3'         , 'enc_dense6',   'TANH')
+        ('core_encoder.module.gru1'       , 'enc_gru1',   'TANH'),
+        ('core_encoder.module.gru2'       , 'enc_gru2',   'TANH'),
+        ('core_encoder.module.gru3'       , 'enc_gru3',   'TANH'),
+        ('core_encoder.module.gru4'       , 'enc_gru4',   'TANH'),
+        ('core_encoder.module.gru5'       , 'enc_gru5',   'TANH'),
     ]
 
     enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
@@ -140,7 +139,11 @@ f"""
 
 
     encoder_conv_layers = [
-        ('core_encoder.module.conv1'         , 'bits_dense' ,   'LINEAR')
+        ('core_encoder.module.conv1.conv'       , 'enc_conv1',   'TANH'),
+        ('core_encoder.module.conv2.conv'       , 'enc_conv2',   'TANH'),
+        ('core_encoder.module.conv3.conv'       , 'enc_conv3',   'TANH'),
+        ('core_encoder.module.conv4.conv'       , 'enc_conv4',   'TANH'),
+        ('core_encoder.module.conv5.conv'       , 'enc_conv5',   'TANH'),
     ]
 
     enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in encoder_conv_layers])
@@ -150,15 +153,10 @@ f"""
 
     # decoder
     decoder_dense_layers = [
-        ('core_decoder.module.gru_1_init'    , 'state1',        'TANH'),
-        ('core_decoder.module.gru_2_init'    , 'state2',        'TANH'),
-        ('core_decoder.module.gru_3_init'    , 'state3',        'TANH'),
-        ('core_decoder.module.dense_1'       , 'dec_dense1',    'TANH'),
-        ('core_decoder.module.dense_2'       , 'dec_dense3',    'TANH'),
-        ('core_decoder.module.dense_3'       , 'dec_dense5',    'TANH'),
-        ('core_decoder.module.dense_4'       , 'dec_dense7',    'TANH'),
-        ('core_decoder.module.dense_5'       , 'dec_dense8',    'TANH'),
-        ('core_decoder.module.output'        , 'dec_final',     'LINEAR')
+        ('core_decoder.module.dense_1'       , 'dec_dense1',   'TANH'),
+        ('core_decoder.module.output'       , 'dec_output',   'LINEAR'),
+        ('core_decoder.module.hidden_init'  , 'dec_hidden_init',        'TANH'),
+        ('core_decoder.module.gru_init'    , 'dec_gru_init',        'TANH'),
     ]
 
     for name, export_name, _ in decoder_dense_layers:
@@ -167,14 +165,26 @@ f"""
 
 
     decoder_gru_layers = [
-        ('core_decoder.module.gru_1'         , 'dec_dense2',    'TANH'),
-        ('core_decoder.module.gru_2'         , 'dec_dense4',    'TANH'),
-        ('core_decoder.module.gru_3'         , 'dec_dense6',    'TANH')
+        ('core_decoder.module.gru1'         , 'dec_gru1',    'TANH'),
+        ('core_decoder.module.gru2'         , 'dec_gru2',    'TANH'),
+        ('core_decoder.module.gru3'         , 'dec_gru3',    'TANH'),
+        ('core_decoder.module.gru4'         , 'dec_gru4',    'TANH'),
+        ('core_decoder.module.gru5'         , 'dec_gru5',    'TANH'),
     ]
 
     dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True)
                              for name, export_name, _ in decoder_gru_layers])
 
+    decoder_conv_layers = [
+        ('core_decoder.module.conv1.conv'       , 'dec_conv1',   'TANH'),
+        ('core_decoder.module.conv2.conv'       , 'dec_conv2',   'TANH'),
+        ('core_decoder.module.conv3.conv'       , 'dec_conv3',   'TANH'),
+        ('core_decoder.module.conv4.conv'       , 'dec_conv4',   'TANH'),
+        ('core_decoder.module.conv5.conv'       , 'dec_conv5',   'TANH'),
+    ]
+
+    dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in decoder_conv_layers])
+
     del dec_writer
 
     # statistical model
@@ -196,7 +206,7 @@ f"""
 
 #define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
 
-#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs}
+#define DRED_MAX_CONV_INPUTS {max(enc_max_conv_inputs, dec_max_conv_inputs)}
 
 #define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
 
@@ -268,4 +278,4 @@ if __name__ == "__main__":
     elif args.format == 'numpy':
         numpy_export(args, model)
     else:
-        raise ValueError(f'error: unknown export format {args.format}')
\ No newline at end of file
+        raise ValueError(f'error: unknown export format {args.format}')
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index 0dc943ec8..b126d4c44 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -224,6 +224,17 @@ def weight_clip_factory(max_value):
 
 # RDOVAE module and submodules
 
+class MyConv(nn.Module):
+    def __init__(self, input_dim, output_dim, dilation=1):
+        super(MyConv, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.dilation=dilation
+        self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
+    def forward(self, x, state=None):
+        device = x.device
+        conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
+        return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
 
 class CoreEncoder(nn.Module):
     STATE_HIDDEN = 128
@@ -248,22 +259,28 @@ class CoreEncoder(nn.Module):
 
         # derived parameters
         self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
-        self.conv_input_channels =  5 * cond_size + 3 * cond_size2
 
         # layers
-        self.dense_1 = nn.Linear(self.input_dim, self.cond_size2)
-        self.gru_1   = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
-        self.dense_2 = nn.Linear(self.cond_size, self.cond_size2)
-        self.gru_2   = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
-        self.dense_3 = nn.Linear(self.cond_size, self.cond_size2)
-        self.gru_3   = nn.GRU(self.cond_size2, self.cond_size, batch_first=True)
-        self.dense_4 = nn.Linear(self.cond_size, self.cond_size)
-        self.dense_5 = nn.Linear(self.cond_size, self.cond_size)
-        self.conv1   = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid')
-
-        self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN)
+        self.dense_1 = nn.Linear(self.input_dim, 64)
+        self.gru1 = nn.GRU(64, 64, batch_first=True)
+        self.conv1 = MyConv(128, 96)
+        self.gru2 = nn.GRU(224, 64, batch_first=True)
+        self.conv2 = MyConv(288, 96, dilation=2)
+        self.gru3 = nn.GRU(384, 64, batch_first=True)
+        self.conv3 = MyConv(448, 96, dilation=2)
+        self.gru4 = nn.GRU(544, 64, batch_first=True)
+        self.conv4 = MyConv(608, 96, dilation=2)
+        self.gru5 = nn.GRU(704, 64, batch_first=True)
+        self.conv5 = MyConv(768, 96, dilation=2)
+
+        self.z_dense = nn.Linear(864, self.output_dim)
+
+
+        self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN)
 
         self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
+        nb_params = sum(p.numel() for p in self.parameters())
+        print(f"encoder: {nb_params} weights")
 
         # initialize weights
         self.apply(init_weights)
@@ -278,25 +295,22 @@ class CoreEncoder(nn.Module):
         device = x.device
 
         # run encoding layer stack
-        x1      = torch.tanh(self.dense_1(x))
-        x2, _   = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device))
-        x3      = torch.tanh(self.dense_2(x2))
-        x4, _   = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device))
-        x5      = torch.tanh(self.dense_3(x4))
-        x6, _   = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device))
-        x7      = torch.tanh(self.dense_4(x6))
-        x8      = torch.tanh(self.dense_5(x7))
-
-        # concatenation of all hidden layer outputs
-        x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+        x = torch.tanh(self.dense_1(x))
+        x = torch.cat([x, self.gru1(x)[0]], -1)
+        x = torch.cat([x, self.conv1(x)], -1)
+        x = torch.cat([x, self.gru2(x)[0]], -1)
+        x = torch.cat([x, self.conv2(x)], -1)
+        x = torch.cat([x, self.gru3(x)[0]], -1)
+        x = torch.cat([x, self.conv3(x)], -1)
+        x = torch.cat([x, self.gru4(x)[0]], -1)
+        x = torch.cat([x, self.conv4(x)], -1)
+        x = torch.cat([x, self.gru5(x)[0]], -1)
+        x = torch.cat([x, self.conv5(x)], -1)
+        z = self.z_dense(x)
 
         # init state for decoder
-        states = torch.tanh(self.state_dense_1(x9))
-        states = torch.tanh(self.state_dense_2(states))
-
-        # latent representation via convolution
-        x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0])
-        z = self.conv1(x9).permute(0, 2, 1)
+        states = torch.tanh(self.state_dense_1(x))
+        states = self.state_dense_2(states)
 
         return z, states
 
@@ -325,47 +339,54 @@ class CoreDecoder(nn.Module):
 
         self.input_size = self.input_dim
 
-        self.concat_size = 4 * self.cond_size + 4 * self.cond_size2
-
         # layers
-        self.dense_1    = nn.Linear(self.input_size, cond_size2)
-        self.gru_1      = nn.GRU(cond_size2, cond_size, batch_first=True)
-        self.dense_2    = nn.Linear(cond_size, cond_size2)
-        self.gru_2      = nn.GRU(cond_size2, cond_size, batch_first=True)
-        self.dense_3    = nn.Linear(cond_size, cond_size2)
-        self.gru_3      = nn.GRU(cond_size2, cond_size, batch_first=True)
-        self.dense_4    = nn.Linear(cond_size, cond_size2)
-        self.dense_5    = nn.Linear(cond_size2, cond_size2)
-
-        self.output  = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim)
-
-
-        self.gru_1_init = nn.Linear(self.state_size, self.cond_size)
-        self.gru_2_init = nn.Linear(self.state_size, self.cond_size)
-        self.gru_3_init = nn.Linear(self.state_size, self.cond_size)
-
+        self.dense_1    = nn.Linear(self.input_size, 96)
+        self.gru1 = nn.GRU(96, 96, batch_first=True)
+        self.conv1 = MyConv(192, 32)
+        self.gru2 = nn.GRU(224, 96, batch_first=True)
+        self.conv2 = MyConv(320, 32)
+        self.gru3 = nn.GRU(352, 96, batch_first=True)
+        self.conv3 = MyConv(448, 32)
+        self.gru4 = nn.GRU(480, 96, batch_first=True)
+        self.conv4 = MyConv(576, 32)
+        self.gru5 = nn.GRU(608, 96, batch_first=True)
+        self.conv5 = MyConv(704, 32)
+        self.output  = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
+
+        self.hidden_init = nn.Linear(self.state_size, 128)
+        self.gru_init = nn.Linear(128, 480)
+
+        nb_params = sum(p.numel() for p in self.parameters())
+        print(f"decoder: {nb_params} weights")
         # initialize weights
         self.apply(init_weights)
 
     def forward(self, z, initial_state):
 
-        gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2))
-        gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2))
-        gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2))
+        hidden = torch.tanh(self.hidden_init(initial_state))
+        gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2))
+        h1_state = gru_state[:,:,:96].contiguous()
+        h2_state = gru_state[:,:,96:192].contiguous()
+        h3_state = gru_state[:,:,192:288].contiguous()
+        h4_state = gru_state[:,:,288:384].contiguous()
+        h5_state = gru_state[:,:,384:].contiguous()
 
         # run decoding layer stack
-        x1  = torch.tanh(self.dense_1(z))
-        x2, _ = self.gru_1(x1, gru_1_state)
-        x3  = torch.tanh(self.dense_2(x2))
-        x4, _ = self.gru_2(x3, gru_2_state)
-        x5  = torch.tanh(self.dense_3(x4))
-        x6, _ = self.gru_3(x5, gru_3_state)
-        x7  = torch.tanh(self.dense_4(x6))
-        x8  = torch.tanh(self.dense_5(x7))
-        x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1)
+        x = torch.tanh(self.dense_1(z))
+
+        x = torch.cat([x, self.gru1(x, h1_state)[0]], -1)
+        x = torch.cat([x, self.conv1(x)], -1)
+        x = torch.cat([x, self.gru2(x, h2_state)[0]], -1)
+        x = torch.cat([x, self.conv2(x)], -1)
+        x = torch.cat([x, self.gru3(x, h3_state)[0]], -1)
+        x = torch.cat([x, self.conv3(x)], -1)
+        x = torch.cat([x, self.gru4(x, h4_state)[0]], -1)
+        x = torch.cat([x, self.conv4(x)], -1)
+        x = torch.cat([x, self.gru5(x, h5_state)[0]], -1)
+        x = torch.cat([x, self.conv5(x)], -1)
 
         # output layer and reshaping
-        x10 = self.output(x9)
+        x10 = self.output(x)
         features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
 
         return features
@@ -466,7 +487,7 @@ class RDOVAE(nn.Module):
         if not type(self.weight_clip_fn) == type(None):
             self.apply(self.weight_clip_fn)
 
-    def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
+    def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24):
 
         enc_stride = self.enc_stride
         dec_stride = self.dec_stride
-- 
GitLab