From 802d4613890917f400b70ca63b4f9e66560063a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Jun 2021 11:46:53 +0200 Subject: [PATCH] Compute d_vectors and speaker_ids separately in TTSDataset --- TTS/bin/extract_tts_spectrograms.py | 19 ++++---------- TTS/trainer.py | 40 +++++++++++------------------ TTS/tts/datasets/TTSDataset.py | 35 +++++++++++++++++-------- 3 files changed, 44 insertions(+), 50 deletions(-) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 3acf5d02..d17bcb30 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -39,7 +39,8 @@ def setup_loader(ap, r, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=False, verbose=verbose, - speaker_mapping=speaker_manager.speaker_ids + speaker_id_mapping=speaker_manager.speaker_ids, + d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None, ) @@ -84,22 +85,12 @@ def format_data(data): mel_input = data[4] mel_lengths = data[5] item_idx = data[7] - attn_mask = data[9] + d_vectors = data[8] + speaker_ids = data[9] + attn_mask = data[10] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) - if c.use_speaker_embedding: - if c.use_external_speaker_embedding_file: - speaker_embeddings = data[8] - speaker_ids = None - else: - speaker_ids = [speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names] - speaker_ids = torch.LongTensor(speaker_ids) - speaker_embeddings = None - else: - speaker_embeddings = None - speaker_ids = None - # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) diff --git a/TTS/trainer.py b/TTS/trainer.py index 55560624..7136e023 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -267,7 +267,8 @@ class TrainerTTS: is_eval: bool, data_items: List, verbose: bool, - speaker_mapping: Union[Dict, List], + speaker_ids: Union[Dict, List], + d_vectors: Union[Dict, List] ) -> DataLoader: if is_eval and not self.config.run_eval: loader = None @@ -289,9 +290,10 @@ class TrainerTTS: enable_eos_bos=self.config.enable_eos_bos_chars, use_noise_augment=not is_eval, verbose=verbose, - speaker_mapping=speaker_mapping - if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file - else None, + speaker_id_mapping=speaker_ids + if self.config.use_speaker_embedding else None, + d_vector_mapping=d_vectors + if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file else None, ) if self.config.use_phonemes and self.config.compute_input_seq_cache: @@ -313,14 +315,14 @@ class TrainerTTS: return loader def get_train_dataloader( - self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] + self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict ) -> DataLoader: - return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping) + return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors) def get_eval_dataloder( - self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] + self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict ) -> DataLoader: - return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping) + return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors) def format_batch(self, batch: List) -> Dict: # setup input batch @@ -332,24 +334,12 @@ class TrainerTTS: mel_lengths = batch[5] stop_targets = batch[6] item_idx = batch[7] - speaker_embeddings = batch[8] - attn_mask = batch[9] + d_vectors = batch[8] + speaker_ids = batch[9] + attn_mask = batch[10] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) - # convert speaker names to ids - if self.config.use_speaker_embedding: - if self.config.use_external_speaker_embedding_file: - speaker_embeddings = batch[8] - speaker_ids = None - else: - speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names] - speaker_ids = torch.LongTensor(speaker_ids) - speaker_embeddings = None - else: - speaker_embeddings = None - speaker_ids = None - # compute durations from attention masks durations = None if attn_mask is not None: @@ -640,11 +630,11 @@ class TrainerTTS: # define data loaders self.train_loader = self.get_train_dataloader( - self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids + self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors ) self.eval_loader = ( self.get_eval_dataloder( - self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids + self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors ) if self.config.run_eval else None diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 76f82c97..2522b55a 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -29,7 +29,8 @@ class TTSDataset(Dataset): phoneme_cache_path=None, phoneme_language="en-us", enable_eos_bos=False, - speaker_mapping=None, + speaker_id_mapping=None, + d_vector_mapping=None, use_noise_augment=False, verbose=False, ): @@ -51,6 +52,8 @@ class TTSDataset(Dataset): phoneme_language (str): one the languages from https://github.com/bootphon/phonemizer#languages enable_eos_bos (bool): enable end of sentence and beginning of sentences characters. + speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids. + d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector. use_noise_augment (bool): enable adding random noise to wav for augmentation. verbose (bool): print diagnostic information. """ @@ -70,7 +73,8 @@ class TTSDataset(Dataset): self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language self.enable_eos_bos = enable_eos_bos - self.speaker_mapping = speaker_mapping + self.speaker_id_mapping = speaker_id_mapping + self.d_vector_mapping = d_vector_mapping self.use_noise_augment = use_noise_augment self.verbose = verbose self.input_seq_computed = False @@ -293,13 +297,18 @@ class TTSDataset(Dataset): item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] text = [batch[idx]["text"] for idx in ids_sorted_decreasing] - speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] - # get speaker embeddings - if self.speaker_mapping is not None: + speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] + # get pre-computed d-vectors + if self.d_vector_mapping is not None: wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] - speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names] + d_vectors = [self.speaker_mapping[w]["embedding"] for w in wav_files_names] else: - speaker_embedding = None + d_vectors = None + # get numerical speaker ids from speaker names + if self.speaker_id_mapping: + speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in speaker_names] + else: + speaker_ids = None # compute features mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] @@ -327,8 +336,11 @@ class TTSDataset(Dataset): mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - if speaker_embedding is not None: - speaker_embedding = torch.FloatTensor(speaker_embedding) + if d_vectors is not None: + d_vectors = torch.FloatTensor(d_vectors) + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) # compute linear spectrogram if self.compute_linear_spec: @@ -355,13 +367,14 @@ class TTSDataset(Dataset): return ( text, text_lenghts, - speaker_name, + speaker_names, linear, mel, mel_lengths, stop_targets, item_idxs, - speaker_embedding, + d_vectors, + speaker_ids, attns, )