mirror of https://github.com/coqui-ai/TTS.git
fix glow-tts `inference()`
This commit is contained in:
parent
82582993cc
commit
25238e0658
|
@ -290,7 +290,10 @@ class GlowTTS(nn.Module):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, cond_input={"d_vectors": None}): # pylint: disable=dangerous-default-value
|
||||
def inference(
|
||||
self, x, cond_input={"x_lengths": None, "d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
x_lengths = cond_input["x_lengths"]
|
||||
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
|
||||
if g is not None:
|
||||
if self.d_vector_dim:
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .text import phoneme_to_sequence, text_to_sequence
|
||||
|
||||
|
@ -65,9 +67,34 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
|||
return style_mel
|
||||
|
||||
|
||||
def run_model_torch(model, inputs, speaker_id=None, style_mel=None, d_vector=None):
|
||||
def run_model_torch(
|
||||
model: nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
speaker_id: int = None,
|
||||
style_mel: torch.Tensor = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to run inference.
|
||||
inputs (torch.Tensor): Input tensor with character ids.
|
||||
speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None.
|
||||
style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
|
||||
d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs.
|
||||
"""
|
||||
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
||||
outputs = model.inference(
|
||||
inputs, cond_input={"speaker_ids": speaker_id, "d_vector": d_vector, "style_mel": style_mel}
|
||||
inputs,
|
||||
cond_input={
|
||||
"x_lengths": input_lengths,
|
||||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
"style_mel": style_mel,
|
||||
},
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
from tests import get_tests_output_path
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
from .. import get_tests_output_path
|
||||
|
||||
|
||||
class SynthesizerTest(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
|
|
|
@ -259,9 +259,7 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(
|
||||
device
|
||||
)
|
||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
|
|
|
@ -8,7 +8,6 @@ from TTS.tts.configs import Tacotron2Config
|
|||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
||||
config = Tacotron2Config(
|
||||
r=5,
|
||||
batch_size=8,
|
||||
|
|
Loading…
Reference in New Issue