muliprocess phoneme computation

This commit is contained in:
erogol 2020-12-07 11:29:41 +01:00
parent 20c86489d7
commit 7505c0ba27
1 changed files with 42 additions and 5 deletions

View File

@ -134,11 +134,13 @@ class MyDataset(Dataset):
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence(wav_file, text)
else:
text = np.asarray(text_to_sequence(text, [self.cleaners],
tp=self.tp, add_blank=self.add_blank),
if not self.input_seq_computed:
if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence(wav_file, text, self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, self.tp, self.add_blank)
else:
text = np.asarray(text_to_sequence(text, [self.cleaners],
tp=self.tp, add_blank=self.add_blank),
dtype=np.int32)
assert text.size > 0, self.items[idx][1]
@ -163,6 +165,41 @@ class MyDataset(Dataset):
}
return sample
@staticmethod
def _phoneme_worker(args):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes
def compute_input_seq(self, num_workers=0):
"""compute input sequences separately. Call it before
passing dataset to data loader."""
if not self.use_phonemes:
if self.verbose:
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray(text_to_sequence(text, [self.cleaners],
tp=self.tp, add_blank=self.add_blank),
dtype=np.int32)
self.items[idx][0] = sequence
else:
func_args = [self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, self.phoneme_language, self.tp, self.add_blank]
if self.verbose:
print(" | > Computing phonemes ...")
if num_workers == 0:
for idx, item in enumerate(tqdm.tqdm(self.items)):
phonemes = self._phoneme_worker([item, func_args])
self.items[idx][0] = phonemes
else:
with Pool(num_workers) as p:
phonemes = list(tqdm.tqdm(p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]), total=len(self.items)))
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
def sort_items(self):
r"""Sort instances based on text length in ascending order"""
lengths = np.array([len(ins[0]) for ins in self.items])