Skip to content
Snippets Groups Projects
Commit 405aa7cf authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

WIP: training with different alignment

parent 981d06ee
No related branches found
No related tags found
No related merge requests found
......@@ -94,9 +94,9 @@ def safelog2(x):
return log2_e*tf.math.log(eps+x)
def feat_dist_loss(y_true,y_pred):
ceps = y_pred[:,:,:18] - y_true[:,:,:18]
pitch = 2*(y_pred[:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
corr = y_pred[:,:,19:] - y_true[:,:,19:]
ceps = y_pred[:,:,:,:18] - y_true[:,:,:18]
pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
corr = y_pred[:,:,:,19:] - y_true[:,:,19:]
pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
return K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr))
......@@ -300,6 +300,18 @@ def new_split_decoder(decoder):
split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
return split
def tensor_concat(x):
#n = x[1]//2
#x = x[0]
n=2
y = []
for i in range(n-1):
offset = n-1-i
tmp = K.concatenate([x[i][:, offset:, :], x[-1][:, -offset:, :]], axis=-2)
y.append(tf.expand_dims(tmp, axis=0))
y.append(tf.expand_dims(x[-1], axis=0))
return Concatenate(axis=0)(y)
def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
......@@ -315,8 +327,8 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc
split_decoder = new_split_decoder(decoder)
dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))
soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))
hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))
noisequant = UniformNoise()
hardquant = Lambda(hard_quantize)
......@@ -326,19 +338,24 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc
mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
ndze = noisequant(dze)
for i in [1]:
combined_output = []
unquantized_output = []
for i in range(bunch//2):
dze_select = mod_select([dze, i])
ndze_select = mod_select([ndze, i])
state_select = mod_select([gru_state_dec, i])
combined_output = split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select])
unquantized_output = split_decoder([ndze_select, quant_embed_dec, state_select])
unquantized_output_dec = split_decoder([tf.stop_gradient(ndze_select), tf.stop_gradient(quant_embed_dec), state_select])
combined_output.append(split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select]))
unquantized_output.append(split_decoder([ndze_select, quant_embed_dec, state_select]))
e2 = Concatenate(name="hard_bits")([dze_select, hard_distr_embed, lambda_bunched])
e = Concatenate(name="soft_bits")([dze_select, soft_distr_embed, lambda_bunched])
concat = Lambda(tensor_concat, name="output")
combined_output = concat(combined_output)
unquantized_output = concat(unquantized_output)
e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_val])
e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_val])
model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end")
model.nb_used_features = nb_used_features
return model, encoder, decoder
......
......@@ -100,7 +100,7 @@ opt = Adam(lr, decay=decay, beta_2=0.99)
with strategy.scope():
model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 0., 1., .1], metrics={'split':'mse', 'hard_bits':rdovae.sq_rate_metric})
model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})
model.summary()
lpc_order = 16
......@@ -147,4 +147,4 @@ if args.logdir is not None:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
callbacks.append(tensorboard_callback)
model.fit([features, quant_id, lambda_val], [features, features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
model.fit([features, quant_id, lambda_val], [features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment