diff --git a/dnn/Makefile.am b/dnn/Makefile.am index 29e1ed606f94895f62f1067f19c28f200195c461..c4b8b7681847a04c54b5bdaea81ea38ba82b1166 100644 --- a/dnn/Makefile.am +++ b/dnn/Makefile.am @@ -62,7 +62,7 @@ dump_data_SOURCES = common.c dump_data.c burg.c freq.c kiss_fft.c pitch.c lpcnet dump_data_LDADD = $(LIBM) dump_data_CFLAGS = $(AM_CFLAGS) -dump_weights_blob_SOURCES = nnet_data.c plc_data.c write_lpcnet_weights.c +dump_weights_blob_SOURCES = write_lpcnet_weights.c dump_weights_blob_LDADD = $(LIBM) dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS diff --git a/dnn/autogen.sh b/dnn/autogen.sh index c69a4e616103b242a250e714ef3af6f2cc2fa993..8ae205c69f2dd570a6e48d5ec906f356538298d0 100755 --- a/dnn/autogen.sh +++ b/dnn/autogen.sh @@ -6,7 +6,7 @@ srcdir=`dirname $0` test -n "$srcdir" && cd "$srcdir" #SHA1 of the first commit compatible with the current model -commit=399be7c +commit=859bfae ./download_model.sh $commit echo "Updating build configuration files for lpcnet, please wait...." diff --git a/dnn/include/lpcnet.h b/dnn/include/lpcnet.h index e8718fcb3ab23062aa7756bba5fcb2ed33131b22..067a43288a027c2565e44198f52ec4e0ce8153c9 100644 --- a/dnn/include/lpcnet.h +++ b/dnn/include/lpcnet.h @@ -199,4 +199,7 @@ LPCNET_EXPORT void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features) LPCNET_EXPORT void lpcnet_plc_fec_clear(LPCNetPLCState *st); +LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len); +LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len); + #endif diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index 914f4476ab6955fb4c46ec15ba1689a4a9b698dd..455a5ed32e89bf3a8de11b7477a7d4b13197168a 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -183,11 +183,25 @@ LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet) lpcnet->sampling_logit_table[i] = -log((1-prob)/prob); } kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string)); +#ifndef USE_WEIGHTS_FILE ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays); +#else + ret = 0; +#endif celt_assert(ret == 0); return ret; } +LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len) { + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_lpcnet_model(&st->model, list); + free(list); + if (ret == 0) return 0; + else return -1; +} + LPCNET_EXPORT LPCNetState *lpcnet_create() { diff --git a/dnn/lpcnet_demo.c b/dnn/lpcnet_demo.c index 3fd6993e14003c1f89e07a52b1ff78ce62be534f..b88176f4d9d7647811414e757f9573513ae71766 100644 --- a/dnn/lpcnet_demo.c +++ b/dnn/lpcnet_demo.c @@ -34,6 +34,49 @@ #include "lpcnet.h" #include "freq.h" +#ifdef USE_WEIGHTS_FILE +# if __unix__ +# include <fcntl.h> +# include <sys/mman.h> +# include <unistd.h> +# include <sys/stat.h> +/* When available, mmap() is preferable to reading the file, as it leads to + better resource utilization, especially if multiple processes are using the same + file (mapping will be shared in cache). */ +unsigned char *load_blob(const char *filename, int *len) { + int fd; + unsigned char *data; + struct stat st; + stat(filename, &st); + *len = st.st_size; + fd = open(filename, O_RDONLY); + data = mmap(NULL, *len, PROT_READ, MAP_SHARED, fd, 0); + close(fd); + return data; +} +void free_blob(unsigned char *blob, int len) { + munmap(blob, len); +} +# else +unsigned char *load_blob(const char *filename, int *len) { + FILE *file; + unsigned char *data; + file = fopen(filename, "r"); + fseek(file, 0L, SEEK_END); + *len = ftell(file); + fseek(file, 0L, SEEK_SET); + if (*len <= 0) return NULL; + data = malloc(*len); + *len = fread(data, 1, *len, file); + return data; +} +void free_blob(unsigned char *blob, int len) { + free(blob); + (void)len; +} +# endif +#endif + #define MODE_ENCODE 0 #define MODE_DECODE 1 #define MODE_FEATURES 2 @@ -64,6 +107,11 @@ int main(int argc, char **argv) { FILE *plc_file = NULL; const char *plc_options; int plc_flags=-1; +#ifdef USE_WEIGHTS_FILE + int len; + unsigned char *data; + const char *filename = "weights_blob.bin"; +#endif if (argc < 4) usage(); if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE; else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE; @@ -109,7 +157,9 @@ int main(int argc, char **argv) { fprintf(stderr, "Can't open %s\n", argv[3]); exit(1); } - +#ifdef USE_WEIGHTS_FILE + data = load_blob(filename, &len); +#endif if (mode == MODE_ENCODE) { LPCNetEncState *net; net = lpcnet_encoder_create(); @@ -152,6 +202,9 @@ int main(int argc, char **argv) { } else if (mode == MODE_SYNTHESIS) { LPCNetState *net; net = lpcnet_create(); +#ifdef USE_WEIGHTS_FILE + lpcnet_load_model(net, data, len); +#endif while (1) { float in_features[NB_TOTAL_FEATURES]; float features[NB_FEATURES]; @@ -207,5 +260,8 @@ int main(int argc, char **argv) { } fclose(fin); fclose(fout); +#ifdef USE_WEIGHTS_FILE + free_blob(data, len); +#endif return 0; } diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index 2d7554ce1d1cf854a18366b1cb6844a40f242ca4..cd9d69201c5b604b539af027c9043e83ae421388 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -68,11 +68,25 @@ LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { return -1; } st->remove_dc = !!(options&LPCNET_PLC_DC_FILTER); +#ifndef USE_WEIGHTS_FILE ret = init_plc_model(&st->model, lpcnet_plc_arrays); +#else + ret = 0; +#endif celt_assert(ret == 0); return ret; } +LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len) { + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_plc_model(&st->model, list); + free(list); + if (ret == 0) return 0; + else return -1; +} + LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options) { LPCNetPLCState *st; st = calloc(sizeof(*st), 1); diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 14b2aeee643c0ce6e228a964ae46e11916ca484f..5db3c637217a37a97907b9751867633c992c9c1c 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -131,4 +131,6 @@ int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, f void process_single_frame(LPCNetEncState *st, FILE *ffeat); void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features); + +int parse_weights(WeightArray **list, const unsigned char *data, int len); #endif diff --git a/dnn/write_lpcnet_weights.c b/dnn/write_lpcnet_weights.c index feba72110e095229b88804933edc1d51d56b4b6b..6bcdb72edf052452ab65a5f8f587912059452e42 100644 --- a/dnn/write_lpcnet_weights.c +++ b/dnn/write_lpcnet_weights.c @@ -31,8 +31,14 @@ #include <stdio.h> #include "nnet.h" -extern const WeightArray lpcnet_arrays[]; -extern const WeightArray lpcnet_plc_arrays[]; +/* This is a bit of a hack because we need to build nnet_data.c and plc_data.c without USE_WEIGHTS_FILE, + but USE_WEIGHTS_FILE is defined in config.h. */ +#undef HAVE_CONFIG_H +#ifdef USE_WEIGHTS_FILE +#undef USE_WEIGHTS_FILE +#endif +#include "nnet_data.c" +#include "plc_data.c" void write_weights(const WeightArray *list, FILE *fout) {