From feb32828877ea5e8723ea2a446eb20d7b3fba426 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Mon, 30 Oct 2023 14:08:07 -0400 Subject: [PATCH] Don't try to use models that aren't loaded --- celt/celt_decoder.c | 4 ++-- dnn/lpcnet_plc.c | 11 ++++++++--- dnn/lpcnet_private.h | 1 + silk/PLC.c | 2 +- silk/dred_encoder.c | 6 +++++- silk/dred_encoder.h | 5 +++-- src/opus_decoder.c | 13 ++++++++++--- src/opus_encoder.c | 26 +++++++++++++------------- 8 files changed, 43 insertions(+), 25 deletions(-) diff --git a/celt/celt_decoder.c b/celt/celt_decoder.c index c0c997f0a..ac1b4fedc 100644 --- a/celt/celt_decoder.c +++ b/celt/celt_decoder.c @@ -721,7 +721,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM if (loss_duration == 0) { #ifdef ENABLE_DEEP_PLC - update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C); + if (lpcnet->loaded) update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C); #endif st->last_pitch_index = pitch_index = celt_plc_pitch_search(decode_mem, C, st->arch); } else { @@ -914,7 +914,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM } while (++c<C); #ifdef ENABLE_DEEP_PLC - if (st->complexity >= 5 || lpcnet->fec_fill_pos > 0) { + if (lpcnet->loaded && (st->complexity >= 5 || lpcnet->fec_fill_pos > 0)) { float overlap_mem; int samples_needed16k; celt_sig *buf; diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index a124cdd00..de3ab1a79 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -57,8 +57,10 @@ int lpcnet_plc_init(LPCNetPLCState *st) { fargan_init(&st->fargan); lpcnet_encoder_init(&st->enc); st->analysis_pos = PLC_BUF_SIZE; + st->loaded = 0; #ifndef USE_WEIGHTS_FILE ret = init_plc_model(&st->model, lpcnet_plc_arrays); + if (ret == 0) st->loaded = 1; #else ret = 0; #endif @@ -75,11 +77,12 @@ int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len free(list); if (ret == 0) { ret = lpcnet_encoder_load_model(&st->enc, data, len); - } else return -1; + } if (ret == 0) { - return fargan_load_model(&st->fargan, data, len); + ret = fargan_load_model(&st->fargan, data, len); } - else return -1; + if (ret == 0) st->loaded = 1; + return ret; } void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features) { @@ -105,6 +108,7 @@ static void compute_plc_pred(LPCNetPLCState *st, float *out, const float *in) { float zeros[3*PLC_MAX_RNN_NEURONS] = {0}; float dense_out[PLC_DENSE1_OUT_SIZE]; PLCNetState *net = &st->plc_net; + celt_assert(st->loaded); _lpcnet_compute_dense(&st->model.plc_dense1, dense_out, in); compute_gruB(&st->model.plc_gru1, zeros, net->plc_gru1_state, dense_out); compute_gruB(&st->model.plc_gru2, zeros, net->plc_gru2_state, net->plc_gru1_state); @@ -152,6 +156,7 @@ int lpcnet_plc_update(LPCNetPLCState *st, opus_int16 *pcm) { static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, -1.6}; int lpcnet_plc_conceal(LPCNetPLCState *st, opus_int16 *pcm) { int i; + celt_assert(st->loaded); if (st->blend == 0) { int count = 0; while (st->analysis_pos + FRAME_SIZE <= PLC_BUF_SIZE) { diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 4f328ad22..30931b1d0 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -47,6 +47,7 @@ struct LPCNetPLCState { PLCModel model; FARGANState fargan; LPCNetEncState enc; + int loaded; int arch; #define LPCNET_PLC_RESET_START fec diff --git a/silk/PLC.c b/silk/PLC.c index 1e5248232..b35bf750a 100644 --- a/silk/PLC.c +++ b/silk/PLC.c @@ -397,7 +397,7 @@ static OPUS_INLINE void silk_PLC_conceal( frame[ i ] = (opus_int16)silk_SAT16( silk_SAT16( silk_RSHIFT_ROUND( silk_SMULWW( sLPC_Q14_ptr[ MAX_LPC_ORDER + i ], prevGain_Q10[ 1 ] ), 8 ) ) ); } #ifdef ENABLE_DEEP_PLC - if ( lpcnet != NULL && psDec->sPLC.fs_kHz == 16 ) { + if ( lpcnet != NULL && lpcnet->loaded && psDec->sPLC.fs_kHz == 16 ) { int run_deep_plc = psDec->sPLC.enable_deep_plc || lpcnet->fec_fill_pos != 0; if( run_deep_plc ) { for( k = 0; k < psDec->nb_subfr; k += 2 ) { diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index af7f9d94d..9b005a63f 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -57,6 +57,7 @@ int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len) if (ret == 0) { ret = lpcnet_encoder_load_model(&enc->lpcnet_enc_state, data, len); } + if (ret == 0) enc->loaded = 1; return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; } @@ -74,8 +75,9 @@ void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels) { enc->Fs = Fs; enc->channels = channels; + enc->loaded = 0; #ifndef USE_WEIGHTS_FILE - init_rdovaeenc(&enc->model, rdovaeenc_arrays); + if (init_rdovaeenc(&enc->model, rdovaeenc_arrays) == 0) enc->loaded = 1; #endif dred_encoder_reset(enc); } @@ -85,6 +87,7 @@ static void dred_process_frame(DREDEnc *enc) float feature_buffer[2 * 36]; float input_buffer[2*DRED_NUM_FEATURES] = {0}; + celt_assert(enc->loaded); /* shift latents buffer */ OPUS_MOVE(enc->latents_buffer + DRED_LATENT_DIM, enc->latents_buffer, (DRED_MAX_FRAMES - 1) * DRED_LATENT_DIM); @@ -184,6 +187,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex { int curr_offset16k; int frame_size16k = frame_size * 16000 / enc->Fs; + celt_assert(enc->loaded); curr_offset16k = 40 + extra_delay*16000/enc->Fs - enc->input_buffer_fill; enc->dred_offset = (int)floor((curr_offset16k+20.f)/40.f); enc->latent_offset = 0; diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h index 2b77d581c..abeaac7f9 100644 --- a/silk/dred_encoder.h +++ b/silk/dred_encoder.h @@ -40,6 +40,9 @@ typedef struct { RDOVAEEnc model; + LPCNetEncState lpcnet_enc_state; + RDOVAEEncState rdovae_enc; + int loaded; opus_int32 Fs; int channels; @@ -53,8 +56,6 @@ typedef struct { float state_buffer[DRED_STATE_DIM]; float initial_state[DRED_STATE_DIM]; float resample_mem[RESAMPLING_ORDER + 1]; - LPCNetEncState lpcnet_enc_state; - RDOVAEEncState rdovae_enc; } DREDEnc; int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len); diff --git a/src/opus_decoder.c b/src/opus_decoder.c index 67b1cfd39..999c6fe04 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -1042,7 +1042,7 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...) { goto bad_arg; } - return lpcnet_plc_load_model(&st->lpcnet, data, len); + ret = lpcnet_plc_load_model(&st->lpcnet, data, len); } break; #endif @@ -1156,6 +1156,7 @@ struct OpusDREDDecoder { #ifdef ENABLE_DRED RDOVAEDec model; #endif + int loaded; int arch; opus_uint32 magic; }; @@ -1188,19 +1189,23 @@ int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int parse_weights(&list, data, len); ret = init_rdovaedec(&dec->model, list); free(list); + if (ret == 0) dec->loaded = 1; return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; } #endif int opus_dred_decoder_init(OpusDREDDecoder *dec) { + int ret = 0; + dec->loaded = 0; #if defined(ENABLE_DRED) && !defined(USE_WEIGHTS_FILE) - init_rdovaedec(&dec->model, rdovaedec_arrays); + ret = init_rdovaedec(&dec->model, rdovaedec_arrays); + if (ret == 0) dec->loaded = 1; #endif dec->arch = opus_select_arch(); /* To make sure nobody forgets to init, use a magic number. */ dec->magic = 0xD8EDDEC0; - return OPUS_OK; + return (ret == 0) ? OPUS_OK : OPUS_UNIMPLEMENTED; } OpusDREDDecoder *opus_dred_decoder_create(int *error) @@ -1378,6 +1383,7 @@ int opus_dred_parse(OpusDREDDecoder *dred_dec, OpusDRED *dred, const unsigned ch const unsigned char *payload; opus_int32 payload_len; VALIDATE_DRED_DECODER(dred_dec); + if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED; dred->process_stage = -1; payload_len = dred_find_payload(data, len, &payload); if (payload_len < 0) @@ -1412,6 +1418,7 @@ int opus_dred_process(OpusDREDDecoder *dred_dec, const OpusDRED *src, OpusDRED * if (dred_dec == NULL || src == NULL || dst == NULL || (src->process_stage != 1 && src->process_stage != 2)) return OPUS_BAD_ARG; VALIDATE_DRED_DECODER(dred_dec); + if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED; if (src != dst) OPUS_COPY(dst, src, 1); if (dst->process_stage == 2) diff --git a/src/opus_encoder.c b/src/opus_encoder.c index 5ed4b1870..27b3196a3 100644 --- a/src/opus_encoder.c +++ b/src/opus_encoder.c @@ -1713,7 +1713,7 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_ #endif #ifdef ENABLE_DRED - if ( st->dred_duration > 0 ) { + if ( st->dred_duration > 0 && st->dred_encoder.loaded ) { /* DRED Encoder */ dred_compute_latents( &st->dred_encoder, &pcm_buf[total_buffer*st->channels], frame_size, total_buffer ); } else { @@ -2255,7 +2255,7 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_ ret += 1+redundancy_bytes; apply_padding = !st->use_vbr; #ifdef ENABLE_DRED - if (st->dred_duration > 0) { + if (st->dred_duration > 0 && st->dred_encoder.loaded) { opus_extension_data extension; unsigned char buf[DRED_MAX_DATA_SIZE]; int dred_chunks; @@ -2893,17 +2893,17 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...) } break; #ifdef USE_WEIGHTS_FILE - case OPUS_SET_DNN_BLOB_REQUEST: - { - const unsigned char *data = va_arg(ap, const unsigned char *); - opus_int32 len = va_arg(ap, opus_int32); - if(len<0 || data == NULL) - { - goto bad_arg; - } - return dred_encoder_load_model(&st->dred_encoder, data, len); - } - break; + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + ret = dred_encoder_load_model(&st->dred_encoder, data, len); + } + break; #endif case CELT_GET_MODE_REQUEST: { -- GitLab