Skip to content
Snippets Groups Projects
Unverified Commit 0459a572 authored by Jan Buethe's avatar Jan Buethe
Browse files

updated PitchDNN export script

parent ce286958
No related branches found
No related tags found
No related merge requests found
Pipeline #4099 passed
......@@ -44,7 +44,7 @@ args = parser.parse_args()
import torch
import numpy as np
from models import large_if_ccode
from models import PitchDNN
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
......@@ -52,39 +52,51 @@ def c_export(args, model):
message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
enc_writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='nnpitch')
enc_writer.header.write(
writer = CWriter(os.path.join(args.output_dir, "neural_pitch_data"), message=message, model_struct_name='PitchDNN')
writer.header.write(
f"""
#include "opus_types.h"
"""
)
# encoder
encoder_dense_layers = [
('initial' , 'initial', 'TANH'),
('upsample' , 'upsample', 'TANH')
layers = [
('if_upsample.0', "dense_if_upsampler_1"),
('if_upsample.2', "dense_if_upsampler_2"),
('conv.1', "conv2d_1"),
('conv.4', "conv2d_2"),
('conv.7', "conv2d_3"),
('downsample.0', "dense_downsampler"),
("upsample.0", "dense_final_upsampler")
]
for name, export_name, _ in encoder_dense_layers:
for name, export_name in layers:
layer = model.get_submodule(name)
dump_torch_weights(enc_writer, layer, name=export_name, verbose=True)
dump_torch_weights(writer, layer, name=export_name, verbose=True)
encoder_gru_layers = [
('gru' , 'gru', 'TANH'),
gru_layers = [
("GRU", "gru_1"),
]
enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
for name, export_name, _ in encoder_gru_layers])
max_rnn_units = max([dump_torch_weights(writer, model.get_submodule(name), export_name, verbose=True, input_sparse=False, quantize=False)
for name, export_name in gru_layers])
writer.header.write(
f"""
#define PITCH_DNN_MAX_RNN_UNITS {max_rnn_units}
"""
)
del enc_writer
writer.close()
if __name__ == "__main__":
os.makedirs(args.output_dir, exist_ok=True)
model = large_if_ccode()
checkpoint = torch.load(args.checkpoint ,map_location='cpu')
model = PitchDNN()
checkpoint = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
c_export(args, model)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment