From fd1fc693aa0fc7c4f0d7e0d8868541a4bc16ee4a Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@jmvalin.ca>
Date: Mon, 1 Apr 2019 15:22:00 -0400
Subject: [PATCH] adaptation flag to avoid training the sample rate network

---
 dnn/lpcnet.py | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/dnn/lpcnet.py b/dnn/lpcnet.py
index f7a982351..effd6398d 100644
--- a/dnn/lpcnet.py
+++ b/dnn/lpcnet.py
@@ -113,7 +113,7 @@ class PCMInit(Initializer):
             'seed': self.seed
         }
 
-def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True):
+def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True, adaptation=False):
     pcm = Input(shape=(None, 3))
     feat = Input(shape=(None, nb_used_features))
     pitch = Input(shape=(None, 1))
@@ -153,10 +153,11 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train
     gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
     ulaw_prob = md(gru_out2)
     
-    rnn.trainable=False
-    rnn2.trainable=False
-    md.trainable=False
-    embed.Trainable=False
+    if adaptation:
+        rnn.trainable=False
+        rnn2.trainable=False
+        md.trainable=False
+        embed.Trainable=False
     
     model = Model([pcm, feat, pitch], ulaw_prob)
     model.rnn_units1 = rnn_units1
-- 
GitLab