diff --git a/dnn/autogen.sh b/dnn/autogen.sh
index 770daecc94c734ca3620fd67adc5486675372cfa..5ee3cc8ab56af1893728ce47ec3fda6228987893 100755
--- a/dnn/autogen.sh
+++ b/dnn/autogen.sh
@@ -6,7 +6,7 @@ srcdir=`dirname $0`
 test -n "$srcdir" && cd "$srcdir"
 
 #SHA1 of the first commit compatible with the current model
-commit=2d22197
+commit=b7d25ac
 
 if [ ! -f lpcnet_data-$commit.tar.gz ]; then
 	echo "Downloading latest model"
diff --git a/dnn/dump_data.c b/dnn/dump_data.c
index 95cb567e487cb54967496226886b28a558951e92..0ed47abc7c29bc14e3b408be664f33f8a5e18d91 100644
--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -75,9 +75,9 @@ void compute_noise(int *noise, float noise_std) {
 }
 
 
-void write_audio(LPCNetEncState *st, const short *pcm, const int *noise, FILE *file) {
+void write_audio(LPCNetEncState *st, const short *pcm, const int *noise, FILE *file, int nframes) {
   int i, k;
-  for (k=0;k<4;k++) {
+  for (k=0;k<nframes;k++) {
   unsigned char data[4*FRAME_SIZE];
   for (i=0;i<FRAME_SIZE;i++) {
     float p=0;
@@ -250,7 +250,7 @@ int main(int argc, char **argv) {
       rand_resp(a_sig, b_sig);
       tmp = (float)rand()/RAND_MAX;
       tmp2 = (float)rand()/RAND_MAX;
-      noise_std = -log(tmp)-log(tmp2);
+      noise_std = ABS16(-1.5*log(1e-4+tmp)-.5*log(1e-4+tmp2));
     }
     biquad(x, mem_hp_x, x, b_hp, a_hp, FRAME_SIZE);
     biquad(x, mem_resp_x, x, b_sig, a_sig, FRAME_SIZE);
@@ -270,12 +270,19 @@ int main(int argc, char **argv) {
     if (fpcm) {
         compute_noise(&noisebuf[st->pcount*FRAME_SIZE], noise_std);
     }
+    
+    if (!quantize) {
+      process_single_frame(st, ffeat);
+      if (fpcm) write_audio(st, pcm, &noisebuf[st->pcount*FRAME_SIZE], fpcm, 1);
+    }
     st->pcount++;
     /* Running on groups of 4 frames. */
     if (st->pcount == 4) {
-      unsigned char buf[8];
-      process_superframe(st, buf, ffeat, encode, quantize);
-      if (fpcm) write_audio(st, pcmbuf, noisebuf, fpcm);
+      if (quantize) {
+        unsigned char buf[8];
+        process_superframe(st, buf, ffeat, encode, quantize);
+        if (fpcm) write_audio(st, pcmbuf, noisebuf, fpcm, 4);
+      }
       st->pcount = 0;
     }
     //if (fpcm) fwrite(pcm, sizeof(short), FRAME_SIZE, fpcm);
diff --git a/dnn/lpcnet_demo.c b/dnn/lpcnet_demo.c
index ae016e19e37620bb0a3f8c7f13ab20de8d31bb2d..e05755b0ad470cf6a3e6337f5b8cc94be87aa88b 100644
--- a/dnn/lpcnet_demo.c
+++ b/dnn/lpcnet_demo.c
@@ -99,13 +99,13 @@ int main(int argc, char **argv) {
         LPCNetEncState *net;
         net = lpcnet_encoder_create();
         while (1) {
-            float features[4][NB_TOTAL_FEATURES];
-            short pcm[LPCNET_PACKET_SAMPLES];
+            float features[NB_TOTAL_FEATURES];
+            short pcm[LPCNET_FRAME_SIZE];
             size_t ret;
-            ret = fread(pcm, sizeof(pcm[0]), LPCNET_PACKET_SAMPLES, fin);
-            if (feof(fin) || ret != LPCNET_PACKET_SAMPLES) break;
-            lpcnet_compute_features(net, pcm, features);
-            fwrite(features, sizeof(float), 4*NB_TOTAL_FEATURES, fout);
+            ret = fread(pcm, sizeof(pcm[0]), LPCNET_FRAME_SIZE, fin);
+            if (feof(fin) || ret != LPCNET_FRAME_SIZE) break;
+            lpcnet_compute_single_frame_features(net, pcm, features);
+            fwrite(features, sizeof(float), NB_TOTAL_FEATURES, fout);
         }
         lpcnet_encoder_destroy(net);
     } else if (mode == MODE_SYNTHESIS) {
diff --git a/dnn/lpcnet_enc.c b/dnn/lpcnet_enc.c
index 1196a3e14691a7b52ac54580978d0c05e6de35a3..be9aaf23a7a837bd3bba760fa78717f0d0a73fc6 100644
--- a/dnn/lpcnet_enc.c
+++ b/dnn/lpcnet_enc.c
@@ -710,6 +710,133 @@ void process_superframe(LPCNetEncState *st, unsigned char *buf, FILE *ffeat, int
   }
 }
 
+
+void process_multi_frame(LPCNetEncState *st, FILE *ffeat) {
+  int i;
+  int sub;
+  int best_i;
+  int best[10];
+  int pitch_prev[8][PITCH_MAX_PERIOD];
+  float frame_corr;
+  float frame_weight_sum = 1e-15;
+  for(sub=0;sub<8;sub++) frame_weight_sum += st->frame_weight[2+sub];
+  for(sub=0;sub<8;sub++) st->frame_weight[2+sub] *= (8.f/frame_weight_sum);
+  for(sub=0;sub<8;sub++) {
+    float max_path_all = -1e15;
+    best_i = 0;
+    for (i=0;i<PITCH_MAX_PERIOD-2*PITCH_MIN_PERIOD;i++) {
+      float xc_half = MAX16(MAX16(st->xc[2+sub][(PITCH_MAX_PERIOD+i)/2], st->xc[2+sub][(PITCH_MAX_PERIOD+i+2)/2]), st->xc[2+sub][(PITCH_MAX_PERIOD+i-1)/2]);
+      if (st->xc[2+sub][i] < xc_half*1.1) st->xc[2+sub][i] *= .8;
+    }
+    for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) {
+      int j;
+      float max_prev;
+      max_prev = st->pitch_max_path_all - 6.f;
+      pitch_prev[sub][i] = st->best_i;
+      for (j=IMIN(0, 4-i);j<=4 && i+j<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;j++) {
+        if (st->pitch_max_path[0][i+j] - .02f*abs(j)*abs(j) > max_prev) {
+          max_prev = st->pitch_max_path[0][i+j] - .02f*abs(j)*abs(j);
+          pitch_prev[sub][i] = i+j;
+        }
+      }
+      st->pitch_max_path[1][i] = max_prev + st->frame_weight[2+sub]*st->xc[2+sub][i];
+      if (st->pitch_max_path[1][i] > max_path_all) {
+        max_path_all = st->pitch_max_path[1][i];
+        best_i = i;
+      }
+    }
+    /* Renormalize. */
+    for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) st->pitch_max_path[1][i] -= max_path_all;
+    //for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) printf("%f ", st->pitch_max_path[1][i]);
+    //printf("\n");
+    RNN_COPY(&st->pitch_max_path[0][0], &st->pitch_max_path[1][0], PITCH_MAX_PERIOD);
+    st->pitch_max_path_all = max_path_all;
+    st->best_i = best_i;
+  }
+  best_i = st->best_i;
+  frame_corr = 0;
+  /* Backward pass. */
+  for (sub=7;sub>=0;sub--) {
+    best[2+sub] = PITCH_MAX_PERIOD-best_i;
+    frame_corr += st->frame_weight[2+sub]*st->xc[2+sub][best_i];
+    best_i = pitch_prev[sub][best_i];
+  }
+  frame_corr /= 8;
+  for (sub=0;sub<4;sub++) {
+    st->features[sub][NB_BANDS] = .01*(IMAX(66, IMIN(510, best[2+2*sub]+best[2+2*sub+1]))-200);
+    st->features[sub][NB_BANDS + 1] = frame_corr-.5;
+    //printf("%f %d %f\n", st->features[sub][NB_BANDS], best[2+2*sub], frame_corr);
+  }
+  //printf("%d %f %f %f\n", best_period, best_a, best_b, best_corr);
+  RNN_COPY(&st->xc[0][0], &st->xc[8][0], PITCH_MAX_PERIOD);
+  RNN_COPY(&st->xc[1][0], &st->xc[9][0], PITCH_MAX_PERIOD);
+  //printf("\n");
+  RNN_COPY(st->vq_mem, &st->features[3][0], NB_BANDS);
+  if (ffeat) {
+    for (i=0;i<4;i++) {
+      fwrite(st->features[i], sizeof(float), NB_TOTAL_FEATURES, ffeat);
+    }
+  }
+}
+
+void process_single_frame(LPCNetEncState *st, FILE *ffeat) {
+  int i;
+  int sub;
+  int best_i;
+  int best[4];
+  int pitch_prev[2][PITCH_MAX_PERIOD];
+  float frame_corr;
+  float frame_weight_sum = 1e-15;
+  for(sub=0;sub<2;sub++) frame_weight_sum += st->frame_weight[2+2*st->pcount+sub];
+  for(sub=0;sub<2;sub++) st->frame_weight[2+2*st->pcount+sub] *= (2.f/frame_weight_sum);
+  for(sub=0;sub<2;sub++) {
+    float max_path_all = -1e15;
+    best_i = 0;
+    for (i=0;i<PITCH_MAX_PERIOD-2*PITCH_MIN_PERIOD;i++) {
+      float xc_half = MAX16(MAX16(st->xc[2+2*st->pcount+sub][(PITCH_MAX_PERIOD+i)/2], st->xc[2+2*st->pcount+sub][(PITCH_MAX_PERIOD+i+2)/2]), st->xc[2+2*st->pcount+sub][(PITCH_MAX_PERIOD+i-1)/2]);
+      if (st->xc[2+2*st->pcount+sub][i] < xc_half*1.1) st->xc[2+2*st->pcount+sub][i] *= .8;
+    }
+    for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) {
+      int j;
+      float max_prev;
+      max_prev = st->pitch_max_path_all - 6.f;
+      pitch_prev[sub][i] = st->best_i;
+      for (j=IMIN(0, 4-i);j<=4 && i+j<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;j++) {
+        if (st->pitch_max_path[0][i+j] - .02f*abs(j)*abs(j) > max_prev) {
+          max_prev = st->pitch_max_path[0][i+j] - .02f*abs(j)*abs(j);
+          pitch_prev[sub][i] = i+j;
+        }
+      }
+      st->pitch_max_path[1][i] = max_prev + st->frame_weight[2+2*st->pcount+sub]*st->xc[2+2*st->pcount+sub][i];
+      if (st->pitch_max_path[1][i] > max_path_all) {
+        max_path_all = st->pitch_max_path[1][i];
+        best_i = i;
+      }
+    }
+    /* Renormalize. */
+    for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) st->pitch_max_path[1][i] -= max_path_all;
+    //for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) printf("%f ", st->pitch_max_path[1][i]);
+    //printf("\n");
+    RNN_COPY(&st->pitch_max_path[0][0], &st->pitch_max_path[1][0], PITCH_MAX_PERIOD);
+    st->pitch_max_path_all = max_path_all;
+    st->best_i = best_i;
+  }
+  best_i = st->best_i;
+  frame_corr = 0;
+  /* Backward pass. */
+  for (sub=1;sub>=0;sub--) {
+    best[2+sub] = PITCH_MAX_PERIOD-best_i;
+    frame_corr += st->frame_weight[2+2*st->pcount+sub]*st->xc[2+2*st->pcount+sub][best_i];
+    best_i = pitch_prev[sub][best_i];
+  }
+  frame_corr /= 2;
+  st->features[st->pcount][NB_BANDS] = .01*(IMAX(66, IMIN(510, best[2]+best[3]))-200);
+  st->features[st->pcount][NB_BANDS + 1] = frame_corr-.5;
+  if (ffeat) {
+    fwrite(st->features[st->pcount], sizeof(float), NB_TOTAL_FEATURES, ffeat);
+  }
+}
+
 void preemphasis(float *y, float *mem, const float *x, float coef, int N) {
   int i;
   for (i=0;i<N;i++) {
@@ -748,3 +875,14 @@ LPCNET_EXPORT int lpcnet_compute_features(LPCNetEncState *st, const short *pcm,
   }
   return 0;
 }
+
+LPCNET_EXPORT int lpcnet_compute_single_frame_features(LPCNetEncState *st, const short *pcm, float features[NB_TOTAL_FEATURES]) {
+  int i;
+  float x[FRAME_SIZE];
+  for (i=0;i<FRAME_SIZE;i++) x[i] = pcm[i];
+  preemphasis(x, &st->mem_preemph, x, PREEMPHASIS, FRAME_SIZE);
+  compute_frame_features(st, x);
+  process_single_frame(st, NULL);
+  RNN_COPY(features, &st->features[0][0], NB_TOTAL_FEATURES);
+  return 0;
+}
diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h
index 8027d32a041fc13f20dea5f828686e2abf55b4dd..ee45a29fbe10123292c0b4ac7afa0ee28fb4bfd4 100644
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -1,6 +1,7 @@
 #ifndef LPCNET_PRIVATE_H
 #define LPCNET_PRIVATE_H
 
+#include <stdio.h>
 #include "common.h"
 #include "freq.h"
 #include "lpcnet.h"
@@ -74,5 +75,7 @@ void compute_frame_features(LPCNetEncState *st, const float *in);
 
 void decode_packet(float features[4][NB_TOTAL_FEATURES], float *vq_mem, const unsigned char buf[8]);
 
+void process_single_frame(LPCNetEncState *st, FILE *ffeat);
+
 void run_frame_network(LPCNetState *lpcnet, float *gru_a_condition, float *gru_b_condition, float *lpc, const float *features);
 #endif
diff --git a/dnn/nnet.c b/dnn/nnet.c
index 3513a60627ca5d7120799033eb71f0a7e460336c..7f4914c400daa19c9a2c8ecb819c7e30c09c1ac3 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -283,7 +283,7 @@ void compute_gru2(const GRULayer *gru, float *state, const float *input)
    sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, stride, input);
    for (i=0;i<3*N;i++)
       recur[i] = gru->bias[3*N + i];
-   sgemv_accum(recur, gru->recurrent_weights, 3*N, N, stride, state);
+   sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, stride, state);
    for (i=0;i<2*N;i++)
       zrh[i] += recur[i];
    compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
@@ -324,9 +324,14 @@ void compute_gruB(const GRULayer *gru, const float* gru_b_condition, float *stat
       zrh[i] = gru->bias[i] + gru_b_condition[i];
 #endif
    sparse_sgemv_accum8x4(zrh, gru->input_weights, 3*N, M, gru->input_weights_idx, input);
+#ifdef USE_SU_BIAS
+   for (i=0;i<3*N;i++)
+      recur[i] = gru->subias[3*N + i];
+#else
    for (i=0;i<3*N;i++)
       recur[i] = gru->bias[3*N + i];
-   sgemv_accum(recur, gru->recurrent_weights, 3*N, N, stride, state);
+#endif
+   sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, stride, state);
    for (i=0;i<2*N;i++)
       zrh[i] += recur[i];
    compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
@@ -361,7 +366,7 @@ void compute_gru3(const GRULayer *gru, float *state, const float *input)
    RNN_COPY(zrh, input, 3*N);
    for (i=0;i<3*N;i++)
       recur[i] = gru->bias[3*N + i];
-   sgemv_accum(recur, gru->recurrent_weights, 3*N, N, stride, state);
+   sgemv_accum8x4(recur, gru->recurrent_weights, 3*N, N, stride, state);
    for (i=0;i<2*N;i++)
       zrh[i] += recur[i];
    compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID);
diff --git a/dnn/nnet.h b/dnn/nnet.h
index 0c06280dd1a383b2dd23849a034fe0a3e65468fc..e0504e530053a9df7804033a9b747f59ab6213e2 100644
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -59,7 +59,7 @@ typedef struct {
   const float *subias;
   const qweight *input_weights;
   const int *input_weights_idx;
-  const float *recurrent_weights;
+  const qweight *recurrent_weights;
   int nb_inputs;
   int nb_neurons;
   int activation;
diff --git a/dnn/training_tf2/dataloader.py b/dnn/training_tf2/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f1f18645a3eaf2d74b4cd80a4220c95925474f
--- /dev/null
+++ b/dnn/training_tf2/dataloader.py
@@ -0,0 +1,26 @@
+import numpy as np
+from tensorflow.keras.utils import Sequence
+
+class LPCNetLoader(Sequence):
+    def __init__(self, data, features, periods, batch_size):
+        self.batch_size = batch_size
+        self.nb_batches = np.minimum(np.minimum(data.shape[0], features.shape[0]), periods.shape[0])//self.batch_size
+        self.data = data[:self.nb_batches*self.batch_size, :]
+        self.features = features[:self.nb_batches*self.batch_size, :]
+        self.periods = periods[:self.nb_batches*self.batch_size, :]
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        self.indices = np.arange(self.nb_batches*self.batch_size)
+        np.random.shuffle(self.indices)
+
+    def __getitem__(self, index):
+        data = self.data[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
+        in_data = data[: , :, :3]
+        out_data = data[: , :, 3:4]
+        features = self.features[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
+        periods = self.periods[self.indices[index*self.batch_size:(index+1)*self.batch_size], :, :]
+        return ([in_data, features, periods], out_data)
+
+    def __len__(self):
+        return self.nb_batches
diff --git a/dnn/training_tf2/dump_lpcnet.py b/dnn/training_tf2/dump_lpcnet.py
index 083dc3ed8e8d9f7709789476f89ab086cedfe1f2..26108dbd9808f24b14df4fb093c7ff483bea429c 100755
--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -138,10 +138,18 @@ def dump_grub(self, f, hf, gru_a_size):
     print("printing layer " + name + " of type " + self.__class__.__name__)
     weights = self.get_weights()
     qweight = printSparseVector(f, weights[0][:gru_a_size, :], name + '_weights', have_diag=False)
+
+    f.write('#ifdef DOT_PROD\n')
+    qweight2 = np.clip(np.round(128.*weights[1]).astype('int'), -128, 127)
+    printVector(f, qweight2, name + '_recurrent_weights', dotp=True, dtype='qweight')
+    f.write('#else /*DOT_PROD*/\n')
     printVector(f, weights[1], name + '_recurrent_weights')
+    f.write('#endif /*DOT_PROD*/\n')
+
     printVector(f, weights[-1], name + '_bias')
     subias = weights[-1].copy()
     subias[0,:] = subias[0,:] - np.sum(qweight*(1./128.),axis=0)
+    subias[1,:] = subias[1,:] - np.sum(qweight2*(1./128.),axis=0)
     printVector(f, subias, name + '_subias')
     if hasattr(self, 'activation'):
         activation = self.activation.__name__.upper()
diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py
index 2f14ecd39c2eb4634e3c8dc592a5bafaa88bbb87..e453a7de3c65bfcb0bb4e9d42a6602d3127e0ba1 100644
--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -28,7 +28,7 @@
 import math
 import tensorflow as tf
 from tensorflow.keras.models import Model
-from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation
+from tensorflow.keras.layers import Input, GRU, Dense, Embedding, Reshape, Concatenate, Lambda, Conv1D, Multiply, Add, Bidirectional, MaxPooling1D, Activation, GaussianNoise
 from tensorflow.compat.v1.keras.layers import CuDNNGRU
 from tensorflow.keras import backend as K
 from tensorflow.keras.constraints import Constraint
@@ -70,21 +70,19 @@ def quant_regularizer(x):
     return .01 * tf.reduce_mean(K.sqrt(K.sqrt(1.0001 - tf.math.cos(2*3.1415926535897931*(Q*x-tf.round(Q*x))))))
 
 class Sparsify(Callback):
-    def __init__(self, t_start, t_end, interval, density):
+    def __init__(self, t_start, t_end, interval, density, quantize=False):
         super(Sparsify, self).__init__()
         self.batch = 0
         self.t_start = t_start
         self.t_end = t_end
         self.interval = interval
         self.final_density = density
+        self.quantize = quantize
 
     def on_batch_end(self, batch, logs=None):
         #print("batch number", self.batch)
         self.batch += 1
-        if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
-            #print("don't constrain");
-            pass
-        else:
+        if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
             #print("constrain");
             layer = self.model.get_layer('gru_a')
             w = layer.get_weights()
@@ -96,7 +94,7 @@ class Sparsify(Callback):
             #print ("density = ", density)
             for k in range(nb):
                 density = self.final_density[k]
-                if self.batch < self.t_end:
+                if self.batch < self.t_end and not self.quantize:
                     r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
                     density = 1 - (1-self.final_density[k])*(1 - r*r*r)
                 A = p[:, k*N:(k+1)*N]
@@ -108,7 +106,7 @@ class Sparsify(Callback):
                 S=np.sum(S, axis=1)
                 SS=np.sort(np.reshape(S, (-1,)))
                 thresh = SS[round(N*N//32*(1-density))]
-                mask = (S>=thresh).astype('float32');
+                mask = (S>=thresh).astype('float32')
                 mask = np.repeat(mask, 4, axis=0)
                 mask = np.repeat(mask, 8, axis=1)
                 mask = np.minimum(1, mask + np.diag(np.ones((N,))))
@@ -116,11 +114,21 @@ class Sparsify(Callback):
                 mask = np.transpose(mask, (1, 0))
                 p[:, k*N:(k+1)*N] = p[:, k*N:(k+1)*N]*mask
                 #print(thresh, np.mean(mask))
+            if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
+                if self.batch < self.t_end:
+                    threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
+                else:
+                    threshold = .5
+                quant = np.round(p*128.)
+                res = p*128.-quant
+                mask = (np.abs(res) <= threshold).astype('float32')
+                p = mask/128.*quant + (1-mask)*p
+
             w[1] = p
             layer.set_weights(w)
 
 class SparsifyGRUB(Callback):
-    def __init__(self, t_start, t_end, interval, grua_units, density):
+    def __init__(self, t_start, t_end, interval, grua_units, density, quantize=False):
         super(SparsifyGRUB, self).__init__()
         self.batch = 0
         self.t_start = t_start
@@ -128,14 +136,12 @@ class SparsifyGRUB(Callback):
         self.interval = interval
         self.final_density = density
         self.grua_units = grua_units
+        self.quantize = quantize
 
     def on_batch_end(self, batch, logs=None):
         #print("batch number", self.batch)
         self.batch += 1
-        if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
-            #print("don't constrain");
-            pass
-        else:
+        if self.quantize or (self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end:
             #print("constrain");
             layer = self.model.get_layer('gru_b')
             w = layer.get_weights()
@@ -144,7 +150,7 @@ class SparsifyGRUB(Callback):
             M = p.shape[1]//3
             for k in range(3):
                 density = self.final_density[k]
-                if self.batch < self.t_end:
+                if self.batch < self.t_end and not self.quantize:
                     r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
                     density = 1 - (1-self.final_density[k])*(1 - r*r*r)
                 A = p[:, k*M:(k+1)*M]
@@ -158,7 +164,7 @@ class SparsifyGRUB(Callback):
                 S=np.sum(S, axis=1)
                 SS=np.sort(np.reshape(S, (-1,)))
                 thresh = SS[round(M*N2//32*(1-density))]
-                mask = (S>=thresh).astype('float32');
+                mask = (S>=thresh).astype('float32')
                 mask = np.repeat(mask, 4, axis=0)
                 mask = np.repeat(mask, 8, axis=1)
                 A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
@@ -167,6 +173,16 @@ class SparsifyGRUB(Callback):
                 A = np.reshape(A, (N, M))
                 p[:, k*M:(k+1)*M] = A
                 #print(thresh, np.mean(mask))
+            if self.quantize and ((self.batch > self.t_start and (self.batch-self.t_start) % self.interval == 0) or self.batch >= self.t_end):
+                if self.batch < self.t_end:
+                    threshold = .5*(self.batch - self.t_start)/(self.t_end - self.t_start)
+                else:
+                    threshold = .5
+                quant = np.round(p*128.)
+                res = p*128.-quant
+                mask = (np.abs(res) <= threshold).astype('float32')
+                p = mask/128.*quant + (1-mask)*p
+
             w[0] = p
             layer.set_weights(w)
             
@@ -215,9 +231,9 @@ class WeightClip(Constraint):
 constraint = WeightClip(0.992)
 
 def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, training=False, adaptation=False, quantize=False, flag_e2e = False):
-    pcm = Input(shape=(None, 3))
-    feat = Input(shape=(None, nb_used_features))
-    pitch = Input(shape=(None, 1))
+    pcm = Input(shape=(None, 3), batch_size=128)
+    feat = Input(shape=(None, nb_used_features), batch_size=128)
+    pitch = Input(shape=(None, 1), batch_size=128)
     dec_feat = Input(shape=(None, 128))
     dec_state1 = Input(shape=(rnn_units1,))
     dec_state2 = Input(shape=(rnn_units2,))
@@ -256,19 +272,20 @@ def new_lpcnet_model(rnn_units1=384, rnn_units2=16, nb_used_features = 20, train
     quant = quant_regularizer if quantize else None
 
     if training:
-        rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a',
+        rnn = CuDNNGRU(rnn_units1, return_sequences=True, return_state=True, name='gru_a', stateful=True,
               recurrent_constraint = constraint, recurrent_regularizer=quant)
-        rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b',
-               kernel_constraint=constraint, kernel_regularizer=quant)
+        rnn2 = CuDNNGRU(rnn_units2, return_sequences=True, return_state=True, name='gru_b', stateful=True,
+               kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
     else:
-        rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a',
+        rnn = GRU(rnn_units1, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_a', stateful=True,
               recurrent_constraint = constraint, recurrent_regularizer=quant)
-        rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b',
-               kernel_constraint=constraint, kernel_regularizer=quant)
+        rnn2 = GRU(rnn_units2, return_sequences=True, return_state=True, recurrent_activation="sigmoid", reset_after='true', name='gru_b', stateful=True,
+               kernel_constraint=constraint, recurrent_constraint = constraint, kernel_regularizer=quant, recurrent_regularizer=quant)
 
     rnn_in = Concatenate()([cpcm, rep(cfeat)])
     md = MDense(pcm_levels, activation='sigmoid', name='dual_fc')
     gru_out1, _ = rnn(rnn_in)
+    gru_out1 = GaussianNoise(.005)(gru_out1)
     gru_out2, _ = rnn2(Concatenate()([gru_out1, rep(cfeat)]))
     ulaw_prob = Lambda(tree_to_pdf_train)(md(gru_out2))
 
diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py
index bbf01bfcacc9f841adc72c2b394b7e72bea14d04..880545f2448e00a59c88c7fc62bae2bb45aeb32c 100755
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -28,6 +28,7 @@
 # Train an LPCNet model
 
 import argparse
+from dataloader import LPCNetLoader
 
 parser = argparse.ArgumentParser(description='Train an LPCNet model')
 
@@ -148,10 +149,10 @@ data = data[:nb_frames*4*pcm_chunk_size]
 
 
 data = np.reshape(data, (nb_frames, pcm_chunk_size, 4))
-in_data = data[:,:,:3]
-out_exc = data[:,:,3:4]
+#in_data = data[:,:,:3]
+#out_exc = data[:,:,3:4]
 
-print("ulaw std = ", np.std(out_exc))
+#print("ulaw std = ", np.std(out_exc))
 
 sizeof = features.strides[-1]
 features = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
@@ -171,8 +172,12 @@ if args.retrain is not None:
 if quantize or retrain:
     #Adapting from an existing model
     model.load_weights(input_model)
-    sparsify = lpcnet.Sparsify(0, 0, 1, density)
-    grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
+    if quantize:
+        sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True)
+        grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True)
+    else:
+        sparsify = lpcnet.Sparsify(0, 0, 1, density)
+        grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
 else:
     #Training from scratch
     sparsify = lpcnet.Sparsify(2000, 40000, 400, density)
@@ -180,4 +185,5 @@ else:
 
 model.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
 csv_logger = CSVLogger('training_vals.log')
-model.fit([in_data, features, periods], out_exc, batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])
+loader = LPCNetLoader(data, features, periods, batch_size)
+model.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=[checkpoint, sparsify, grub_sparsify, csv_logger])