mirror of https://github.com/coqui-ai/TTS.git
GST inference
This commit is contained in:
parent
dfa974c9d1
commit
0f8936d744
|
@ -55,10 +55,14 @@ class TacotronGST(nn.Module):
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
linear_outputs = self.last_linear(linear_outputs)
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference(self, characters):
|
def inference(self, characters, style_mel=None):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
if style_mel is not None:
|
||||||
|
gst_outputs = self.gst(style_mel)
|
||||||
|
gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1)
|
||||||
|
encoder_outputs = encoder_outputs + gst_outputs
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
|
|
@ -18,11 +18,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos
|
||||||
use_cuda (bool): enable cuda.
|
use_cuda (bool): enable cuda.
|
||||||
ap (TTS.utils.audio.AudioProcessor): audio processor to process
|
ap (TTS.utils.audio.AudioProcessor): audio processor to process
|
||||||
model outputs.
|
model outputs.
|
||||||
|
style_wav (str): Uses for style embedding of GST.
|
||||||
truncated (bool): keep model states after inference. It can be used
|
truncated (bool): keep model states after inference. It can be used
|
||||||
for continuous inference at long texts.
|
for continuous inference at long texts.
|
||||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||||
trim_silence (bool): trim silence after synthesis.
|
trim_silence (bool): trim silence after synthesis.
|
||||||
"""
|
"""
|
||||||
|
# GST processing
|
||||||
|
if CONFIG.model == "TacotronGST" and style_wav is not None:
|
||||||
|
style_mel = compute_style_mel(style_wav, ap)
|
||||||
|
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
text_cleaner = [CONFIG.text_cleaner]
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
if CONFIG.use_phonemes:
|
if CONFIG.use_phonemes:
|
||||||
|
@ -33,14 +38,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos
|
||||||
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
if use_cuda:
|
if CONFIG.model == "TacotronGST" and style_wav is not None:
|
||||||
chars_var = chars_var.cuda()
|
|
||||||
if truncated:
|
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
|
||||||
chars_var.long())
|
|
||||||
else:
|
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
chars_var.long())
|
chars_var.long(), style_mel)
|
||||||
|
else:
|
||||||
|
if truncated:
|
||||||
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||||
|
chars_var.long())
|
||||||
|
else:
|
||||||
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
|
chars_var.long())
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||||
|
|
Loading…
Reference in New Issue