diff --git a/dnn/training_tf2/rdovae_exchange.py b/dnn/training_tf2/rdovae_exchange.py index 7dcc579a5ab996362780de11e4435295c4dc4f9c..f1c888cb0a70810b2a523373ba084ca669fe7023 100644 --- a/dnn/training_tf2/rdovae_exchange.py +++ b/dnn/training_tf2/rdovae_exchange.py @@ -29,10 +29,8 @@ import argparse -from ftplib import parse150 import os import sys -sys.path.append('/Users/jbuethe/Projects/DRED') os.environ['CUDA_VISIBLE_DEVICES'] = "" @@ -46,11 +44,8 @@ parser.add_argument('--latent-dim', type=int, help="dimension of latent space (d args = parser.parse_args() # now import the heavy stuff -import tensorflow as tf -import numpy as np from rdovae import new_rdovae_model -from exchange.tf import dump_tf_gru_weights, dump_tf_conv1d_weights, dump_tf_dense_weights, dump_tf_embedding_weights - +from wexchange.tf import dump_tf_weights, load_tf_weights exchange_name = { @@ -109,21 +104,14 @@ if __name__ == "__main__": 'bits_dense' ] - for name in encoder_dense_names: - print(f"writing layer {exchange_name[name]}...") - dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) - - for name in encoder_gru_names: - print(f"writing layer {exchange_name[name]}...") - dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) - for name in encoder_conv1d_names: + for name in encoder_dense_names + encoder_gru_names + encoder_conv1d_names: print(f"writing layer {exchange_name[name]}...") - dump_tf_conv1d_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) + dump_tf_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name)) # qembedding print(f"writing layer {exchange_name['qembedding']}...") - dump_tf_embedding_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding) + dump_tf_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding) # decoder decoder_dense_names = [ @@ -144,10 +132,6 @@ if __name__ == "__main__": 'dec_dense6' ] - for name in decoder_dense_names: - print(f"writing layer {exchange_name[name]}...") - dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name)) - - for name in decoder_gru_names: + for name in decoder_dense_names + decoder_gru_names: print(f"writing layer {exchange_name[name]}...") - dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name)) \ No newline at end of file + dump_tf_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name))