From 7505c0ba27d622fd006be9ea8f903ff6a5d300dd Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 7 Dec 2020 11:29:41 +0100 Subject: [PATCH] muliprocess phoneme computation --- TTS/tts/datasets/TTSDataset.py | 47 ++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 65ceb7dd..88545d45 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -134,11 +134,13 @@ class MyDataset(Dataset): wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) - if self.use_phonemes: - text = self._load_or_generate_phoneme_sequence(wav_file, text) - else: - text = np.asarray(text_to_sequence(text, [self.cleaners], - tp=self.tp, add_blank=self.add_blank), + if not self.input_seq_computed: + if self.use_phonemes: + text = self._load_or_generate_phoneme_sequence(wav_file, text, self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, self.tp, self.add_blank) + + else: + text = np.asarray(text_to_sequence(text, [self.cleaners], + tp=self.tp, add_blank=self.add_blank), dtype=np.int32) assert text.size > 0, self.items[idx][1] @@ -163,6 +165,41 @@ class MyDataset(Dataset): } return sample + @staticmethod + def _phoneme_worker(args): + item = args[0] + func_args = args[1] + text, wav_file, *_ = item + phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) + return phonemes + + def compute_input_seq(self, num_workers=0): + """compute input sequences separately. Call it before + passing dataset to data loader.""" + if not self.use_phonemes: + if self.verbose: + print(" | > Computing input sequences ...") + for idx, item in enumerate(tqdm.tqdm(self.items)): + text, *_ = item + sequence = np.asarray(text_to_sequence(text, [self.cleaners], + tp=self.tp, add_blank=self.add_blank), + dtype=np.int32) + self.items[idx][0] = sequence + + else: + func_args = [self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, self.tp, self.add_blank] + if self.verbose: + print(" | > Computing phonemes ...") + if num_workers == 0: + for idx, item in enumerate(tqdm.tqdm(self.items)): + phonemes = self._phoneme_worker([item, func_args]) + self.items[idx][0] = phonemes + else: + with Pool(num_workers) as p: + phonemes = list(tqdm.tqdm(p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]), total=len(self.items))) + for idx, p in enumerate(phonemes): + self.items[idx][0] = p + def sort_items(self): r"""Sort instances based on text length in ascending order""" lengths = np.array([len(ins[0]) for ins in self.items])