Make eos bos chars optional

This commit is contained in:
Eren Golge 2019-04-12 16:12:15 +02:00
parent b44aa872c7
commit 9466505f27
6 changed files with 39 additions and 20 deletions

View File

@ -26,6 +26,7 @@ class MyDataset(Dataset):
use_phonemes=True, use_phonemes=True,
phoneme_cache_path=None, phoneme_cache_path=None,
phoneme_language="en-us", phoneme_language="en-us",
enable_eos_bos=False,
verbose=False): verbose=False):
""" """
Args: Args:
@ -48,6 +49,7 @@ class MyDataset(Dataset):
phoneme_cache_path (str): path to cache phoneme features. phoneme_cache_path (str): path to cache phoneme features.
phoneme_language (str): one the languages from phoneme_language (str): one the languages from
https://github.com/bootphon/phonemizer#languages https://github.com/bootphon/phonemizer#languages
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
verbose (bool): print diagnostic information. verbose (bool): print diagnostic information.
""" """
self.root_path = root_path self.root_path = root_path
@ -63,6 +65,7 @@ class MyDataset(Dataset):
self.use_phonemes = use_phonemes self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.verbose = verbose self.verbose = verbose
if use_phonemes and not os.path.isdir(phoneme_cache_path): if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True) os.makedirs(phoneme_cache_path, exist_ok=True)
@ -98,13 +101,13 @@ class MyDataset(Dataset):
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file)) print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
text = np.asarray( text = np.asarray(
phoneme_to_sequence( phoneme_to_sequence(
text, [self.cleaners], language=self.phoneme_language), text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
dtype=np.int32) dtype=np.int32)
np.save(tmp_path, text) np.save(tmp_path, text)
else: else:
text = np.asarray( text = np.asarray(
phoneme_to_sequence( phoneme_to_sequence(
text, [self.cleaners], language=self.phoneme_language), text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
dtype=np.int32) dtype=np.int32)
np.save(tmp_path, text) np.save(tmp_path, text)
return text return text

View File

@ -9,14 +9,14 @@ def test_phoneme_to_sequence():
lang = "en-us" lang = "en-us"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
gt = "^ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == gt assert text_hat == gt
# multiple punctuations # multiple punctuations
text = "Be a voice, not an! echo?" text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == gt
@ -25,7 +25,7 @@ def test_phoneme_to_sequence():
text = "Be a voice, not an! echo" text = "Be a voice, not an! echo"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == gt
@ -34,7 +34,7 @@ def test_phoneme_to_sequence():
text = "Be a voice, not an echo!" text = "Be a voice, not an echo!"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == gt
@ -43,7 +43,16 @@ def test_phoneme_to_sequence():
text = "Be a voice, not an! echo. " text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
print(text_hat)
print(len(sequence))
assert text_hat == gt
# extra space after the sentence
text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == gt

View File

@ -57,6 +57,7 @@ def setup_loader(is_val=False, verbose=False):
phoneme_cache_path=c.phoneme_cache_path, phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes, use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language, phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose) verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(

View File

@ -8,7 +8,7 @@ from .visual import visualize
from matplotlib import pylab as plt from matplotlib import pylab as plt
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False): def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos_chars=False, trim_silence=False):
"""Synthesize voice for the given text. """Synthesize voice for the given text.
Args: Args:
@ -20,30 +20,37 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
model outputs. model outputs.
truncated (bool): keep model states after inference. It can be used truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts. for continuous inference at long texts.
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
trim_silence (bool): trim silence after synthesis.
""" """
# preprocess the given text
text_cleaner = [CONFIG.text_cleaner] text_cleaner = [CONFIG.text_cleaner]
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
seq = np.asarray( seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language), phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars),
dtype=np.int32) dtype=np.int32)
else: else:
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
chars_var = torch.from_numpy(seq).unsqueeze(0) chars_var = torch.from_numpy(seq).unsqueeze(0)
# synthesize voice
if use_cuda: if use_cuda:
chars_var = chars_var.cuda() chars_var = chars_var.cuda()
# chars_var = chars_var[:-1]
if truncated: if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
chars_var.long()) chars_var.long())
else: else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference( decoder_output, postnet_output, alignments, stop_tokens = model.inference(
chars_var.long()) chars_var.long())
# convert outputs to numpy
postnet_output = postnet_output[0].data.cpu().numpy() postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy() alignment = alignments[0].cpu().data.numpy()
# plot results
if CONFIG.model == "Tacotron": if CONFIG.model == "Tacotron":
wav = ap.inv_spectrogram(postnet_output.T) wav = ap.inv_spectrogram(postnet_output.T)
else: else:
wav = ap.inv_mel_spectrogram(postnet_output.T) wav = ap.inv_mel_spectrogram(postnet_output.T)
# wav = wav[:ap.find_endpoint(wav)] # trim silence
if trim_silence:
wav = wav[:ap.find_endpoint(wav)]
return wav, alignment, decoder_output, postnet_output, stop_tokens return wav, alignment, decoder_output, postnet_output, stop_tokens

View File

@ -45,8 +45,10 @@ def text2phone(text, language):
return ph return ph
def phoneme_to_sequence(text, cleaner_names, language): def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
# sequence = [_phonemes_to_id['^']] if enable_eos_bos:
sequence = [_phonemes_to_id['^']]
else:
sequence = [] sequence = []
clean_text = _clean_text(text, cleaner_names) clean_text = _clean_text(text, cleaner_names)
phonemes = text2phone(clean_text, language) phonemes = text2phone(clean_text, language)
@ -56,7 +58,8 @@ def phoneme_to_sequence(text, cleaner_names, language):
for phoneme in filter(None, phonemes.split('|')): for phoneme in filter(None, phonemes.split('|')):
sequence += _phoneme_to_sequence(phoneme) sequence += _phoneme_to_sequence(phoneme)
# Append EOS char # Append EOS char
# sequence.append(_phonemes_to_id['~']) if enable_eos_bos:
sequence.append(_phonemes_to_id['~'])
return sequence return sequence
@ -84,7 +87,6 @@ def text_to_sequence(text, cleaner_names):
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' '''
sequence = [] sequence = []
# sequence = [_phonemes_to_id['^']]
# Check for curly braces and treat their contents as ARPAbet: # Check for curly braces and treat their contents as ARPAbet:
while len(text): while len(text):
m = _curly_re.match(text) m = _curly_re.match(text)
@ -95,9 +97,6 @@ def text_to_sequence(text, cleaner_names):
_clean_text(m.group(1), cleaner_names)) _clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2)) sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3) text = m.group(3)
# Append EOS token
# sequence.append(_symbol_to_id['~'])
return sequence return sequence

View File

@ -44,7 +44,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
plt.xlabel("Decoder timestamp", fontsize=label_fontsize) plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
plt.ylabel("Encoder timestamp", fontsize=label_fontsize) plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language) seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars)
text = sequence_to_phoneme(seq) text = sequence_to_phoneme(seq)
plt.yticks(range(len(text)), list(text)) plt.yticks(range(len(text)), list(text))
plt.colorbar() plt.colorbar()