diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index d4a12c07..9b78ddba 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -69,6 +69,9 @@ class TTSDataset(Dataset): samples (list): List of dataset samples. + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else use the given. Defaults to None. @@ -202,6 +205,20 @@ class TTSDataset(Dataset): token_ids = self.tokenizer.text_to_ids(text) return np.array(token_ids, dtype=np.int32) + @staticmethod + def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + def load_data(self, idx): item = self.samples[idx]