diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py
index aeaf98d9a87d693498cc61bf8427f8e6f37265a0..bbf01bfcacc9f841adc72c2b394b7e72bea14d04 100755
--- a/dnn/training_tf2/train_lpcnet.py
+++ b/dnn/training_tf2/train_lpcnet.py
@@ -35,7 +35,9 @@ parser.add_argument('features', metavar='<features file>', help='binary features
 parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)')
 parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
 parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)')
-parser.add_argument('--quantize', metavar='<input weights>', help='quantize model')
+group1 = parser.add_mutually_exclusive_group()
+group1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
+group1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
 parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)')
 parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)')
 parser.add_argument('--grub-density', metavar='<global GRU B density>', type=float, help='average density of the recurrent weights (default 1.0)')
@@ -45,6 +47,10 @@ parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help
 parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
 parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
 parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
+parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
+parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
+parser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)')
+
 
 args = parser.parse_args()
 
@@ -60,6 +66,8 @@ if args.grub_density_split is not None:
 elif args.grub_density is not None:
     grub_density = [0.5*args.grub_density, 0.5*args.grub_density, 2.0*args.grub_density];
 
+gamma = 2.0 if args.gamma is None else args.gamma
+
 import importlib
 lpcnet = importlib.import_module(args.model)
 
@@ -87,14 +95,25 @@ nb_epochs = args.epochs
 batch_size = args.batch_size
 
 quantize = args.quantize is not None
+retrain = args.retrain is not None
 
 if quantize:
     lr = 0.00003
     decay = 0
+    input_model = args.quantize
 else:
     lr = 0.001
     decay = 2.5e-5
 
+if args.lr is not None:
+    lr = args.lr
+
+if args.decay is not None:
+    decay = args.decay
+
+if retrain:
+    input_model = args.retrain
+
 flag_e2e = args.flag_e2e
 
 opt = Adam(lr, decay=decay, beta_2=0.99)
@@ -105,7 +124,7 @@ with strategy.scope():
     if not flag_e2e:
         model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
     else:
-        model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
+        model.compile(optimizer=opt, loss = interp_mulaw(gamma=gamma),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
     model.summary()
 
 feature_file = args.features
@@ -146,9 +165,12 @@ periods = (.1 + 50*features[:,:,18:19]+100).astype('int16')
 # dump models to disk as we go
 checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}'))
 
-if quantize:
+if args.retrain is not None:
+    model.load_weights(args.retrain)
+
+if quantize or retrain:
     #Adapting from an existing model
-    model.load_weights(args.quantize)
+    model.load_weights(input_model)
     sparsify = lpcnet.Sparsify(0, 0, 1, density)
     grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
 else: