diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index c36dc529..824f0128 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,14 +1,9 @@ import os -import torch - -from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config +from TTS.config import load_config, register_config from TTS.trainer import Trainer, TrainingArgs from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model -from TTS.tts.utils.languages import LanguageManager -from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.audio import AudioProcessor def main(): diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 5168dd06..5a38039b 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -292,7 +292,6 @@ def brspeech(root_path, meta_file, ignored_speakers=None): def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None): """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" file_ext = "flac" - test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 869adcad..23eb48da 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -123,7 +123,9 @@ class GlowTTS(BaseTTS): config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) if self.speaker_manager is not None: - assert config.d_vector_dim == self.speaker_manager.d_vector_dim, " [!] d-vector dimension mismatch b/w config and speaker manager." + assert ( + config.d_vector_dim == self.speaker_manager.d_vector_dim + ), " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b5551268..2ecd1a07 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,5 +1,4 @@ import math -import random from dataclasses import dataclass, field, replace from itertools import chain from typing import Dict, List, Tuple, Union @@ -269,10 +268,20 @@ class Vits(BaseTTS): Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. Examples: + Init only model layers. + >>> from TTS.tts.configs.vits_config import VitsConfig >>> from TTS.tts.models.vits import Vits >>> config = VitsConfig() >>> model = Vits(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> from TTS.tts.models.vits import Vits + >>> config = VitsConfig() + >>> model = Vits.init_from_config(config) """ # pylint: disable=dangerous-default-value @@ -908,13 +917,10 @@ class Vits(BaseTTS): aux_inputs["text"], self.config, "cuda" in str(next(self.parameters()).device), - self.ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], language_id=aux_inputs["language_id"], - language_name=aux_inputs["language_name"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, ).values() diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index f84a51ee..80be368d 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -148,7 +148,7 @@ class TTSTokenizer: # init cleaners text_cleaner = None if isinstance(config.text_cleaner, (str, list)): - text_cleaner = getattr(config, "text_cleaner") + text_cleaner = getattr(cleaners, config.text_cleaner) # init characters if characters is None: diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index f6a1ae6a..a1a323e8 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -122,13 +122,9 @@ class Synthesizer(object): speaker_manager = self._init_speaker_encoder(speaker_manager) if language_manager is not None: - self.tts_model = setup_tts_model( - config=self.tts_config, - speaker_manager=speaker_manager, - language_manager=language_manager, - ) + self.tts_model = setup_tts_model(config=self.tts_config) else: - self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager) + self.tts_model = setup_tts_model(config=self.tts_config) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() @@ -333,7 +329,6 @@ class Synthesizer(object): use_cuda=self.use_cuda, speaker_id=speaker_id, language_id=language_id, - language_name=language_name, style_wav=style_wav, use_griffin_lim=use_gl, d_vector=speaker_embedding, diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index e97b793a..e48977e9 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -1,8 +1,6 @@ import copy import os import unittest -from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.logging.tensorboard_logger import TensorboardLogger import torch from torch import optim @@ -11,7 +9,9 @@ from tests import get_tests_data_path, get_tests_input_path, get_tests_output_pa from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS +from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor +from TTS.utils.logging.tensorboard_logger import TensorboardLogger # pylint: disable=unused-variable @@ -31,7 +31,8 @@ def count_parameters(model): class TestGlowTTS(unittest.TestCase): - def _create_inputs(self): + @staticmethod + def _create_inputs(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 @@ -40,7 +41,8 @@ class TestGlowTTS(unittest.TestCase): speaker_ids = torch.randint(0, 5, (8,)).long().to(device) return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids - def _check_parameter_changes(self, model, model_ref): + @staticmethod + def _check_parameter_changes(model, model_ref): count = 0 for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( @@ -166,7 +168,7 @@ class TestGlowTTS(unittest.TestCase): def _assert_inference_outputs(self, outputs, input_dummy, mel_spec): output_shape = outputs["model_outputs"].shape - self.assertEqual(outputs["model_outputs"].shape[::2] , mel_spec.shape[::2]) + self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2]) self.assertEqual(outputs["logdet"], None) self.assertEqual(outputs["y_mean"].shape, output_shape) self.assertEqual(outputs["y_log_scale"].shape, output_shape) @@ -185,7 +187,12 @@ class TestGlowTTS(unittest.TestCase): def test_inference_with_d_vector(self): input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs() d_vector = torch.rand(8, 256).to(device) - config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) model = GlowTTS.init_from_config(config, verbose=False).to(device) model.eval() outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) @@ -268,7 +275,9 @@ class TestGlowTTS(unittest.TestCase): model = GlowTTS.init_from_config(config, verbose=False).to(device) model.run_data_dep_init = False model.train() - logger = TensorboardLogger(log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name = "glow_tts_test_train_log") + logger = TensorboardLogger( + log_dir=os.path.join(get_tests_output_path(), "dummy_glow_tts_logs"), model_name="glow_tts_test_train_log" + ) criterion = model.get_criterion() outputs, _ = model.train_step(batch, criterion) model.train_log(batch, outputs, logger, None, 1) @@ -316,14 +325,23 @@ class TestGlowTTS(unittest.TestCase): self.assertTrue(model.num_speakers == 2) self.assertTrue(hasattr(model, "emb_g")) - config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json")) + config = GlowTTSConfig( + num_chars=32, + num_speakers=2, + use_speaker_embedding=True, + speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), + ) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 10) self.assertTrue(hasattr(model, "emb_g")) - config = GlowTTSConfig(num_chars=32, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json")) + config = GlowTTSConfig( + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) model = GlowTTS.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.c_in_channels == config.d_vector_dim) -