diff --git a/dnn/training_tf2/rdovae.py b/dnn/training_tf2/rdovae.py
index 1f2616358198d14d22adafad942f825d3c5ae98d..45b3efb01b4ac5bf63e27ee4c7c81c48b76bf21e 100644
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -334,9 +334,10 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc
     dzone = Lambda(apply_dead_zone)
     dze = dzone([ze,dead_zone])
     ndze = noisequant(dze)
+    dze_quant = hardquant(dze)
     
     div = Lambda(lambda x: x[0]/x[1])
-    dze_unquant = div([dze,quant_scale])
+    dze_quant = div([dze_quant,quant_scale])
     ndze_unquant = div([ndze,quant_scale])
 
     mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
@@ -345,11 +346,11 @@ def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batc
     unquantized_output = []
     cat = Concatenate(name="out_cat")
     for i in range(bunch//2):
-        dze_select = mod_select([dze_unquant, i])
+        dze_select = mod_select([dze_quant, i])
         ndze_select = mod_select([ndze_unquant, i])
         state_select = mod_select([gru_state_dec, i])
 
-        tmp = split_decoder([hardquant(dze_select), state_select])
+        tmp = split_decoder([dze_select, state_select])
         tmp = cat([tmp, lambda_up])
         combined_output.append(tmp)