From a8cb719d05b2a45d4cf12ca1a61755cc5d904e55 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Tue, 6 Jun 2023 17:19:12 -0400 Subject: [PATCH] Add blob loading for DRED encoder and decoder --- include/opus.h | 13 +++++++++++- silk/dred_encoder.c | 10 +++++++++ silk/dred_encoder.h | 2 +- src/opus_decoder.c | 50 +++++++++++++++++++++++++++++++++++++++++++++ src/opus_demo.c | 8 +++++--- src/opus_encoder.c | 13 ++++++++++++ 6 files changed, 91 insertions(+), 5 deletions(-) diff --git a/include/opus.h b/include/opus.h index 41244f23f..a52daa221 100644 --- a/include/opus.h +++ b/include/opus.h @@ -547,7 +547,18 @@ OPUS_EXPORT int opus_dred_decoder_init(OpusDREDDecoder *dec); */ OPUS_EXPORT void opus_dred_decoder_destroy(OpusDREDDecoder *dec); - +/** Perform a CTL function on an Opus DRED decoder. + * + * Generally the request and subsequent arguments are generated + * by a convenience macro. + * @param st <tt>OpusDREDDecoder*</tt>: DRED Decoder state. + * @param request This and all remaining parameters should be replaced by one + * of the convenience macros in @ref opus_genericctls or + * @ref opus_decoderctls. + * @see opus_genericctls + * @see opus_decoderctls + */ +OPUS_EXPORT int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...); /** Gets the size of an <code>OpusDRED</code> structure. * @returns The size in bytes. diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index afec129d2..c2628c7f8 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -44,6 +44,16 @@ #include "float_cast.h" #include "os_support.h" +int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len) +{ + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_rdovaeenc(&enc->model, list); + free(list); + return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; +} + void dred_encoder_reset(DREDEnc* enc) { RNN_CLEAR((char*)&enc->DREDENC_RESET_START, diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h index 30e639a9b..439ef6540 100644 --- a/silk/dred_encoder.h +++ b/silk/dred_encoder.h @@ -54,7 +54,7 @@ typedef struct { RDOVAEEncState rdovae_enc; } DREDEnc; - +int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len); void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels); void dred_encoder_reset(DREDEnc* enc); diff --git a/src/opus_decoder.c b/src/opus_decoder.c index d28a052a1..aad378f03 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -1141,6 +1141,16 @@ int opus_dred_decoder_get_size(void) return sizeof(OpusDREDDecoder); } +int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int len) +{ + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_rdovaedec(&dec->model, list); + free(list); + return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; +} + int opus_dred_decoder_init(OpusDREDDecoder *dec) { #ifndef USE_WEIGHTS_FILE @@ -1180,7 +1190,47 @@ void opus_dred_decoder_destroy(OpusDREDDecoder *dec) free(dec); } +int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...) +{ +#ifdef ENABLE_NEURAL_FEC + int ret = OPUS_OK; + va_list ap; + va_start(ap, request); + (void)dred_dec; + switch (request) + { +# ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_decoder_load_model(dred_dec, data, len); + } + break; +# endif + default: + /*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/ + ret = OPUS_UNIMPLEMENTED; + break; + } + va_end(ap); + return ret; +# ifdef USE_WEIGHTS_FILE +bad_arg: + va_end(ap); + return OPUS_BAD_ARG; +# endif +#else + (void)dred_dec; + (void)request; + return OPUS_UNIMPLEMENTED; +#endif +} #ifdef ENABLE_NEURAL_FEC static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload) diff --git a/src/opus_demo.c b/src/opus_demo.c index 563d3c5b1..b48845c9e 100644 --- a/src/opus_demo.c +++ b/src/opus_demo.c @@ -617,9 +617,6 @@ int main(int argc, char *argv[]) goto failure; } } -#ifdef USE_WEIGHTS_FILE - opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); -#endif switch(bandwidth) { case OPUS_BANDWIDTH_NARROWBAND: @@ -684,6 +681,11 @@ int main(int argc, char *argv[]) } dred_dec = opus_dred_decoder_create(&err); dred = opus_dred_alloc(&err); +#ifdef USE_WEIGHTS_FILE + opus_encoder_ctl(enc, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_dred_decoder_ctl(dred_dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); +#endif while (!stop) { if (delayed_celt) diff --git a/src/opus_encoder.c b/src/opus_encoder.c index 3738c8488..f6d3bc585 100644 --- a/src/opus_encoder.c +++ b/src/opus_encoder.c @@ -2847,6 +2847,19 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...) } } break; +#ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_encoder_load_model(&st->dred_encoder, data, len); + } + break; +#endif case CELT_GET_MODE_REQUEST: { const CELTMode ** value = va_arg(ap, const CELTMode**); -- GitLab