tf conversion fixes

This commit is contained in:
erogol 2020-05-20 12:25:24 +02:00
parent dc166b42e3
commit 1835628335
2 changed files with 8 additions and 6 deletions

View File

@ -13,9 +13,9 @@ from fuzzywuzzy import fuzz
from TTS.utils.text.symbols import phonemes, symbols
from TTS.utils.generic_utils import setup_model
from TTS.utils.io import load_config
from TTS_tf.models.tacotron2 import Tacotron2
from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
from TTS_tf.utils.generic_utils import save_checkpoint
from TTS.tf.models.tacotron2 import Tacotron2
from TTS.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
from TTS.tf.utils.generic_utils import save_checkpoint
parser = argparse.ArgumentParser()
parser.add_argument('--torch_model_path',
@ -147,21 +147,22 @@ output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
training=False)
assert compare_torch_tf(output, output_tf).mean() < 1e-5
# compare decoder.attention
query = output
inputs = torch.rand([1, 128, 512])
query_tf = query.detach().numpy()
inputs_tf = inputs.numpy()
# compare decoder.attention
model.decoder.attention.init_states(inputs)
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(
query, processes_inputs)
context = model.decoder.attention(query, inputs, processes_inputs, None)
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf)
context_tf = model_tf.decoder.attention(query_tf, training=False)
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5

View File

@ -72,6 +72,7 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
numpy_weight = torch_weight.detach().cpu().numpy()
assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
tf.keras.backend.set_value(tf_var, numpy_weight)
return tf_vars
def load_tf_vars(model_tf, tf_vars):