diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index 2f5c2511ac06af3038eecbb39fa9e13f4473f1e9..8b90b16df89d76990b51b1dfa47ba194ac0afe0d 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -176,6 +176,14 @@ LPCNET_EXPORT void lpcnet_destroy(LPCNetState *lpcnet) free(lpcnet); } +void lpcnet_reset_signal(LPCNetState *lpcnet) +{ + lpcnet->deemph_mem = 0; + lpcnet->last_exc = lin2ulaw(0.f); + RNN_CLEAR(lpcnet->last_sig, LPC_ORDER); + RNN_CLEAR(lpcnet->nnet.gru_a_state, GRU_A_STATE_SIZE); + RNN_CLEAR(lpcnet->nnet.gru_b_state, GRU_B_STATE_SIZE); +} void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload) { diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index 6dfb50534af8fb4a1bbcc969c37367757b7242ca..3fcf431466d4dd00234b5702579b6f163012fa6e 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -32,6 +32,9 @@ #include "lpcnet.h" #include "plc_data.h" +/* Comment this out to have LPCNet update its state on every good packet (slow). */ +#define PLC_SKIP_UPDATES + LPCNET_EXPORT int lpcnet_plc_get_size() { return sizeof(LPCNetPLCState); } @@ -200,8 +203,12 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { } else { if (FEATURES_DELAY > 0) st->plc_net = st->plc_copy[FEATURES_DELAY-1]; fec_rewind(st, FEATURES_DELAY); +#ifdef PLC_SKIP_UPDATES + lpcnet_reset_signal(&st->lpcnet); +#else RNN_COPY(tmp, pcm, FRAME_SIZE-TRAINING_OFFSET); lpcnet_synthesize_tail_impl(&st->lpcnet, tmp, FRAME_SIZE-TRAINING_OFFSET, FRAME_SIZE-TRAINING_OFFSET); +#endif } RNN_COPY(st->pcm, &pcm[FRAME_SIZE-TRAINING_OFFSET], TRAINING_OFFSET); st->pcm_fill = TRAINING_OFFSET; @@ -237,7 +244,16 @@ static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { } else { for (i=0;i<FRAME_SIZE;i++) st->pcm[PLC_BUF_SIZE+i] = pcm[i]; RNN_COPY(output, &st->pcm[0], FRAME_SIZE); +#ifdef PLC_SKIP_UPDATES + { + float lpc[LPC_ORDER]; + float gru_a_condition[3*GRU_A_STATE_SIZE]; + float gru_b_condition[3*GRU_B_STATE_SIZE]; + run_frame_network(&st->lpcnet, gru_a_condition, gru_b_condition, lpc, st->enc.features[0]); + } +#else lpcnet_synthesize_impl(&st->lpcnet, st->enc.features[0], output, FRAME_SIZE, FRAME_SIZE); +#endif RNN_MOVE(st->pcm, &st->pcm[FRAME_SIZE], PLC_BUF_SIZE); } st->loss_count = 0; diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 966ef9b32096a634c1d17579c9ec2d784ea2b2c2..1d2936d69f389de8eaae5c3b7ed7d702887d75f1 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -111,6 +111,7 @@ void compute_frame_features(LPCNetEncState *st, const float *in); void decode_packet(float features[4][NB_TOTAL_FEATURES], float *vq_mem, const unsigned char buf[8]); +void lpcnet_reset_signal(LPCNetState *lpcnet); void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features); void lpcnet_synthesize_tail_impl(LPCNetState *lpcnet, short *output, int N, int preload); void lpcnet_synthesize_impl(LPCNetState *lpcnet, const float *features, short *output, int N, int preload);