From 25238e0658697bbcb355bcd99c2a52eaf1910680 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Jun 2021 13:42:35 +0200 Subject: [PATCH] fix glow-tts `inference()` --- TTS/tts/models/glow_tts.py | 5 +++- TTS/tts/utils/synthesis.py | 31 +++++++++++++++++++++-- tests/inference_tests/test_synthesizer.py | 3 ++- tests/tts_tests/test_tacotron2_model.py | 4 +-- tests/tts_tests/test_tacotron2_train.py | 1 - 5 files changed, 36 insertions(+), 8 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 9c928a67..3b3207f0 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -290,7 +290,10 @@ class GlowTTS(nn.Module): return outputs @torch.no_grad() - def inference(self, x, x_lengths, cond_input={"d_vectors": None}): # pylint: disable=dangerous-default-value + def inference( + self, x, cond_input={"x_lengths": None, "d_vectors": None} + ): # pylint: disable=dangerous-default-value + x_lengths = cond_input["x_lengths"] g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None if g is not None: if self.d_vector_dim: diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 04fef715..72eff2e5 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -1,8 +1,10 @@ import os +from typing import Dict import numpy as np import pkg_resources import torch +from torch import nn from .text import phoneme_to_sequence, text_to_sequence @@ -65,9 +67,34 @@ def compute_style_mel(style_wav, ap, cuda=False): return style_mel -def run_model_torch(model, inputs, speaker_id=None, style_mel=None, d_vector=None): +def run_model_torch( + model: nn.Module, + inputs: torch.Tensor, + speaker_id: int = None, + style_mel: torch.Tensor = None, + d_vector: torch.Tensor = None, +) -> Dict: + """Run a torch model for inference. It does not support batch inference. + + Args: + model (nn.Module): The model to run inference. + inputs (torch.Tensor): Input tensor with character ids. + speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None. + style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None. + d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None. + + Returns: + Dict: model outputs. + """ + input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) outputs = model.inference( - inputs, cond_input={"speaker_ids": speaker_id, "d_vector": d_vector, "style_mel": style_mel} + inputs, + cond_input={ + "x_lengths": input_lengths, + "speaker_ids": speaker_id, + "d_vectors": d_vector, + "style_mel": style_mel, + }, ) return outputs diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index b0fa22d3..4379c8ca 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -1,13 +1,14 @@ import os import unittest -from tests import get_tests_output_path from TTS.config import load_config from TTS.tts.models import setup_model from TTS.tts.utils.io import save_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.synthesizer import Synthesizer +from .. import get_tests_output_path + class SynthesizerTest(unittest.TestCase): # pylint: disable=R0201 diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py index b77f7cc5..66372470 100644 --- a/tests/tts_tests/test_tacotron2_model.py +++ b/tests/tts_tests/test_tacotron2_model.py @@ -259,9 +259,7 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = MSELossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to( - device - ) + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(device) model.train() model_ref = copy.deepcopy(model) count = 0 diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index 70975490..577de014 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -8,7 +8,6 @@ from TTS.tts.configs import Tacotron2Config config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") - config = Tacotron2Config( r=5, batch_size=8,