Skip to content
Snippets Groups Projects
Commit 61459c24 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Change decoder architecture to be like the encoder

parent 79d1a916
No related branches found
No related tags found
No related merge requests found
......@@ -82,8 +82,8 @@ bits = bits[:nb_sequences*sequence_size*40]
bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4))
print(bits.shape)
lambda_val = 0.0007 * np.ones((nb_sequences, sequence_size//2, 1))
quant_id = np.round(10*np.log(lambda_val/.0002)).astype('int16')
lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1))
quant_id = np.round(3.8*np.log(lambda_val/.0002)).astype('int16')
quant_id = quant_id[:,:,0]
quant_embed = qembedding(quant_id)
quant_scale = tf.math.softplus(quant_embed[:,:,:nbits])
......@@ -98,7 +98,7 @@ state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
state = state[:,-1,:]
state = pvq_quantize(state, 30)
state = pvq_quantize(state, 82)
#state = state/(1e-15+tf.norm(state, axis=-1,keepdims=True))
print("shapes are:")
......
......@@ -105,7 +105,7 @@ nbits=80
bits.astype('float32').tofile(args.output + "-syms.f32")
lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1))
quant_id = np.round(10*np.log(lambda_val/.0002)).astype('int16')
quant_id = np.round(3.8*np.log(lambda_val/.0002)).astype('int16')
quant_id = quant_id[:,:,0]
quant_embed = qembedding(quant_id)
quant_scale = tf.math.softplus(quant_embed[:,:,:nbits])
......@@ -115,7 +115,7 @@ bits = bits*quant_scale
bits = np.round(apply_dead_zone([bits, dead_zone]).numpy())
bits = bits/quant_scale
gru_state_dec = pvq_quantize(gru_state_dec, 30)
gru_state_dec = pvq_quantize(gru_state_dec, 82)
#gru_state_dec = gru_state_dec/(1e-15+tf.norm(gru_state_dec, axis=-1,keepdims=True))
gru_state_dec = gru_state_dec[:,-1,:]
dec_out = decoder([bits[:,1::2,:], gru_state_dec])
......
......@@ -238,15 +238,15 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
gru = CuDNNGRU if training else GRU
dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1')
dec_dense2 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense2')
dec_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense2')
dec_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense3')
gru = CuDNNGRU if training else GRU
dec_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense4')
dec_dense5 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense5')
dec_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense5')
dec_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='dec_dense6')
dec_dense7 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
dec_dense8 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
dec_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense7')
dec_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='dec_dense8')
dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final')
......@@ -260,10 +260,10 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba
gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
dec1 = dec_dense1(time_reverse(bits_input))
dec2 = dec_dense2(dec1)
dec2 = dec_dense2(dec1, initial_state=gru_state1)
dec3 = dec_dense3(dec2)
dec4 = dec_dense4(dec3, initial_state=gru_state1)
dec5 = dec_dense5(dec4, initial_state=gru_state2)
dec4 = dec_dense4(dec3, initial_state=gru_state2)
dec5 = dec_dense5(dec4)
dec6 = dec_dense6(dec5, initial_state=gru_state3)
dec7 = dec_dense7(dec6)
dec8 = dec_dense8(dec7)
......@@ -340,7 +340,7 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc
ndze_unquant = div([ndze,quant_scale])
mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 82))(gru_state_dec)
combined_output = []
unquantized_output = []
cat = Concatenate(name="out_cat")
......
......@@ -124,8 +124,8 @@ features = features[:, :, :nb_used_features]
#lambda_val = np.repeat(np.random.uniform(.0007, .002, (features.shape[0], 1, 1)), features.shape[1]//2, axis=1)
#quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
#quant_id = quant_id[:,:,0]
quant_id = np.repeat(np.random.randint(39, size=(features.shape[0], 1, 1), dtype='int16'), features.shape[1]//2, axis=1)
lambda_val = .0002*np.exp(quant_id/10.)
quant_id = np.repeat(np.random.randint(16, size=(features.shape[0], 1, 1), dtype='int16'), features.shape[1]//2, axis=1)
lambda_val = .0002*np.exp(quant_id/3.8)
quant_id = quant_id[:,:,0]
# dump models to disk as we go
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment