Skip to content
Snippets Groups Projects
Commit 981d06ee authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Refactoring towards multiple offset decoding

parent a4f7c157
No related branches found
No related tags found
No related merge requests found
......@@ -282,27 +282,22 @@ def new_split_decoder(decoder):
nb_bits = decoder.nb_bits
bunch = decoder.bunch
bits_input = Input(shape=(None, nb_bits), name="split_bits")
uqbits_input = Input(shape=(None, nb_bits), name="split_uqbits")
quant_embed_input = Input(shape=(None, 6*nb_bits), name="split_embed")
gru_state_input = Input(shape=(None,nb_state_dim), name="split_state")
range_select = Lambda(lambda x: x[0][:,x[1]+bunch//2-1:x[2]:bunch//2,:])
range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
elem_select = Lambda(lambda x: x[0][:,x[1],:])
points = [0, 64, 128, 192, 256]
outputs = []
uqbits = []
for i in range(len(points)-1):
begin = points[i]//2
end = points[i+1]//2
begin = points[i]//bunch
end = points[i+1]//bunch
state = elem_select([gru_state_input, end-1])
bits = range_select([bits_input, begin, end])
uq = range_select([uqbits_input, begin, end])
uqbits.append(uq)
embed = range_select([quant_embed_input, begin, end])
outputs.append(decoder([bits, embed, state]))
output = Concatenate(axis=1)(outputs)
uqbits = Concatenate(axis=1)(uqbits)
split = Model([bits_input, uqbits_input, quant_embed_input, gru_state_input], [output, uqbits], name="split")
split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
return split
......@@ -327,14 +322,20 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc
hardquant = Lambda(hard_quantize)
dzone = Lambda(apply_dead_zone)
dze = dzone([ze,dead_zone])
mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
combined_output, uqbits = split_decoder([hardquant(dze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
ndze = noisequant(dze)
unquantized_output, uqbits = split_decoder([ndze, dze, quant_embed_dec, gru_state_dec])
unquantized_output_dec, uqbits = split_decoder([tf.stop_gradient(ndze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
e2 = Concatenate(name="hard_bits")([uqbits, hard_distr_embed, lambda_bunched])
e = Concatenate(name="soft_bits")([uqbits, soft_distr_embed, lambda_bunched])
for i in [1]:
dze_select = mod_select([dze, i])
ndze_select = mod_select([ndze, i])
state_select = mod_select([gru_state_dec, i])
combined_output = split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select])
unquantized_output = split_decoder([ndze_select, quant_embed_dec, state_select])
unquantized_output_dec = split_decoder([tf.stop_gradient(ndze_select), tf.stop_gradient(quant_embed_dec), state_select])
e2 = Concatenate(name="hard_bits")([dze_select, hard_distr_embed, lambda_bunched])
e = Concatenate(name="soft_bits")([dze_select, soft_distr_embed, lambda_bunched])
model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment