From 0459a572f592fb07376c480c1ebbf04c16090211 Mon Sep 17 00:00:00 2001 From: Jan Buethe <jbuethe@amazon.de> Date: Fri, 29 Sep 2023 15:34:59 +0200 Subject: [PATCH] updated PitchDNN export script --- .../export_neuralpitch_weights.py | 46 ++++++++++++------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/dnn/torch/neural-pitch/export_neuralpitch_weights.py b/dnn/torch/neural-pitch/export_neuralpitch_weights.py index a56784a99..9f20ec9e7 100644 --- a/dnn/torch/neural-pitch/export_neuralpitch_weights.py +++ b/dnn/torch/neural-pitch/export_neuralpitch_weights.py @@ -44,7 +44,7 @@ args = parser.parse_args() import torch import numpy as np -from models import large_if_ccode +from models import PitchDNN from wexchange.torch import dump_torch_weights from wexchange.c_export import CWriter, print_vector @@ -52,39 +52,51 @@ def c_export(args, model): message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" - enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch') - enc_writer.header.write( + writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN') + writer.header.write( f""" #include "opus_types.h" """ ) - - # encoder - encoder_dense_layers = [ - ('initial' , 'initial', 'TANH'), - ('upsample' , 'upsample', 'TANH') + layers = [ + ('if_upsample.0', "dense_if_upsampler_1"), + ('if_upsample.2', "dense_if_upsampler_2"), + ('conv.1', "conv2d_1"), + ('conv.4', "conv2d_2"), + ('conv.7', "conv2d_3"), + ('downsample.0', "dense_downsampler"), + ("upsample.0", "dense_final_upsampler") ] - for name, export_name, _ in encoder_dense_layers: + + for name, export_name in layers: layer = model.get_submodule(name) - dump_torch_weights(enc_writer, layer, name=export_name, verbose=True) + dump_torch_weights(writer, layer, name=export_name, verbose=True) - encoder_gru_layers = [ - ('gru' , 'gru', 'TANH'), + gru_layers = [ + ("GRU", "gru_1"), ] - enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False) - for name, export_name, _ in encoder_gru_layers]) + max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False) + for name, export_name in gru_layers]) + + writer.header.write( +f""" + +#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units} + +""" + ) - del enc_writer + writer.close() if __name__ == "__main__": os.makedirs(args.output_dir, exist_ok=True) - model = large_if_ccode() - checkpoint = torch.load(args.checkpoint ,map_location='cpu') + model = PitchDNN() + checkpoint = torch.load(args.checkpoint, map_location='cpu') model.load_state_dict(checkpoint['state_dict']) c_export(args, model) -- GitLab