diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py index f450ba5a757a2ed7505fb7fffe36caf4ce39dc9a..3ab4599020b77137372ad9c805fc66c974cb9788 100644 --- a/dnn/training_tf2/lpcnet.py +++ b/dnn/training_tf2/lpcnet.py @@ -258,20 +258,18 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features=20, batch_s cfeat = fdense2(fdense1(cfeat)) - Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1)) error_calc = Lambda(lambda x: tf_l2u(x[0] - tf.roll(x[1],1,axis = 1))) if flag_e2e: lpcoeffs = diff_rc2lpc(name = "rc2lpc")(cfeat) else: lpcoeffs = Input(shape=(None, lpc_order), batch_size=batch_size) - tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs]) - past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds]) + tensor_preds = diff_pred(name = "lpc2preds")([pcm,lpcoeffs]) + past_errors = error_calc([pcm,tensor_preds]) embed = diff_Embed(name='embed_sig',initializer = PCMInit()) - cpcm = Concatenate()([tf_l2u(Input_extractor([pcm,0])),tf_l2u(tensor_preds),past_errors]) + cpcm = Concatenate()([tf_l2u(pcm),tf_l2u(tensor_preds),past_errors]) cpcm = GaussianNoise(.3)(cpcm) cpcm = Reshape((-1, embed_size*3))(embed(cpcm)) - cpcm_decoder = Concatenate()([Input_extractor([dpcm,0]),Input_extractor([dpcm,1]),Input_extractor([dpcm,2])]) - cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder)) + cpcm_decoder = Reshape((-1, embed_size*3))(embed(dpcm)) rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))