diff --git a/utils/synthesis.py b/utils/synthesis.py index 9158ef02..0c68dbf2 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -1,3 +1,7 @@ +import pkg_resources +installed = {pkg.key for pkg in pkg_resources.working_set} +if 'tensorflow' in installed: + import tensorflow as tf import torch import numpy as np from .text import text_to_sequence, phoneme_to_sequence @@ -14,23 +18,32 @@ def text_to_seqvec(text, CONFIG, use_cuda): dtype=np.int32) else: seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) - # torch tensor - chars_var = torch.from_numpy(seq).unsqueeze(0) - if use_cuda: - chars_var = chars_var.cuda() - return chars_var.long() + return seq + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.Tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +def numpy_to_tf(np_array, dtype): + if np_array is None: + return None + tensor = tf.convert_to_tensor(np_array, dtype=dtype) + return tensor def compute_style_mel(style_wav, ap, use_cuda): - print(style_wav) - style_mel = torch.FloatTensor(ap.melspectrogram( - ap.load_wav(style_wav))).unsqueeze(0) - if use_cuda: - return style_mel.cuda() + style_mel = ap.melspectrogram( + ap.load_wav(style_wav)).expand_dims(0) return style_mel -def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): +def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( inputs, style_mel=style_mel, speaker_ids=speaker_id) @@ -44,11 +57,31 @@ def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None) return decoder_output, postnet_output, alignments, stop_tokens -def parse_outputs(postnet_output, decoder_output, alignments): +def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): + if CONFIG.use_gst: + raise NotImplemented(' [!] GST inference not implemented for TF') + if truncated: + raise NotImplemented(' [!] Truncated inference not implemented for TF') + # TODO: handle multispeaker case + decoder_output, postnet_output, alignments, stop_tokens = model( + inputs, training=False) + return decoder_output, postnet_output, alignments, stop_tokens + + +def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens): postnet_output = postnet_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() - return postnet_output, decoder_output, alignment + stop_tokens = stop_tokens[0].cpu().numpy() + return postnet_output, decoder_output, alignment, stop_tokens + + +def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens): + postnet_output = postnet_output[0].numpy() + decoder_output = decoder_output[0].numpy() + alignment = alignments[0].numpy() + stop_tokens = stop_tokens[0].numpy() + return postnet_output, decoder_output, alignment, stop_tokens def trim_silence(wav, ap): @@ -98,7 +131,8 @@ def synthesis(model, truncated=False, enable_eos_bos_chars=False, #pylint: disable=unused-argument use_griffin_lim=False, - do_trim_silence=False): + do_trim_silence=False, + backend='torch'): """Synthesize voice for the given text. Args: @@ -114,6 +148,7 @@ def synthesis(model, for continuous inference at long texts. enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. do_trim_silence (bool): trim silence after synthesis. + backend (str): tf or torch """ # GST processing style_mel = None @@ -121,15 +156,29 @@ def synthesis(model, style_mel = compute_style_mel(style_wav, ap, use_cuda) # preprocess the given text inputs = text_to_seqvec(text, CONFIG, use_cuda) - speaker_id = id_to_torch(speaker_id) - if speaker_id is not None and use_cuda: - speaker_id = speaker_id.cuda() + # pass tensors to backend + if backend == 'torch': + speaker_id = id_to_torch(speaker_id) + style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) + inputs = inputs.unsqueeze(0) + else: + # TODO: handle speaker id for tf model + style_mel = numpy_to_tf(style_mel, tf.float32) + inputs = numpy_to_tf(inputs, tf.int32) + inputs = tf.expand_dims(inputs, 0) # synthesize voice - decoder_output, postnet_output, alignments, stop_tokens = run_model( - model, inputs, CONFIG, truncated, speaker_id, style_mel) + if backend == 'torch': + decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( + model, inputs, CONFIG, truncated, speaker_id, style_mel) + postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( + postnet_output, decoder_output, alignments, stop_tokens) + else: + decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( + model, inputs, CONFIG, truncated, speaker_id, style_mel) + postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( + postnet_output, decoder_output, alignments, stop_tokens) # convert outputs to numpy - postnet_output, decoder_output, alignment = parse_outputs( - postnet_output, decoder_output, alignments) # plot results wav = None if use_griffin_lim: diff --git a/utils/visual.py b/utils/visual.py index 8789cf8f..87fbc8e4 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -61,7 +61,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, plt.yticks(range(len(text)), list(text)) plt.colorbar() # plot stopnet predictions - stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy() plt.subplot(num_plot, 1, 2) plt.plot(range(len(stop_tokens)), list(stop_tokens)) # plot postnet spectrogram