mirror of https://github.com/coqui-ai/TTS.git
Extend glow_tts model tests
This commit is contained in:
parent
8e248913d6
commit
235f7d9b02
|
@ -40,11 +40,20 @@ class GlowTTS(BaseTTS):
|
|||
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
|
||||
|
||||
Examples:
|
||||
Init only model layers.
|
||||
|
||||
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
>>> from TTS.tts.models.glow_tts import GlowTTS
|
||||
>>> config = GlowTTSConfig(num_chars=2)
|
||||
>>> model = GlowTTS(config)
|
||||
|
||||
Fully init a model ready for action. All the class attributes and class members
|
||||
(e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values.
|
||||
|
||||
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
>>> from TTS.tts.models.glow_tts import GlowTTS
|
||||
>>> config = GlowTTSConfig()
|
||||
>>> model = GlowTTS(config)
|
||||
|
||||
>>> model = GlowTTS.init_from_config(config, verbose=False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -98,25 +107,23 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
|
||||
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
|
||||
vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets
|
||||
speaker embedding vector dimension to the d-vector dimension from the config.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
# init speaker manager
|
||||
if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file):
|
||||
raise ValueError(
|
||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||
)
|
||||
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
if self.speaker_manager is not None:
|
||||
assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager."
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
|
@ -186,12 +193,33 @@ class GlowTTS(BaseTTS):
|
|||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, T]`
|
||||
- x_lenghts::math:`B`
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths::math:`B`
|
||||
- g: :math:`[B, C] or B`
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input text sequence ids. :math:`[B, T_en]`
|
||||
|
||||
x_lengths (torch.Tensor):
|
||||
Lengths of input text sequences. :math:`[B]`
|
||||
|
||||
y (torch.Tensor):
|
||||
Target mel-spectrogram frames. :math:`[B, T_de, C_mel]`
|
||||
|
||||
y_lengths (torch.Tensor):
|
||||
Lengths of target mel-spectrogram frames. :math:`[B]`
|
||||
|
||||
aux_input (Dict):
|
||||
Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model.
|
||||
:math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding
|
||||
layer. :math:`B`
|
||||
|
||||
Returns:
|
||||
Dict:
|
||||
- z: :math: `[B, T_de, C]`
|
||||
- logdet: :math:`B`
|
||||
- y_mean: :math:`[B, T_de, C]`
|
||||
- y_log_scale: :math:`[B, T_de, C]`
|
||||
- alignments: :math:`[B, T_en, T_de]`
|
||||
- durations_log: :math:`[B, T_en, 1]`
|
||||
- total_durations_log: :math:`[B, T_en, 1]`
|
||||
"""
|
||||
# [B, T, C] -> [B, C, T]
|
||||
y = y.transpose(1, 2)
|
||||
|
@ -510,17 +538,18 @@ class GlowTTS(BaseTTS):
|
|||
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (VitsConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
verbose (bool): If True, print init messages. Defaults to True.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
ap = AudioProcessor.init_from_config(config, verbose)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return GlowTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import copy
|
||||
import os
|
||||
import unittest
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.models.glow_tts import GlowTTS
|
||||
|
@ -28,36 +30,211 @@ def count_parameters(model):
|
|||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class GlowTTSTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
class TestGlowTTS(unittest.TestCase):
|
||||
def _create_inputs(self):
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||
|
||||
def _check_parameter_changes(self, model, model_ref):
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
def test_init_multispeaker(self):
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config)
|
||||
# speaker embedding with default speaker_embedding_dim
|
||||
config.use_speaker_embedding = True
|
||||
config.num_speakers = 5
|
||||
config.d_vector_dim = None
|
||||
model.init_multispeaker(config)
|
||||
self.assertEqual(model.c_in_channels, model.hidden_channels_enc)
|
||||
# use external speaker embeddings with speaker_embedding_dim = 301
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
config.use_d_vector_file = True
|
||||
config.d_vector_dim = 301
|
||||
model = GlowTTS(config)
|
||||
model.init_multispeaker(config)
|
||||
self.assertEqual(model.c_in_channels, 301)
|
||||
# use speaker embedddings by the provided speaker_manager
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
config.use_speaker_embedding = True
|
||||
config.speakers_file = os.path.join(get_tests_data_path(), "ljspeech", "speakers.json")
|
||||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
model = GlowTTS(config)
|
||||
model.speaker_manager = speaker_manager
|
||||
model.init_multispeaker(config)
|
||||
self.assertEqual(model.c_in_channels, model.hidden_channels_enc)
|
||||
self.assertEqual(model.num_speakers, speaker_manager.num_speakers)
|
||||
# use external speaker embeddings by the provided speaker_manager
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
config.use_d_vector_file = True
|
||||
config.d_vector_dim = 256
|
||||
config.d_vector_file = os.path.join(get_tests_data_path(), "dummy_speakers.json")
|
||||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
model = GlowTTS(config)
|
||||
model.speaker_manager = speaker_manager
|
||||
model.init_multispeaker(config)
|
||||
self.assertEqual(model.c_in_channels, speaker_manager.d_vector_dim)
|
||||
self.assertEqual(model.num_speakers, speaker_manager.num_speakers)
|
||||
|
||||
def test_unlock_act_norm_layers(self):
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
model.unlock_act_norm_layers()
|
||||
for f in model.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
self.assertFalse(f.initialized)
|
||||
|
||||
def test_lock_act_norm_layers(self):
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
model.lock_act_norm_layers()
|
||||
for f in model.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
self.assertTrue(f.initialized)
|
||||
|
||||
def test_forward(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
# create model
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
model.train()
|
||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
# inference encoder and decoder with MAS
|
||||
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
self.assertEqual(y["z"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([8]))
|
||||
self.assertEqual(y["y_mean"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["y_log_scale"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],))
|
||||
self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def test_forward_with_d_vector(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
d_vector = torch.rand(8, 256).to(device)
|
||||
# create model
|
||||
config = GlowTTSConfig(
|
||||
num_chars=32,
|
||||
use_d_vector_file=True,
|
||||
d_vector_dim=256,
|
||||
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||
)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
model.train()
|
||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
# inference encoder and decoder with MAS
|
||||
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector})
|
||||
self.assertEqual(y["z"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([8]))
|
||||
self.assertEqual(y["y_mean"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["y_log_scale"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],))
|
||||
self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def test_forward_with_speaker_id(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
speaker_ids = torch.randint(0, 24, (8,)).long().to(device)
|
||||
# create model
|
||||
config = GlowTTSConfig(
|
||||
num_chars=32,
|
||||
use_speaker_embedding=True,
|
||||
num_speakers=24,
|
||||
)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
model.train()
|
||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
# inference encoder and decoder with MAS
|
||||
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids})
|
||||
self.assertEqual(y["z"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([8]))
|
||||
self.assertEqual(y["y_mean"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["y_log_scale"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],))
|
||||
self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def _assert_inference_outputs(self, outputs, input_dummy, mel_spec):
|
||||
output_shape = outputs["model_outputs"].shape
|
||||
self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2])
|
||||
self.assertEqual(outputs["logdet"], None)
|
||||
self.assertEqual(outputs["y_mean"].shape, output_shape)
|
||||
self.assertEqual(outputs["y_log_scale"].shape, output_shape)
|
||||
self.assertEqual(outputs["alignments"].shape, output_shape[:2] + (input_dummy.shape[1],))
|
||||
self.assertEqual(outputs["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(outputs["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def test_inference(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
model.eval()
|
||||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths})
|
||||
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
|
||||
|
||||
def test_inference_with_d_vector(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
d_vector = torch.rand(8, 256).to(device)
|
||||
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"))
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
model.eval()
|
||||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
|
||||
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
|
||||
|
||||
def test_inference_with_speaker_ids(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
speaker_ids = torch.randint(0, 24, (8,)).long().to(device)
|
||||
# create model
|
||||
config = GlowTTSConfig(
|
||||
num_chars=32,
|
||||
use_speaker_embedding=True,
|
||||
num_speakers=24,
|
||||
)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids})
|
||||
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
|
||||
|
||||
def test_inference_with_MAS(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
# create model
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
model.eval()
|
||||
# inference encoder and decoder with MAS
|
||||
y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
y2 = model.decoder_inference(mel_spec, mel_lengths)
|
||||
assert (
|
||||
y2["model_outputs"].shape == y["model_outputs"].shape
|
||||
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||
y["model_outputs"].shape, y2["model_outputs"].shape
|
||||
)
|
||||
|
||||
def test_train_step(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
criterion = GlowTTSLoss()
|
||||
|
||||
# model to train
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
|
||||
# reference model to compare model weights
|
||||
model_ref = GlowTTS(config).to(device)
|
||||
|
||||
model.train()
|
||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
|
||||
# pass the state to ref model
|
||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
for _ in range(5):
|
||||
optimizer.zero_grad()
|
||||
|
@ -75,40 +252,78 @@ class GlowTTSTrainTest(unittest.TestCase):
|
|||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
|
||||
class GlowTTSInferenceTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_inference():
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||
|
||||
# create model
|
||||
def test_train_eval_log(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs()
|
||||
batch = {}
|
||||
batch["text_input"] = input_dummy
|
||||
batch["text_lengths"] = input_lengths
|
||||
batch["mel_lengths"] = mel_lengths
|
||||
batch["mel_input"] = mel_spec
|
||||
batch["d_vectors"] = None
|
||||
batch["speaker_ids"] = None
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
model.run_data_dep_init = False
|
||||
model.train()
|
||||
logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log")
|
||||
criterion = model.get_criterion()
|
||||
outputs, _ = model.train_step(batch, criterion)
|
||||
model.train_log(batch, outputs, logger, None, 1)
|
||||
model.eval_log(batch, outputs, logger, None, 1)
|
||||
logger.finish()
|
||||
|
||||
def test_test_run(self):
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
model.run_data_dep_init = False
|
||||
model.eval()
|
||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
test_figures, test_audios = model.test_run(None)
|
||||
self.assertTrue(test_figures is not None)
|
||||
self.assertTrue(test_audios is not None)
|
||||
|
||||
# inference encoder and decoder with MAS
|
||||
y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
def test_load_checkpoint(self):
|
||||
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
chkp = {}
|
||||
chkp["model"] = model.state_dict()
|
||||
torch.save(chkp, chkp_path)
|
||||
model.load_checkpoint(config, chkp_path)
|
||||
self.assertTrue(model.training)
|
||||
model.load_checkpoint(config, chkp_path, eval=True)
|
||||
self.assertFalse(model.training)
|
||||
|
||||
y2 = model.decoder_inference(mel_spec, mel_lengths)
|
||||
def test_get_criterion(self):
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
criterion = model.get_criterion()
|
||||
self.assertTrue(criterion is not None)
|
||||
|
||||
def test_init_from_config(self):
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
|
||||
config = GlowTTSConfig(num_chars=32, num_speakers=2)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(model.num_speakers == 2)
|
||||
self.assertTrue(not hasattr(model, "emb_g"))
|
||||
|
||||
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True)
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(model.num_speakers == 2)
|
||||
self.assertTrue(hasattr(model, "emb_g"))
|
||||
|
||||
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"))
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(model.num_speakers == 10)
|
||||
self.assertTrue(hasattr(model, "emb_g"))
|
||||
|
||||
config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"))
|
||||
model = GlowTTS.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(model.num_speakers == 1)
|
||||
self.assertTrue(not hasattr(model, "emb_g"))
|
||||
self.assertTrue(model.c_in_channels == config.d_vector_dim)
|
||||
|
||||
assert (
|
||||
y2["model_outputs"].shape == y["model_outputs"].shape
|
||||
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||
y["model_outputs"].shape, y2["model_outputs"].shape
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue