From 0459a572f592fb07376c480c1ebbf04c16090211 Mon Sep 17 00:00:00 2001
From: Jan Buethe <jbuethe@amazon.de>
Date: Fri, 29 Sep 2023 15:34:59 +0200
Subject: [PATCH] updated PitchDNN export script

---
 .../export_neuralpitch_weights.py             | 46 ++++++++++++-------
 1 file changed, 29 insertions(+), 17 deletions(-)

diff --git a/dnn/torch/neural-pitch/export_neuralpitch_weights.py b/dnn/torch/neural-pitch/export_neuralpitch_weights.py
index a56784a99..9f20ec9e7 100644
--- a/dnn/torch/neural-pitch/export_neuralpitch_weights.py
+++ b/dnn/torch/neural-pitch/export_neuralpitch_weights.py
@@ -44,7 +44,7 @@ args = parser.parse_args()
 import torch
 import numpy as np
 
-from models import large_if_ccode
+from models import PitchDNN
 from wexchange.torch import dump_torch_weights
 from wexchange.c_export import CWriter, print_vector
 
@@ -52,39 +52,51 @@ def c_export(args, model):
 
     message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
 
-    enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
-    enc_writer.header.write(
+    writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
+    writer.header.write(
 f"""
 #include "opus_types.h"
 """
         )
 
-
-    # encoder
-    encoder_dense_layers = [
-        ('initial'       , 'initial',   'TANH'),
-        ('upsample'       , 'upsample',   'TANH')
+    layers = [
+        ('if_upsample.0', "dense_if_upsampler_1"),
+        ('if_upsample.2', "dense_if_upsampler_2"),
+        ('conv.1', "conv2d_1"),
+        ('conv.4', "conv2d_2"),
+        ('conv.7', "conv2d_3"),
+        ('downsample.0', "dense_downsampler"),
+        ("upsample.0", "dense_final_upsampler")
     ]
 
-    for name, export_name, _ in encoder_dense_layers:
+
+    for name, export_name in layers:
         layer = model.get_submodule(name)
-        dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
+        dump_torch_weights(writer, layer, name=export_name, verbose=True)
 
 
-    encoder_gru_layers = [
-        ('gru'         , 'gru',   'TANH'),
+    gru_layers = [
+        ("GRU", "gru_1"),
     ]
 
-    enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
-                             for name, export_name, _ in encoder_gru_layers])
+    max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
+                             for name, export_name in gru_layers])
+
+    writer.header.write(
+f"""
+
+#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
+
+"""
+        )
 
-    del enc_writer
+    writer.close()
 
 
 if __name__ == "__main__":
 
     os.makedirs(args.output_dir, exist_ok=True)
-    model = large_if_ccode()
-    checkpoint = torch.load(args.checkpoint ,map_location='cpu')
+    model = PitchDNN()
+    checkpoint = torch.load(args.checkpoint, map_location='cpu')
     model.load_state_dict(checkpoint['state_dict'])
     c_export(args, model)
-- 
GitLab