diff --git a/include/opus_defines.h b/include/opus_defines.h index 3ca90188965a129209165f185404850524ca54bc..b1ead564a6ef7f256a8a14d65619fba9b0bb51cf 100644 --- a/include/opus_defines.h +++ b/include/opus_defines.h @@ -171,6 +171,8 @@ extern "C" { #define OPUS_GET_IN_DTX_REQUEST 4049 #define OPUS_SET_DRED_DURATION_REQUEST 4050 #define OPUS_GET_DRED_DURATION_REQUEST 4051 +#define OPUS_SET_DNN_BLOB_REQUEST 4052 +/*#define OPUS_GET_DNN_BLOB_REQUEST 4053 */ /** Defines for the presence of extended APIs. */ #define OPUS_HAVE_OPUS_PROJECTION_H @@ -179,6 +181,7 @@ extern "C" { #define __opus_check_int(x) (((void)((x) == (opus_int32)0)), (opus_int32)(x)) #define __opus_check_int_ptr(ptr) ((ptr) + ((ptr) - (opus_int32*)(ptr))) #define __opus_check_uint_ptr(ptr) ((ptr) + ((ptr) - (opus_uint32*)(ptr))) +#define __opus_check_uint8_ptr(ptr) ((ptr) + ((ptr) - (opus_uint8*)(ptr))) #define __opus_check_val16_ptr(ptr) ((ptr) + ((ptr) - (opus_val16*)(ptr))) /** @endcond */ @@ -629,6 +632,10 @@ extern "C" { * @hideinitializer */ #define OPUS_GET_DRED_DURATION(x) OPUS_GET_DRED_DURATION_REQUEST, __opus_check_int_ptr(x) +/** Provide external DNN weights from binary object (only when explicitly built without the weights) + * @hideinitializer */ +#define OPUS_SET_DNN_BLOB(data, len) OPUS_SET_DNN_BLOB_REQUEST, __opus_check_uint8_ptr(data), __opus_check_int(len) + /**@}*/ diff --git a/src/opus_decoder.c b/src/opus_decoder.c index 9774f8907eb95da72eb4ab77770895a5214765b0..7df39c4fdff42586d6aa16b5b3b62bbbcbf801dd 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -995,6 +995,19 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...) ret = celt_decoder_ctl(celt_dec, OPUS_GET_PHASE_INVERSION_DISABLED(value)); } break; +#ifdef USE_WEIGHTS_FILE + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + return lpcnet_plc_load_model(&st->lpcnet, data, len); + } + break; +#endif default: /*fprintf(stderr, "unknown opus_decoder_ctl() request: %d", request);*/ ret = OPUS_UNIMPLEMENTED; diff --git a/src/opus_demo.c b/src/opus_demo.c index f1417f7e6203cea1062bba605765c6901b184a20..7594c5cc377dbb9cd4087fee00f16fe79f7a129a 100644 --- a/src/opus_demo.c +++ b/src/opus_demo.c @@ -42,6 +42,50 @@ #define MAX_PACKET 1500 +#ifdef USE_WEIGHTS_FILE +# if __unix__ +# include <fcntl.h> +# include <sys/mman.h> +# include <unistd.h> +# include <sys/stat.h> +/* When available, mmap() is preferable to reading the file, as it leads to + better resource utilization, especially if multiple processes are using the same + file (mapping will be shared in cache). */ +unsigned char *load_blob(const char *filename, int *len) { + int fd; + unsigned char *data; + struct stat st; + stat(filename, &st); + *len = st.st_size; + fd = open(filename, O_RDONLY); + data = mmap(NULL, *len, PROT_READ, MAP_SHARED, fd, 0); + close(fd); + return data; +} +void free_blob(unsigned char *blob, int len) { + munmap(blob, len); +} +# else +unsigned char *load_blob(const char *filename, int *len) { + FILE *file; + unsigned char *data; + file = fopen(filename, "r"); + fseek(file, 0L, SEEK_END); + *len = ftell(file); + fseek(file, 0L, SEEK_SET); + if (*len <= 0) return NULL; + data = malloc(*len); + *len = fread(data, 1, *len, file); + return data; +} +void free_blob(unsigned char *blob, int len) { + free(blob); + (void)len; +} +# endif +#endif + + void print_usage( char* argv[] ) { fprintf(stderr, "Usage: %s [-e] <application> <sampling rate (Hz)> <channels (1/2)> " @@ -270,6 +314,12 @@ int main(int argc, char *argv[]) int lost_count=0; FILE *packet_loss_file=NULL; int dred_duration=0; +#ifdef USE_WEIGHTS_FILE + int blob_len; + unsigned char *blob_data; + const char *filename = "weights_blob.bin"; + blob_data = load_blob(filename, &blob_len); +#endif if (argc < 5 ) { @@ -567,8 +617,9 @@ int main(int argc, char *argv[]) goto failure; } } - - +#ifdef USE_WEIGHTS_FILE + opus_decoder_ctl(dec, OPUS_SET_DNN_BLOB(blob_data, blob_len)); +#endif switch(bandwidth) { case OPUS_BANDWIDTH_NARROWBAND: @@ -928,5 +979,8 @@ failure: free(in); free(out); free(fbytes); +#ifdef USE_WEIGHTS_FILE + free_blob(blob_data, blob_len); +#endif return ret; }