mirror of https://github.com/coqui-ai/TTS.git
tf conversion fixes
This commit is contained in:
parent
dc166b42e3
commit
1835628335
|
@ -13,9 +13,9 @@ from fuzzywuzzy import fuzz
|
|||
from TTS.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.generic_utils import setup_model
|
||||
from TTS.utils.io import load_config
|
||||
from TTS_tf.models.tacotron2 import Tacotron2
|
||||
from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
|
||||
from TTS_tf.utils.generic_utils import save_checkpoint
|
||||
from TTS.tf.models.tacotron2 import Tacotron2
|
||||
from TTS.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
|
||||
from TTS.tf.utils.generic_utils import save_checkpoint
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--torch_model_path',
|
||||
|
@ -147,21 +147,22 @@ output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
|
|||
training=False)
|
||||
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
||||
|
||||
# compare decoder.attention
|
||||
query = output
|
||||
inputs = torch.rand([1, 128, 512])
|
||||
query_tf = query.detach().numpy()
|
||||
inputs_tf = inputs.numpy()
|
||||
|
||||
# compare decoder.attention
|
||||
model.decoder.attention.init_states(inputs)
|
||||
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
|
||||
loc_attn, proc_query = model.decoder.attention.get_location_attention(
|
||||
query, processes_inputs)
|
||||
context = model.decoder.attention(query, inputs, processes_inputs, None)
|
||||
|
||||
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
|
||||
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
|
||||
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf)
|
||||
context_tf = model_tf.decoder.attention(query_tf, training=False)
|
||||
loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
|
||||
context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)
|
||||
|
||||
assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
|
||||
|
|
|
@ -72,6 +72,7 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
|
|||
numpy_weight = torch_weight.detach().cpu().numpy()
|
||||
assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
|
||||
tf.keras.backend.set_value(tf_var, numpy_weight)
|
||||
return tf_vars
|
||||
|
||||
|
||||
def load_tf_vars(model_tf, tf_vars):
|
||||
|
|
Loading…
Reference in New Issue