diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index d4044c7e..cd84bf08 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -1,4 +1,5 @@ import os +import random from typing import Dict, List, Tuple import torch @@ -9,12 +10,12 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseModel +from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text import make_symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.audio import AudioProcessor # pylint: skip-file @@ -64,7 +65,7 @@ class BaseTTS(BaseModel): else: from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols - config.characters = parse_symbols() + config.characters = CharactersConfig(**parse_symbols()) model_characters = phonemes if config.use_phonemes else symbols num_chars = len(model_characters) + getattr(config, "add_blank", False) return model_characters, config, num_chars @@ -80,14 +81,13 @@ class BaseTTS(BaseModel): config (Coqpit): Model configuration. """ # init speaker manager - if self.speaker_manager is None: - raise ValueError(" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model.") - - print(f" > Number of speakers : {len(self.speaker_manager.speaker_ids)}") - - # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager - self.num_speakers = self.speaker_manager.num_speakers - + if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers # set ultimate speaker embedding size if config.use_speaker_embedding or config.use_d_vector_file: self.embedded_speaker_dim = ( @@ -99,10 +99,6 @@ class BaseTTS(BaseModel): self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) - def get_aux_input(self, **kwargs) -> Dict: - """Prepare and return `aux_input` used by `forward()`""" - return {"speaker_id": None, "style_wav": None, "d_vector": None} - def format_batch(self, batch: Dict) -> Dict: """Generic batch formatting for `TTSDataset`. @@ -293,6 +289,20 @@ class BaseTTS(BaseModel): ) return loader + def _get_test_aux_input( + self, + ) -> Dict: + aux_inputs = { + "speaker_id": None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1), + "d_vector": None + if not self.config.use_d_vector_file + else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1), + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. @@ -309,7 +319,7 @@ class BaseTTS(BaseModel): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - aux_inputs = self.get_aux_input() + aux_inputs = self._get_test_aux_input() for idx, sen in enumerate(test_sentences): outputs_dict = synthesis( self,