diff --git a/dnn/training_tf2/rdovae.py b/dnn/training_tf2/rdovae.py index 619f37e228249d5b6db70b1bfe3b79005455a6c8..c49b6250cb97cba95710cc48b5f39d118e296f9e 100644 --- a/dnn/training_tf2/rdovae.py +++ b/dnn/training_tf2/rdovae.py @@ -202,9 +202,8 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba lambda_val = Input(shape=(None, 1), batch_size=batch_size) qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros') quant_embed = qembedding(quant_id) - quant_embed_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(quant_embed) - quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_bunched)) + quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed)) enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1') enc_dense2 = CuDNNGRU(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2') @@ -230,19 +229,19 @@ def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba d7 = enc_dense7(d6) d8 = enc_dense8(d7) enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])) - enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out) + #enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out) bits = Multiply()([enc_out, quant_scale]) global_dense1 = Dense(128, activation='tanh', name='gdense1') global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2') global_bits = global_dense2(global_dense1(d6)) - encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed_bunched, global_bits], name='encoder') + encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder') return encoder def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256): - bits_input = Input(shape=(None, nb_bits), batch_size=batch_size) - quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size) - gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size) + bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits") + quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed") + gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state") dec_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='dec_dense1') @@ -282,23 +281,28 @@ def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, ba def new_split_decoder(decoder): nb_bits = decoder.nb_bits bunch = decoder.bunch - bits_input = Input(shape=(None, nb_bits)) - quant_embed_input = Input(shape=(None, 6*nb_bits)) - gru_state_input = Input(shape=(None,nb_state_dim)) + bits_input = Input(shape=(None, nb_bits), name="split_bits") + uqbits_input = Input(shape=(None, nb_bits), name="split_uqbits") + quant_embed_input = Input(shape=(None, 6*nb_bits), name="split_embed") + gru_state_input = Input(shape=(None,nb_state_dim), name="split_state") - range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:]) + range_select = Lambda(lambda x: x[0][:,x[1]+bunch//2-1:x[2]:bunch//2,:]) elem_select = Lambda(lambda x: x[0][:,x[1],:]) points = [0, 64, 128, 192, 256] outputs = [] + uqbits = [] for i in range(len(points)-1): - begin = points[i]//bunch - end = points[i+1]//bunch - state = elem_select([gru_state_input, 2*end-1]) + begin = points[i]//2 + end = points[i+1]//2 + state = elem_select([gru_state_input, end-1]) bits = range_select([bits_input, begin, end]) + uq = range_select([uqbits_input, begin, end]) + uqbits.append(uq) embed = range_select([quant_embed_input, begin, end]) outputs.append(decoder([bits, embed, state])) output = Concatenate(axis=1)(outputs) - split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split") + uqbits = Concatenate(axis=1)(uqbits) + split = Model([bits_input, uqbits_input, quant_embed_input, gru_state_input], [output, uqbits], name="split") return split @@ -316,21 +320,21 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc split_decoder = new_split_decoder(decoder) dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec)) - soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec)) - hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec)) + soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec)) + hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec)) noisequant = UniformNoise() hardquant = Lambda(hard_quantize) dzone = Lambda(apply_dead_zone) dze = dzone([ze,dead_zone]) gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec) - combined_output = split_decoder([hardquant(dze), tf.stop_gradient(quant_embed_dec), gru_state_dec]) + combined_output, uqbits = split_decoder([hardquant(dze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec]) ndze = noisequant(dze) - unquantized_output = split_decoder([ndze, quant_embed_dec, gru_state_dec]) - unquantized_output_dec = split_decoder([tf.stop_gradient(ndze), tf.stop_gradient(quant_embed_dec), gru_state_dec]) + unquantized_output, uqbits = split_decoder([ndze, dze, quant_embed_dec, gru_state_dec]) + unquantized_output_dec, uqbits = split_decoder([tf.stop_gradient(ndze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec]) - e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_bunched]) - e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_bunched]) + e2 = Concatenate(name="hard_bits")([uqbits, hard_distr_embed, lambda_bunched]) + e = Concatenate(name="soft_bits")([uqbits, soft_distr_embed, lambda_bunched]) model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")