diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py
index 4608ea5b894f02d78b4ca58bc0e432fee90695a7..bc4a97bff02a59fbcde0debc6979961818d3c89e 100644
--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -26,8 +26,10 @@
 '''
 
 import math
+import tensorflow as tf
 from tensorflow.keras.models import Model
 from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
+from tensorflow.compat.v1.keras.layers import CuDNNGRU
 from tensorflow.keras import backend as K
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.initializers import Initializer
@@ -42,6 +44,12 @@ pcm_bits = 8
 embed_size = 128
 pcm_levels = 2**pcm_bits
 
+def quant_regularizer(x):
+    Q = 128
+    Q_1 = 1./Q
+    #return .01 * tf.reduce_mean(1 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))
+    return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
+
 class Sparsify(Callback):
     def __init__(self, t_start, t_end, interval, density):
         super(Sparsify, self).__init__()
@@ -129,9 +137,9 @@ class WeightClip(Constraint):
         return {'name': self.__class__.__name__,
             'c': self.c}
 
-constraint = WeightClip(0.999)
+constraint = WeightClip(0.992)
 
-def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, adaptation=False):
+def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, training=False, adaptation=False, quantize=False):
     pcm = Input(shape=(None, 3))
     feat = Input(shape=(None, nb_used_features))
     pitch = Input(shape=(None, 1))
@@ -158,10 +166,18 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 38, train
     
     rep = Lambda(lambda x: K.repeat_elements(x, frame_size, 1))
 
-    rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a',
-              recurrent_constraint = constraint)
-    rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b',
-               kernel_constraint=constraint)
+    quant = quant_regularizer if quantize else None
+
+    if training:
+        rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a',
+              recurrent_constraint = constraint, recurrent_regularizer=quant)
+        rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b',
+               kernel_constraint=constraint, kernel_regularizer=quant)
+    else:
+        rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a',
+              recurrent_constraint = constraint, recurrent_regularizer=quant)
+        rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b',
+               kernel_constraint=constraint, kernel_regularizer=quant)
 
     rnn_in = Concatenate()([cpcm, rep(cfeat)])
     md = MDense(pcm_levels, activation='softmax', name='dual_fc')
diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py
index 0e90a28fa65277e29c334a062d01da41083c46d7..96ef11fbeddb17e6cafb8eb723f77321e797d5f0 100755
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -49,10 +49,23 @@ nb_epochs = 120
 # Try reducing batch_size if you run out of memory on your GPU
 batch_size = 128
 
-model, _, _ = lpcnet.new_lpcnet_model(training=True)
+#Set this to True to adapt an existing model (e.g. on new data)
+adaptation = False
+
+if adaptation:
+    lr = 0.0001
+    decay = 0
+else:
+    lr = 0.001
+    decay = 2.5e-5
 
-model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
-model.summary()
+opt = Adam(lr, decay=decay, beta_2=0.99)
+strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
+
+with strategy.scope():
+    model, _, _ = lpcnet.new_lpcnet_model(training=True)
+    model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
+    model.summary()
 
 feature_file = sys.argv[1]
 pcm_file = sys.argv[2]     # 16 bit unsigned short PCM samples
@@ -65,7 +78,7 @@ pcm_chunk_size = frame_size*feature_chunk_size
 # u for unquantised, load 16 bit PCM samples and convert to mu-law
 
 data = np.fromfile(pcm_file, dtype='uint8')
-nb_frames = len(data)//(4*pcm_chunk_size)
+nb_frames = len(data)//(4*pcm_chunk_size)//batch_size*batch_size
 
 features = np.fromfile(feature_file, dtype='float32')
 
@@ -102,23 +115,15 @@ del pred
 del in_exc
 
 # dump models to disk as we go
-checkpoint = ModelCheckpoint('lpcnet33_384_{epoch:02d}.h5')
-
-#Set this to True to adapt an existing model (e.g. on new data)
-adaptation = False
+checkpoint = ModelCheckpoint('lpcnet33e_384_{epoch:02d}.h5')
 
 if adaptation:
     #Adapting from an existing model
-    model.load_weights('lpcnet32v_384_100.h5')
+    model.load_weights('lpcnet33a_384_100.h5')
     sparsify = lpcnet.Sparsify(0, 0, 1, (0.05, 0.05, 0.2))
-    lr = 0.0001
-    decay = 0
 else:
     #Training from scratch
     sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
-    lr = 0.001
-    decay = 5e-5
 
-model.compile(optimizer=Adam(lr, decay=decay, beta_2=0.99), loss='sparse_categorical_crossentropy')
-model.save_weights('lpcnet33_384_00.h5');
+model.save_weights('lpcnet33e_384_00.h5');
 model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify])