From c1532559a2fc7ccdf2e442d4257779849f59a52d Mon Sep 17 00:00:00 2001
From: Krishna Subramani <subramani.krishna97@gmail.com>
Date: Thu, 29 Jul 2021 03:36:13 -0400
Subject: [PATCH] Adds end-to-end LPC training

Making LPC computation and prediction differentiable
---
 dnn/dump_data.c                  |  4 ++
 dnn/lpcnet.c                     | 51 +++++++++++++++++++
 dnn/lpcnet_private.h             |  7 +++
 dnn/training_tf2/diffembed.py    | 49 ++++++++++++++++++
 dnn/training_tf2/difflpc.py      | 27 ++++++++++
 dnn/training_tf2/dump_lpcnet.py  | 11 ++++-
 dnn/training_tf2/lossfuncs.py    | 85 ++++++++++++++++++++++++++++++++
 dnn/training_tf2/lpcnet.py       | 39 ++++++++++++---
 dnn/training_tf2/test_lpcnet.py  | 14 ++++--
 dnn/training_tf2/tf_funcs.py     | 69 ++++++++++++++++++++++++++
 dnn/training_tf2/train_lpcnet.py | 18 +++++--
 11 files changed, 357 insertions(+), 17 deletions(-)
 create mode 100644 dnn/training_tf2/diffembed.py
 create mode 100644 dnn/training_tf2/difflpc.py
 create mode 100644 dnn/training_tf2/lossfuncs.py
 create mode 100644 dnn/training_tf2/tf_funcs.py

diff --git a/dnn/dump_data.c b/dnn/dump_data.c
index 9bfcbfedc..7a8055a1f 100644
--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -92,7 +92,11 @@ void write_audio(LPCNetEncState *st, const short *pcm, const int *noise, FILE *f
     /* Excitation in. */
     data[4*i+2] = st->exc_mem;
     /* Excitation out. */
+#ifdef END2END
+    data[4*i+3] = lin2ulaw(pcm[k*FRAME_SIZE+i]);
+#else
     data[4*i+3] = e;
+#endif
     /* Simulate error on excitation. */
     e += noise[k*FRAME_SIZE+i];
     e = IMIN(255, IMAX(0, e));
diff --git a/dnn/lpcnet.c b/dnn/lpcnet.c
index 0020eb75c..c740a0778 100644
--- a/dnn/lpcnet.c
+++ b/dnn/lpcnet.c
@@ -131,6 +131,51 @@ LPCNET_EXPORT void lpcnet_destroy(LPCNetState *lpcnet)
     free(lpcnet);
 }
 
+#ifdef END2END
+void rc2lpc(float *lpc, const float *rc)
+{
+  float tmp[LPC_ORDER];
+  float ntmp[LPC_ORDER] = {0.0};
+  RNN_COPY(tmp, rc, LPC_ORDER);
+  for(int i = 0; i < LPC_ORDER ; i++)
+    { 
+        for(int j = 0; j <= i-1; j++)
+        {
+            ntmp[j] = tmp[j] + tmp[i]*tmp[i - j - 1];
+        }
+        for(int k = 0; k <= i-1; k++)
+        {
+            tmp[k] = ntmp[k];
+        }
+    }
+  for(int i = 0; i < LPC_ORDER ; i++)
+  {
+    lpc[i] = tmp[i];
+  }
+}
+
+void lpc_from_features(LPCNetState *lpcnet,const float *features)
+{
+  NNetState *net;
+  float in[NB_FEATURES];
+  float conv1_out[F2RC_CONV1_OUT_SIZE];
+  float conv2_out[F2RC_CONV2_OUT_SIZE];
+  float dense1_out[F2RC_DENSE3_OUT_SIZE];
+  float rc[LPC_ORDER];
+  net = &lpcnet->nnet;
+  RNN_COPY(in, features, NB_FEATURES);
+  compute_conv1d(&f2rc_conv1, conv1_out, net->f2rc_conv1_state, in);
+  if (lpcnet->frame_count < F2RC_CONV1_DELAY + 1) RNN_CLEAR(conv1_out, F2RC_CONV1_OUT_SIZE);
+  compute_conv1d(&f2rc_conv2, conv2_out, net->f2rc_conv2_state, conv1_out);
+  if (lpcnet->frame_count < (FEATURES_DELAY_F2RC + 1)) RNN_CLEAR(conv2_out, F2RC_CONV2_OUT_SIZE);
+  memmove(lpcnet->old_input_f2rc[1], lpcnet->old_input_f2rc[0], (FEATURES_DELAY_F2RC-1)*NB_FEATURES*sizeof(in[0]));
+  memcpy(lpcnet->old_input_f2rc[0], in, NB_FEATURES*sizeof(in[0]));
+  compute_dense(&f2rc_dense3, dense1_out, conv2_out);
+  compute_dense(&f2rc_dense4_outp_rc, rc, dense1_out);
+  rc2lpc(lpcnet->old_lpc[0], rc);
+}
+#endif
+
 LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features, short *output, int N)
 {
     int i;
@@ -144,9 +189,15 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *lpcnet, const float *features,
     memmove(&lpcnet->old_gain[1], &lpcnet->old_gain[0], (FEATURES_DELAY-1)*sizeof(lpcnet->old_gain[0]));
     lpcnet->old_gain[0] = features[PITCH_GAIN_FEATURE];
     run_frame_network(lpcnet, gru_a_condition, gru_b_condition, features, pitch);
+#ifdef END2END
+    lpc_from_features(lpcnet,features);
+    memcpy(lpc, lpcnet->old_lpc[0], LPC_ORDER*sizeof(lpc[0]));
+#else
     memcpy(lpc, lpcnet->old_lpc[FEATURES_DELAY-1], LPC_ORDER*sizeof(lpc[0]));
     memmove(lpcnet->old_lpc[1], lpcnet->old_lpc[0], (FEATURES_DELAY-1)*LPC_ORDER*sizeof(lpc[0]));
     lpc_from_cepstrum(lpcnet->old_lpc[0], features);
+#endif
+
     if (lpcnet->frame_count <= FEATURES_DELAY)
     {
         RNN_CLEAR(output, N);
diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h
index fedcd58ed..9ec8e50f4 100644
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -22,11 +22,18 @@
 
 #define FEATURES_DELAY (FEATURE_CONV1_DELAY + FEATURE_CONV2_DELAY)
 
+#ifdef END2END
+  #define FEATURES_DELAY_F2RC (F2RC_CONV1_DELAY + F2RC_CONV2_DELAY)
+#endif
+
 struct LPCNetState {
     NNetState nnet;
     int last_exc;
     float last_sig[LPC_ORDER];
     float old_input[FEATURES_DELAY][FEATURE_CONV2_OUT_SIZE];
+#ifdef END2END
+    float old_input_f2rc[FEATURES_DELAY_F2RC][F2RC_CONV2_OUT_SIZE];
+#endif
     float old_lpc[FEATURES_DELAY][LPC_ORDER];
     float old_gain[FEATURES_DELAY];
     float sampling_logit_table[256];
diff --git a/dnn/training_tf2/diffembed.py b/dnn/training_tf2/diffembed.py
new file mode 100644
index 000000000..64f098e21
--- /dev/null
+++ b/dnn/training_tf2/diffembed.py
@@ -0,0 +1,49 @@
+"""
+Modification of Tensorflow's Embedding Layer:
+    1. Not restricted to be the first layer of a model
+    2. Differentiable (allows non-integer lookups)
+        - For non integer lookup, this layer linearly interpolates between the adjacent embeddings in the following way to preserver gradient flow
+            - E = (1 - frac(x))*embed(floor(x)) + frac(x)*embed(ceil(x)) 
+"""
+
+import tensorflow as tf
+from tensorflow.keras.layers import Layer
+
+class diff_Embed(Layer):
+    """
+    Parameters:
+        - units: int
+            Dimension of the Embedding
+        - dict_size: int
+            Number of Embeddings to lookup
+        - pcm_init: boolean
+            Initialized for the embedding matrix
+    """
+    def __init__(self, units=128, dict_size = 256, pcm_init = True, initializer = None, **kwargs):
+        super(diff_Embed, self).__init__(**kwargs)
+        self.units = units
+        self.dict_size = dict_size
+        self.pcm_init = pcm_init
+        self.initializer = initializer
+
+    def build(self, input_shape):  
+        w_init = tf.random_normal_initializer()
+        if self.pcm_init:  
+            w_init = self.initializer
+        self.w = tf.Variable(initial_value=w_init(shape=(self.dict_size, self.units),dtype='float32'),trainable=True)
+
+    def call(self, inputs):  
+        alpha = inputs - tf.math.floor(inputs)
+        alpha = tf.expand_dims(alpha,axis = -1)
+        alpha = tf.tile(alpha,[1,1,1,self.units])
+        inputs = tf.cast(inputs,'int32')
+        M = (1 - alpha)*tf.gather(self.w,inputs) + alpha*tf.gather(self.w,tf.clip_by_value(inputs + 1, 0, 255))
+        return M
+
+    def get_config(self):
+        config = super(diff_Embed, self).get_config()
+        config.update({"units": self.units})
+        config.update({"dict_size" : self.dict_size})
+        config.update({"pcm_init" : self.pcm_init})
+        config.update({"initializer" : self.initializer})
+        return config
\ No newline at end of file
diff --git a/dnn/training_tf2/difflpc.py b/dnn/training_tf2/difflpc.py
new file mode 100644
index 000000000..efa5e21c0
--- /dev/null
+++ b/dnn/training_tf2/difflpc.py
@@ -0,0 +1,27 @@
+"""
+Tensorflow model (differentiable lpc) to learn the lpcs from the features
+"""
+
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Input, Dense, Concatenate, Lambda, Conv1D, Multiply, Layer, LeakyReLU
+from tensorflow.keras import backend as K
+from tf_funcs import diff_rc2lpc
+
+frame_size = 160
+lpcoeffs_N = 16
+
+def difflpc(nb_used_features = 20, training=False):
+    feat = Input(shape=(None, nb_used_features)) # BFCC
+    padding = 'valid' if training else 'same'
+    L1 = Conv1D(100, 3, padding=padding, activation='tanh', name='f2rc_conv1')
+    L2 = Conv1D(75, 3, padding=padding, activation='tanh', name='f2rc_conv2')
+    L3 = Dense(50, activation='tanh',name = 'f2rc_dense3')
+    L4 = Dense(lpcoeffs_N, activation='tanh',name = "f2rc_dense4_outp_rc")
+    rc = L4(L3(L2(L1(feat))))
+    # Differentiable RC 2 LPC
+    lpcoeffs = diff_rc2lpc(name = "rc2lpc")(rc)
+
+    model = Model(feat,lpcoeffs,name = 'f2lpc')
+    model.nb_used_features = nb_used_features
+    model.frame_size = frame_size
+    return model
diff --git a/dnn/training_tf2/dump_lpcnet.py b/dnn/training_tf2/dump_lpcnet.py
index 106426f39..102b021d8 100755
--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -35,6 +35,9 @@ from mdense import MDense
 import h5py
 import re
 
+# Flag for dumping e2e (differentiable lpc) network weights
+flag_e2e = False
+
 max_rnn_neurons = 1
 max_conv_inputs = 1
 max_mdense_tmp = 1
@@ -237,7 +240,7 @@ with h5py.File(filename, "r") as f:
     units = min(f['model_weights']['gru_a']['gru_a']['recurrent_kernel:0'].shape)
     units2 = min(f['model_weights']['gru_b']['gru_b']['recurrent_kernel:0'].shape)
 
-model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2)
+model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=units, rnn_units2=units2, flag_e2e = flag_e2e)
 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
 #model.summary()
 
@@ -288,6 +291,12 @@ for i, layer in enumerate(model.layers):
     if layer.dump_layer(f, hf):
         layer_list.append(layer.name)
 
+if flag_e2e:
+    print("-- Weight Dumping for the Differentiable LPC Block --")
+    for i, layer in enumerate(model.get_layer("f2lpc").layers):
+        if layer.dump_layer(f, hf):
+            layer_list.append(layer.name)
+
 dump_sparse_gru(model.get_layer('gru_a'), f, hf)
 
 hf.write('#define MAX_RNN_NEURONS {}\n\n'.format(max_rnn_neurons))
diff --git a/dnn/training_tf2/lossfuncs.py b/dnn/training_tf2/lossfuncs.py
new file mode 100644
index 000000000..8a627eadb
--- /dev/null
+++ b/dnn/training_tf2/lossfuncs.py
@@ -0,0 +1,85 @@
+"""
+Custom Loss functions and metrics for training/analysis
+"""
+
+from tf_funcs import *
+import tensorflow as tf
+
+# The following loss functions all expect the lpcnet model to output the lpc prediction
+
+# Computing the excitation by subtracting the lpc prediction from the target, followed by minimizing the cross entropy
+def res_from_sigloss():
+    def loss(y_true,y_pred):
+        p = y_pred[:,:,0:1]
+        model_out = y_pred[:,:,1:]
+        e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
+        e_gt = tf.round(e_gt)
+        e_gt = tf.cast(e_gt,'int32')
+        sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
+        return sparse_cel
+    return loss
+
+# Interpolated and Compensated Loss (In case of end to end lpcnet)
+# Interpolates between adjacent embeddings based on the fractional value of the excitation computed (similar to the embedding interpolation)
+# Also adds a probability compensation (to account for matching cross entropy in the linear domain), weighted by gamma
+def interp_mulaw(gamma = 1):
+    def loss(y_true,y_pred):
+        p = y_pred[:,:,0:1]
+        model_out = y_pred[:,:,1:]
+        e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
+        prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
+        alpha = e_gt - tf.math.floor(e_gt)
+        alpha = tf.tile(alpha,[1,1,256])
+        e_gt = tf.cast(e_gt,'int32')
+        e_gt = tf.clip_by_value(e_gt,0,254) 
+        interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
+        sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
+        loss_mod = sparse_cel + gamma*prob_compensation
+        return loss_mod
+    return loss
+
+# Same as above, except a metric
+def metric_oginterploss(y_true,y_pred):
+    p = y_pred[:,:,0:1]
+    model_out = y_pred[:,:,1:]
+    e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
+    prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
+    alpha = e_gt - tf.math.floor(e_gt)
+    alpha = tf.tile(alpha,[1,1,256])
+    e_gt = tf.cast(e_gt,'int32')
+    e_gt = tf.clip_by_value(e_gt,0,254) 
+    interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
+    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
+    loss_mod = sparse_cel + prob_compensation
+    return loss_mod
+
+# Interpolated cross entropy loss metric
+def metric_icel(y_true, y_pred):
+    p = y_pred[:,:,0:1]
+    model_out = y_pred[:,:,1:]
+    e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
+    alpha = e_gt - tf.math.floor(e_gt)
+    alpha = tf.tile(alpha,[1,1,256])
+    e_gt = tf.cast(e_gt,'int32')
+    e_gt = tf.clip_by_value(e_gt,0,254) #Check direction
+    interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
+    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
+    return sparse_cel
+
+# Non-interpolated (rounded) cross entropy loss metric
+def metric_cel(y_true, y_pred):
+    p = y_pred[:,:,0:1]
+    model_out = y_pred[:,:,1:]
+    e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
+    e_gt = tf.round(e_gt)
+    e_gt = tf.cast(e_gt,'int32')
+    e_gt = tf.clip_by_value(e_gt,0,255) 
+    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
+    return sparse_cel
+
+# Variance metric of the output excitation
+def metric_exc_sd(y_true,y_pred):
+    p = y_pred[:,:,0:1]
+    e_gt = tf_l2u(tf_u2l(y_true) - tf_u2l(p))
+    sd_egt = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(e_gt,128)
+    return sd_egt
diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py
index 11d5f329e..e08b809ad 100644
--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -38,6 +38,9 @@ from mdense import MDense
 import numpy as np
 import h5py
 import sys
+from tf_funcs import *
+from diffembed import diff_Embed
+import difflpc
 
 frame_size = 160
 pcm_bits = 8
@@ -186,7 +189,7 @@ class PCMInit(Initializer):
         #a[:,0] = math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows
         #a[:,1] = .5*a[:,0]*a[:,0]*a[:,0]
         a = a + np.reshape(math.sqrt(12)*np.arange(-.5*num_rows+.5,.5*num_rows-.4)/num_rows, (num_rows, 1))
-        return self.gain * a
+        return self.gain * a.astype("float32")
 
     def get_config(self):
         return {
@@ -212,7 +215,7 @@ class WeightClip(Constraint):
 
 constraint = WeightClip(0.992)
 
-def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False):
+def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
     pcm = Input(shape=(None, 3))
     feat = Input(shape=(None, nb_used_features))
     pitch = Input(shape=(None, 1))
@@ -224,8 +227,21 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
     fconv1 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv1')
     fconv2 = Conv1D(128, 3, padding=padding, activation='tanh', name='feature_conv2')
 
-    embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
-    cpcm = Reshape((-1, embed_size*3))(embed(pcm))
+    if not flag_e2e:
+        embed = Embedding(256, embed_size, embeddings_initializer=PCMInit(), name='embed_sig')
+        cpcm = Reshape((-1, embed_size*3))(embed(pcm))
+    else:
+        Input_extractor = Lambda(lambda x: K.expand_dims(x[0][:,:,x[1]],axis = -1))
+        error_calc = Lambda(lambda x: tf_l2u(tf_u2l(x[0]) - tf.roll(tf_u2l(x[1]),1,axis = 1)))
+        feat2lpc = difflpc.difflpc(training = training)
+        lpcoeffs = feat2lpc(feat)
+        tensor_preds = diff_pred(name = "lpc2preds")([Input_extractor([pcm,0]),lpcoeffs])
+        past_errors = error_calc([Input_extractor([pcm,0]),tensor_preds])
+        embed = diff_Embed(name='embed_sig',initializer = PCMInit())
+        cpcm = Concatenate()([Input_extractor([pcm,0]),tensor_preds,past_errors])
+        cpcm = Reshape((-1, embed_size*3))(embed(cpcm))
+        cpcm_decoder = Concatenate()([Input_extractor([pcm,0]),Input_extractor([pcm,1]),Input_extractor([pcm,2])])
+        cpcm_decoder = Reshape((-1, embed_size*3))(embed(cpcm_decoder))
 
     pembed = Embedding(256, 64, name='embed_pitch')
     cat_feat = Concatenate()([feat, Reshape((-1, 64))(pembed(pitch))])
@@ -264,15 +280,22 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
         md.trainable=False
         embed.Trainable=False
     
-    model = Model([pcm, feat, pitch], ulaw_prob)
+    if not flag_e2e:
+        model = Model([pcm, feat, pitch], ulaw_prob)
+    else:
+        m_out = Concatenate()([tensor_preds,ulaw_prob])
+        model = Model([pcm, feat, pitch], m_out)
     model.rnn_units1 = rnn_units1
     model.rnn_units2 = rnn_units2
     model.nb_used_features = nb_used_features
     model.frame_size = frame_size
-
-    encoder = Model([feat, pitch], cfeat)
     
-    dec_rnn_in = Concatenate()([cpcm, dec_feat])
+    if not flag_e2e:
+        encoder = Model([feat, pitch], cfeat)
+        dec_rnn_in = Concatenate()([cpcm, dec_feat])
+    else:
+        encoder = Model([feat, pitch], [cfeat,lpcoeffs])
+        dec_rnn_in = Concatenate()([cpcm_decoder, dec_feat])
     dec_gru_out1, state1 = rnn(dec_rnn_in, initial_state=dec_state1)
     dec_gru_out2, state2 = rnn2(Concatenate()([dec_gru_out1, dec_feat]), initial_state=dec_state2)
     dec_ulaw_prob = Lambda(tree_to_pdf_infer)(md(dec_gru_out2))
diff --git a/dnn/training_tf2/test_lpcnet.py b/dnn/training_tf2/test_lpcnet.py
index 9a48d5667..88439cf17 100755
--- a/dnn/training_tf2/test_lpcnet.py
+++ b/dnn/training_tf2/test_lpcnet.py
@@ -31,8 +31,10 @@ import numpy as np
 from ulaw import ulaw2lin, lin2ulaw
 import h5py
 
+# Flag for synthesizing e2e (differentiable lpc) model
+flag_e2e = False
 
-model, enc, dec = lpcnet.new_lpcnet_model()
+model, enc, dec = lpcnet.new_lpcnet_model(training = False, flag_e2e = flag_e2e)
 
 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
 #model.summary()
@@ -70,10 +72,16 @@ fout = open(out_file, 'wb')
 
 skip = order + 1
 for c in range(0, nb_frames):
-    cfeat = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]])
+    if not flag_e2e:
+        cfeat = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]])
+    else:
+        cfeat,lpcs = enc.predict([features[c:c+1, :, :nb_used_features], periods[c:c+1, :, :]])
     for fr in range(0, feature_chunk_size):
         f = c*feature_chunk_size + fr
-        a = features[c, fr, nb_features-order:]
+        if not flag_e2e:
+            a = features[c, fr, nb_features-order:]
+        else:
+            a = lpcs[c,fr]
         for i in range(skip, frame_size):
             pred = -sum(a*pcm[f*frame_size + i - 1:f*frame_size + i - order-1:-1])
             fexc[0, 0, 1] = lin2ulaw(pred)
diff --git a/dnn/training_tf2/tf_funcs.py b/dnn/training_tf2/tf_funcs.py
new file mode 100644
index 000000000..cf593184b
--- /dev/null
+++ b/dnn/training_tf2/tf_funcs.py
@@ -0,0 +1,69 @@
+"""
+Tensorflow/Keras helper functions to do the following:
+    1. \mu law <-> Linear domain conversion
+    2. Differentiable prediction from the input signal and LP coefficients
+    3. Differentiable transformations Reflection Coefficients (RCs) <-> LP Coefficients
+"""
+from tensorflow.keras.layers import Lambda, Multiply, Layer, Concatenate
+from tensorflow.keras import backend as K
+import tensorflow as tf
+
+# \mu law <-> Linear conversion functions
+scale = 255.0/32768.0
+scale_1 = 32768.0/255.0
+def tf_l2u(x):
+    s = K.sign(x)
+    x = K.abs(x)
+    u = (s*(128*K.log(1+scale*x)/K.log(256.0)))
+    u = K.clip(128 + u, 0, 255)
+    return u
+
+def tf_u2l(u):
+    u = tf.cast(u,"float32")
+    u = u - 128.0
+    s = K.sign(u)
+    u = K.abs(u)
+    return s*scale_1*(K.exp(u/128.*K.log(256.0))-1)
+
+# Differentiable Prediction Layer
+# Computes the LP prediction from the input lag signal and the LP coefficients
+# The inputs xt and lpc conform with the shapes in lpcnet.py (the '2400' is coded keeping this in mind)
+class diff_pred(Layer):
+    def call(self, inputs, lpcoeffs_N = 16, frame_size = 160):
+        xt = tf_u2l(inputs[0])
+        lpc = inputs[1]
+
+        rept = Lambda(lambda x: K.repeat_elements(x , frame_size, 1))
+        zpX = Lambda(lambda x: K.concatenate([0*x[:,0:lpcoeffs_N,:], x],axis = 1))
+        cX = Lambda(lambda x: K.concatenate([x[:,(lpcoeffs_N - i):(lpcoeffs_N - i + 2400),:] for i in range(lpcoeffs_N)],axis = 2))
+        
+        pred = -Multiply()([rept(lpc),cX(zpX(xt))])
+
+        return tf_l2u(K.sum(pred,axis = 2,keepdims = True))
+
+# Differentiable Transformations (RC <-> LPC) computed using the Levinson Durbin Recursion 
+class diff_rc2lpc(Layer):
+    def call(self, inputs, lpcoeffs_N = 16):
+        def pred_lpc_recursive(input):
+            temp = (input[0] + K.repeat_elements(input[1],input[0].shape[2],2)*K.reverse(input[0],axes = 2))
+            temp = Concatenate(axis = 2)([temp,input[1]])
+            return temp
+        Llpc = Lambda(pred_lpc_recursive)
+        lpc_init = inputs
+        for i in range(1,lpcoeffs_N):
+            lpc_init = Llpc([lpc_init[:,:,:i],K.expand_dims(inputs[:,:,i],axis = -1)])
+        return lpc_init
+
+class diff_lpc2rc(Layer):
+    def call(self, inputs, lpcoeffs_N = 16):
+        def pred_rc_recursive(input):
+            ki = K.repeat_elements(K.expand_dims(input[1][:,:,0],axis = -1),input[0].shape[2],2)
+            temp = (input[0] - ki*K.reverse(input[0],axes = 2))/(1 - ki*ki)
+            temp = Concatenate(axis = 2)([temp,input[1]])
+            return temp
+        Lrc = Lambda(pred_rc_recursive)
+        rc_init = inputs
+        for i in range(1,lpcoeffs_N):
+            j = (lpcoeffs_N - i + 1)
+            rc_init = Lrc([rc_init[:,:,:(j - 1)],rc_init[:,:,(j - 1):]])
+        return rc_init
\ No newline at end of file
diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py
index 59214e5dc..aeaf98d9a 100755
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -44,7 +44,7 @@ parser.add_argument('--grua-size', metavar='<units>', default=384, type=int, hel
 parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)')
 parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
 parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
-
+parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
 
 args = parser.parse_args()
 
@@ -66,12 +66,14 @@ lpcnet = importlib.import_module(args.model)
 import sys
 import numpy as np
 from tensorflow.keras.optimizers import Adam
-from tensorflow.keras.callbacks import ModelCheckpoint
+from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
 from ulaw import ulaw2lin, lin2ulaw
 import tensorflow.keras.backend as K
 import h5py
 
 import tensorflow as tf
+from tf_funcs import *
+from lossfuncs import *
 #gpus = tf.config.experimental.list_physical_devices('GPU')
 #if gpus:
 #  try:
@@ -93,12 +95,17 @@ else:
     lr = 0.001
     decay = 2.5e-5
 
+flag_e2e = args.flag_e2e
+
 opt = Adam(lr, decay=decay, beta_2=0.99)
 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
 with strategy.scope():
-    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize)
-    model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
+    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size, rnn_units2=args.grub_size, training=True, quantize=quantize, flag_e2e = flag_e2e)
+    if not flag_e2e:
+        model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
+    else:
+        model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
     model.summary()
 
 feature_file = args.features
@@ -150,4 +157,5 @@ else:
     grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density)
 
 model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
-model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify])
+csv_logger = CSVLogger('training_vals.log')
+model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
-- 
GitLab