diff --git a/dnn/lpcnet.py b/dnn/lpcnet.py index f7a9823518af06d72a55e20c8d2ea5f3d044e5ac..effd6398d8c92919bc7064270c02c17f9d99b60a 100644 --- a/dnn/lpcnet.py +++ b/dnn/lpcnet.py @@ -113,7 +113,7 @@ class PCMInit(Initializer): 'seed': self.seed } -def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True): +def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, use_gpu=True, adaptation=False): pcm = Input(shape=(None, 3)) feat = Input(shape=(None, nb_used_features)) pitch = Input(shape=(None, 1)) @@ -153,10 +153,11 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)])) ulaw_prob = md(gru_out2) - rnn.trainable=False - rnn2.trainable=False - md.trainable=False - embed.Trainable=False + if adaptation: + rnn.trainable=False + rnn2.trainable=False + md.trainable=False + embed.Trainable=False model = Model([pcm, feat, pitch], ulaw_prob) model.rnn_units1 = rnn_units1