diff --git a/celt/celt_decoder.c b/celt/celt_decoder.c index c0c997f0a8d2bd4bed32c71c34e668bcbbbed2dd..ac1b4fedc129d05e4e0a8b436f16257aec046e77 100644 --- a/celt/celt_decoder.c +++ b/celt/celt_decoder.c @@ -721,7 +721,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM if (loss_duration == 0) { #ifdef ENABLE_DEEP_PLC - update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C); + if (lpcnet->loaded) update_plc_state(lpcnet, decode_mem, &st->plc_preemphasis_mem, C); #endif st->last_pitch_index = pitch_index = celt_plc_pitch_search(decode_mem, C, st->arch); } else { @@ -914,7 +914,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM } while (++c<C); #ifdef ENABLE_DEEP_PLC - if (st->complexity >= 5 || lpcnet->fec_fill_pos > 0) { + if (lpcnet->loaded && (st->complexity >= 5 || lpcnet->fec_fill_pos > 0)) { float overlap_mem; int samples_needed16k; celt_sig *buf; diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index a124cdd001c23d5f9505ba9d4a45eb3ae3935d26..de3ab1a793d15699c28a5c45c48bb39088ee28c9 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -57,8 +57,10 @@ int lpcnet_plc_init(LPCNetPLCState *st) { fargan_init(&st->fargan); lpcnet_encoder_init(&st->enc); st->analysis_pos = PLC_BUF_SIZE; + st->loaded = 0; #ifndef USE_WEIGHTS_FILE ret = init_plc_model(&st->model, lpcnet_plc_arrays); + if (ret == 0) st->loaded = 1; #else ret = 0; #endif @@ -75,11 +77,12 @@ int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len free(list); if (ret == 0) { ret = lpcnet_encoder_load_model(&st->enc, data, len); - } else return -1; + } if (ret == 0) { - return fargan_load_model(&st->fargan, data, len); + ret = fargan_load_model(&st->fargan, data, len); } - else return -1; + if (ret == 0) st->loaded = 1; + return ret; } void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features) { @@ -105,6 +108,7 @@ static void compute_plc_pred(LPCNetPLCState *st, float *out, const float *in) { float zeros[3*PLC_MAX_RNN_NEURONS] = {0}; float dense_out[PLC_DENSE1_OUT_SIZE]; PLCNetState *net = &st->plc_net; + celt_assert(st->loaded); _lpcnet_compute_dense(&st->model.plc_dense1, dense_out, in); compute_gruB(&st->model.plc_gru1, zeros, net->plc_gru1_state, dense_out); compute_gruB(&st->model.plc_gru2, zeros, net->plc_gru2_state, net->plc_gru1_state); @@ -152,6 +156,7 @@ int lpcnet_plc_update(LPCNetPLCState *st, opus_int16 *pcm) { static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, -1.6}; int lpcnet_plc_conceal(LPCNetPLCState *st, opus_int16 *pcm) { int i; + celt_assert(st->loaded); if (st->blend == 0) { int count = 0; while (st->analysis_pos + FRAME_SIZE <= PLC_BUF_SIZE) { diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 4f328ad22e87319496b6c2057537feae977aabaa..30931b1d01cb7118c574fb5a9f5710274d826e4f 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -47,6 +47,7 @@ struct LPCNetPLCState { PLCModel model; FARGANState fargan; LPCNetEncState enc; + int loaded; int arch; #define LPCNET_PLC_RESET_START fec diff --git a/silk/PLC.c b/silk/PLC.c index 1e5248232148d0ad81085d8ef4f97142f922ee36..b35bf750a0a891198b2b8ace3d7ff253771b8613 100644 --- a/silk/PLC.c +++ b/silk/PLC.c @@ -397,7 +397,7 @@ static OPUS_INLINE void silk_PLC_conceal( frame[ i ] = (opus_int16)silk_SAT16( silk_SAT16( silk_RSHIFT_ROUND( silk_SMULWW( sLPC_Q14_ptr[ MAX_LPC_ORDER + i ], prevGain_Q10[ 1 ] ), 8 ) ) ); } #ifdef ENABLE_DEEP_PLC - if ( lpcnet != NULL && psDec->sPLC.fs_kHz == 16 ) { + if ( lpcnet != NULL && lpcnet->loaded && psDec->sPLC.fs_kHz == 16 ) { int run_deep_plc = psDec->sPLC.enable_deep_plc || lpcnet->fec_fill_pos != 0; if( run_deep_plc ) { for( k = 0; k < psDec->nb_subfr; k += 2 ) { diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c index af7f9d94dfffff213b34b6b7bbe0caeac89d0ae5..9b005a63ffda01afdf77fd4a4943ca48cee5237e 100644 --- a/silk/dred_encoder.c +++ b/silk/dred_encoder.c @@ -57,6 +57,7 @@ int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len) if (ret == 0) { ret = lpcnet_encoder_load_model(&enc->lpcnet_enc_state, data, len); } + if (ret == 0) enc->loaded = 1; return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; } @@ -74,8 +75,9 @@ void dred_encoder_init(DREDEnc* enc, opus_int32 Fs, int channels) { enc->Fs = Fs; enc->channels = channels; + enc->loaded = 0; #ifndef USE_WEIGHTS_FILE - init_rdovaeenc(&enc->model, rdovaeenc_arrays); + if (init_rdovaeenc(&enc->model, rdovaeenc_arrays) == 0) enc->loaded = 1; #endif dred_encoder_reset(enc); } @@ -85,6 +87,7 @@ static void dred_process_frame(DREDEnc *enc) float feature_buffer[2 * 36]; float input_buffer[2*DRED_NUM_FEATURES] = {0}; + celt_assert(enc->loaded); /* shift latents buffer */ OPUS_MOVE(enc->latents_buffer + DRED_LATENT_DIM, enc->latents_buffer, (DRED_MAX_FRAMES - 1) * DRED_LATENT_DIM); @@ -184,6 +187,7 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex { int curr_offset16k; int frame_size16k = frame_size * 16000 / enc->Fs; + celt_assert(enc->loaded); curr_offset16k = 40 + extra_delay*16000/enc->Fs - enc->input_buffer_fill; enc->dred_offset = (int)floor((curr_offset16k+20.f)/40.f); enc->latent_offset = 0; diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h index 2b77d581ccbf0bcbb5b2fac5239811b76ce0e3d2..abeaac7f97bf8fd299540effa09004260a141e94 100644 --- a/silk/dred_encoder.h +++ b/silk/dred_encoder.h @@ -40,6 +40,9 @@ typedef struct { RDOVAEEnc model; + LPCNetEncState lpcnet_enc_state; + RDOVAEEncState rdovae_enc; + int loaded; opus_int32 Fs; int channels; @@ -53,8 +56,6 @@ typedef struct { float state_buffer[DRED_STATE_DIM]; float initial_state[DRED_STATE_DIM]; float resample_mem[RESAMPLING_ORDER + 1]; - LPCNetEncState lpcnet_enc_state; - RDOVAEEncState rdovae_enc; } DREDEnc; int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len); diff --git a/src/opus_decoder.c b/src/opus_decoder.c index 67b1cfd39a50d88d0896064747d581bba979b04a..999c6fe041d745f1bd8c7caef614080b46785f3b 100644 --- a/src/opus_decoder.c +++ b/src/opus_decoder.c @@ -1042,7 +1042,7 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...) { goto bad_arg; } - return lpcnet_plc_load_model(&st->lpcnet, data, len); + ret = lpcnet_plc_load_model(&st->lpcnet, data, len); } break; #endif @@ -1156,6 +1156,7 @@ struct OpusDREDDecoder { #ifdef ENABLE_DRED RDOVAEDec model; #endif + int loaded; int arch; opus_uint32 magic; }; @@ -1188,19 +1189,23 @@ int dred_decoder_load_model(OpusDREDDecoder *dec, const unsigned char *data, int parse_weights(&list, data, len); ret = init_rdovaedec(&dec->model, list); free(list); + if (ret == 0) dec->loaded = 1; return (ret == 0) ? OPUS_OK : OPUS_BAD_ARG; } #endif int opus_dred_decoder_init(OpusDREDDecoder *dec) { + int ret = 0; + dec->loaded = 0; #if defined(ENABLE_DRED) && !defined(USE_WEIGHTS_FILE) - init_rdovaedec(&dec->model, rdovaedec_arrays); + ret = init_rdovaedec(&dec->model, rdovaedec_arrays); + if (ret == 0) dec->loaded = 1; #endif dec->arch = opus_select_arch(); /* To make sure nobody forgets to init, use a magic number. */ dec->magic = 0xD8EDDEC0; - return OPUS_OK; + return (ret == 0) ? OPUS_OK : OPUS_UNIMPLEMENTED; } OpusDREDDecoder *opus_dred_decoder_create(int *error) @@ -1378,6 +1383,7 @@ int opus_dred_parse(OpusDREDDecoder *dred_dec, OpusDRED *dred, const unsigned ch const unsigned char *payload; opus_int32 payload_len; VALIDATE_DRED_DECODER(dred_dec); + if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED; dred->process_stage = -1; payload_len = dred_find_payload(data, len, &payload); if (payload_len < 0) @@ -1412,6 +1418,7 @@ int opus_dred_process(OpusDREDDecoder *dred_dec, const OpusDRED *src, OpusDRED * if (dred_dec == NULL || src == NULL || dst == NULL || (src->process_stage != 1 && src->process_stage != 2)) return OPUS_BAD_ARG; VALIDATE_DRED_DECODER(dred_dec); + if (!dred_dec->loaded) return OPUS_UNIMPLEMENTED; if (src != dst) OPUS_COPY(dst, src, 1); if (dst->process_stage == 2) diff --git a/src/opus_encoder.c b/src/opus_encoder.c index 5ed4b1870060c85c898c7aef574e421bdf32ff54..27b3196a381250009ff1607d74d3263aa2602842 100644 --- a/src/opus_encoder.c +++ b/src/opus_encoder.c @@ -1713,7 +1713,7 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_ #endif #ifdef ENABLE_DRED - if ( st->dred_duration > 0 ) { + if ( st->dred_duration > 0 && st->dred_encoder.loaded ) { /* DRED Encoder */ dred_compute_latents( &st->dred_encoder, &pcm_buf[total_buffer*st->channels], frame_size, total_buffer ); } else { @@ -2255,7 +2255,7 @@ opus_int32 opus_encode_native(OpusEncoder *st, const opus_val16 *pcm, int frame_ ret += 1+redundancy_bytes; apply_padding = !st->use_vbr; #ifdef ENABLE_DRED - if (st->dred_duration > 0) { + if (st->dred_duration > 0 && st->dred_encoder.loaded) { opus_extension_data extension; unsigned char buf[DRED_MAX_DATA_SIZE]; int dred_chunks; @@ -2893,17 +2893,17 @@ int opus_encoder_ctl(OpusEncoder *st, int request, ...) } break; #ifdef USE_WEIGHTS_FILE - case OPUS_SET_DNN_BLOB_REQUEST: - { - const unsigned char *data = va_arg(ap, const unsigned char *); - opus_int32 len = va_arg(ap, opus_int32); - if(len<0 || data == NULL) - { - goto bad_arg; - } - return dred_encoder_load_model(&st->dred_encoder, data, len); - } - break; + case OPUS_SET_DNN_BLOB_REQUEST: + { + const unsigned char *data = va_arg(ap, const unsigned char *); + opus_int32 len = va_arg(ap, opus_int32); + if(len<0 || data == NULL) + { + goto bad_arg; + } + ret = dred_encoder_load_model(&st->dred_encoder, data, len); + } + break; #endif case CELT_GET_MODE_REQUEST: {