mirror of https://github.com/coqui-ai/TTS.git
update `synthesis.py` for being more generic
This commit is contained in:
parent
f121b0ff5d
commit
fb9289d365
|
@ -152,16 +152,6 @@ def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||||
return decoder_output, postnet_output, None, None
|
return decoder_output, postnet_output, None, None
|
||||||
|
|
||||||
|
|
||||||
def parse_outputs_torch(postnet_output, decoder_output, alignments,
|
|
||||||
stop_tokens):
|
|
||||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
|
||||||
decoder_output = None if decoder_output is None else decoder_output[
|
|
||||||
0].data.cpu().numpy()
|
|
||||||
alignment = alignments[0].cpu().data.numpy()
|
|
||||||
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
|
|
||||||
return postnet_output, decoder_output, alignment, stop_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
||||||
postnet_output = postnet_output[0].numpy()
|
postnet_output = postnet_output[0].numpy()
|
||||||
decoder_output = decoder_output[0].numpy()
|
decoder_output = decoder_output[0].numpy()
|
||||||
|
@ -200,8 +190,8 @@ def speaker_id_to_torch(speaker_id, cuda=False):
|
||||||
def embedding_to_torch(x_vector, cuda=False):
|
def embedding_to_torch(x_vector, cuda=False):
|
||||||
if x_vector is not None:
|
if x_vector is not None:
|
||||||
x_vector = np.asarray(x_vector)
|
x_vector = np.asarray(x_vector)
|
||||||
x_vector = torch.from_numpy(x_vector).unsqueeze(
|
x_vector = torch.from_numpy(x_vector).unsqueeze(0).type(
|
||||||
0).type(torch.FloatTensor)
|
torch.FloatTensor)
|
||||||
if cuda:
|
if cuda:
|
||||||
return x_vector.cuda()
|
return x_vector.cuda()
|
||||||
return x_vector
|
return x_vector
|
||||||
|
@ -263,57 +253,59 @@ def synthesis(
|
||||||
else:
|
else:
|
||||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
inputs = text_to_seq(text, CONFIG)
|
text_inputs = text_to_seq(text, CONFIG)
|
||||||
# pass tensors to backend
|
# pass tensors to backend
|
||||||
if backend == "torch":
|
if backend == "torch":
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
||||||
|
|
||||||
if x_vector is not None:
|
if x_vector is not None:
|
||||||
x_vector = embedding_to_torch(x_vector,
|
x_vector = embedding_to_torch(x_vector, cuda=use_cuda)
|
||||||
cuda=use_cuda)
|
|
||||||
|
|
||||||
if not isinstance(style_mel, dict):
|
if not isinstance(style_mel, dict):
|
||||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||||
inputs = inputs.unsqueeze(0)
|
text_inputs = text_inputs.unsqueeze(0)
|
||||||
elif backend == "tf":
|
elif backend == "tf":
|
||||||
# TODO: handle speaker id for tf model
|
# TODO: handle speaker id for tf model
|
||||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||||
inputs = numpy_to_tf(inputs, tf.int32)
|
text_inputs = numpy_to_tf(text_inputs, tf.int32)
|
||||||
inputs = tf.expand_dims(inputs, 0)
|
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||||
elif backend == "tflite":
|
elif backend == "tflite":
|
||||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||||
inputs = numpy_to_tf(inputs, tf.int32)
|
text_inputs = numpy_to_tf(text_inputs, tf.int32)
|
||||||
inputs = tf.expand_dims(inputs, 0)
|
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
if backend == "torch":
|
if backend == "torch":
|
||||||
outputs = run_model_torch(model,
|
outputs = run_model_torch(model,
|
||||||
inputs,
|
text_inputs,
|
||||||
speaker_id,
|
speaker_id,
|
||||||
style_mel,
|
style_mel,
|
||||||
x_vector=x_vector)
|
x_vector=x_vector)
|
||||||
postnet_output, decoder_output, alignments, stop_tokens = \
|
model_outputs = outputs['model_outputs']
|
||||||
outputs['postnet_outputs'], outputs['decoder_outputs'],\
|
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||||
outputs['alignments'], outputs['stop_tokens']
|
|
||||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
|
||||||
postnet_output, decoder_output, alignments, stop_tokens)
|
|
||||||
elif backend == "tf":
|
elif backend == "tf":
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||||
model, inputs, CONFIG, speaker_id, style_mel)
|
model, text_inputs, CONFIG, speaker_id, style_mel)
|
||||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
model_outputs, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||||
postnet_output, decoder_output, alignments, stop_tokens)
|
postnet_output, decoder_output, alignments, stop_tokens)
|
||||||
elif backend == "tflite":
|
elif backend == "tflite":
|
||||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||||
model, inputs, CONFIG, speaker_id, style_mel)
|
model, text_inputs, CONFIG, speaker_id, style_mel)
|
||||||
postnet_output, decoder_output = parse_outputs_tflite(
|
model_outputs, decoder_output = parse_outputs_tflite(
|
||||||
postnet_output, decoder_output)
|
postnet_output, decoder_output)
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
# plot results
|
# plot results
|
||||||
wav = None
|
wav = None
|
||||||
if use_griffin_lim:
|
if use_griffin_lim:
|
||||||
wav = inv_spectrogram(postnet_output, ap, CONFIG)
|
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||||
# trim silence
|
# trim silence
|
||||||
if do_trim_silence:
|
if do_trim_silence:
|
||||||
wav = trim_silence(wav, ap)
|
wav = trim_silence(wav, ap)
|
||||||
return wav, alignment, decoder_output, postnet_output, stop_tokens, inputs
|
return_dict = {
|
||||||
|
'wav': wav,
|
||||||
|
'alignments': outputs['alignments'],
|
||||||
|
'model_outputs': model_outputs,
|
||||||
|
'text_inputs': text_inputs
|
||||||
|
}
|
||||||
|
return return_dict
|
||||||
|
|
Loading…
Reference in New Issue