mirror of https://github.com/coqui-ai/TTS.git
muliprocess phoneme computation
This commit is contained in:
parent
20c86489d7
commit
7505c0ba27
|
@ -134,11 +134,13 @@ class MyDataset(Dataset):
|
|||
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
|
||||
if self.use_phonemes:
|
||||
text = self._load_or_generate_phoneme_sequence(wav_file, text)
|
||||
else:
|
||||
text = np.asarray(text_to_sequence(text, [self.cleaners],
|
||||
tp=self.tp, add_blank=self.add_blank),
|
||||
if not self.input_seq_computed:
|
||||
if self.use_phonemes:
|
||||
text = self._load_or_generate_phoneme_sequence(wav_file, text, self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, self.tp, self.add_blank)
|
||||
|
||||
else:
|
||||
text = np.asarray(text_to_sequence(text, [self.cleaners],
|
||||
tp=self.tp, add_blank=self.add_blank),
|
||||
dtype=np.int32)
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
|
@ -163,6 +165,41 @@ class MyDataset(Dataset):
|
|||
}
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def _phoneme_worker(args):
|
||||
item = args[0]
|
||||
func_args = args[1]
|
||||
text, wav_file, *_ = item
|
||||
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||
return phonemes
|
||||
|
||||
def compute_input_seq(self, num_workers=0):
|
||||
"""compute input sequences separately. Call it before
|
||||
passing dataset to data loader."""
|
||||
if not self.use_phonemes:
|
||||
if self.verbose:
|
||||
print(" | > Computing input sequences ...")
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(text_to_sequence(text, [self.cleaners],
|
||||
tp=self.tp, add_blank=self.add_blank),
|
||||
dtype=np.int32)
|
||||
self.items[idx][0] = sequence
|
||||
|
||||
else:
|
||||
func_args = [self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, self.tp, self.add_blank]
|
||||
if self.verbose:
|
||||
print(" | > Computing phonemes ...")
|
||||
if num_workers == 0:
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
phonemes = self._phoneme_worker([item, func_args])
|
||||
self.items[idx][0] = phonemes
|
||||
else:
|
||||
with Pool(num_workers) as p:
|
||||
phonemes = list(tqdm.tqdm(p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]), total=len(self.items)))
|
||||
for idx, p in enumerate(phonemes):
|
||||
self.items[idx][0] = p
|
||||
|
||||
def sort_items(self):
|
||||
r"""Sort instances based on text length in ascending order"""
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
|
|
Loading…
Reference in New Issue