From 50966eecc5024b3b648889319153d534a147cc71 Mon Sep 17 00:00:00 2001 From: Jan Buethe <jbuethe@amazon.de> Date: Wed, 19 Oct 2022 14:43:12 +0000 Subject: [PATCH] bugfixes in nfec encoder --- dnn/nfec_enc.c | 60 ++++++++++++++++++++++++++++++++++++++++++--- dnn/nfec_enc_demo.c | 12 ++++----- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/dnn/nfec_enc.c b/dnn/nfec_enc.c index 1957ab9ee..d524e8211 100644 --- a/dnn/nfec_enc.c +++ b/dnn/nfec_enc.c @@ -2,50 +2,104 @@ #include "nnet.h" #include "nfec_enc_data.h" +//#define DEBUG + +#ifdef DEBUG +#include <stdio.h> +#endif + void nfec_encode_dframe(struct NFECEncState *enc_state, float *latents, float *initial_state, const float *input) { float buffer[ENC_DENSE1_OUT_SIZE + ENC_DENSE2_OUT_SIZE + ENC_DENSE3_OUT_SIZE + ENC_DENSE4_OUT_SIZE + ENC_DENSE5_OUT_SIZE + ENC_DENSE6_OUT_SIZE + ENC_DENSE7_OUT_SIZE + ENC_DENSE8_OUT_SIZE + GDENSE1_OUT_SIZE]; int output_index = 0; int input_index = 0; +#ifdef DEBUG + static FILE *fids[8] = {NULL}; + static FILE *fpre = NULL; + int i; + char filename[256]; + + for (i=0; i < 8; i ++) + { + if (fids[i] == NULL) + { + sprintf(filename, "x%d.f32", i + 1); + fids[i] = fopen(filename, "wb"); + } + } + if (fpre == NULL) + { + fpre = fopen("x_pre.f32", "wb"); + } +#endif + /* run encoder stack and concatenate output in buffer*/ compute_dense(&enc_dense1, &buffer[output_index], input); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE1_OUT_SIZE, fids[0]); +#endif input_index = output_index; output_index += ENC_DENSE1_OUT_SIZE; - compute_gru3(&enc_dense2, enc_state->dense2_state, &buffer[input_index]); + compute_gru2(&enc_dense2, enc_state->dense2_state, &buffer[input_index]); memcpy(&buffer[output_index], enc_state->dense2_state, ENC_DENSE2_OUT_SIZE * sizeof(float)); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE2_OUT_SIZE, fids[1]); +#endif input_index = output_index; output_index += ENC_DENSE2_OUT_SIZE; compute_dense(&enc_dense3, &buffer[output_index], &buffer[input_index]); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE3_OUT_SIZE, fids[2]); +#endif input_index = output_index; output_index += ENC_DENSE3_OUT_SIZE; - compute_gru3(&enc_dense4, enc_state->dense4_state, &buffer[input_index]); + compute_gru2(&enc_dense4, enc_state->dense4_state, &buffer[input_index]); memcpy(&buffer[output_index], enc_state->dense4_state, ENC_DENSE4_OUT_SIZE * sizeof(float)); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE4_OUT_SIZE, fids[3]); +#endif input_index = output_index; output_index += ENC_DENSE4_OUT_SIZE; compute_dense(&enc_dense5, &buffer[output_index], &buffer[input_index]); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE5_OUT_SIZE, fids[4]); +#endif input_index = output_index; output_index += ENC_DENSE5_OUT_SIZE; - compute_gru3(&enc_dense6, enc_state->dense6_state, &buffer[input_index]); + compute_gru2(&enc_dense6, enc_state->dense6_state, &buffer[input_index]); memcpy(&buffer[output_index], enc_state->dense6_state, ENC_DENSE6_OUT_SIZE * sizeof(float)); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE6_OUT_SIZE, fids[5]); +#endif input_index = output_index; output_index += ENC_DENSE6_OUT_SIZE; compute_dense(&enc_dense7, &buffer[output_index], &buffer[input_index]); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE7_OUT_SIZE, fids[6]); +#endif input_index = output_index; output_index += ENC_DENSE7_OUT_SIZE; compute_dense(&enc_dense8, &buffer[output_index], &buffer[input_index]); +#ifdef DEBUG + fwrite(&buffer[output_index], sizeof(buffer[0]), ENC_DENSE8_OUT_SIZE, fids[7]); +#endif output_index += ENC_DENSE8_OUT_SIZE; /* compute latents from concatenated input buffer */ +#ifdef DEBUG + fwrite(buffer, sizeof(buffer[0]), bits_dense.nb_inputs, fpre); +#endif compute_conv1d(&bits_dense, latents, enc_state->bits_dense_state, buffer); + /* next, calculate initial state */ compute_dense(&gdense1, &buffer[output_index], buffer); input_index = output_index; diff --git a/dnn/nfec_enc_demo.c b/dnn/nfec_enc_demo.c index addc52dd8..809c90bd5 100644 --- a/dnn/nfec_enc_demo.c +++ b/dnn/nfec_enc_demo.c @@ -12,8 +12,8 @@ void usage() int main(int argc, char **argv) { struct NFECEncState enc_state; - float feature_buffer[32]; - float dframe[2 * 20]; + float feature_buffer[36]; + float dframe[2 * NFEC_NUM_FEATURES]; float latents[80]; float initial_state[24]; int index = 0; @@ -41,16 +41,16 @@ int main(int argc, char **argv) } states_fid = fopen(argv[3], "wb"); - if (fid == NULL) + if (states_fid == NULL) { fprintf(stderr, "could not open states file %s\n", argv[3]); usage(); } - while (fread(feature_buffer, sizeof(float), 32, fid) == 32) + while (fread(feature_buffer, sizeof(float), 36, fid) == 36) { - memcpy(&dframe[16 * index++], feature_buffer, 16*sizeof(float)); + memcpy(&dframe[NFEC_NUM_FEATURES * index++], feature_buffer, NFEC_NUM_FEATURES*sizeof(float)); if (index == 2) { @@ -66,4 +66,4 @@ int main(int argc, char **argv) fclose(latents_fid); } -/* gcc -DDISABLE_DOT_PROD -DDISABLE_NEON nfec_enc_demo.c nfec_enc.c nnet.c nfec_enc_data.c kiss99.c -o nfec_enc_demo */ \ No newline at end of file +/* gcc -DDISABLE_DOT_PROD -DDISABLE_NEON nfec_enc_demo.c nfec_enc.c nnet.c nfec_enc_data.c kiss99.c -g -o nfec_enc_demo */ \ No newline at end of file -- GitLab