diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index 30ead9ad..613adf1c 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -26,6 +26,7 @@ class MyDataset(Dataset): use_phonemes=True, phoneme_cache_path=None, phoneme_language="en-us", + enable_eos_bos=False, verbose=False): """ Args: @@ -48,6 +49,7 @@ class MyDataset(Dataset): phoneme_cache_path (str): path to cache phoneme features. phoneme_language (str): one the languages from https://github.com/bootphon/phonemizer#languages + enable_eos_bos (bool): enable end of sentence and beginning of sentences characters. verbose (bool): print diagnostic information. """ self.root_path = root_path @@ -63,6 +65,7 @@ class MyDataset(Dataset): self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language + self.enable_eos_bos = enable_eos_bos self.verbose = verbose if use_phonemes and not os.path.isdir(phoneme_cache_path): os.makedirs(phoneme_cache_path, exist_ok=True) @@ -98,13 +101,13 @@ class MyDataset(Dataset): print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file)) text = np.asarray( phoneme_to_sequence( - text, [self.cleaners], language=self.phoneme_language), + text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos), dtype=np.int32) np.save(tmp_path, text) else: text = np.asarray( phoneme_to_sequence( - text, [self.cleaners], language=self.phoneme_language), + text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos), dtype=np.int32) np.save(tmp_path, text) return text diff --git a/tests/text_processing_tests.py b/tests/text_processing_tests.py index 82052323..d991cee9 100644 --- a/tests/text_processing_tests.py +++ b/tests/text_processing_tests.py @@ -9,14 +9,14 @@ def test_phoneme_to_sequence(): lang = "en-us" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) - gt = "^ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" + gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" assert text_hat == gt # multiple punctuations text = "Be a voice, not an! echo?" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) - gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" print(text_hat) print(len(sequence)) assert text_hat == gt @@ -25,7 +25,7 @@ def test_phoneme_to_sequence(): text = "Be a voice, not an! echo" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) - gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" print(text_hat) print(len(sequence)) assert text_hat == gt @@ -34,7 +34,7 @@ def test_phoneme_to_sequence(): text = "Be a voice, not an echo!" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) - gt = "^biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" + gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" print(text_hat) print(len(sequence)) assert text_hat == gt @@ -43,7 +43,16 @@ def test_phoneme_to_sequence(): text = "Be a voice, not an! echo. " sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) - gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." + print(text_hat) + print(len(sequence)) + assert text_hat == gt + + # extra space after the sentence + text = "Be a voice, not an! echo. " + sequence = phoneme_to_sequence(text, text_cleaner, lang, True) + text_hat = sequence_to_phoneme(sequence) + gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~" print(text_hat) print(len(sequence)) assert text_hat == gt diff --git a/train.py b/train.py index 62b8f074..90427fbf 100644 --- a/train.py +++ b/train.py @@ -57,6 +57,7 @@ def setup_loader(is_val=False, verbose=False): phoneme_cache_path=c.phoneme_cache_path, use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, + enable_eos_bos=c.enable_eos_bos_chars, verbose=verbose) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( diff --git a/utils/synthesis.py b/utils/synthesis.py index 601d2a48..4f4386ff 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -8,7 +8,7 @@ from .visual import visualize from matplotlib import pylab as plt -def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False): +def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos_chars=False, trim_silence=False): """Synthesize voice for the given text. Args: @@ -20,30 +20,37 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False): model outputs. truncated (bool): keep model states after inference. It can be used for continuous inference at long texts. + enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. + trim_silence (bool): trim silence after synthesis. """ + # preprocess the given text text_cleaner = [CONFIG.text_cleaner] if CONFIG.use_phonemes: seq = np.asarray( - phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language), + phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars), dtype=np.int32) else: seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) chars_var = torch.from_numpy(seq).unsqueeze(0) + # synthesize voice if use_cuda: chars_var = chars_var.cuda() - # chars_var = chars_var[:-1] if truncated: decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( chars_var.long()) else: decoder_output, postnet_output, alignments, stop_tokens = model.inference( chars_var.long()) + # convert outputs to numpy postnet_output = postnet_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() + # plot results if CONFIG.model == "Tacotron": wav = ap.inv_spectrogram(postnet_output.T) else: wav = ap.inv_mel_spectrogram(postnet_output.T) - # wav = wav[:ap.find_endpoint(wav)] + # trim silence + if trim_silence: + wav = wav[:ap.find_endpoint(wav)] return wav, alignment, decoder_output, postnet_output, stop_tokens \ No newline at end of file diff --git a/utils/text/__init__.py b/utils/text/__init__.py index 297b4514..9c0e3f47 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -45,9 +45,11 @@ def text2phone(text, language): return ph -def phoneme_to_sequence(text, cleaner_names, language): - # sequence = [_phonemes_to_id['^']] - sequence = [] +def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): + if enable_eos_bos: + sequence = [_phonemes_to_id['^']] + else: + sequence = [] clean_text = _clean_text(text, cleaner_names) phonemes = text2phone(clean_text, language) if phonemes is None: @@ -56,7 +58,8 @@ def phoneme_to_sequence(text, cleaner_names, language): for phoneme in filter(None, phonemes.split('|')): sequence += _phoneme_to_sequence(phoneme) # Append EOS char - # sequence.append(_phonemes_to_id['~']) + if enable_eos_bos: + sequence.append(_phonemes_to_id['~']) return sequence @@ -84,7 +87,6 @@ def text_to_sequence(text, cleaner_names): List of integers corresponding to the symbols in the text ''' sequence = [] - # sequence = [_phonemes_to_id['^']] # Check for curly braces and treat their contents as ARPAbet: while len(text): m = _curly_re.match(text) @@ -95,9 +97,6 @@ def text_to_sequence(text, cleaner_names): _clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) - - # Append EOS token - # sequence.append(_symbol_to_id['~']) return sequence diff --git a/utils/visual.py b/utils/visual.py index 0df37815..b259bdd9 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -44,7 +44,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON plt.xlabel("Decoder timestamp", fontsize=label_fontsize) plt.ylabel("Encoder timestamp", fontsize=label_fontsize) if CONFIG.use_phonemes: - seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language) + seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars) text = sequence_to_phoneme(seq) plt.yticks(range(len(text)), list(text)) plt.colorbar()