diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 928a2a46..8054b181 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -25,6 +25,8 @@ class Synthesizer(object): tts_speakers_file: str = "", vocoder_checkpoint: str = "", vocoder_config: str = "", + encoder_checkpoint: str = "", + encoder_config: str = "", use_cuda: bool = False, ) -> None: """General 🐸 TTS interface for inference. It takes a tts and a vocoder @@ -41,6 +43,8 @@ class Synthesizer(object): tts_config_path (str): path to the tts config file. vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. vocoder_config (str, optional): path to the vocoder config file. Defaults to None. + encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`, + encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`, use_cuda (bool, optional): enable/disable cuda. Defaults to False. """ self.tts_checkpoint = tts_checkpoint @@ -48,6 +52,8 @@ class Synthesizer(object): self.tts_speakers_file = tts_speakers_file self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_config = vocoder_config + self.encoder_checkpoint = encoder_checkpoint + self.encoder_config = encoder_config self.use_cuda = use_cuda self.tts_model = None @@ -69,16 +75,37 @@ class Synthesizer(object): @staticmethod def _get_segmenter(lang: str): + """get the sentence segmenter for the given language. + + Args: + lang (str): target language code. + + Returns: + [type]: [description] + """ return pysbd.Segmenter(language=lang, clean=True) def _load_speakers(self, speaker_file: str) -> None: + """Load the SpeakerManager to organize multi-speaker TTS. It loads the speakers meta-data and the speaker + encoder if it is defined. + + Args: + speaker_file (str): path to the speakers meta-data file. + """ print("Loading speakers ...") - self.speaker_manager = SpeakerManager() + self.speaker_manager = SpeakerManager(encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config) self.speaker_manager.load_x_vectors_file(self.tts_config.get("external_speaker_embedding_file", speaker_file)) self.num_speakers = self.speaker_manager.num_speakers self.speaker_embedding_dim = self.speaker_manager.x_vector_dim def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: + """Load the TTS model. + + Args: + tts_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ # pylint: disable=global-statement global symbols, phonemes @@ -109,6 +136,13 @@ class Synthesizer(object): self.tts_model.cuda() def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: + """Load the vocoder model. + + Args: + model_file (str): path to the model checkpoint. + model_config (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ self.vocoder_config = load_config(model_config) self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"]) self.vocoder_model = setup_generator(self.vocoder_config) @@ -117,24 +151,54 @@ class Synthesizer(object): self.vocoder_model.cuda() def _split_into_sentences(self, text) -> List[str]: + """Split give text into sentences. + + Args: + text (str): input text in string format. + + Returns: + List[str]: list of sentences. + """ return self.seg.segment(text) def save_wav(self, wav: List[int], path: str) -> None: + """Save the waveform as a file. + + Args: + wav (List[int]): waveform as a list of values. + path (str): output path to save the waveform. + """ wav = np.array(wav) self.ap.save_wav(wav, path, self.output_sample_rate) - def tts(self, text: str, speaker_idx: str = "", style_wav=None) -> List[int]: + def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]: + """🐸 TTS magic. Run all the models and generate speech. + + Args: + text (str): input text. + speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "". + speaker_wav (): + style_wav ([type], optional): style waveform for GST. Defaults to None. + + Returns: + List[int]: [description] + """ start_time = time.time() wavs = [] sens = self._split_into_sentences(text) print(" > Text splitted to sentences.") print(sens) + # get the speaker embedding from the saved x_vectors. if speaker_idx and isinstance(speaker_idx, str): speaker_embedding = self.speaker_manager.get_x_vectors_by_speaker(speaker_idx)[0] else: speaker_embedding = None + # compute a new x_vector from the given clip. + if speaker_wav is not None: + speaker_embedding = self.speaker_manager.compute_x_vector_from_clip(speaker_wav) + use_gl = self.vocoder_model is None for sen in sens: