mirror of https://github.com/coqui-ai/TTS.git
tf bacend for synthesis
This commit is contained in:
parent
d99fda8e42
commit
b3ec50b5c4
|
@ -1,3 +1,7 @@
|
||||||
|
import pkg_resources
|
||||||
|
installed = {pkg.key for pkg in pkg_resources.working_set}
|
||||||
|
if 'tensorflow' in installed:
|
||||||
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .text import text_to_sequence, phoneme_to_sequence
|
from .text import text_to_sequence, phoneme_to_sequence
|
||||||
|
@ -14,23 +18,32 @@ def text_to_seqvec(text, CONFIG, use_cuda):
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
else:
|
else:
|
||||||
seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32)
|
seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32)
|
||||||
# torch tensor
|
return seq
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
|
||||||
if use_cuda:
|
|
||||||
chars_var = chars_var.cuda()
|
def numpy_to_torch(np_array, dtype, cuda=False):
|
||||||
return chars_var.long()
|
if np_array is None:
|
||||||
|
return None
|
||||||
|
tensor = torch.Tensor(np_array, dtype=dtype)
|
||||||
|
if cuda:
|
||||||
|
return tensor.cuda()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def numpy_to_tf(np_array, dtype):
|
||||||
|
if np_array is None:
|
||||||
|
return None
|
||||||
|
tensor = tf.convert_to_tensor(np_array, dtype=dtype)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def compute_style_mel(style_wav, ap, use_cuda):
|
def compute_style_mel(style_wav, ap, use_cuda):
|
||||||
print(style_wav)
|
style_mel = ap.melspectrogram(
|
||||||
style_mel = torch.FloatTensor(ap.melspectrogram(
|
ap.load_wav(style_wav)).expand_dims(0)
|
||||||
ap.load_wav(style_wav))).unsqueeze(0)
|
|
||||||
if use_cuda:
|
|
||||||
return style_mel.cuda()
|
|
||||||
return style_mel
|
return style_mel
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||||
if CONFIG.use_gst:
|
if CONFIG.use_gst:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
inputs, style_mel=style_mel, speaker_ids=speaker_id)
|
inputs, style_mel=style_mel, speaker_ids=speaker_id)
|
||||||
|
@ -44,11 +57,31 @@ def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None)
|
||||||
return decoder_output, postnet_output, alignments, stop_tokens
|
return decoder_output, postnet_output, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def parse_outputs(postnet_output, decoder_output, alignments):
|
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||||
|
if CONFIG.use_gst:
|
||||||
|
raise NotImplemented(' [!] GST inference not implemented for TF')
|
||||||
|
if truncated:
|
||||||
|
raise NotImplemented(' [!] Truncated inference not implemented for TF')
|
||||||
|
# TODO: handle multispeaker case
|
||||||
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||||
|
inputs, training=False)
|
||||||
|
return decoder_output, postnet_output, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||||
alignment = alignments[0].cpu().data.numpy()
|
alignment = alignments[0].cpu().data.numpy()
|
||||||
return postnet_output, decoder_output, alignment
|
stop_tokens = stop_tokens[0].cpu().numpy()
|
||||||
|
return postnet_output, decoder_output, alignment, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
||||||
|
postnet_output = postnet_output[0].numpy()
|
||||||
|
decoder_output = decoder_output[0].numpy()
|
||||||
|
alignment = alignments[0].numpy()
|
||||||
|
stop_tokens = stop_tokens[0].numpy()
|
||||||
|
return postnet_output, decoder_output, alignment, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def trim_silence(wav, ap):
|
def trim_silence(wav, ap):
|
||||||
|
@ -98,7 +131,8 @@ def synthesis(model,
|
||||||
truncated=False,
|
truncated=False,
|
||||||
enable_eos_bos_chars=False, #pylint: disable=unused-argument
|
enable_eos_bos_chars=False, #pylint: disable=unused-argument
|
||||||
use_griffin_lim=False,
|
use_griffin_lim=False,
|
||||||
do_trim_silence=False):
|
do_trim_silence=False,
|
||||||
|
backend='torch'):
|
||||||
"""Synthesize voice for the given text.
|
"""Synthesize voice for the given text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -114,6 +148,7 @@ def synthesis(model,
|
||||||
for continuous inference at long texts.
|
for continuous inference at long texts.
|
||||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||||
do_trim_silence (bool): trim silence after synthesis.
|
do_trim_silence (bool): trim silence after synthesis.
|
||||||
|
backend (str): tf or torch
|
||||||
"""
|
"""
|
||||||
# GST processing
|
# GST processing
|
||||||
style_mel = None
|
style_mel = None
|
||||||
|
@ -121,15 +156,29 @@ def synthesis(model,
|
||||||
style_mel = compute_style_mel(style_wav, ap, use_cuda)
|
style_mel = compute_style_mel(style_wav, ap, use_cuda)
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
inputs = text_to_seqvec(text, CONFIG, use_cuda)
|
inputs = text_to_seqvec(text, CONFIG, use_cuda)
|
||||||
speaker_id = id_to_torch(speaker_id)
|
# pass tensors to backend
|
||||||
if speaker_id is not None and use_cuda:
|
if backend == 'torch':
|
||||||
speaker_id = speaker_id.cuda()
|
speaker_id = id_to_torch(speaker_id)
|
||||||
|
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||||
|
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||||
|
inputs = inputs.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
# TODO: handle speaker id for tf model
|
||||||
|
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||||
|
inputs = numpy_to_tf(inputs, tf.int32)
|
||||||
|
inputs = tf.expand_dims(inputs, 0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = run_model(
|
if backend == 'torch':
|
||||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||||
|
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||||
|
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||||
|
postnet_output, decoder_output, alignments, stop_tokens)
|
||||||
|
else:
|
||||||
|
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||||
|
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||||
|
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||||
|
postnet_output, decoder_output, alignments, stop_tokens)
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
postnet_output, decoder_output, alignment = parse_outputs(
|
|
||||||
postnet_output, decoder_output, alignments)
|
|
||||||
# plot results
|
# plot results
|
||||||
wav = None
|
wav = None
|
||||||
if use_griffin_lim:
|
if use_griffin_lim:
|
||||||
|
|
|
@ -61,7 +61,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
|
||||||
plt.yticks(range(len(text)), list(text))
|
plt.yticks(range(len(text)), list(text))
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
# plot stopnet predictions
|
# plot stopnet predictions
|
||||||
stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy()
|
|
||||||
plt.subplot(num_plot, 1, 2)
|
plt.subplot(num_plot, 1, 2)
|
||||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||||
# plot postnet spectrogram
|
# plot postnet spectrogram
|
||||||
|
|
Loading…
Reference in New Issue