mirror of https://github.com/coqui-ai/TTS.git
GST inference
This commit is contained in:
parent
dfa974c9d1
commit
0f8936d744
|
@ -55,10 +55,14 @@ class TacotronGST(nn.Module):
|
|||
linear_outputs = self.last_linear(linear_outputs)
|
||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||
|
||||
def inference(self, characters):
|
||||
def inference(self, characters, style_mel=None):
|
||||
B = characters.size(0)
|
||||
inputs = self.embedding(characters)
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
if style_mel is not None:
|
||||
gst_outputs = self.gst(style_mel)
|
||||
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
|
||||
encoder_outputs = encoder_outputs + gst_outputs
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
|
|
|
@ -18,11 +18,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos
|
|||
use_cuda (bool): enable cuda.
|
||||
ap (TTS.utils.audio.AudioProcessor): audio processor to process
|
||||
model outputs.
|
||||
style_wav (str): Uses for style embedding of GST.
|
||||
truncated (bool): keep model states after inference. It can be used
|
||||
for continuous inference at long texts.
|
||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||
trim_silence (bool): trim silence after synthesis.
|
||||
"""
|
||||
# GST processing
|
||||
if CONFIG.model == "TacotronGST" and style_wav is not None:
|
||||
style_mel = compute_style_mel(style_wav, ap)
|
||||
|
||||
# preprocess the given text
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
if CONFIG.use_phonemes:
|
||||
|
@ -33,14 +38,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos
|
|||
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
# synthesize voice
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
chars_var.long())
|
||||
else:
|
||||
if CONFIG.model == "TacotronGST" and style_wav is not None:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
chars_var.long())
|
||||
chars_var.long(), style_mel)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
chars_var.long())
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
chars_var.long())
|
||||
# convert outputs to numpy
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||
|
|
Loading…
Reference in New Issue