Skip to content
Snippets Groups Projects
Commit c5364153 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Add more training options

parent ab9a0926
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,9 @@ parser.add_argument('features', metavar='<features file>', help='binary features
parser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)')
parser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
parser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)')
parser.add_argument('--quantize', metavar='<input weights>', help='quantize model')
group1 = parser.add_mutually_exclusive_group()
group1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
group1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
parser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)')
parser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)')
parser.add_argument('--grub-density', metavar='<global GRU B density>', type=float, help='average density of the recurrent weights (default 1.0)')
......@@ -45,6 +47,10 @@ parser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help
parser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
parser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
parser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
parser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
parser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
parser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)')
args = parser.parse_args()
......@@ -60,6 +66,8 @@ if args.grub_density_split is not None:
elif args.grub_density is not None:
grub_density = [0.5*args.grub_density, 0.5*args.grub_density, 2.0*args.grub_density];
gamma = 2.0 if args.gamma is None else args.gamma
import importlib
lpcnet = importlib.import_module(args.model)
......@@ -87,14 +95,25 @@ nb_epochs = args.epochs
batch_size = args.batch_size
quantize = args.quantize is not None
retrain = args.retrain is not None
if quantize:
lr = 0.00003
decay = 0
input_model = args.quantize
else:
lr = 0.001
decay = 2.5e-5
if args.lr is not None:
lr = args.lr
if args.decay is not None:
decay = args.decay
if retrain:
input_model = args.retrain
flag_e2e = args.flag_e2e
opt = Adam(lr, decay=decay, beta_2=0.99)
......@@ -105,7 +124,7 @@ with strategy.scope():
if not flag_e2e:
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics='sparse_categorical_crossentropy')
else:
model.compile(optimizer=opt, loss = interp_mulaw(gamma = 2),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
model.compile(optimizer=opt, loss = interp_mulaw(gamma=gamma),metrics=[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss])
model.summary()
feature_file = args.features
......@@ -146,9 +165,12 @@ periods = (.1 + 50*features[:,:,18:19]+100).astype('int16')
# dump models to disk as we go
checkpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}'))
if quantize:
if args.retrain is not None:
model.load_weights(args.retrain)
if quantize or retrain:
#Adapting from an existing model
model.load_weights(args.quantize)
model.load_weights(input_model)
sparsify = lpcnet.Sparsify(0, 0, 1, density)
grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment