From f867f61e8b024bc3f1684ff8c1b4fc0ac21af97b Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Thu, 1 Jun 2023 23:06:45 -0400
Subject: [PATCH] Convert RDOVAE to blob format

---
 dnn/Makefile.am            |  2 +-
 dnn/dred_rdovae.c          | 43 +++++++++++++++++++-------------------
 dnn/dred_rdovae_dec.c      | 30 +++++++++++++-------------
 dnn/dred_rdovae_dec.h      |  6 +++---
 dnn/dred_rdovae_enc.c      | 25 +++++++++++-----------
 dnn/dred_rdovae_enc.h      |  4 ++--
 dnn/include/dred_rdovae.h  | 24 +++++++++++----------
 dnn/nnet.h                 |  2 ++
 dnn/write_lpcnet_weights.c |  4 ++++
 9 files changed, 76 insertions(+), 64 deletions(-)

diff --git a/dnn/Makefile.am b/dnn/Makefile.am
index c4b8b7681..0a2b06532 100644
--- a/dnn/Makefile.am
+++ b/dnn/Makefile.am
@@ -1,6 +1,6 @@
 ACLOCAL_AMFLAGS = -I m4
 
-AM_CFLAGS = -I$(top_srcdir)/include $(DEPS_CFLAGS)
+AM_CFLAGS = -I$(top_srcdir)/include -I$(top_srcdir)/ $(DEPS_CFLAGS)
 
 dist_doc_DATA = COPYING AUTHORS README README.md
 
diff --git a/dnn/dred_rdovae.c b/dnn/dred_rdovae.c
index 047efb17d..9805d0fda 100644
--- a/dnn/dred_rdovae.c
+++ b/dnn/dred_rdovae.c
@@ -35,16 +35,17 @@
 #include "dred_rdovae_dec.h"
 #include "dred_rdovae_stats_data.h"
 
-void DRED_rdovae_decode_all(float *features, const float *state, const float *latents, int nb_latents)
+void DRED_rdovae_decode_all(const RDOVAEDec *model, float *features, const float *state, const float *latents, int nb_latents)
 {
     int i;
-    RDOVAEDec dec;
+    RDOVAEDecState dec;
     memset(&dec, 0, sizeof(dec));
-    DRED_rdovae_dec_init_states(&dec, state);
+    DRED_rdovae_dec_init_states(&dec, model, state);
     for (i = 0; i < 2*nb_latents; i += 2)
     {
         DRED_rdovae_decode_qframe(
             &dec,
+            model,
             &features[2*i*DRED_NUM_FEATURES],
             &latents[(i/2)*DRED_LATENT_DIM]);
     }
@@ -52,65 +53,65 @@ void DRED_rdovae_decode_all(float *features, const float *state, const float *la
 
 size_t DRED_rdovae_get_enc_size()
 {
-    return sizeof(RDOVAEEnc);
+    return sizeof(RDOVAEEncState);
 }
 
 size_t DRED_rdovae_get_dec_size()
 {
-    return sizeof(RDOVAEDec);
+    return sizeof(RDOVAEDecState);
 }
 
-void DRED_rdovae_init_encoder(RDOVAEEnc *enc_state)
+void DRED_rdovae_init_encoder(RDOVAEEncState *enc_state)
 {
     memset(enc_state, 0, sizeof(*enc_state));
 
 }
 
-void DRED_rdovae_init_decoder(RDOVAEDec *dec_state)
+void DRED_rdovae_init_decoder(RDOVAEDecState *dec_state)
 {
     memset(dec_state, 0, sizeof(*dec_state));
 }
 
 
-RDOVAEEnc * DRED_rdovae_create_encoder()
+RDOVAEEncState * DRED_rdovae_create_encoder()
 {
-    RDOVAEEnc *enc;
-    enc = (RDOVAEEnc*) calloc(sizeof(*enc), 1);
+    RDOVAEEncState *enc;
+    enc = (RDOVAEEncState*) calloc(sizeof(*enc), 1);
     DRED_rdovae_init_encoder(enc);
     return enc;
 }
 
-RDOVAEDec * DRED_rdovae_create_decoder()
+RDOVAEDecState * DRED_rdovae_create_decoder()
 {
-    RDOVAEDec *dec;
-    dec = (RDOVAEDec*) calloc(sizeof(*dec), 1);
+    RDOVAEDecState *dec;
+    dec = (RDOVAEDecState*) calloc(sizeof(*dec), 1);
     DRED_rdovae_init_decoder(dec);
     return dec;
 }
 
-void DRED_rdovae_destroy_decoder(RDOVAEDec* dec)
+void DRED_rdovae_destroy_decoder(RDOVAEDecState* dec)
 {
     free(dec);
 }
 
-void DRED_rdovae_destroy_encoder(RDOVAEEnc* enc)
+void DRED_rdovae_destroy_encoder(RDOVAEEncState* enc)
 {
     free(enc);
 }
 
-void DRED_rdovae_encode_dframe(RDOVAEEnc *enc_state, float *latents, float *initial_state, const float *input)
+void DRED_rdovae_encode_dframe(RDOVAEEncState *enc_state, const RDOVAEEnc *model, float *latents, float *initial_state, const float *input)
 {
-    dred_rdovae_encode_dframe(enc_state, latents, initial_state, input);
+    dred_rdovae_encode_dframe(enc_state, model, latents, initial_state, input);
 }
 
-void DRED_rdovae_dec_init_states(RDOVAEDec *h, const float * initial_state)
+void DRED_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, const float * initial_state)
 {
-    dred_rdovae_dec_init_states(h, initial_state);
+    dred_rdovae_dec_init_states(h, model, initial_state);
 }
 
-void DRED_rdovae_decode_qframe(RDOVAEDec *h, float *qframe, const float *z)
+void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float *z)
 {
-    dred_rdovae_decode_qframe(h, qframe, z);
+    dred_rdovae_decode_qframe(h, model, qframe, z);
 }
 
 
diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c
index 965629c1f..3cf2d69aa 100644
--- a/dnn/dred_rdovae_dec.c
+++ b/dnn/dred_rdovae_dec.c
@@ -35,19 +35,21 @@
 
 
 void dred_rdovae_dec_init_states(
-    RDOVAEDec *h,            /* io: state buffer handle */
+    RDOVAEDecState *h,            /* io: state buffer handle */
+    const RDOVAEDec *model,
     const float *initial_state  /* i: initial state */
     )
 {
     /* initialize GRU states from initial state */
-    _lpcnet_compute_dense(&state1, h->dense2_state, initial_state);
-    _lpcnet_compute_dense(&state2, h->dense4_state, initial_state);
-    _lpcnet_compute_dense(&state3, h->dense6_state, initial_state);
+    _lpcnet_compute_dense(&model->state1, h->dense2_state, initial_state);
+    _lpcnet_compute_dense(&model->state2, h->dense4_state, initial_state);
+    _lpcnet_compute_dense(&model->state3, h->dense6_state, initial_state);
 }
 
 
 void dred_rdovae_decode_qframe(
-    RDOVAEDec *dec_state,       /* io: state buffer handle */
+    RDOVAEDecState *dec_state,       /* io: state buffer handle */
+    const RDOVAEDec *model,
     float *qframe,              /* o: quadruple feature frame (four concatenated frames in reverse order) */
     const float *input          /* i: latent vector */
     )
@@ -58,39 +60,39 @@ void dred_rdovae_decode_qframe(
     float zero_vector[1024] = {0};
 
     /* run encoder stack and concatenate output in buffer*/
-    _lpcnet_compute_dense(&dec_dense1, &buffer[output_index], input);
+    _lpcnet_compute_dense(&model->dec_dense1, &buffer[output_index], input);
     input_index = output_index;
     output_index += DEC_DENSE1_OUT_SIZE;
 
-    compute_gruB(&dec_dense2, zero_vector, dec_state->dense2_state, &buffer[input_index]);
+    compute_gruB(&model->dec_dense2, zero_vector, dec_state->dense2_state, &buffer[input_index]);
     RNN_COPY(&buffer[output_index], dec_state->dense2_state, DEC_DENSE2_OUT_SIZE);
     input_index = output_index;
     output_index += DEC_DENSE2_OUT_SIZE;
 
-    _lpcnet_compute_dense(&dec_dense3, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->dec_dense3, &buffer[output_index], &buffer[input_index]);
     input_index = output_index;
     output_index += DEC_DENSE3_OUT_SIZE;
 
-    compute_gruB(&dec_dense4, zero_vector, dec_state->dense4_state, &buffer[input_index]);
+    compute_gruB(&model->dec_dense4, zero_vector, dec_state->dense4_state, &buffer[input_index]);
     RNN_COPY(&buffer[output_index], dec_state->dense4_state, DEC_DENSE4_OUT_SIZE);
     input_index = output_index;
     output_index += DEC_DENSE4_OUT_SIZE;
 
-    _lpcnet_compute_dense(&dec_dense5, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->dec_dense5, &buffer[output_index], &buffer[input_index]);
     input_index = output_index;
     output_index += DEC_DENSE5_OUT_SIZE;
 
-    compute_gruB(&dec_dense6, zero_vector, dec_state->dense6_state, &buffer[input_index]);
+    compute_gruB(&model->dec_dense6, zero_vector, dec_state->dense6_state, &buffer[input_index]);
     RNN_COPY(&buffer[output_index], dec_state->dense6_state, DEC_DENSE6_OUT_SIZE);
     input_index = output_index;
     output_index += DEC_DENSE6_OUT_SIZE;
 
-    _lpcnet_compute_dense(&dec_dense7, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->dec_dense7, &buffer[output_index], &buffer[input_index]);
     input_index = output_index;
     output_index += DEC_DENSE7_OUT_SIZE;
 
-    _lpcnet_compute_dense(&dec_dense8, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->dec_dense8, &buffer[output_index], &buffer[input_index]);
     output_index += DEC_DENSE8_OUT_SIZE;
 
-    _lpcnet_compute_dense(&dec_final, qframe, buffer);
+    _lpcnet_compute_dense(&model->dec_final, qframe, buffer);
 }
diff --git a/dnn/dred_rdovae_dec.h b/dnn/dred_rdovae_dec.h
index 055571003..008551b5d 100644
--- a/dnn/dred_rdovae_dec.h
+++ b/dnn/dred_rdovae_dec.h
@@ -38,7 +38,7 @@ struct RDOVAEDecStruct {
     float dense6_state[DEC_DENSE2_STATE_SIZE];
 };
 
-void dred_rdovae_dec_init_states(RDOVAEDec *h, const float * initial_state);
-void dred_rdovae_decode_qframe(RDOVAEDec *h, float *qframe, const float * z);
+void dred_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, const float * initial_state);
+void dred_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float * z);
 
-#endif
\ No newline at end of file
+#endif
diff --git a/dnn/dred_rdovae_enc.c b/dnn/dred_rdovae_enc.c
index a8dad8540..9fb93cd81 100644
--- a/dnn/dred_rdovae_enc.c
+++ b/dnn/dred_rdovae_enc.c
@@ -36,7 +36,8 @@
 #include "common.h"
 
 void dred_rdovae_encode_dframe(
-    RDOVAEEnc *enc_state,           /* io: encoder state */
+    RDOVAEEncState *enc_state,           /* io: encoder state */
+    const RDOVAEEnc *model,
     float *latents,                 /* o: latent vector */
     float *initial_state,           /* o: initial state */
     const float *input              /* i: double feature frame (concatenated) */
@@ -48,47 +49,47 @@ void dred_rdovae_encode_dframe(
     float zero_vector[1024] = {0};
 
     /* run encoder stack and concatenate output in buffer*/
-    _lpcnet_compute_dense(&enc_dense1, &buffer[output_index], input);
+    _lpcnet_compute_dense(&model->enc_dense1, &buffer[output_index], input);
     input_index = output_index;
     output_index += ENC_DENSE1_OUT_SIZE;
 
-    compute_gruB(&enc_dense2, zero_vector, enc_state->dense2_state, &buffer[input_index]);
+    compute_gruB(&model->enc_dense2, zero_vector, enc_state->dense2_state, &buffer[input_index]);
     RNN_COPY(&buffer[output_index], enc_state->dense2_state, ENC_DENSE2_OUT_SIZE);
     input_index = output_index;
     output_index += ENC_DENSE2_OUT_SIZE;
 
-    _lpcnet_compute_dense(&enc_dense3, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->enc_dense3, &buffer[output_index], &buffer[input_index]);
     input_index = output_index;
     output_index += ENC_DENSE3_OUT_SIZE;
 
-    compute_gruB(&enc_dense4, zero_vector, enc_state->dense4_state, &buffer[input_index]);
+    compute_gruB(&model->enc_dense4, zero_vector, enc_state->dense4_state, &buffer[input_index]);
     RNN_COPY(&buffer[output_index], enc_state->dense4_state, ENC_DENSE4_OUT_SIZE);
     input_index = output_index;
     output_index += ENC_DENSE4_OUT_SIZE;
 
-    _lpcnet_compute_dense(&enc_dense5, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->enc_dense5, &buffer[output_index], &buffer[input_index]);
     input_index = output_index;
     output_index += ENC_DENSE5_OUT_SIZE;
 
-    compute_gruB(&enc_dense6, zero_vector, enc_state->dense6_state, &buffer[input_index]);
+    compute_gruB(&model->enc_dense6, zero_vector, enc_state->dense6_state, &buffer[input_index]);
     RNN_COPY(&buffer[output_index], enc_state->dense6_state, ENC_DENSE6_OUT_SIZE);
     input_index = output_index;
     output_index += ENC_DENSE6_OUT_SIZE;
 
-    _lpcnet_compute_dense(&enc_dense7, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->enc_dense7, &buffer[output_index], &buffer[input_index]);
     input_index = output_index;
     output_index += ENC_DENSE7_OUT_SIZE;
 
-    _lpcnet_compute_dense(&enc_dense8, &buffer[output_index], &buffer[input_index]);
+    _lpcnet_compute_dense(&model->enc_dense8, &buffer[output_index], &buffer[input_index]);
     output_index += ENC_DENSE8_OUT_SIZE;
 
     /* compute latents from concatenated input buffer */
-    compute_conv1d(&bits_dense, latents, enc_state->bits_dense_state, buffer);
+    compute_conv1d(&model->bits_dense, latents, enc_state->bits_dense_state, buffer);
 
 
     /* next, calculate initial state */
-    _lpcnet_compute_dense(&gdense1, &buffer[output_index], buffer);
+    _lpcnet_compute_dense(&model->gdense1, &buffer[output_index], buffer);
     input_index = output_index;
-    _lpcnet_compute_dense(&gdense2, initial_state, &buffer[input_index]);
+    _lpcnet_compute_dense(&model->gdense2, initial_state, &buffer[input_index]);
 
 }
diff --git a/dnn/dred_rdovae_enc.h b/dnn/dred_rdovae_enc.h
index 8328a3e69..70ff6adca 100644
--- a/dnn/dred_rdovae_enc.h
+++ b/dnn/dred_rdovae_enc.h
@@ -39,7 +39,7 @@ struct RDOVAEEncStruct {
     float bits_dense_state[BITS_DENSE_STATE_SIZE];
 };
 
-void dred_rdovae_encode_dframe(RDOVAEEnc *enc_state, float *latents, float *initial_state, const float *input);
+void dred_rdovae_encode_dframe(RDOVAEEncState *enc_state, const RDOVAEEnc *model, float *latents, float *initial_state, const float *input);
 
 
-#endif
\ No newline at end of file
+#endif
diff --git a/dnn/include/dred_rdovae.h b/dnn/include/dred_rdovae.h
index ce61cd713..f2c3235e0 100644
--- a/dnn/include/dred_rdovae.h
+++ b/dnn/include/dred_rdovae.h
@@ -32,29 +32,31 @@
 
 #include "opus_types.h"
 
-typedef struct RDOVAEDecStruct RDOVAEDec;
-typedef struct RDOVAEEncStruct RDOVAEEnc;
+typedef struct RDOVAEDec RDOVAEDec;
+typedef struct RDOVAEEnc RDOVAEEnc;
+typedef struct RDOVAEDecStruct RDOVAEDecState;
+typedef struct RDOVAEEncStruct RDOVAEEncState;
 
-void DRED_rdovae_decode_all(float *features, const float *state, const float *latents, int nb_latents);
+void DRED_rdovae_decode_all(const RDOVAEDec *model, float *features, const float *state, const float *latents, int nb_latents);
 
 
 size_t DRED_rdovae_get_enc_size(void);
 
 size_t DRED_rdovae_get_dec_size(void);
 
-RDOVAEDec * DRED_rdovae_create_decoder(void);
-RDOVAEEnc * DRED_rdovae_create_encoder(void);
-void DRED_rdovae_destroy_decoder(RDOVAEDec* h);
-void DRED_rdovae_destroy_encoder(RDOVAEEnc* h);
+RDOVAEDecState * DRED_rdovae_create_decoder(void);
+RDOVAEEncState * DRED_rdovae_create_encoder(void);
+void DRED_rdovae_destroy_decoder(RDOVAEDecState* h);
+void DRED_rdovae_destroy_encoder(RDOVAEEncState* h);
 
 
-void DRED_rdovae_init_encoder(RDOVAEEnc *enc_state);
+void DRED_rdovae_init_encoder(RDOVAEEncState *enc_state);
 
-void DRED_rdovae_encode_dframe(RDOVAEEnc *enc_state, float *latents, float *initial_state, const float *input);
+void DRED_rdovae_encode_dframe(RDOVAEEncState *enc_state, const RDOVAEEnc *model, float *latents, float *initial_state, const float *input);
 
-void DRED_rdovae_dec_init_states(RDOVAEDec *h, const float * initial_state);
+void DRED_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, const float * initial_state);
 
-void DRED_rdovae_decode_qframe(RDOVAEDec *h, float *qframe, const float * z);
+void DRED_rdovae_decode_qframe(RDOVAEDecState *h, const RDOVAEDec *model, float *qframe, const float * z);
 
 const opus_uint16 * DRED_rdovae_get_p0_pointer(void);
 const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void);
diff --git a/dnn/nnet.h b/dnn/nnet.h
index cdd4ac09a..136de559c 100644
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -148,6 +148,8 @@ int sample_from_pdf(const float *pdf, int N, float exp_boost, float pdf_floor);
 
 extern const WeightArray lpcnet_arrays[];
 extern const WeightArray lpcnet_plc_arrays[];
+extern const WeightArray rdovae_enc_arrays[];
+extern const WeightArray rdovae_dec_arrays[];
 
 int mdense_init(MDenseLayer *layer, const WeightArray *arrays,
   const char *bias,
diff --git a/dnn/write_lpcnet_weights.c b/dnn/write_lpcnet_weights.c
index f5eb88261..15c20837c 100644
--- a/dnn/write_lpcnet_weights.c
+++ b/dnn/write_lpcnet_weights.c
@@ -41,6 +41,8 @@
 #endif
 #include "nnet_data.c"
 #include "plc_data.c"
+#include "dred_rdovae_enc_data.c"
+#include "dred_rdovae_dec_data.c"
 
 void write_weights(const WeightArray *list, FILE *fout)
 {
@@ -69,6 +71,8 @@ int main(void)
   FILE *fout = fopen("weights_blob.bin", "w");
   write_weights(lpcnet_arrays, fout);
   write_weights(lpcnet_plc_arrays, fout);
+  write_weights(rdovae_enc_arrays, fout);
+  write_weights(rdovae_dec_arrays, fout);
   fclose(fout);
   return 0;
 }
-- 
GitLab