diff --git a/dnn/training_tf2/rdovae.py b/dnn/training_tf2/rdovae.py index 1f2616358198d14d22adafad942f825d3c5ae98d..45b3efb01b4ac5bf63e27ee4c7c81c48b76bf21e 100644 --- a/dnn/training_tf2/rdovae.py +++ b/dnn/training_tf2/rdovae.py @@ -334,9 +334,10 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc dzone = Lambda(apply_dead_zone) dze = dzone([ze,dead_zone]) ndze = noisequant(dze) + dze_quant = hardquant(dze) div = Lambda(lambda x: x[0]/x[1]) - dze_unquant = div([dze,quant_scale]) + dze_quant = div([dze_quant,quant_scale]) ndze_unquant = div([ndze,quant_scale]) mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:]) @@ -345,11 +346,11 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc unquantized_output = [] cat = Concatenate(name="out_cat") for i in range(bunch//2): - dze_select = mod_select([dze_unquant, i]) + dze_select = mod_select([dze_quant, i]) ndze_select = mod_select([ndze_unquant, i]) state_select = mod_select([gru_state_dec, i]) - tmp = split_decoder([hardquant(dze_select), state_select]) + tmp = split_decoder([dze_select, state_select]) tmp = cat([tmp, lambda_up]) combined_output.append(tmp)