diff --git a/include/opus.h b/include/opus.h index 41244f23fc1fa5314da30b11a5136f7644e33b1a..a52daa22138b7ab9e537cc59ff28058e53e12204 100644 --- a/include/opus.h +++ b/include/opus.h @@ -547,7 +547,18 @@ OPUS_EXPORT int opus_dred_decoder_init(OpusDREDDecoder *dec); */ OPUS_EXPORT void opus_dred_decoder_destroy(OpusDREDDecoder *dec); - +/** Perform a CTL function on an Opus DRED decoder. + * + * Generally the request and subsequent arguments are generated + * by a convenience macro. + * @param st <tt>OpusDREDDecoder*</tt>: DRED Decoder state. + * @param request This and all remaining parameters should be replaced by one + * of the convenience macros in @ref opus_genericctls or + * @ref opus_decoderctls. + * @see opus_genericctls + * @see opus_decoderctls + */ +OPUS_EXPORT int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...); /** Gets the size of an <code>OpusDRED</code> structure. * @returns The size in bytes. diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index afec129d266b0f25137bf14db9885e174c59ea3a..c2628c7f897dd55026aa70ecedeff2befce6f6da 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -44,6 +44,16 @@ #include "float_cast.h" #include "os_support.h" +int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len) +{ + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_rdovaeenc(&enc->model, list); + free(list); + return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; +} + void dred_encoder_reset(DREDEnc* enc) { RNN_CLEAR((char*)&enc->DREDENC_RESET_START, diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h index 30e639a9b7cc345e5d3cf48adf943980896953fd..439ef6540115dd864bbd70a9cb5206d2b540a34e 100644 --- a/silk/dred_encoder.h +++ b/silk/dred_encoder.h @@ -54,7 +54,7 @@ typedef struct { RDOVAEEncState rdovae_enc; } DREDEnc; - +int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len); void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels); void dred_encoder_reset(DREDEnc* enc); diff --git a/src/opus_decoder.c b/src/opus_decoder.c index d28a052a125ab513cb737f60945caac7a30c9bd2..aad378f03a6b1e4bf0bbd44d50a4275497e1d3ce 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -1141,6 +1141,16 @@ int opus_dred_decoder_get_size(void) return sizeof(OpusDREDDecoder); } +int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int len) +{ + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_rdovaedec(&dec->model, list); + free(list); + return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; +} + int opus_dred_decoder_init(OpusDREDDecoder *dec) { #ifndef USE_WEIGHTS_FILE @@ -1180,7 +1190,47 @@ void opus_dred_decoder_destroy(OpusDREDDecoder *dec) free(dec); } +int opus_dred_decoder_ctl(OpusDREDDecoder *dred_dec, int request, ...) +{ +#ifdef ENABLE_NEURAL_FEC + int ret = OPUS_OK; + va_list ap; + va_start(ap, request); + (void)dred_dec; + switch (request) + { +# ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_decoder_load_model(dred_dec, data, len); + } + break; +# endif + default: + /*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/ + ret = OPUS_UNIMPLEMENTED; + break; + } + va_end(ap); + return ret; +# ifdef USE_WEIGHTS_FILE +bad_arg: + va_end(ap); + return OPUS_BAD_ARG; +# endif +#else + (void)dred_dec; + (void)request; + return OPUS_UNIMPLEMENTED; +#endif +} #ifdef ENABLE_NEURAL_FEC static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload) diff --git a/src/opus_demo.c b/src/opus_demo.c index 563d3c5b136b983cb9dc107fdfa3a2bed23eb3bd..b48845c9e2b821acec080af855d70cc8e06e4435 100644 --- a/src/opus_demo.c +++ b/src/opus_demo.c @@ -617,9 +617,6 @@ int main(int argc, char *argv[]) goto failure; } } -#ifdef USE_WEIGHTS_FILE - opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); -#endif switch(bandwidth) { case OPUS_BANDWIDTH_NARROWBAND: @@ -684,6 +681,11 @@ int main(int argc, char *argv[]) } dred_dec = opus_dred_decoder_create(&err); dred = opus_dred_alloc(&err); +#ifdef USE_WEIGHTS_FILE + opus_encoder_ctl(enc, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); + opus_dred_decoder_ctl(dred_dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); +#endif while (!stop) { if (delayed_celt) diff --git a/src/opus_encoder.c b/src/opus_encoder.c index 3738c84881414f92a7f955c269278ccf5710af5c..f6d3bc5855420ecfe190a7e415794bfc6286bdba 100644 --- a/src/opus_encoder.c +++ b/src/opus_encoder.c @@ -2847,6 +2847,19 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...) } } break; +#ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return dred_encoder_load_model(&st->dred_encoder, data, len); + } + break; +#endif case CELT_GET_MODE_REQUEST: { const CELTMode ** value = va_arg(ap, const CELTMode**);