From 27663d364188711e5f662304357cb532e689bfe2 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Thu, 21 Sep 2023 12:20:11 -0400 Subject: [PATCH] Using a DenseNet for DRED --- dnn/dred_rdovae_dec.c | 100 +++++++++------ dnn/dred_rdovae_dec.h | 14 ++- dnn/dred_rdovae_enc.c | 90 ++++++++------ dnn/dred_rdovae_enc.h | 15 ++- dnn/nnet.c | 19 +++ dnn/nnet.h | 1 + dnn/torch/rdovae/export_rdovae_weights.py | 54 ++++---- dnn/torch/rdovae/rdovae/rdovae.py | 143 +++++++++++++--------- 8 files changed, 268 insertions(+), 168 deletions(-) diff --git a/dnn/dred_rdovae_dec.c b/dnn/dred_rdovae_dec.c index c79723b35..1fef375ad 100644 --- a/dnn/dred_rdovae_dec.c +++ b/dnn/dred_rdovae_dec.c @@ -33,16 +33,36 @@ #include "dred_rdovae_constants.h" #include "os_support.h" +static void conv1_cond_init(float *mem, int len, int dilation, int *init) +{ + if (!*init) { + int i; + for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len); + } + *init = 1; +} + void dred_rdovae_dec_init_states( RDOVAEDecState *h, /* io: state buffer handle */ const RDOVAEDec *model, const float *initial_state /* i: initial state */ ) { - /* initialize GRU states from initial state */ - compute_generic_dense(&model->state1, h->dense2_state, initial_state, ACTIVATION_TANH); - compute_generic_dense(&model->state2, h->dense4_state, initial_state, ACTIVATION_TANH); - compute_generic_dense(&model->state3, h->dense6_state, initial_state, ACTIVATION_TANH); + float hidden[DEC_HIDDEN_INIT_OUT_SIZE]; + float state_init[DEC_GRU1_STATE_SIZE+DEC_GRU2_STATE_SIZE+DEC_GRU3_STATE_SIZE+DEC_GRU4_STATE_SIZE+DEC_GRU5_STATE_SIZE]; + int counter=0; + compute_generic_dense(&model->dec_hidden_init, hidden, initial_state, ACTIVATION_TANH); + compute_generic_dense(&model->dec_gru_init, state_init, hidden, ACTIVATION_TANH); + OPUS_COPY(h->gru1_state, state_init, DEC_GRU1_STATE_SIZE); + counter += DEC_GRU1_STATE_SIZE; + OPUS_COPY(h->gru2_state, &state_init[counter], DEC_GRU2_STATE_SIZE); + counter += DEC_GRU2_STATE_SIZE; + OPUS_COPY(h->gru3_state, &state_init[counter], DEC_GRU3_STATE_SIZE); + counter += DEC_GRU3_STATE_SIZE; + OPUS_COPY(h->gru4_state, &state_init[counter], DEC_GRU4_STATE_SIZE); + counter += DEC_GRU4_STATE_SIZE; + OPUS_COPY(h->gru5_state, &state_init[counter], DEC_GRU5_STATE_SIZE); + h->initialized = 0; } @@ -53,44 +73,48 @@ void dred_rdovae_decode_qframe( const float *input /* i: latent vector */ ) { - float buffer[DEC_DENSE1_OUT_SIZE + DEC_DENSE2_OUT_SIZE + DEC_DENSE3_OUT_SIZE + DEC_DENSE4_OUT_SIZE + DEC_DENSE5_OUT_SIZE + DEC_DENSE6_OUT_SIZE + DEC_DENSE7_OUT_SIZE + DEC_DENSE8_OUT_SIZE]; + float buffer[DEC_DENSE1_OUT_SIZE + DEC_GRU1_OUT_SIZE + DEC_GRU2_OUT_SIZE + DEC_GRU3_OUT_SIZE + DEC_GRU4_OUT_SIZE + DEC_GRU5_OUT_SIZE + + DEC_CONV1_OUT_SIZE + DEC_CONV2_OUT_SIZE + DEC_CONV3_OUT_SIZE + DEC_CONV4_OUT_SIZE + DEC_CONV5_OUT_SIZE]; int output_index = 0; - int input_index = 0; /* run encoder stack and concatenate output in buffer*/ compute_generic_dense(&model->dec_dense1, &buffer[output_index], input, ACTIVATION_TANH); - input_index = output_index; output_index += DEC_DENSE1_OUT_SIZE; - compute_generic_gru(&model->dec_dense2_input, &model->dec_dense2_recurrent, dec_state->dense2_state, &buffer[input_index]); - OPUS_COPY(&buffer[output_index], dec_state->dense2_state, DEC_DENSE2_OUT_SIZE); - input_index = output_index; - output_index += DEC_DENSE2_OUT_SIZE; - - compute_generic_dense(&model->dec_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - input_index = output_index; - output_index += DEC_DENSE3_OUT_SIZE; - - compute_generic_gru(&model->dec_dense4_input, &model->dec_dense4_recurrent, dec_state->dense4_state, &buffer[input_index]); - OPUS_COPY(&buffer[output_index], dec_state->dense4_state, DEC_DENSE4_OUT_SIZE); - input_index = output_index; - output_index += DEC_DENSE4_OUT_SIZE; - - compute_generic_dense(&model->dec_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - input_index = output_index; - output_index += DEC_DENSE5_OUT_SIZE; - - compute_generic_gru(&model->dec_dense6_input, &model->dec_dense6_recurrent, dec_state->dense6_state, &buffer[input_index]); - OPUS_COPY(&buffer[output_index], dec_state->dense6_state, DEC_DENSE6_OUT_SIZE); - input_index = output_index; - output_index += DEC_DENSE6_OUT_SIZE; - - compute_generic_dense(&model->dec_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - input_index = output_index; - output_index += DEC_DENSE7_OUT_SIZE; - - compute_generic_dense(&model->dec_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - output_index += DEC_DENSE8_OUT_SIZE; - - compute_generic_dense(&model->dec_final, qframe, buffer, ACTIVATION_LINEAR); + compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer); + OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE); + output_index += DEC_GRU1_OUT_SIZE; + conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized); + compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH); + output_index += DEC_CONV1_OUT_SIZE; + + compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer); + OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE); + output_index += DEC_GRU2_OUT_SIZE; + conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized); + compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH); + output_index += DEC_CONV2_OUT_SIZE; + + compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer); + OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE); + output_index += DEC_GRU3_OUT_SIZE; + conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized); + compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH); + output_index += DEC_CONV3_OUT_SIZE; + + compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer); + OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE); + output_index += DEC_GRU4_OUT_SIZE; + conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized); + compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH); + output_index += DEC_CONV4_OUT_SIZE; + + compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer); + OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE); + output_index += DEC_GRU5_OUT_SIZE; + conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized); + compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH); + output_index += DEC_CONV5_OUT_SIZE; + + compute_generic_dense(&model->dec_output, qframe, buffer, ACTIVATION_LINEAR); } diff --git a/dnn/dred_rdovae_dec.h b/dnn/dred_rdovae_dec.h index 008551b5d..4e039cf27 100644 --- a/dnn/dred_rdovae_dec.h +++ b/dnn/dred_rdovae_dec.h @@ -33,9 +33,17 @@ #include "dred_rdovae_stats_data.h" struct RDOVAEDecStruct { - float dense2_state[DEC_DENSE2_STATE_SIZE]; - float dense4_state[DEC_DENSE2_STATE_SIZE]; - float dense6_state[DEC_DENSE2_STATE_SIZE]; + int initialized; + float gru1_state[DEC_GRU1_STATE_SIZE]; + float gru2_state[DEC_GRU2_STATE_SIZE]; + float gru3_state[DEC_GRU3_STATE_SIZE]; + float gru4_state[DEC_GRU4_STATE_SIZE]; + float gru5_state[DEC_GRU5_STATE_SIZE]; + float conv1_state[DEC_CONV1_STATE_SIZE]; + float conv2_state[DEC_CONV2_STATE_SIZE]; + float conv3_state[DEC_CONV3_STATE_SIZE]; + float conv4_state[DEC_CONV4_STATE_SIZE]; + float conv5_state[DEC_CONV5_STATE_SIZE]; }; void dred_rdovae_dec_init_states(RDOVAEDecState *h, const RDOVAEDec *model, const float * initial_state); diff --git a/dnn/dred_rdovae_enc.c b/dnn/dred_rdovae_enc.c index 9361af17b..98ffba8c8 100644 --- a/dnn/dred_rdovae_enc.c +++ b/dnn/dred_rdovae_enc.c @@ -35,6 +35,15 @@ #include "dred_rdovae_enc.h" #include "os_support.h" +static void conv1_cond_init(float *mem, int len, int dilation, int *init) +{ + if (!*init) { + int i; + for (i=0;i<dilation;i++) OPUS_CLEAR(&mem[i*len], len); + } + *init = 1; +} + void dred_rdovae_encode_dframe( RDOVAEEncState *enc_state, /* io: encoder state */ const RDOVAEEnc *model, @@ -43,52 +52,53 @@ void dred_rdovae_encode_dframe( const float *input /* i: double feature frame (concatenated) */ ) { - float buffer[ENC_DENSE1_OUT_SIZE + ENC_DENSE2_OUT_SIZE + ENC_DENSE3_OUT_SIZE + ENC_DENSE4_OUT_SIZE + ENC_DENSE5_OUT_SIZE + ENC_DENSE6_OUT_SIZE + ENC_DENSE7_OUT_SIZE + ENC_DENSE8_OUT_SIZE + GDENSE1_OUT_SIZE]; + float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE + + ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE]; + float state_hidden[GDENSE1_OUT_SIZE]; int output_index = 0; - int input_index = 0; /* run encoder stack and concatenate output in buffer*/ compute_generic_dense(&model->enc_dense1, &buffer[output_index], input, ACTIVATION_TANH); - input_index = output_index; output_index += ENC_DENSE1_OUT_SIZE; - compute_generic_gru(&model->enc_dense2_input, &model->enc_dense2_recurrent, enc_state->dense2_state, &buffer[input_index]); - OPUS_COPY(&buffer[output_index], enc_state->dense2_state, ENC_DENSE2_OUT_SIZE); - input_index = output_index; - output_index += ENC_DENSE2_OUT_SIZE; - - compute_generic_dense(&model->enc_dense3, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - input_index = output_index; - output_index += ENC_DENSE3_OUT_SIZE; - - compute_generic_gru(&model->enc_dense4_input, &model->enc_dense4_recurrent, enc_state->dense4_state, &buffer[input_index]); - OPUS_COPY(&buffer[output_index], enc_state->dense4_state, ENC_DENSE4_OUT_SIZE); - input_index = output_index; - output_index += ENC_DENSE4_OUT_SIZE; - - compute_generic_dense(&model->enc_dense5, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - input_index = output_index; - output_index += ENC_DENSE5_OUT_SIZE; - - compute_generic_gru(&model->enc_dense6_input, &model->enc_dense6_recurrent, enc_state->dense6_state, &buffer[input_index]); - OPUS_COPY(&buffer[output_index], enc_state->dense6_state, ENC_DENSE6_OUT_SIZE); - input_index = output_index; - output_index += ENC_DENSE6_OUT_SIZE; - - compute_generic_dense(&model->enc_dense7, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - input_index = output_index; - output_index += ENC_DENSE7_OUT_SIZE; - - compute_generic_dense(&model->enc_dense8, &buffer[output_index], &buffer[input_index], ACTIVATION_TANH); - output_index += ENC_DENSE8_OUT_SIZE; - - /* compute latents from concatenated input buffer */ - compute_generic_conv1d(&model->bits_dense, latents, enc_state->bits_dense_state, buffer, BITS_DENSE_IN_SIZE, ACTIVATION_LINEAR); - + compute_generic_gru(&model->enc_gru1_input, &model->enc_gru1_recurrent, enc_state->gru1_state, buffer); + OPUS_COPY(&buffer[output_index], enc_state->gru1_state, ENC_GRU1_OUT_SIZE); + output_index += ENC_GRU1_OUT_SIZE; + conv1_cond_init(enc_state->conv1_state, output_index, 1, &enc_state->initialized); + compute_generic_conv1d(&model->enc_conv1, &buffer[output_index], enc_state->conv1_state, buffer, output_index, ACTIVATION_TANH); + output_index += ENC_CONV1_OUT_SIZE; + + compute_generic_gru(&model->enc_gru2_input, &model->enc_gru2_recurrent, enc_state->gru2_state, buffer); + OPUS_COPY(&buffer[output_index], enc_state->gru2_state, ENC_GRU2_OUT_SIZE); + output_index += ENC_GRU2_OUT_SIZE; + conv1_cond_init(enc_state->conv2_state, output_index, 2, &enc_state->initialized); + compute_generic_conv1d_dilation(&model->enc_conv2, &buffer[output_index], enc_state->conv2_state, buffer, output_index, 2, ACTIVATION_TANH); + output_index += ENC_CONV2_OUT_SIZE; + + compute_generic_gru(&model->enc_gru3_input, &model->enc_gru3_recurrent, enc_state->gru3_state, buffer); + OPUS_COPY(&buffer[output_index], enc_state->gru3_state, ENC_GRU3_OUT_SIZE); + output_index += ENC_GRU3_OUT_SIZE; + conv1_cond_init(enc_state->conv3_state, output_index, 2, &enc_state->initialized); + compute_generic_conv1d_dilation(&model->enc_conv3, &buffer[output_index], enc_state->conv3_state, buffer, output_index, 2, ACTIVATION_TANH); + output_index += ENC_CONV3_OUT_SIZE; + + compute_generic_gru(&model->enc_gru4_input, &model->enc_gru4_recurrent, enc_state->gru4_state, buffer); + OPUS_COPY(&buffer[output_index], enc_state->gru4_state, ENC_GRU4_OUT_SIZE); + output_index += ENC_GRU4_OUT_SIZE; + conv1_cond_init(enc_state->conv4_state, output_index, 2, &enc_state->initialized); + compute_generic_conv1d_dilation(&model->enc_conv4, &buffer[output_index], enc_state->conv4_state, buffer, output_index, 2, ACTIVATION_TANH); + output_index += ENC_CONV4_OUT_SIZE; + + compute_generic_gru(&model->enc_gru5_input, &model->enc_gru5_recurrent, enc_state->gru5_state, buffer); + OPUS_COPY(&buffer[output_index], enc_state->gru5_state, ENC_GRU5_OUT_SIZE); + output_index += ENC_GRU5_OUT_SIZE; + conv1_cond_init(enc_state->conv5_state, output_index, 2, &enc_state->initialized); + compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH); + output_index += ENC_CONV5_OUT_SIZE; + + compute_generic_dense(&model->enc_zdense, latents, buffer, ACTIVATION_LINEAR); /* next, calculate initial state */ - compute_generic_dense(&model->gdense1, &buffer[output_index], buffer, ACTIVATION_TANH); - input_index = output_index; - compute_generic_dense(&model->gdense2, initial_state, &buffer[input_index], ACTIVATION_TANH); - + compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH); + compute_generic_dense(&model->gdense2, initial_state, state_hidden, ACTIVATION_LINEAR); } diff --git a/dnn/dred_rdovae_enc.h b/dnn/dred_rdovae_enc.h index 70ff6adca..832bd7379 100644 --- a/dnn/dred_rdovae_enc.h +++ b/dnn/dred_rdovae_enc.h @@ -33,10 +33,17 @@ #include "dred_rdovae_enc_data.h" struct RDOVAEEncStruct { - float dense2_state[3 * ENC_DENSE2_STATE_SIZE]; - float dense4_state[3 * ENC_DENSE4_STATE_SIZE]; - float dense6_state[3 * ENC_DENSE6_STATE_SIZE]; - float bits_dense_state[BITS_DENSE_STATE_SIZE]; + int initialized; + float gru1_state[ENC_GRU1_STATE_SIZE]; + float gru2_state[ENC_GRU2_STATE_SIZE]; + float gru3_state[ENC_GRU3_STATE_SIZE]; + float gru4_state[ENC_GRU4_STATE_SIZE]; + float gru5_state[ENC_GRU5_STATE_SIZE]; + float conv1_state[ENC_CONV1_STATE_SIZE]; + float conv2_state[2*ENC_CONV2_STATE_SIZE]; + float conv3_state[2*ENC_CONV3_STATE_SIZE]; + float conv4_state[2*ENC_CONV4_STATE_SIZE]; + float conv5_state[2*ENC_CONV5_STATE_SIZE]; }; void dred_rdovae_encode_dframe(RDOVAEEncState *enc_state, const RDOVAEEnc *model, float *latents, float *initial_state, const float *input); diff --git a/dnn/nnet.c b/dnn/nnet.c index 3661ba77f..d5ef904ec 100644 --- a/dnn/nnet.c +++ b/dnn/nnet.c @@ -366,6 +366,25 @@ void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size); } +void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation) +{ + float tmp[MAX_CONV_INPUTS_ALL]; + int ksize = layer->nb_inputs/input_size; + int i; + celt_assert(input != output); + celt_assert(layer->nb_inputs <= MAX_CONV_INPUTS_ALL); + if (dilation==1) OPUS_COPY(tmp, mem, layer->nb_inputs-input_size); + else for (i=0;i<ksize-1;i++) OPUS_COPY(&tmp[i*input_size], &mem[i*input_size*dilation], input_size); + OPUS_COPY(&tmp[layer->nb_inputs-input_size], input, input_size); + compute_linear(layer, output, tmp); + compute_activation(output, output, layer->nb_outputs, activation); + if (dilation==1) OPUS_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size); + else { + OPUS_COPY(mem, &mem[input_size], input_size*dilation*(ksize-1)-input_size); + OPUS_COPY(&mem[input_size*dilation*(ksize-1)-input_size], input, input_size); + } +} + void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input) { LinearLayer matrix; diff --git a/dnn/nnet.h b/dnn/nnet.h index 16ce82bab..9ed20b028 100644 --- a/dnn/nnet.h +++ b/dnn/nnet.h @@ -145,6 +145,7 @@ void compute_linear(const LinearLayer *linear, float *out, const float *in); void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation); void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in); void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation); +void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation); void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation); void compute_activation(float *output, const float *input, int N, int activation); diff --git a/dnn/torch/rdovae/export_rdovae_weights.py b/dnn/torch/rdovae/export_rdovae_weights.py index f9c1db815..c2cc61bde 100644 --- a/dnn/torch/rdovae/export_rdovae_weights.py +++ b/dnn/torch/rdovae/export_rdovae_weights.py @@ -116,10 +116,7 @@ f""" # encoder encoder_dense_layers = [ ('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'), - ('core_encoder.module.dense_2' , 'enc_dense3', 'TANH'), - ('core_encoder.module.dense_3' , 'enc_dense5', 'TANH'), - ('core_encoder.module.dense_4' , 'enc_dense7', 'TANH'), - ('core_encoder.module.dense_5' , 'enc_dense8', 'TANH'), + ('core_encoder.module.z_dense' , 'enc_zdense', 'LINEAR'), ('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH'), ('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH') ] @@ -130,9 +127,11 @@ f""" encoder_gru_layers = [ - ('core_encoder.module.gru_1' , 'enc_dense2', 'TANH'), - ('core_encoder.module.gru_2' , 'enc_dense4', 'TANH'), - ('core_encoder.module.gru_3' , 'enc_dense6', 'TANH') + ('core_encoder.module.gru1' , 'enc_gru1', 'TANH'), + ('core_encoder.module.gru2' , 'enc_gru2', 'TANH'), + ('core_encoder.module.gru3' , 'enc_gru3', 'TANH'), + ('core_encoder.module.gru4' , 'enc_gru4', 'TANH'), + ('core_encoder.module.gru5' , 'enc_gru5', 'TANH'), ] enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True) @@ -140,7 +139,11 @@ f""" encoder_conv_layers = [ - ('core_encoder.module.conv1' , 'bits_dense' , 'LINEAR') + ('core_encoder.module.conv1.conv' , 'enc_conv1', 'TANH'), + ('core_encoder.module.conv2.conv' , 'enc_conv2', 'TANH'), + ('core_encoder.module.conv3.conv' , 'enc_conv3', 'TANH'), + ('core_encoder.module.conv4.conv' , 'enc_conv4', 'TANH'), + ('core_encoder.module.conv5.conv' , 'enc_conv5', 'TANH'), ] enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in encoder_conv_layers]) @@ -150,15 +153,10 @@ f""" # decoder decoder_dense_layers = [ - ('core_decoder.module.gru_1_init' , 'state1', 'TANH'), - ('core_decoder.module.gru_2_init' , 'state2', 'TANH'), - ('core_decoder.module.gru_3_init' , 'state3', 'TANH'), - ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'), - ('core_decoder.module.dense_2' , 'dec_dense3', 'TANH'), - ('core_decoder.module.dense_3' , 'dec_dense5', 'TANH'), - ('core_decoder.module.dense_4' , 'dec_dense7', 'TANH'), - ('core_decoder.module.dense_5' , 'dec_dense8', 'TANH'), - ('core_decoder.module.output' , 'dec_final', 'LINEAR') + ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH'), + ('core_decoder.module.output' , 'dec_output', 'LINEAR'), + ('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH'), + ('core_decoder.module.gru_init' , 'dec_gru_init', 'TANH'), ] for name, export_name, _ in decoder_dense_layers: @@ -167,14 +165,26 @@ f""" decoder_gru_layers = [ - ('core_decoder.module.gru_1' , 'dec_dense2', 'TANH'), - ('core_decoder.module.gru_2' , 'dec_dense4', 'TANH'), - ('core_decoder.module.gru_3' , 'dec_dense6', 'TANH') + ('core_decoder.module.gru1' , 'dec_gru1', 'TANH'), + ('core_decoder.module.gru2' , 'dec_gru2', 'TANH'), + ('core_decoder.module.gru3' , 'dec_gru3', 'TANH'), + ('core_decoder.module.gru4' , 'dec_gru4', 'TANH'), + ('core_decoder.module.gru5' , 'dec_gru5', 'TANH'), ] dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=True) for name, export_name, _ in decoder_gru_layers]) + decoder_conv_layers = [ + ('core_decoder.module.conv1.conv' , 'dec_conv1', 'TANH'), + ('core_decoder.module.conv2.conv' , 'dec_conv2', 'TANH'), + ('core_decoder.module.conv3.conv' , 'dec_conv3', 'TANH'), + ('core_decoder.module.conv4.conv' , 'dec_conv4', 'TANH'), + ('core_decoder.module.conv5.conv' , 'dec_conv5', 'TANH'), + ] + + dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=False) for name, export_name, _ in decoder_conv_layers]) + del dec_writer # statistical model @@ -196,7 +206,7 @@ f""" #define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)} -#define DRED_MAX_CONV_INPUTS {enc_max_conv_inputs} +#define DRED_MAX_CONV_INPUTS {max(enc_max_conv_inputs, dec_max_conv_inputs)} #define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs} @@ -268,4 +278,4 @@ if __name__ == "__main__": elif args.format == 'numpy': numpy_export(args, model) else: - raise ValueError(f'error: unknown export format {args.format}') \ No newline at end of file + raise ValueError(f'error: unknown export format {args.format}') diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py index 0dc943ec8..b126d4c44 100644 --- a/dnn/torch/rdovae/rdovae/rdovae.py +++ b/dnn/torch/rdovae/rdovae/rdovae.py @@ -224,6 +224,17 @@ def weight_clip_factory(max_value): # RDOVAE module and submodules +class MyConv(nn.Module): + def __init__(self, input_dim, output_dim, dilation=1): + super(MyConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.dilation=dilation + self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation) + def forward(self, x, state=None): + device = x.device + conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1) + return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) class CoreEncoder(nn.Module): STATE_HIDDEN = 128 @@ -248,22 +259,28 @@ class CoreEncoder(nn.Module): # derived parameters self.input_dim = self.FRAMES_PER_STEP * self.feature_dim - self.conv_input_channels = 5 * cond_size + 3 * cond_size2 # layers - self.dense_1 = nn.Linear(self.input_dim, self.cond_size2) - self.gru_1 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) - self.dense_2 = nn.Linear(self.cond_size, self.cond_size2) - self.gru_2 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) - self.dense_3 = nn.Linear(self.cond_size, self.cond_size2) - self.gru_3 = nn.GRU(self.cond_size2, self.cond_size, batch_first=True) - self.dense_4 = nn.Linear(self.cond_size, self.cond_size) - self.dense_5 = nn.Linear(self.cond_size, self.cond_size) - self.conv1 = nn.Conv1d(self.conv_input_channels, self.output_dim, kernel_size=self.CONV_KERNEL_SIZE, padding='valid') - - self.state_dense_1 = nn.Linear(self.conv_input_channels, self.STATE_HIDDEN) + self.dense_1 = nn.Linear(self.input_dim, 64) + self.gru1 = nn.GRU(64, 64, batch_first=True) + self.conv1 = MyConv(128, 96) + self.gru2 = nn.GRU(224, 64, batch_first=True) + self.conv2 = MyConv(288, 96, dilation=2) + self.gru3 = nn.GRU(384, 64, batch_first=True) + self.conv3 = MyConv(448, 96, dilation=2) + self.gru4 = nn.GRU(544, 64, batch_first=True) + self.conv4 = MyConv(608, 96, dilation=2) + self.gru5 = nn.GRU(704, 64, batch_first=True) + self.conv5 = MyConv(768, 96, dilation=2) + + self.z_dense = nn.Linear(864, self.output_dim) + + + self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN) self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size) + nb_params = sum(p.numel() for p in self.parameters()) + print(f"encoder: {nb_params} weights") # initialize weights self.apply(init_weights) @@ -278,25 +295,22 @@ class CoreEncoder(nn.Module): device = x.device # run encoding layer stack - x1 = torch.tanh(self.dense_1(x)) - x2, _ = self.gru_1(x1, torch.zeros((1, batch, self.cond_size)).to(device)) - x3 = torch.tanh(self.dense_2(x2)) - x4, _ = self.gru_2(x3, torch.zeros((1, batch, self.cond_size)).to(device)) - x5 = torch.tanh(self.dense_3(x4)) - x6, _ = self.gru_3(x5, torch.zeros((1, batch, self.cond_size)).to(device)) - x7 = torch.tanh(self.dense_4(x6)) - x8 = torch.tanh(self.dense_5(x7)) - - # concatenation of all hidden layer outputs - x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1) + x = torch.tanh(self.dense_1(x)) + x = torch.cat([x, self.gru1(x)[0]], -1) + x = torch.cat([x, self.conv1(x)], -1) + x = torch.cat([x, self.gru2(x)[0]], -1) + x = torch.cat([x, self.conv2(x)], -1) + x = torch.cat([x, self.gru3(x)[0]], -1) + x = torch.cat([x, self.conv3(x)], -1) + x = torch.cat([x, self.gru4(x)[0]], -1) + x = torch.cat([x, self.conv4(x)], -1) + x = torch.cat([x, self.gru5(x)[0]], -1) + x = torch.cat([x, self.conv5(x)], -1) + z = self.z_dense(x) # init state for decoder - states = torch.tanh(self.state_dense_1(x9)) - states = torch.tanh(self.state_dense_2(states)) - - # latent representation via convolution - x9 = F.pad(x9.permute(0, 2, 1), [self.CONV_KERNEL_SIZE - 1, 0]) - z = self.conv1(x9).permute(0, 2, 1) + states = torch.tanh(self.state_dense_1(x)) + states = self.state_dense_2(states) return z, states @@ -325,47 +339,54 @@ class CoreDecoder(nn.Module): self.input_size = self.input_dim - self.concat_size = 4 * self.cond_size + 4 * self.cond_size2 - # layers - self.dense_1 = nn.Linear(self.input_size, cond_size2) - self.gru_1 = nn.GRU(cond_size2, cond_size, batch_first=True) - self.dense_2 = nn.Linear(cond_size, cond_size2) - self.gru_2 = nn.GRU(cond_size2, cond_size, batch_first=True) - self.dense_3 = nn.Linear(cond_size, cond_size2) - self.gru_3 = nn.GRU(cond_size2, cond_size, batch_first=True) - self.dense_4 = nn.Linear(cond_size, cond_size2) - self.dense_5 = nn.Linear(cond_size2, cond_size2) - - self.output = nn.Linear(self.concat_size, self.FRAMES_PER_STEP * self.output_dim) - - - self.gru_1_init = nn.Linear(self.state_size, self.cond_size) - self.gru_2_init = nn.Linear(self.state_size, self.cond_size) - self.gru_3_init = nn.Linear(self.state_size, self.cond_size) - + self.dense_1 = nn.Linear(self.input_size, 96) + self.gru1 = nn.GRU(96, 96, batch_first=True) + self.conv1 = MyConv(192, 32) + self.gru2 = nn.GRU(224, 96, batch_first=True) + self.conv2 = MyConv(320, 32) + self.gru3 = nn.GRU(352, 96, batch_first=True) + self.conv3 = MyConv(448, 32) + self.gru4 = nn.GRU(480, 96, batch_first=True) + self.conv4 = MyConv(576, 32) + self.gru5 = nn.GRU(608, 96, batch_first=True) + self.conv5 = MyConv(704, 32) + self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim) + + self.hidden_init = nn.Linear(self.state_size, 128) + self.gru_init = nn.Linear(128, 480) + + nb_params = sum(p.numel() for p in self.parameters()) + print(f"decoder: {nb_params} weights") # initialize weights self.apply(init_weights) def forward(self, z, initial_state): - gru_1_state = torch.tanh(self.gru_1_init(initial_state).permute(1, 0, 2)) - gru_2_state = torch.tanh(self.gru_2_init(initial_state).permute(1, 0, 2)) - gru_3_state = torch.tanh(self.gru_3_init(initial_state).permute(1, 0, 2)) + hidden = torch.tanh(self.hidden_init(initial_state)) + gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2)) + h1_state = gru_state[:,:,:96].contiguous() + h2_state = gru_state[:,:,96:192].contiguous() + h3_state = gru_state[:,:,192:288].contiguous() + h4_state = gru_state[:,:,288:384].contiguous() + h5_state = gru_state[:,:,384:].contiguous() # run decoding layer stack - x1 = torch.tanh(self.dense_1(z)) - x2, _ = self.gru_1(x1, gru_1_state) - x3 = torch.tanh(self.dense_2(x2)) - x4, _ = self.gru_2(x3, gru_2_state) - x5 = torch.tanh(self.dense_3(x4)) - x6, _ = self.gru_3(x5, gru_3_state) - x7 = torch.tanh(self.dense_4(x6)) - x8 = torch.tanh(self.dense_5(x7)) - x9 = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), dim=-1) + x = torch.tanh(self.dense_1(z)) + + x = torch.cat([x, self.gru1(x, h1_state)[0]], -1) + x = torch.cat([x, self.conv1(x)], -1) + x = torch.cat([x, self.gru2(x, h2_state)[0]], -1) + x = torch.cat([x, self.conv2(x)], -1) + x = torch.cat([x, self.gru3(x, h3_state)[0]], -1) + x = torch.cat([x, self.conv3(x)], -1) + x = torch.cat([x, self.gru4(x, h4_state)[0]], -1) + x = torch.cat([x, self.conv4(x)], -1) + x = torch.cat([x, self.gru5(x, h5_state)[0]], -1) + x = torch.cat([x, self.conv5(x)], -1) # output layer and reshaping - x10 = self.output(x9) + x10 = self.output(x) features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP)) return features @@ -466,7 +487,7 @@ class RDOVAE(nn.Module): if not type(self.weight_clip_fn) == type(None): self.apply(self.weight_clip_fn) - def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4): + def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 24): enc_stride = self.enc_stride dec_stride = self.dec_stride -- GitLab