diff --git a/dnn/include/lpcnet.h b/dnn/include/lpcnet.h index 7db3776fcdc4aba12c4dbb03bc04fbfd350d8509..65d08fb85af9f5274cebfa8f3418d85476a2f346 100644 --- a/dnn/include/lpcnet.h +++ b/dnn/include/lpcnet.h @@ -176,11 +176,18 @@ LPCNET_EXPORT void lpcnet_destroy(LPCNetState *st); */ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *st, const float *features, short *output, int N); + +#define LPCNET_PLC_CAUSAL 0 +#define LPCNET_PLC_NONCAUSAL 1 +#define LPCNET_PLC_CODEC 2 + +#define LPCNET_PLC_DC_FILTER 4 + LPCNET_EXPORT int lpcnet_plc_get_size(void); -LPCNET_EXPORT void lpcnet_plc_init(LPCNetPLCState *st); +LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options); -LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(void); +LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options); LPCNET_EXPORT void lpcnet_plc_destroy(LPCNetPLCState *st); diff --git a/dnn/lpcnet_demo.c b/dnn/lpcnet_demo.c index 0797079124f18765516bb1f0703ed039b0838d5e..85009529d29ee6e3b48dc643d07c351cdc9c9acf 100644 --- a/dnn/lpcnet_demo.c +++ b/dnn/lpcnet_demo.c @@ -40,45 +40,70 @@ #define MODE_SYNTHESIS 3 #define MODE_PLC 4 +void usage(void) { + fprintf(stderr, "usage: lpcnet_demo -encode <input.pcm> <compressed.lpcnet>\n"); + fprintf(stderr, " lpcnet_demo -decode <compressed.lpcnet> <output.pcm>\n"); + fprintf(stderr, " lpcnet_demo -features <input.pcm> <features.f32>\n"); + fprintf(stderr, " lpcnet_demo -synthesis <features.f32> <output.pcm>\n"); + fprintf(stderr, " lpcnet_demo -plc <plc_options> <percent> <input.pcm> <output.pcm>\n"); + fprintf(stderr, " lpcnet_demo -plc_file <plc_options> <percent> <input.pcm> <output.pcm>\n\n"); + fprintf(stderr, " plc_options:\n"); + fprintf(stderr, " causal: normal (causal) PLC\n"); + fprintf(stderr, " causal_dc: normal (causal) PLC with DC offset compensation\n"); + fprintf(stderr, " noncausal: non-causal PLC\n"); + fprintf(stderr, " noncausal_dc: non-causal PLC with DC offset compensation\n"); + exit(1); +} + int main(int argc, char **argv) { int mode; int plc_percent=0; FILE *fin, *fout; FILE *plc_file = NULL; - if (argc != 4 && !(argc == 5 && (strcmp(argv[1], "-plc") == 0 || strcmp(argv[1], "-plc_file") == 0))) - { - fprintf(stderr, "usage: lpcnet_demo -encode <input.pcm> <compressed.lpcnet>\n"); - fprintf(stderr, " lpcnet_demo -decode <compressed.lpcnet> <output.pcm>\n"); - fprintf(stderr, " lpcnet_demo -features <input.pcm> <features.f32>\n"); - fprintf(stderr, " lpcnet_demo -synthesis <features.f32> <output.pcm>\n"); - fprintf(stderr, " lpcnet_demo -plc <percent> <input.pcm> <output.pcm>\n"); - return 0; - } + const char *plc_options; + int plc_flags=-1; + if (argc < 4) usage(); if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE; else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE; else if (strcmp(argv[1], "-features") == 0) mode=MODE_FEATURES; else if (strcmp(argv[1], "-synthesis") == 0) mode=MODE_SYNTHESIS; else if (strcmp(argv[1], "-plc") == 0) { mode=MODE_PLC; - plc_percent = atoi(argv[2]); - argv++; + plc_options = argv[2]; + plc_percent = atoi(argv[3]); + argv+=2; + argc-=2; } else if (strcmp(argv[1], "-plc_file") == 0) { mode=MODE_PLC; - plc_file = fopen(argv[2], "r"); - argv++; + plc_options = argv[2]; + plc_file = fopen(argv[3], "r"); + if (!plc_file) { + fprintf(stderr, "Can't open %s\n", argv[3]); + exit(1); + } + argv+=2; + argc-=2; } else { - exit(1); + usage(); + } + if (mode == MODE_PLC) { + if (strcmp(plc_options, "causal")==0) plc_flags = LPCNET_PLC_CAUSAL; + else if (strcmp(plc_options, "causal_dc")==0) plc_flags = LPCNET_PLC_CAUSAL | LPCNET_PLC_DC_FILTER; + else if (strcmp(plc_options, "noncausal")==0) plc_flags = LPCNET_PLC_NONCAUSAL; + else if (strcmp(plc_options, "noncausal_dc")==0) plc_flags = LPCNET_PLC_NONCAUSAL | LPCNET_PLC_DC_FILTER; + else usage(); } + if (argc != 4) usage(); fin = fopen(argv[2], "rb"); if (fin == NULL) { - fprintf(stderr, "Can't open %s\n", argv[2]); - exit(1); + fprintf(stderr, "Can't open %s\n", argv[2]); + exit(1); } fout = fopen(argv[3], "wb"); if (fout == NULL) { - fprintf(stderr, "Can't open %s\n", argv[3]); - exit(1); + fprintf(stderr, "Can't open %s\n", argv[3]); + exit(1); } if (mode == MODE_ENCODE) { @@ -140,7 +165,7 @@ int main(int argc, char **argv) { int count=0; int loss=0; LPCNetPLCState *net; - net = lpcnet_plc_create(); + net = lpcnet_plc_create(plc_flags); while (1) { size_t ret; ret = fread(pcm, sizeof(pcm[0]), FRAME_SIZE, fin); diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index 77507ba459a587f0c6b0dc58ec7ae3c43a8b3537..9b2680ef1d4b0ce400df1d01ef3830a5b2be76db 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -36,7 +36,7 @@ LPCNET_EXPORT int lpcnet_plc_get_size() { return sizeof(LPCNetPLCState); } -LPCNET_EXPORT void lpcnet_plc_init(LPCNetPLCState *st) { +LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { RNN_CLEAR(st, 1); lpcnet_init(&st->lpcnet); lpcnet_encoder_init(&st->enc); @@ -45,16 +45,28 @@ LPCNET_EXPORT void lpcnet_plc_init(LPCNetPLCState *st) { st->skip_analysis = 0; st->blend = 0; st->loss_count = 0; - st->enable_blending = 1; st->dc_mem = 0; - st->remove_dc = 1; st->queued_update = 0; + if ((options&0x3) == LPCNET_PLC_CAUSAL) { + st->enable_blending = 1; + st->non_causal = 0; + } else if ((options&0x3) == LPCNET_PLC_NONCAUSAL) { + st->enable_blending = 1; + st->non_causal = 1; + } else if ((options&0x3) == LPCNET_PLC_CODEC) { + st->enable_blending = 0; + st->non_causal = 0; + } else { + return -1; + } + st->remove_dc = !!(options&LPCNET_PLC_DC_FILTER); + return 0; } -LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create() { +LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options) { LPCNetPLCState *st; st = calloc(sizeof(*st), 1); - lpcnet_plc_init(st); + lpcnet_plc_init(st, options); return st; } @@ -81,12 +93,10 @@ void clear_state(LPCNetPLCState *st) { #define DC_CONST 0.003 -#if 1 - /* In this causal version of the code, the DNN model implemented by compute_plc_pred() needs to generate two feature vectors to conceal the first lost packet.*/ -LPCNET_EXPORT int lpcnet_plc_update(LPCNetPLCState *st, short *pcm) { +static int lpcnet_plc_update_causal(LPCNetPLCState *st, short *pcm) { int i; float x[FRAME_SIZE]; short output[FRAME_SIZE]; @@ -168,7 +178,7 @@ LPCNET_EXPORT int lpcnet_plc_update(LPCNetPLCState *st, short *pcm) { } static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, -1.6}; -LPCNET_EXPORT int lpcnet_plc_conceal(LPCNetPLCState *st, short *pcm) { +static int lpcnet_plc_conceal_causal(LPCNetPLCState *st, short *pcm) { int i; short output[FRAME_SIZE]; float zeros[2*NB_BANDS+NB_FEATURES+1] = {0}; @@ -212,8 +222,6 @@ LPCNET_EXPORT int lpcnet_plc_conceal(LPCNetPLCState *st, short *pcm) { return 0; } -#else - /* In this non-causal version of the code, the DNN model implemented by compute_plc_pred() is always called once per frame. We process audio up to the current position minus TRAINING_OFFSET. */ @@ -224,7 +232,7 @@ void process_queued_update(LPCNetPLCState *st) { } } -LPCNET_EXPORT int lpcnet_plc_update(LPCNetPLCState *st, short *pcm) { +static int lpcnet_plc_update_non_causal(LPCNetPLCState *st, short *pcm) { int i; float x[FRAME_SIZE]; short pcm_save[FRAME_SIZE]; @@ -320,8 +328,7 @@ LPCNET_EXPORT int lpcnet_plc_update(LPCNetPLCState *st, short *pcm) { return 0; } -static const float att_table[10] = {0, 0, -.2, -.2, -.4, -.4, -.8, -.8, -1.6, -1.6}; -LPCNET_EXPORT int lpcnet_plc_conceal(LPCNetPLCState *st, short *pcm) { +static int lpcnet_plc_conceal_non_causal(LPCNetPLCState *st, short *pcm) { int i; float x[FRAME_SIZE]; float zeros[2*NB_BANDS+NB_FEATURES+1] = {0}; @@ -364,4 +371,13 @@ LPCNET_EXPORT int lpcnet_plc_conceal(LPCNetPLCState *st, short *pcm) { return 0; } -#endif + +LPCNET_EXPORT int lpcnet_plc_update(LPCNetPLCState *st, short *pcm) { + if (st->non_causal) return lpcnet_plc_update_non_causal(st, pcm); + else return lpcnet_plc_update_causal(st, pcm); +} + +LPCNET_EXPORT int lpcnet_plc_conceal(LPCNetPLCState *st, short *pcm) { + if (st->non_causal) return lpcnet_plc_conceal_non_causal(st, pcm); + else return lpcnet_plc_conceal_causal(st, pcm); +} diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index ea06f006aa608cb13943d9bc79ecb39ee8145641..d6c9b615ce3bee2ecb501e67ca3e319aa8f96ced 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -79,6 +79,7 @@ struct LPCNetPLCState { int loss_count; PLCNetState plc_net; int enable_blending; + int non_causal; double dc_mem; double syn_dc; int remove_dc;