fix glow-tts `inference()`

This commit is contained in:
Eren Gölge 2021-06-06 13:42:35 +02:00
parent 7ea71c7586
commit 4f29725eb6
5 changed files with 36 additions and 8 deletions

View File

@ -290,7 +290,10 @@ class GlowTTS(nn.Module):
return outputs
@torch.no_grad()
def inference(self, x, x_lengths, cond_input={"d_vectors": None}): # pylint: disable=dangerous-default-value
def inference(
self, x, cond_input={"x_lengths": None, "d_vectors": None}
): # pylint: disable=dangerous-default-value
x_lengths = cond_input["x_lengths"]
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
if g is not None:
if self.d_vector_dim:

View File

@ -1,8 +1,10 @@
import os
from typing import Dict
import numpy as np
import pkg_resources
import torch
from torch import nn
from .text import phoneme_to_sequence, text_to_sequence
@ -64,9 +66,34 @@ def compute_style_mel(style_wav, ap, cuda=False):
return style_mel
def run_model_torch(model, inputs, speaker_id=None, style_mel=None, d_vector=None):
def run_model_torch(
model: nn.Module,
inputs: torch.Tensor,
speaker_id: int = None,
style_mel: torch.Tensor = None,
d_vector: torch.Tensor = None,
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
Args:
model (nn.Module): The model to run inference.
inputs (torch.Tensor): Input tensor with character ids.
speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None.
style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None.
Returns:
Dict: model outputs.
"""
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
outputs = model.inference(
inputs, cond_input={"speaker_ids": speaker_id, "d_vector": d_vector, "style_mel": style_mel}
inputs,
cond_input={
"x_lengths": input_lengths,
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
},
)
return outputs

View File

@ -1,13 +1,14 @@
import os
import unittest
from tests import get_tests_output_path
from TTS.config import load_config
from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.synthesizer import Synthesizer
from .. import get_tests_output_path
class SynthesizerTest(unittest.TestCase):
# pylint: disable=R0201

View File

@ -259,9 +259,7 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(
device
)
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0

View File

@ -8,7 +8,6 @@ from TTS.tts.configs import Tacotron2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = Tacotron2Config(
r=5,
batch_size=8,