From 1835628335aaf5636f96798e2e3dd2b7c7347129 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 20 May 2020 12:25:24 +0200 Subject: [PATCH] tf conversion fixes --- tf/convert_tacotron2_torch_to_tf.py | 13 +++++++------ tf/utils/convert_torch_to_tf_utils.py | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tf/convert_tacotron2_torch_to_tf.py b/tf/convert_tacotron2_torch_to_tf.py index 3b57782e..b1878343 100644 --- a/tf/convert_tacotron2_torch_to_tf.py +++ b/tf/convert_tacotron2_torch_to_tf.py @@ -13,9 +13,9 @@ from fuzzywuzzy import fuzz from TTS.utils.text.symbols import phonemes, symbols from TTS.utils.generic_utils import setup_model from TTS.utils.io import load_config -from TTS_tf.models.tacotron2 import Tacotron2 -from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name -from TTS_tf.utils.generic_utils import save_checkpoint +from TTS.tf.models.tacotron2 import Tacotron2 +from TTS.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name +from TTS.tf.utils.generic_utils import save_checkpoint parser = argparse.ArgumentParser() parser.add_argument('--torch_model_path', @@ -147,21 +147,22 @@ output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, training=False) assert compare_torch_tf(output, output_tf).mean() < 1e-5 -# compare decoder.attention query = output inputs = torch.rand([1, 128, 512]) query_tf = query.detach().numpy() inputs_tf = inputs.numpy() +# compare decoder.attention model.decoder.attention.init_states(inputs) processes_inputs = model.decoder.attention.preprocess_inputs(inputs) loc_attn, proc_query = model.decoder.attention.get_location_attention( query, processes_inputs) context = model.decoder.attention(query, inputs, processes_inputs, None) +attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1] model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf)) -loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf) -context_tf = model_tf.decoder.attention(query_tf, training=False) +loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states) +context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False) assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5 assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5 diff --git a/tf/utils/convert_torch_to_tf_utils.py b/tf/utils/convert_torch_to_tf_utils.py index ba7e629b..e9e1e8a3 100644 --- a/tf/utils/convert_torch_to_tf_utils.py +++ b/tf/utils/convert_torch_to_tf_utils.py @@ -72,6 +72,7 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): numpy_weight = torch_weight.detach().cpu().numpy() assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" tf.keras.backend.set_value(tf_var, numpy_weight) + return tf_vars def load_tf_vars(model_tf, tf_vars):