Make eos bos chars optional

This commit is contained in:
Eren Golge 2019-04-12 16:12:15 +02:00
parent b44aa872c7
commit 9466505f27
6 changed files with 39 additions and 20 deletions

View File

@ -26,6 +26,7 @@ class MyDataset(Dataset):
use_phonemes=True,
phoneme_cache_path=None,
phoneme_language="en-us",
enable_eos_bos=False,
verbose=False):
"""
Args:
@ -48,6 +49,7 @@ class MyDataset(Dataset):
phoneme_cache_path (str): path to cache phoneme features.
phoneme_language (str): one the languages from
https://github.com/bootphon/phonemizer#languages
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
verbose (bool): print diagnostic information.
"""
self.root_path = root_path
@ -63,6 +65,7 @@ class MyDataset(Dataset):
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.verbose = verbose
if use_phonemes and not os.path.isdir(phoneme_cache_path):
os.makedirs(phoneme_cache_path, exist_ok=True)
@ -98,13 +101,13 @@ class MyDataset(Dataset):
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
text = np.asarray(
phoneme_to_sequence(
text, [self.cleaners], language=self.phoneme_language),
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
dtype=np.int32)
np.save(tmp_path, text)
else:
text = np.asarray(
phoneme_to_sequence(
text, [self.cleaners], language=self.phoneme_language),
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
dtype=np.int32)
np.save(tmp_path, text)
return text

View File

@ -9,14 +9,14 @@ def test_phoneme_to_sequence():
lang = "en-us"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
gt = "^ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == gt
# multiple punctuations
text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
print(text_hat)
print(len(sequence))
assert text_hat == gt
@ -25,7 +25,7 @@ def test_phoneme_to_sequence():
text = "Be a voice, not an! echo"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == gt
@ -34,7 +34,7 @@ def test_phoneme_to_sequence():
text = "Be a voice, not an echo!"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
print(text_hat)
print(len(sequence))
assert text_hat == gt
@ -43,7 +43,16 @@ def test_phoneme_to_sequence():
text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
print(text_hat)
print(len(sequence))
assert text_hat == gt
# extra space after the sentence
text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
text_hat = sequence_to_phoneme(sequence)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
print(text_hat)
print(len(sequence))
assert text_hat == gt

View File

@ -57,6 +57,7 @@ def setup_loader(is_val=False, verbose=False):
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(

View File

@ -8,7 +8,7 @@ from .visual import visualize
from matplotlib import pylab as plt
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos_chars=False, trim_silence=False):
"""Synthesize voice for the given text.
Args:
@ -20,30 +20,37 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
model outputs.
truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts.
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
trim_silence (bool): trim silence after synthesis.
"""
# preprocess the given text
text_cleaner = [CONFIG.text_cleaner]
if CONFIG.use_phonemes:
seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language),
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars),
dtype=np.int32)
else:
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
chars_var = torch.from_numpy(seq).unsqueeze(0)
# synthesize voice
if use_cuda:
chars_var = chars_var.cuda()
# chars_var = chars_var[:-1]
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
chars_var.long())
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
chars_var.long())
# convert outputs to numpy
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
# plot results
if CONFIG.model == "Tacotron":
wav = ap.inv_spectrogram(postnet_output.T)
else:
wav = ap.inv_mel_spectrogram(postnet_output.T)
# wav = wav[:ap.find_endpoint(wav)]
# trim silence
if trim_silence:
wav = wav[:ap.find_endpoint(wav)]
return wav, alignment, decoder_output, postnet_output, stop_tokens

View File

@ -45,9 +45,11 @@ def text2phone(text, language):
return ph
def phoneme_to_sequence(text, cleaner_names, language):
# sequence = [_phonemes_to_id['^']]
sequence = []
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
if enable_eos_bos:
sequence = [_phonemes_to_id['^']]
else:
sequence = []
clean_text = _clean_text(text, cleaner_names)
phonemes = text2phone(clean_text, language)
if phonemes is None:
@ -56,7 +58,8 @@ def phoneme_to_sequence(text, cleaner_names, language):
for phoneme in filter(None, phonemes.split('|')):
sequence += _phoneme_to_sequence(phoneme)
# Append EOS char
# sequence.append(_phonemes_to_id['~'])
if enable_eos_bos:
sequence.append(_phonemes_to_id['~'])
return sequence
@ -84,7 +87,6 @@ def text_to_sequence(text, cleaner_names):
List of integers corresponding to the symbols in the text
'''
sequence = []
# sequence = [_phonemes_to_id['^']]
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
@ -95,9 +97,6 @@ def text_to_sequence(text, cleaner_names):
_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
# sequence.append(_symbol_to_id['~'])
return sequence

View File

@ -44,7 +44,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
if CONFIG.use_phonemes:
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language)
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars)
text = sequence_to_phoneme(seq)
plt.yticks(range(len(text)), list(text))
plt.colorbar()