From a8170986ecc6259b54d1ae30404b833a48acbc26 Mon Sep 17 00:00:00 2001 From: Jan Buethe <jbuethe@amazon.de> Date: Mon, 31 Oct 2022 15:21:12 +0100 Subject: [PATCH] updated rdovae_exchange --- dnn/training_tf2/rdovae_exchange.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/dnn/training_tf2/rdovae_exchange.py b/dnn/training_tf2/rdovae_exchange.py index 7dcc579a5..f1c888cb0 100644 --- a/dnn/training_tf2/rdovae_exchange.py +++ b/dnn/training_tf2/rdovae_exchange.py @@ -29,10 +29,8 @@ import argparse -from ftplib import parse150 import os import sys -sys.path.append('/Users/jbuethe/Projects/DRED') os.environ['CUDA_VISIBLE_DEVICES'] = "" @@ -46,11 +44,8 @@ parser.add_argument('--latent-dim', type=int, help="dimension of latent space (d args = parser.parse_args() # now import the heavy stuff -import tensorflow as tf -import numpy as np from rdovae import new_rdovae_model -from exchange.tf import dump_tf_gru_weights, dump_tf_conv1d_weights, dump_tf_dense_weights, dump_tf_embedding_weights - +from wexchange.tf import dump_tf_weights, load_tf_weights exchange_name = { @@ -109,21 +104,14 @@ if __name__ == "__main__": 'bits_dense' ] - for name in encoder_dense_names: - print(f"writing layer {exchange_name[name]}...") - dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) - - for name in encoder_gru_names: - print(f"writing layer {exchange_name[name]}...") - dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) - for name in encoder_conv1d_names: + for name in encoder_dense_names + encoder_gru_names + encoder_conv1d_names: print(f"writing layer {exchange_name[name]}...") - dump_tf_conv1d_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) + dump_tf_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) # qembedding print(f"writing layer {exchange_name['qembedding']}...") - dump_tf_embedding_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding) + dump_tf_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding) # decoder decoder_dense_names = [ @@ -144,10 +132,6 @@ if __name__ == "__main__": 'dec_dense6' ] - for name in decoder_dense_names: - print(f"writing layer {exchange_name[name]}...") - dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name)) - - for name in decoder_gru_names: + for name in decoder_dense_names + decoder_gru_names: print(f"writing layer {exchange_name[name]}...") - dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name)) \ No newline at end of file + dump_tf_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name)) -- GitLab