make static methods for faster multiprocess call

This commit is contained in:
erogol 2020-12-07 11:29:10 +01:00
parent affe1c1138
commit 20c86489d7
1 changed files with 29 additions and 20 deletions

View File

@ -1,12 +1,16 @@
import os
import numpy as np
import collections import collections
import torch import os
import random import random
from torch.utils.data import Dataset from multiprocessing import Manager, Pool
from TTS.tts.utils.text import text_to_sequence, phoneme_to_sequence, pad_with_eos_bos import numpy as np
from TTS.tts.utils.data import prepare_data, prepare_tensor, prepare_stop_target import torch
import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import (prepare_data, prepare_stop_target,
prepare_tensor)
from TTS.tts.utils.text import (pad_with_eos_bos, phoneme_to_sequence,
text_to_sequence)
class MyDataset(Dataset): class MyDataset(Dataset):
@ -82,35 +86,40 @@ class MyDataset(Dataset):
data = np.load(filename).astype('float32') data = np.load(filename).astype('float32')
return data return data
def _generate_and_cache_phoneme_sequence(self, text, cache_path): @staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank):
"""generate a phoneme sequence from text. """generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the eos chars here. Instead we add those dynamically later; based on the
config option.""" config option."""
phonemes = phoneme_to_sequence(text, [self.cleaners], phonemes = phoneme_to_sequence(text, [cleaners],
language=self.phoneme_language, language=language,
enable_eos_bos=False, enable_eos_bos=False,
tp=self.tp, add_blank=self.add_blank) tp=tp, add_blank=add_blank)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes) np.save(cache_path, phonemes)
return phonemes return phonemes
def _load_or_generate_phoneme_sequence(self, wav_file, text): @staticmethod
def _load_or_generate_phoneme_sequence(wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank):
file_name = os.path.splitext(os.path.basename(wav_file))[0] file_name = os.path.splitext(os.path.basename(wav_file))[0]
cache_path = os.path.join(self.phoneme_cache_path,
file_name + '_phoneme.npy') # different names for normal phonemes and with blank chars.
file_name_ext = '_blanked_phoneme.npy' if add_blank else '_phoneme.npy'
cache_path = os.path.join(phoneme_cache_path,
file_name + file_name_ext)
try: try:
phonemes = np.load(cache_path) phonemes = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
phonemes = self._generate_and_cache_phoneme_sequence( phonemes = MyDataset._generate_and_cache_phoneme_sequence(
text, cache_path) text, cache_path, cleaners, language, tp, add_blank)
except (ValueError, IOError): except (ValueError, IOError):
print(" > ERROR: failed loading phonemes for {}. " print(" [!] failed loading phonemes for {}. "
"Recomputing.".format(wav_file)) "Recomputing.".format(wav_file))
phonemes = self._generate_and_cache_phoneme_sequence( phonemes = MyDataset._generate_and_cache_phoneme_sequence(
text, cache_path) text, cache_path, cleaners, language, tp, add_blank)
if self.enable_eos_bos: if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=self.tp) phonemes = pad_with_eos_bos(phonemes, tp=tp)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes return phonemes