mirror of https://github.com/coqui-ai/TTS.git
fix glow-tts `inference()`
This commit is contained in:
parent
7ea71c7586
commit
4f29725eb6
|
@ -290,7 +290,10 @@ class GlowTTS(nn.Module):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, x_lengths, cond_input={"d_vectors": None}): # pylint: disable=dangerous-default-value
|
def inference(
|
||||||
|
self, x, cond_input={"x_lengths": None, "d_vectors": None}
|
||||||
|
): # pylint: disable=dangerous-default-value
|
||||||
|
x_lengths = cond_input["x_lengths"]
|
||||||
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
|
g = cond_input["d_vectors"] if cond_input is not None and "d_vectors" in cond_input else None
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.d_vector_dim:
|
if self.d_vector_dim:
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from .text import phoneme_to_sequence, text_to_sequence
|
from .text import phoneme_to_sequence, text_to_sequence
|
||||||
|
|
||||||
|
@ -64,9 +66,34 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
||||||
return style_mel
|
return style_mel
|
||||||
|
|
||||||
|
|
||||||
def run_model_torch(model, inputs, speaker_id=None, style_mel=None, d_vector=None):
|
def run_model_torch(
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
speaker_id: int = None,
|
||||||
|
style_mel: torch.Tensor = None,
|
||||||
|
d_vector: torch.Tensor = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""Run a torch model for inference. It does not support batch inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to run inference.
|
||||||
|
inputs (torch.Tensor): Input tensor with character ids.
|
||||||
|
speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None.
|
||||||
|
style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
|
||||||
|
d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: model outputs.
|
||||||
|
"""
|
||||||
|
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
||||||
outputs = model.inference(
|
outputs = model.inference(
|
||||||
inputs, cond_input={"speaker_ids": speaker_id, "d_vector": d_vector, "style_mel": style_mel}
|
inputs,
|
||||||
|
cond_input={
|
||||||
|
"x_lengths": input_lengths,
|
||||||
|
"speaker_ids": speaker_id,
|
||||||
|
"d_vectors": d_vector,
|
||||||
|
"style_mel": style_mel,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tests import get_tests_output_path
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.io import save_checkpoint
|
from TTS.tts.utils.io import save_checkpoint
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
|
|
||||||
|
from .. import get_tests_output_path
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTest(unittest.TestCase):
|
class SynthesizerTest(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
|
|
|
@ -259,9 +259,7 @@ class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(
|
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, d_vector_dim=55, use_gst=True, gst=c.gst).to(device)
|
||||||
device
|
|
||||||
)
|
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
|
|
|
@ -8,7 +8,6 @@ from TTS.tts.configs import Tacotron2Config
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
config = Tacotron2Config(
|
config = Tacotron2Config(
|
||||||
r=5,
|
r=5,
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
|
|
Loading…
Reference in New Issue