GST inference

This commit is contained in:
Eren Golge 2019-06-12 12:12:01 +02:00
parent dfa974c9d1
commit 0f8936d744
2 changed files with 19 additions and 8 deletions

View File

@ -55,10 +55,14 @@ class TacotronGST(nn.Module):
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
def inference(self, characters):
def inference(self, characters, style_mel=None):
B = characters.size(0)
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
if style_mel is not None:
gst_outputs = self.gst(style_mel)
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
encoder_outputs = encoder_outputs + gst_outputs
mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

View File

@ -18,11 +18,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos
use_cuda (bool): enable cuda.
ap (TTS.utils.audio.AudioProcessor): audio processor to process
model outputs.
style_wav (str): Uses for style embedding of GST.
truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts.
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
trim_silence (bool): trim silence after synthesis.
"""
# GST processing
if CONFIG.model == "TacotronGST" and style_wav is not None:
style_mel = compute_style_mel(style_wav, ap)
# preprocess the given text
text_cleaner = [CONFIG.text_cleaner]
if CONFIG.use_phonemes:
@ -33,14 +38,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
chars_var = torch.from_numpy(seq).unsqueeze(0)
# synthesize voice
if use_cuda:
chars_var = chars_var.cuda()
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
chars_var.long())
else:
if CONFIG.model == "TacotronGST" and style_wav is not None:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
chars_var.long())
chars_var.long(), style_mel)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
chars_var.long())
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
chars_var.long())
# convert outputs to numpy
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy()