From fd1fc693aa0fc7c4f0d7e0d8868541a4bc16ee4a Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@jmvalin.ca> Date: Mon, 1 Apr 2019 15:22:00 -0400 Subject: [PATCH] adaptation flag to avoid training the sample rate network --- dnn/lpcnet.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dnn/lpcnet.py b/dnn/lpcnet.py index f7a982351..effd6398d 100644 --- a/dnn/lpcnet.py +++ b/dnn/lpcnet.py @@ -113,7 +113,7 @@ class PCMInit(Initializer): 'seed': self.seed } -def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True): +def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True, adaptation=False): pcm = Input(shape=(None, 3)) feat = Input(shape=(None, nb_used_features)) pitch = Input(shape=(None, 1)) @@ -153,10 +153,11 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)])) ulaw_prob = md(gru_out2) - rnn.trainable=False - rnn2.trainable=False - md.trainable=False - embed.Trainable=False + if adaptation: + rnn.trainable=False + rnn2.trainable=False + md.trainable=False + embed.Trainable=False model = Model([pcm, feat, pitch], ulaw_prob) model.rnn_units1 = rnn_units1 -- GitLab