diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 01f4a1de..493c8869 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -14,7 +14,7 @@ from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import Graphemes, make_symbols +from TTS.tts.utils.text.characters import Graphemes, make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file @@ -33,7 +33,9 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ - def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None): + def __init__( + self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None + ): super().__init__(config) self.config = config self.ap = ap @@ -71,9 +73,6 @@ class BaseTTS(BaseModel): else: raise ValueError("config must be either a *Config or *Args") - def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: - return get_speaker_manager(config, restore_path, data, out_path) - def init_multispeaker(self, config: Coqpit, data: List = None): """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining `in_channels` size of the connected layers. @@ -291,7 +290,7 @@ class BaseTTS(BaseModel): verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, - tokenizer=self.tokenizer + tokenizer=self.tokenizer, ) # pre-compute phonemes diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index fe4a9d9b..4762a77a 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -71,12 +71,7 @@ model = GlowTTS(config, ap, tokenizer, speaker_manager=None) # Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training, # distributed training, etc. trainer = Trainer( - TrainingArgs(), - config, - output_path, - model=model, - train_samples=train_samples, - eval_samples=eval_samples + TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples ) # AND... 3,2,1... 🚀