From a8cb719d05b2a45d4cf12ca1a61755cc5d904e55 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Tue, 6 Jun 2023 17:19:12 -0400
Subject: [PATCH] Add blob loading for DRED encoder and decoder

---
 include/opus.h      | 13 +++++++++++-
 silk/dred_encoder.c | 10 +++++++++
 silk/dred_encoder.h |  2 +-
 src/opus_decoder.c  | 50 +++++++++++++++++++++++++++++++++++++++++++++
 src/opus_demo.c     |  8 +++++---
 src/opus_encoder.c  | 13 ++++++++++++
 6 files changed, 91 insertions(+), 5 deletions(-)

diff --git a/include/opus.h b/include/opus.h
index 41244f23f..a52daa221 100644
--- a/include/opus.h
+++ b/include/opus.h
@@ -547,7 +547,18 @@ OPUS_EXPORT int opus_dred_decoder_init(OpusDREDDecoder *dec);
   */
 OPUS_EXPORT void opus_dred_decoder_destroy(OpusDREDDecoder *dec);
 
-
+/** Perform a CTL function on an Opus DRED decoder.
+  *
+  * Generally the request and subsequent arguments are generated
+  * by a convenience macro.
+  * @param st <tt>OpusDREDDecoder*</tt>: DRED Decoder state.
+  * @param request This and all remaining parameters should be replaced by one
+  *                of the convenience macros in @ref opus_genericctls or
+  *                @ref opus_decoderctls.
+  * @see opus_genericctls
+  * @see opus_decoderctls
+  */
+OPUS_EXPORT int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...);
 
 /** Gets the size of an <code>OpusDRED</code> structure.
   * @returns The size in bytes.
diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c
index afec129d2..c2628c7f8 100644
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -44,6 +44,16 @@
 #include "float_cast.h"
 #include "os_support.h"
 
+int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len)
+{
+    WeightArray *list;
+    int ret;
+    parse_weights(&list, data, len);
+    ret = init_rdovaeenc(&enc->model, list);
+    free(list);
+    return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
+}
+
 void dred_encoder_reset(DREDEnc* enc)
 {
     RNN_CLEAR((char*)&enc->DREDENC_RESET_START,
diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h
index 30e639a9b..439ef6540 100644
--- a/silk/dred_encoder.h
+++ b/silk/dred_encoder.h
@@ -54,7 +54,7 @@ typedef struct {
     RDOVAEEncState rdovae_enc;
 } DREDEnc;
 
-
+int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len);
 void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels);
 void dred_encoder_reset(DREDEnc* enc);
 
diff --git a/src/opus_decoder.c b/src/opus_decoder.c
index d28a052a1..aad378f03 100644
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -1141,6 +1141,16 @@ int opus_dred_decoder_get_size(void)
   return sizeof(OpusDREDDecoder);
 }
 
+int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int len)
+{
+    WeightArray *list;
+    int ret;
+    parse_weights(&list, data, len);
+    ret = init_rdovaedec(&dec->model, list);
+    free(list);
+    return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG;
+}
+
 int opus_dred_decoder_init(OpusDREDDecoder *dec)
 {
 #ifndef USE_WEIGHTS_FILE
@@ -1180,7 +1190,47 @@ void opus_dred_decoder_destroy(OpusDREDDecoder *dec)
    free(dec);
 }
 
+int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...)
+{
+#ifdef ENABLE_NEURAL_FEC
+   int ret = OPUS_OK;
+   va_list ap;
 
+   va_start(ap, request);
+   (void)dred_dec;
+   switch (request)
+   {
+# ifdef USE_WEIGHTS_FILE
+   case OPUS_SET_DNN_BLOB_REQUEST:
+   {
+      const unsigned char *data = va_arg(ap, const unsigned char *);
+      opus_int32 len = va_arg(ap, opus_int32);
+      if(len<0 || data == NULL)
+      {
+         goto bad_arg;
+      }
+      return dred_decoder_load_model(dred_dec, data, len);
+   }
+   break;
+# endif
+   default:
+     /*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/
+     ret = OPUS_UNIMPLEMENTED;
+     break;
+  }
+  va_end(ap);
+  return ret;
+# ifdef USE_WEIGHTS_FILE
+bad_arg:
+  va_end(ap);
+  return OPUS_BAD_ARG;
+# endif
+#else
+  (void)dred_dec;
+  (void)request;
+  return OPUS_UNIMPLEMENTED;
+#endif
+}
 
 #ifdef ENABLE_NEURAL_FEC
 static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload)
diff --git a/src/opus_demo.c b/src/opus_demo.c
index 563d3c5b1..b48845c9e 100644
--- a/src/opus_demo.c
+++ b/src/opus_demo.c
@@ -617,9 +617,6 @@ int main(int argc, char *argv[])
           goto failure;
        }
     }
-#ifdef USE_WEIGHTS_FILE
-    opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
-#endif
     switch(bandwidth)
     {
     case OPUS_BANDWIDTH_NARROWBAND:
@@ -684,6 +681,11 @@ int main(int argc, char *argv[])
     }
     dred_dec = opus_dred_decoder_create(&err);
     dred = opus_dred_alloc(&err);
+#ifdef USE_WEIGHTS_FILE
+    opus_encoder_ctl(enc, OPUS_SET_DNN_BLOB(blob_data, blob_len));
+    opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
+    opus_dred_decoder_ctl(dred_dec, OPUS_SET_DNN_BLOB(blob_data, blob_len));
+#endif
     while (!stop)
     {
         if (delayed_celt)
diff --git a/src/opus_encoder.c b/src/opus_encoder.c
index 3738c8488..f6d3bc585 100644
--- a/src/opus_encoder.c
+++ b/src/opus_encoder.c
@@ -2847,6 +2847,19 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...)
             }
         }
         break;
+#ifdef USE_WEIGHTS_FILE
+   case OPUS_SET_DNN_BLOB_REQUEST:
+   {
+       const unsigned char *data = va_arg(ap, const unsigned char *);
+       opus_int32 len = va_arg(ap, opus_int32);
+       if(len<0 || data == NULL)
+       {
+          goto bad_arg;
+       }
+       return dred_encoder_load_model(&st->dred_encoder, data, len);
+   }
+   break;
+#endif
         case CELT_GET_MODE_REQUEST:
         {
            const CELTMode ** value = va_arg(ap, const CELTMode**);
-- 
GitLab