From c1532559a2fc7ccdf2e442d4257779849f59a52d Mon Sep 17 00:00:00 2001 From: Krishna Subramani <subramani.krishna97@gmail.com> Date: Thu, 29 Jul 2021 03:36:13 -0400 Subject: [PATCH] Adds end-to-end LPC training Making LPC computation and prediction differentiable --- dnn/dump_data.c | 4 ++ dnn/lpcnet.c | 51 +++++++++++++++++++ dnn/lpcnet_private.h | 7 +++ dnn/training_tf2/diffembed.py | 49 ++++++++++++++++++ dnn/training_tf2/difflpc.py | 27 ++++++++++ dnn/training_tf2/dump_lpcnet.py | 11 ++++- dnn/training_tf2/lossfuncs.py | 85 ++++++++++++++++++++++++++++++++ dnn/training_tf2/lpcnet.py | 39 ++++++++++++--- dnn/training_tf2/test_lpcnet.py | 14 ++++-- dnn/training_tf2/tf_funcs.py | 69 ++++++++++++++++++++++++++ dnn/training_tf2/train_lpcnet.py | 18 +++++-- 11 files changed, 357 insertions(+), 17 deletions(-) create mode 100644 dnn/training_tf2/diffembed.py create mode 100644 dnn/training_tf2/difflpc.py create mode 100644 dnn/training_tf2/lossfuncs.py create mode 100644 dnn/training_tf2/tf_funcs.py diff --git a/dnn/dump_data.c b/dnn/dump_data.c index 9bfcbfedc..7a8055a1f 100644 --- a/dnn/dump_data.c +++ b/dnn/dump_data.c @@ -92,7 +92,11 @@ void write_audio(LPCNetEncState *st, const short *pcm, const int *noise, FILE *f /* Excitation in. */ data[4*i+2] = st->exc_mem; /* Excitation out. */ +#ifdef END2END + data[4*i+3] = lin2ulaw(pcm[k*FRAME_SIZE+i]); +#else data[4*i+3] = e; +#endif /* Simulate error on excitation. */ e += noise[k*FRAME_SIZE+i]; e = IMIN(255, IMAX(0, e)); diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c index 0020eb75c..c740a0778 100644 --- a/dnn/lpcnet.c +++ b/dnn/lpcnet.c @@ -131,6 +131,51 @@ LPCNET_EXPORT void lpcnet_destroy(LPCNetState *lpcnet) free(lpcnet); } +#ifdef END2END +void rc2lpc(float *lpc, const float *rc) +{ + float tmp[LPC_ORDER]; + float ntmp[LPC_ORDER] = {0.0}; + RNN_COPY(tmp, rc, LPC_ORDER); + for(int i = 0; i < LPC_ORDER ; i++) + { + for(int j = 0; j <= i-1; j++) + { + ntmp[j] = tmp[j] + tmp[i]*tmp[i - j - 1]; + } + for(int k = 0; k <= i-1; k++) + { + tmp[k] = ntmp[k]; + } + } + for(int i = 0; i < LPC_ORDER ; i++) + { + lpc[i] = tmp[i]; + } +} + +void lpc_from_features(LPCNetState *lpcnet,const float *features) +{ + NNetState *net; + float in[NB_FEATURES]; + float conv1_out[F2RC_CONV1_OUT_SIZE]; + float conv2_out[F2RC_CONV2_OUT_SIZE]; + float dense1_out[F2RC_DENSE3_OUT_SIZE]; + float rc[LPC_ORDER]; + net = &lpcnet->nnet; + RNN_COPY(in, features, NB_FEATURES); + compute_conv1d(&f2rc_conv1, conv1_out, net->f2rc_conv1_state, in); + if (lpcnet->frame_count < F2RC_CONV1_DELAY + 1) RNN_CLEAR(conv1_out, F2RC_CONV1_OUT_SIZE); + compute_conv1d(&f2rc_conv2, conv2_out, net->f2rc_conv2_state, conv1_out); + if (lpcnet->frame_count < (FEATURES_DELAY_F2RC + 1)) RNN_CLEAR(conv2_out, F2RC_CONV2_OUT_SIZE); + memmove(lpcnet->old_input_f2rc[1], lpcnet->old_input_f2rc[0], (FEATURES_DELAY_F2RC-1)*NB_FEATURES*sizeof(in[0])); + memcpy(lpcnet->old_input_f2rc[0], in, NB_FEATURES*sizeof(in[0])); + compute_dense(&f2rc_dense3, dense1_out, conv2_out); + compute_dense(&f2rc_dense4_outp_rc, rc, dense1_out); + rc2lpc(lpcnet->old_lpc[0], rc); +} +#endif + LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, short *output, int N) { int i; @@ -144,9 +189,15 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, memmove(&lpcnet->old_gain[1], &lpcnet->old_gain[0], (FEATURES_DELAY-1)*sizeof(lpcnet->old_gain[0])); lpcnet->old_gain[0] = features[PITCH_GAIN_FEATURE]; run_frame_network(lpcnet, gru_a_condition, gru_b_condition, features, pitch); +#ifdef END2END + lpc_from_features(lpcnet,features); + memcpy(lpc, lpcnet->old_lpc[0], LPC_ORDER*sizeof(lpc[0])); +#else memcpy(lpc, lpcnet->old_lpc[FEATURES_DELAY-1], LPC_ORDER*sizeof(lpc[0])); memmove(lpcnet->old_lpc[1], lpcnet->old_lpc[0], (FEATURES_DELAY-1)*LPC_ORDER*sizeof(lpc[0])); lpc_from_cepstrum(lpcnet->old_lpc[0], features); +#endif + if (lpcnet->frame_count <= FEATURES_DELAY) { RNN_CLEAR(output, N); diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h index fedcd58ed..9ec8e50f4 100644 --- a/dnn/lpcnet_private.h +++ b/dnn/lpcnet_private.h @@ -22,11 +22,18 @@ #define FEATURES_DELAY (FEATURE_CONV1_DELAY + FEATURE_CONV2_DELAY) +#ifdef END2END + #define FEATURES_DELAY_F2RC (F2RC_CONV1_DELAY + F2RC_CONV2_DELAY) +#endif + struct LPCNetState { NNetState nnet; int last_exc; float last_sig[LPC_ORDER]; float old_input[FEATURES_DELAY][FEATURE_CONV2_OUT_SIZE]; +#ifdef END2END + float old_input_f2rc[FEATURES_DELAY_F2RC][F2RC_CONV2_OUT_SIZE]; +#endif float old_lpc[FEATURES_DELAY][LPC_ORDER]; float old_gain[FEATURES_DELAY]; float sampling_logit_table[256]; diff --git a/dnn/training_tf2/diffembed.py b/dnn/training_tf2/diffembed.py new file mode 100644 index 000000000..64f098e21 --- /dev/null +++ b/dnn/training_tf2/diffembed.py @@ -0,0 +1,49 @@ +""" +Modification of Tensorflow's Embedding Layer: + 1. Not restricted to be the first layer of a model + 2. Differentiable (allows non-integer lookups) + - For non integer lookup, this layer linearly interpolates between the adjacent embeddings in the following way to preserver gradient flow + - E = (1 - frac(x))*embed(floor(x)) + frac(x)*embed(ceil(x)) +""" + +import tensorflow as tf +from tensorflow.keras.layers import Layer + +class diff_Embed(Layer): + """ + Parameters: + - units: int + Dimension of the Embedding + - dict_size: int + Number of Embeddings to lookup + - pcm_init: boolean + Initialized for the embedding matrix + """ + def __init__(self, units=128, dict_size = 256, pcm_init = True, initializer = None, **kwargs): + super(diff_Embed, self).__init__(**kwargs) + self.units = units + self.dict_size = dict_size + self.pcm_init = pcm_init + self.initializer = initializer + + def build(self, input_shape): + w_init = tf.random_normal_initializer() + if self.pcm_init: + w_init = self.initializer + self.w = tf.Variable(initial_value=w_init(shape=(self.dict_size, self.units),dtype='float32'),trainable=True) + + def call(self, inputs): + alpha = inputs - tf.math.floor(inputs) + alpha = tf.expand_dims(alpha,axis = -1) + alpha = tf.tile(alpha,[1,1,1,self.units]) + inputs = tf.cast(inputs,'int32') + M = (1 - alpha)*tf.gather(self.w,inputs) + alpha*tf.gather(self.w,tf.clip_by_value(inputs + 1, 0, 255)) + return M + + def get_config(self): + config = super(diff_Embed, self).get_config() + config.update({"units": self.units}) + config.update({"dict_size" : self.dict_size}) + config.update({"pcm_init" : self.pcm_init}) + config.update({"initializer" : self.initializer}) + return config \ No newline at end of file diff --git a/dnn/training_tf2/difflpc.py b/dnn/training_tf2/difflpc.py new file mode 100644 index 000000000..efa5e21c0 --- /dev/null +++ b/dnn/training_tf2/difflpc.py @@ -0,0 +1,27 @@ +""" +Tensorflow model (differentiable lpc) to learn the lpcs from the features +""" + +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Input, Dense, Concatenate, Lambda, Conv1D, Multiply, Layer, LeakyReLU +from tensorflow.keras import backend as K +from tf_funcs import diff_rc2lpc + +frame_size = 160 +lpcoeffs_N = 16 + +def difflpc(nb_used_features = 20, training=False): + feat = Input(shape=(None, nb_used_features)) # BFCC + padding = 'valid' if training else 'same' + L1 = Conv1D(100, 3, padding=padding, activation='tanh', name='f2rc_conv1') + L2 = Conv1D(75, 3, padding=padding, activation='tanh', name='f2rc_conv2') + L3 = Dense(50, activation='tanh',name = 'f2rc_dense3') + L4 = Dense(lpcoeffs_N, activation='tanh',name = "f2rc_dense4_outp_rc") + rc = L4(L3(L2(L1(feat)))) + # Differentiable RC 2 LPC + lpcoeffs = diff_rc2lpc(name = "rc2lpc")(rc) + + model = Model(feat,lpcoeffs,name = 'f2lpc') + model.nb_used_features = nb_used_features + model.frame_size = frame_size + return model diff --git a/dnn/training_tf2/dump_lpcnet.py b/dnn/training_tf2/dump_lpcnet.py index 106426f39..102b021d8 100755 --- a/dnn/training_tf2/dump_lpcnet.py +++ b/dnn/training_tf2/dump_lpcnet.py @@ -35,6 +35,9 @@ from mdense import MDense import h5py import re +# Flag for dumping e2e (differentiable lpc) network weights +flag_e2e = False + max_rnn_neurons = 1 max_conv_inputs = 1 max_mdense_tmp = 1 @@ -237,7 +240,7 @@ with h5py.File(filename, "r") as f: units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape) units2 = min(f['model_weights']['gru_b']['gru_b']['recurrent_kernel:0'].shape) -model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2) +model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2, flag_e2e = flag_e2e) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) #model.summary() @@ -288,6 +291,12 @@ for i, layer in enumerate(model.layers): if layer.dump_layer(f, hf): layer_list.append(layer.name) +if flag_e2e: + print("-- Weight Dumping for the Differentiable LPC Block --") + for i, layer in enumerate(model.get_layer("f2lpc").layers): + if layer.dump_layer(f, hf): + layer_list.append(layer.name) + dump_sparse_gru(model.get_layer('gru_a'), f, hf) hf.write('#define MAX_RNN_NEURONS {}\n\n'.format(max_rnn_neurons)) diff --git a/dnn/training_tf2/lossfuncs.py b/dnn/training_tf2/lossfuncs.py new file mode 100644 index 000000000..8a627eadb --- /dev/null +++ b/dnn/training_tf2/lossfuncs.py @@ -0,0 +1,85 @@ +""" +Custom Loss functions and metrics for training/analysis +""" + +from tf_funcs import * +import tensorflow as tf + +# The following loss functions all expect the lpcnet model to output the lpc prediction + +# Computing the excitation by subtracting the lpc prediction from the target, followed by minimizing the cross entropy +def res_from_sigloss(): + def loss(y_true,y_pred): + p = y_pred[:,:,0:1] + model_out = y_pred[:,:,1:] + e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p)) + e_gt = tf.round(e_gt) + e_gt = tf.cast(e_gt,'int32') + sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out) + return sparse_cel + return loss + +# Interpolated and Compensated Loss (In case of end to end lpcnet) +# Interpolates between adjacent embeddings based on the fractional value of the excitation computed (similar to the embedding interpolation) +# Also adds a probability compensation (to account for matching cross entropy in the linear domain), weighted by gamma +def interp_mulaw(gamma = 1): + def loss(y_true,y_pred): + p = y_pred[:,:,0:1] + model_out = y_pred[:,:,1:] + e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p)) + prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0)) + alpha = e_gt - tf.math.floor(e_gt) + alpha = tf.tile(alpha,[1,1,256]) + e_gt = tf.cast(e_gt,'int32') + e_gt = tf.clip_by_value(e_gt,0,254) + interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) + sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) + loss_mod = sparse_cel + gamma*prob_compensation + return loss_mod + return loss + +# Same as above, except a metric +def metric_oginterploss(y_true,y_pred): + p = y_pred[:,:,0:1] + model_out = y_pred[:,:,1:] + e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p)) + prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0)) + alpha = e_gt - tf.math.floor(e_gt) + alpha = tf.tile(alpha,[1,1,256]) + e_gt = tf.cast(e_gt,'int32') + e_gt = tf.clip_by_value(e_gt,0,254) + interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) + sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) + loss_mod = sparse_cel + prob_compensation + return loss_mod + +# Interpolated cross entropy loss metric +def metric_icel(y_true, y_pred): + p = y_pred[:,:,0:1] + model_out = y_pred[:,:,1:] + e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p)) + alpha = e_gt - tf.math.floor(e_gt) + alpha = tf.tile(alpha,[1,1,256]) + e_gt = tf.cast(e_gt,'int32') + e_gt = tf.clip_by_value(e_gt,0,254) #Check direction + interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) + sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) + return sparse_cel + +# Non-interpolated (rounded) cross entropy loss metric +def metric_cel(y_true, y_pred): + p = y_pred[:,:,0:1] + model_out = y_pred[:,:,1:] + e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p)) + e_gt = tf.round(e_gt) + e_gt = tf.cast(e_gt,'int32') + e_gt = tf.clip_by_value(e_gt,0,255) + sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out) + return sparse_cel + +# Variance metric of the output excitation +def metric_exc_sd(y_true,y_pred): + p = y_pred[:,:,0:1] + e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p)) + sd_egt = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(e_gt,128) + return sd_egt diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py index 11d5f329e..e08b809ad 100644 --- a/dnn/training_tf2/lpcnet.py +++ b/dnn/training_tf2/lpcnet.py @@ -38,6 +38,9 @@ from mdense import MDense import numpy as np import h5py import sys +from tf_funcs import * +from diffembed import diff_Embed +import difflpc frame_size = 160 pcm_bits = 8 @@ -186,7 +189,7 @@ class PCMInit(Initializer): #a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows #a[:,1] = .5*a[:,0]*a[:,0]*a[:,0] a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1)) - return self.gain * a + return self.gain * a.astype("float32") def get_config(self): return { @@ -212,7 +215,7 @@ class WeightClip(Constraint): constraint = WeightClip(0.992) -def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False): +def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False): pcm = Input(shape=(None, 3)) feat = Input(shape=(None, nb_used_features)) pitch = Input(shape=(None, 1)) @@ -224,8 +227,21 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train fconv1 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv1') fconv2 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv2') - embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig') - cpcm = Reshape((-1, embed_size*3))(embed(pcm)) + if not flag_e2e: + embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig') + cpcm = Reshape((-1, embed_size*3))(embed(pcm)) + else: + Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1)) + error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1))) + feat2lpc = difflpc.difflpc(training = training) + lpcoeffs = feat2lpc(feat) + tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs]) + past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds]) + embed = diff_Embed(name='embed_sig',initializer = PCMInit()) + cpcm = Concatenate()([Input_extractor([pcm,0]),tensor_preds,past_errors]) + cpcm = Reshape((-1, embed_size*3))(embed(cpcm)) + cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])]) + cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder)) pembed = Embedding(256, 64, name='embed_pitch') cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))]) @@ -264,15 +280,22 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train md.trainable=False embed.Trainable=False - model = Model([pcm, feat, pitch], ulaw_prob) + if not flag_e2e: + model = Model([pcm, feat, pitch], ulaw_prob) + else: + m_out = Concatenate()([tensor_preds,ulaw_prob]) + model = Model([pcm, feat, pitch], m_out) model.rnn_units1 = rnn_units1 model.rnn_units2 = rnn_units2 model.nb_used_features = nb_used_features model.frame_size = frame_size - - encoder = Model([feat, pitch], cfeat) - dec_rnn_in = Concatenate()([cpcm, dec_feat]) + if not flag_e2e: + encoder = Model([feat, pitch], cfeat) + dec_rnn_in = Concatenate()([cpcm, dec_feat]) + else: + encoder = Model([feat, pitch], [cfeat,lpcoeffs]) + dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat]) dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1) dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2) dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2)) diff --git a/dnn/training_tf2/test_lpcnet.py b/dnn/training_tf2/test_lpcnet.py index 9a48d5667..88439cf17 100755 --- a/dnn/training_tf2/test_lpcnet.py +++ b/dnn/training_tf2/test_lpcnet.py @@ -31,8 +31,10 @@ import numpy as np from ulaw import ulaw2lin, lin2ulaw import h5py +# Flag for synthesizing e2e (differentiable lpc) model +flag_e2e = False -model, enc, dec = lpcnet.new_lpcnet_model() +model, enc, dec = lpcnet.new_lpcnet_model(training = False, flag_e2e = flag_e2e) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) #model.summary() @@ -70,10 +72,16 @@ fout = open(out_file, 'wb') skip = order + 1 for c in range(0, nb_frames): - cfeat = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]]) + if not flag_e2e: + cfeat = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]]) + else: + cfeat,lpcs = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]]) for fr in range(0, feature_chunk_size): f = c*feature_chunk_size + fr - a = features[c, fr, nb_features-order:] + if not flag_e2e: + a = features[c, fr, nb_features-order:] + else: + a = lpcs[c,fr] for i in range(skip, frame_size): pred = -sum(a*pcm[f*frame_size + i - 1:f*frame_size + i - order-1:-1]) fexc[0, 0, 1] = lin2ulaw(pred) diff --git a/dnn/training_tf2/tf_funcs.py b/dnn/training_tf2/tf_funcs.py new file mode 100644 index 000000000..cf593184b --- /dev/null +++ b/dnn/training_tf2/tf_funcs.py @@ -0,0 +1,69 @@ +""" +Tensorflow/Keras helper functions to do the following: + 1. \mu law <-> Linear domain conversion + 2. Differentiable prediction from the input signal and LP coefficients + 3. Differentiable transformations Reflection Coefficients (RCs) <-> LP Coefficients +""" +from tensorflow.keras.layers import Lambda, Multiply, Layer, Concatenate +from tensorflow.keras import backend as K +import tensorflow as tf + +# \mu law <-> Linear conversion functions +scale = 255.0/32768.0 +scale_1 = 32768.0/255.0 +def tf_l2u(x): + s = K.sign(x) + x = K.abs(x) + u = (s*(128*K.log(1+scale*x)/K.log(256.0))) + u = K.clip(128 + u, 0, 255) + return u + +def tf_u2l(u): + u = tf.cast(u,"float32") + u = u - 128.0 + s = K.sign(u) + u = K.abs(u) + return s*scale_1*(K.exp(u/128.*K.log(256.0))-1) + +# Differentiable Prediction Layer +# Computes the LP prediction from the input lag signal and the LP coefficients +# The inputs xt and lpc conform with the shapes in lpcnet.py (the '2400' is coded keeping this in mind) +class diff_pred(Layer): + def call(self, inputs, lpcoeffs_N = 16, frame_size = 160): + xt = tf_u2l(inputs[0]) + lpc = inputs[1] + + rept = Lambda(lambda x: K.repeat_elements(x , frame_size, 1)) + zpX = Lambda(lambda x: K.concatenate([0*x[:,0:lpcoeffs_N,:], x],axis = 1)) + cX = Lambda(lambda x: K.concatenate([x[:,(lpcoeffs_N - i):(lpcoeffs_N - i + 2400),:] for i in range(lpcoeffs_N)],axis = 2)) + + pred = -Multiply()([rept(lpc),cX(zpX(xt))]) + + return tf_l2u(K.sum(pred,axis = 2,keepdims = True)) + +# Differentiable Transformations (RC <-> LPC) computed using the Levinson Durbin Recursion +class diff_rc2lpc(Layer): + def call(self, inputs, lpcoeffs_N = 16): + def pred_lpc_recursive(input): + temp = (input[0] + K.repeat_elements(input[1],input[0].shape[2],2)*K.reverse(input[0],axes = 2)) + temp = Concatenate(axis = 2)([temp,input[1]]) + return temp + Llpc = Lambda(pred_lpc_recursive) + lpc_init = inputs + for i in range(1,lpcoeffs_N): + lpc_init = Llpc([lpc_init[:,:,:i],K.expand_dims(inputs[:,:,i],axis = -1)]) + return lpc_init + +class diff_lpc2rc(Layer): + def call(self, inputs, lpcoeffs_N = 16): + def pred_rc_recursive(input): + ki = K.repeat_elements(K.expand_dims(input[1][:,:,0],axis = -1),input[0].shape[2],2) + temp = (input[0] - ki*K.reverse(input[0],axes = 2))/(1 - ki*ki) + temp = Concatenate(axis = 2)([temp,input[1]]) + return temp + Lrc = Lambda(pred_rc_recursive) + rc_init = inputs + for i in range(1,lpcoeffs_N): + j = (lpcoeffs_N - i + 1) + rc_init = Lrc([rc_init[:,:,:(j - 1)],rc_init[:,:,(j - 1):]]) + return rc_init \ No newline at end of file diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py index 59214e5dc..aeaf98d9a 100755 --- a/dnn/training_tf2/train_lpcnet.py +++ b/dnn/training_tf2/train_lpcnet.py @@ -44,7 +44,7 @@ parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, hel parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)') parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)') parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)') - +parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation') args = parser.parse_args() @@ -66,12 +66,14 @@ lpcnet = importlib.import_module(args.model) import sys import numpy as np from tensorflow.keras.optimizers import Adam -from tensorflow.keras.callbacks import ModelCheckpoint +from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger from ulaw import ulaw2lin, lin2ulaw import tensorflow.keras.backend as K import h5py import tensorflow as tf +from tf_funcs import * +from lossfuncs import * #gpus = tf.config.experimental.list_physical_devices('GPU') #if gpus: # try: @@ -93,12 +95,17 @@ else: lr = 0.001 decay = 2.5e-5 +flag_e2e = args.flag_e2e + opt = Adam(lr, decay=decay, beta_2=0.99) strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() with strategy.scope(): - model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize) - model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy') + model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize, flag_e2e = flag_e2e) + if not flag_e2e: + model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy') + else: + model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss]) model.summary() feature_file = args.features @@ -150,4 +157,5 @@ else: grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density) model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size)) -model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify]) +csv_logger = CSVLogger('training_vals.log') +model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger]) -- GitLab