From fa7b432eed4e9fc11d5f2bfe10c4f54d89dd1788 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Sun, 28 May 2023 01:53:20 -0400 Subject: [PATCH] Initial blob loading support --- dnn/Makefile.am | 2 +- dnn/autogen.sh | 2 +- dnn/include/lpcnet.h | 3 ++ dnn/lpcnet.c | 14 +++++++++ dnn/lpcnet_demo.c | 58 +++++++++++++++++++++++++++++++++++++- dnn/lpcnet_plc.c | 14 +++++++++ dnn/lpcnet_private.h | 2 ++ dnn/write_lpcnet_weights.c | 10 +++++-- 8 files changed, 100 insertions(+), 5 deletions(-) diff --git a/dnn/Makefile.am b/dnn/Makefile.am index 29e1ed606..c4b8b7681 100644 --- a/dnn/Makefile.am +++ b/dnn/Makefile.am @@ -62,7 +62,7 @@ dump_data_SOURCES = common.c dump_data.c burg.c freq.c kiss_fft.c pitch.c lpcnet dump_data_LDADD = $(LIBM) dump_data_CFLAGS = $(AM_CFLAGS) -dump_weights_blob_SOURCES = nnet_data.c plc_data.c write_lpcnet_weights.c +dump_weights_blob_SOURCES = write_lpcnet_weights.c dump_weights_blob_LDADD = $(LIBM) dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS diff --git a/dnn/autogen.sh b/dnn/autogen.sh index c69a4e616..8ae205c69 100755 --- a/dnn/autogen.sh +++ b/dnn/autogen.sh @@ -6,7 +6,7 @@ srcdir=`dirname $0` test -n "$srcdir" && cd "$srcdir" #SHA1 of the first commit compatible with the current model -commit=399be7c +commit=859bfae ./download_model.sh $commit echo "Updating build configuration files for lpcnet, please wait...." diff --git a/dnn/include/lpcnet.h b/dnn/include/lpcnet.h index e8718fcb3..067a43288 100644 --- a/dnn/include/lpcnet.h +++ b/dnn/include/lpcnet.h @@ -199,4 +199,7 @@ LPCNET_EXPORT void lpcnet_plc_fec_add(LPCNetPLCState *st, const float *features) LPCNET_EXPORT void lpcnet_plc_fec_clear(LPCNetPLCState *st); +LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len); +LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len); + #endif diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index 914f4476a..455a5ed32 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -183,11 +183,25 @@ LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet) lpcnet->sampling_logit_table[i] = -log((1-prob)/prob); } kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string)); +#ifndef USE_WEIGHTS_FILE ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays); +#else + ret = 0; +#endif celt_assert(ret == 0); return ret; } +LPCNET_EXPORT int lpcnet_load_model(LPCNetState *st, const unsigned char *data, int len) { + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_lpcnet_model(&st->model, list); + free(list); + if (ret == 0) return 0; + else return -1; +} + LPCNET_EXPORT LPCNetState *lpcnet_create() { diff --git a/dnn/lpcnet_demo.c b/dnn/lpcnet_demo.c index 3fd6993e1..b88176f4d 100644 --- a/dnn/lpcnet_demo.c +++ b/dnn/lpcnet_demo.c @@ -34,6 +34,49 @@ #include "lpcnet.h" #include "freq.h" +#ifdef USE_WEIGHTS_FILE +# if __unix__ +# include <fcntl.h> +# include <sys/mman.h> +# include <unistd.h> +# include <sys/stat.h> +/* When available, mmap() is preferable to reading the file, as it leads to + better resource utilization, especially if multiple processes are using the same + file (mapping will be shared in cache). */ +unsigned char *load_blob(const char *filename, int *len) { + int fd; + unsigned char *data; + struct stat st; + stat(filename, &st); + *len = st.st_size; + fd = open(filename, O_RDONLY); + data = mmap(NULL, *len, PROT_READ, MAP_SHARED, fd, 0); + close(fd); + return data; +} +void free_blob(unsigned char *blob, int len) { + munmap(blob, len); +} +# else +unsigned char *load_blob(const char *filename, int *len) { + FILE *file; + unsigned char *data; + file = fopen(filename, "r"); + fseek(file, 0L, SEEK_END); + *len = ftell(file); + fseek(file, 0L, SEEK_SET); + if (*len <= 0) return NULL; + data = malloc(*len); + *len = fread(data, 1, *len, file); + return data; +} +void free_blob(unsigned char *blob, int len) { + free(blob); + (void)len; +} +# endif +#endif + #define MODE_ENCODE 0 #define MODE_DECODE 1 #define MODE_FEATURES 2 @@ -64,6 +107,11 @@ int main(int argc, char **argv) { FILE *plc_file = NULL; const char *plc_options; int plc_flags=-1; +#ifdef USE_WEIGHTS_FILE + int len; + unsigned char *data; + const char *filename = "weights_blob.bin"; +#endif if (argc < 4) usage(); if (strcmp(argv[1], "-encode") == 0) mode=MODE_ENCODE; else if (strcmp(argv[1], "-decode") == 0) mode=MODE_DECODE; @@ -109,7 +157,9 @@ int main(int argc, char **argv) { fprintf(stderr, "Can't open %s\n", argv[3]); exit(1); } - +#ifdef USE_WEIGHTS_FILE + data = load_blob(filename, &len); +#endif if (mode == MODE_ENCODE) { LPCNetEncState *net; net = lpcnet_encoder_create(); @@ -152,6 +202,9 @@ int main(int argc, char **argv) { } else if (mode == MODE_SYNTHESIS) { LPCNetState *net; net = lpcnet_create(); +#ifdef USE_WEIGHTS_FILE + lpcnet_load_model(net, data, len); +#endif while (1) { float in_features[NB_TOTAL_FEATURES]; float features[NB_FEATURES]; @@ -207,5 +260,8 @@ int main(int argc, char **argv) { } fclose(fin); fclose(fout); +#ifdef USE_WEIGHTS_FILE + free_blob(data, len); +#endif return 0; } diff --git a/dnn/lpcnet_plc.c b/dnn/lpcnet_plc.c index 2d7554ce1..cd9d69201 100644 --- a/dnn/lpcnet_plc.c +++ b/dnn/lpcnet_plc.c @@ -68,11 +68,25 @@ LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { return -1; } st->remove_dc = !!(options&LPCNET_PLC_DC_FILTER); +#ifndef USE_WEIGHTS_FILE ret = init_plc_model(&st->model, lpcnet_plc_arrays); +#else + ret = 0; +#endif celt_assert(ret == 0); return ret; } +LPCNET_EXPORT int lpcnet_plc_load_model(LPCNetPLCState *st, const unsigned char *data, int len) { + WeightArray *list; + int ret; + parse_weights(&list, data, len); + ret = init_plc_model(&st->model, list); + free(list); + if (ret == 0) return 0; + else return -1; +} + LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options) { LPCNetPLCState *st; st = calloc(sizeof(*st), 1); diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index 14b2aeee6..5db3c6372 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -131,4 +131,6 @@ int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, f void process_single_frame(LPCNetEncState *st, FILE *ffeat); void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features); + +int parse_weights(WeightArray **list, const unsigned char *data, int len); #endif diff --git a/dnn/write_lpcnet_weights.c b/dnn/write_lpcnet_weights.c index feba72110..6bcdb72ed 100644 --- a/dnn/write_lpcnet_weights.c +++ b/dnn/write_lpcnet_weights.c @@ -31,8 +31,14 @@ #include <stdio.h> #include "nnet.h" -extern const WeightArray lpcnet_arrays[]; -extern const WeightArray lpcnet_plc_arrays[]; +/* This is a bit of a hack because we need to build nnet_data.c and plc_data.c without USE_WEIGHTS_FILE, + but USE_WEIGHTS_FILE is defined in config.h. */ +#undef HAVE_CONFIG_H +#ifdef USE_WEIGHTS_FILE +#undef USE_WEIGHTS_FILE +#endif +#include "nnet_data.c" +#include "plc_data.c" void write_weights(const WeightArray *list, FILE *fout) { -- GitLab