diff --git a/dnn/training_tf2/train_lpcnet.py b/dnn/training_tf2/train_lpcnet.py index aeaf98d9a87d693498cc61bf8427f8e6f37265a0..bbf01bfcacc9f841adc72c2b394b7e72bea14d04 100755 --- a/dnn/training_tf2/train_lpcnet.py +++ b/dnn/training_tf2/train_lpcnet.py @@ -35,7 +35,9 @@ parser.add_argument('features', metavar='<features file>', help='binary features parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)') parser.add_argument('output', metavar='<output>', help='trained model file (.h5)') parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)') -parser.add_argument('--quantize', metavar='<input weights>', help='quantize model') +group1 = parser.add_mutually_exclusive_group() +group1.add_argument('--quantize', metavar='<input weights>', help='quantize model') +group1.add_argument('--retrain', metavar='<input weights>', help='continue training model') parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)') parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)') parser.add_argument('--grub-density', metavar='<global GRU B density>', type=float, help='average density of the recurrent weights (default 1.0)') @@ -45,6 +47,10 @@ parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)') parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)') parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation') +parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate') +parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay') +parser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)') + args = parser.parse_args() @@ -60,6 +66,8 @@ if args.grub_density_split is not None: elif args.grub_density is not None: grub_density = [0.5*args.grub_density, 0.5*args.grub_density, 2.0*args.grub_density]; +gamma = 2.0 if args.gamma is None else args.gamma + import importlib lpcnet = importlib.import_module(args.model) @@ -87,14 +95,25 @@ nb_epochs = args.epochs batch_size = args.batch_size quantize = args.quantize is not None +retrain = args.retrain is not None if quantize: lr = 0.00003 decay = 0 + input_model = args.quantize else: lr = 0.001 decay = 2.5e-5 +if args.lr is not None: + lr = args.lr + +if args.decay is not None: + decay = args.decay + +if retrain: + input_model = args.retrain + flag_e2e = args.flag_e2e opt = Adam(lr, decay=decay, beta_2=0.99) @@ -105,7 +124,7 @@ with strategy.scope(): if not flag_e2e: model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy') else: - model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss]) + model.compile(optimizer=opt, loss = interp_mulaw(gamma=gamma),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss]) model.summary() feature_file = args.features @@ -146,9 +165,12 @@ periods = (.1 + 50*features[:,:,18:19]+100).astype('int16') # dump models to disk as we go checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}')) -if quantize: +if args.retrain is not None: + model.load_weights(args.retrain) + +if quantize or retrain: #Adapting from an existing model - model.load_weights(args.quantize) + model.load_weights(input_model) sparsify = lpcnet.Sparsify(0, 0, 1, density) grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density) else: