diff --git a/models/tacotrongst.py b/models/tacotrongst.py index b0d0fe78..1a77cd53 100644 --- a/models/tacotrongst.py +++ b/models/tacotrongst.py @@ -55,10 +55,14 @@ class TacotronGST(nn.Module): linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens - def inference(self, characters): + def inference(self, characters, style_mel=None): B = characters.size(0) inputs = self.embedding(characters) encoder_outputs = self.encoder(inputs) + if style_mel is not None: + gst_outputs = self.gst(style_mel) + gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) + encoder_outputs = encoder_outputs + gst_outputs mel_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) diff --git a/utils/synthesis.py b/utils/synthesis.py index 913c84ad..7931f7ab 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -18,11 +18,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos use_cuda (bool): enable cuda. ap (TTS.utils.audio.AudioProcessor): audio processor to process model outputs. + style_wav (str): Uses for style embedding of GST. truncated (bool): keep model states after inference. It can be used for continuous inference at long texts. enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. trim_silence (bool): trim silence after synthesis. """ + # GST processing + if CONFIG.model == "TacotronGST" and style_wav is not None: + style_mel = compute_style_mel(style_wav, ap) + # preprocess the given text text_cleaner = [CONFIG.text_cleaner] if CONFIG.use_phonemes: @@ -33,14 +38,16 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) chars_var = torch.from_numpy(seq).unsqueeze(0) # synthesize voice - if use_cuda: - chars_var = chars_var.cuda() - if truncated: - decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - chars_var.long()) - else: + if CONFIG.model == "TacotronGST" and style_wav is not None: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - chars_var.long()) + chars_var.long(), style_mel) + else: + if truncated: + decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( + chars_var.long()) + else: + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + chars_var.long()) # convert outputs to numpy postnet_output = postnet_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy()