mirror of https://github.com/coqui-ai/TTS.git
Make eos bos chars optional
This commit is contained in:
parent
b44aa872c7
commit
9466505f27
|
@ -26,6 +26,7 @@ class MyDataset(Dataset):
|
||||||
use_phonemes=True,
|
use_phonemes=True,
|
||||||
phoneme_cache_path=None,
|
phoneme_cache_path=None,
|
||||||
phoneme_language="en-us",
|
phoneme_language="en-us",
|
||||||
|
enable_eos_bos=False,
|
||||||
verbose=False):
|
verbose=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -48,6 +49,7 @@ class MyDataset(Dataset):
|
||||||
phoneme_cache_path (str): path to cache phoneme features.
|
phoneme_cache_path (str): path to cache phoneme features.
|
||||||
phoneme_language (str): one the languages from
|
phoneme_language (str): one the languages from
|
||||||
https://github.com/bootphon/phonemizer#languages
|
https://github.com/bootphon/phonemizer#languages
|
||||||
|
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
||||||
verbose (bool): print diagnostic information.
|
verbose (bool): print diagnostic information.
|
||||||
"""
|
"""
|
||||||
self.root_path = root_path
|
self.root_path = root_path
|
||||||
|
@ -63,6 +65,7 @@ class MyDataset(Dataset):
|
||||||
self.use_phonemes = use_phonemes
|
self.use_phonemes = use_phonemes
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
self.phoneme_language = phoneme_language
|
self.phoneme_language = phoneme_language
|
||||||
|
self.enable_eos_bos = enable_eos_bos
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
|
@ -98,13 +101,13 @@ class MyDataset(Dataset):
|
||||||
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
|
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
|
||||||
text = np.asarray(
|
text = np.asarray(
|
||||||
phoneme_to_sequence(
|
phoneme_to_sequence(
|
||||||
text, [self.cleaners], language=self.phoneme_language),
|
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
np.save(tmp_path, text)
|
np.save(tmp_path, text)
|
||||||
else:
|
else:
|
||||||
text = np.asarray(
|
text = np.asarray(
|
||||||
phoneme_to_sequence(
|
phoneme_to_sequence(
|
||||||
text, [self.cleaners], language=self.phoneme_language),
|
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
np.save(tmp_path, text)
|
np.save(tmp_path, text)
|
||||||
return text
|
return text
|
||||||
|
|
|
@ -9,14 +9,14 @@ def test_phoneme_to_sequence():
|
||||||
lang = "en-us"
|
lang = "en-us"
|
||||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
gt = "^ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
||||||
assert text_hat == gt
|
assert text_hat == gt
|
||||||
|
|
||||||
# multiple punctuations
|
# multiple punctuations
|
||||||
text = "Be a voice, not an! echo?"
|
text = "Be a voice, not an! echo?"
|
||||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
|
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == gt
|
assert text_hat == gt
|
||||||
|
@ -25,7 +25,7 @@ def test_phoneme_to_sequence():
|
||||||
text = "Be a voice, not an! echo"
|
text = "Be a voice, not an! echo"
|
||||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == gt
|
assert text_hat == gt
|
||||||
|
@ -34,7 +34,7 @@ def test_phoneme_to_sequence():
|
||||||
text = "Be a voice, not an echo!"
|
text = "Be a voice, not an echo!"
|
||||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
|
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == gt
|
assert text_hat == gt
|
||||||
|
@ -43,7 +43,16 @@ def test_phoneme_to_sequence():
|
||||||
text = "Be a voice, not an! echo. "
|
text = "Be a voice, not an! echo. "
|
||||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||||
text_hat = sequence_to_phoneme(sequence)
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
|
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
|
||||||
|
print(text_hat)
|
||||||
|
print(len(sequence))
|
||||||
|
assert text_hat == gt
|
||||||
|
|
||||||
|
# extra space after the sentence
|
||||||
|
text = "Be a voice, not an! echo. "
|
||||||
|
sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
|
||||||
|
text_hat = sequence_to_phoneme(sequence)
|
||||||
|
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
|
||||||
print(text_hat)
|
print(text_hat)
|
||||||
print(len(sequence))
|
print(len(sequence))
|
||||||
assert text_hat == gt
|
assert text_hat == gt
|
||||||
|
|
1
train.py
1
train.py
|
@ -57,6 +57,7 @@ def setup_loader(is_val=False, verbose=False):
|
||||||
phoneme_cache_path=c.phoneme_cache_path,
|
phoneme_cache_path=c.phoneme_cache_path,
|
||||||
use_phonemes=c.use_phonemes,
|
use_phonemes=c.use_phonemes,
|
||||||
phoneme_language=c.phoneme_language,
|
phoneme_language=c.phoneme_language,
|
||||||
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
|
|
|
@ -8,7 +8,7 @@ from .visual import visualize
|
||||||
from matplotlib import pylab as plt
|
from matplotlib import pylab as plt
|
||||||
|
|
||||||
|
|
||||||
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
|
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos_chars=False, trim_silence=False):
|
||||||
"""Synthesize voice for the given text.
|
"""Synthesize voice for the given text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -20,30 +20,37 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
|
||||||
model outputs.
|
model outputs.
|
||||||
truncated (bool): keep model states after inference. It can be used
|
truncated (bool): keep model states after inference. It can be used
|
||||||
for continuous inference at long texts.
|
for continuous inference at long texts.
|
||||||
|
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||||
|
trim_silence (bool): trim silence after synthesis.
|
||||||
"""
|
"""
|
||||||
|
# preprocess the given text
|
||||||
text_cleaner = [CONFIG.text_cleaner]
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
if CONFIG.use_phonemes:
|
if CONFIG.use_phonemes:
|
||||||
seq = np.asarray(
|
seq = np.asarray(
|
||||||
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language),
|
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
else:
|
else:
|
||||||
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||||
|
# synthesize voice
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
chars_var = chars_var.cuda()
|
chars_var = chars_var.cuda()
|
||||||
# chars_var = chars_var[:-1]
|
|
||||||
if truncated:
|
if truncated:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||||
chars_var.long())
|
chars_var.long())
|
||||||
else:
|
else:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
chars_var.long())
|
chars_var.long())
|
||||||
|
# convert outputs to numpy
|
||||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||||
alignment = alignments[0].cpu().data.numpy()
|
alignment = alignments[0].cpu().data.numpy()
|
||||||
|
# plot results
|
||||||
if CONFIG.model == "Tacotron":
|
if CONFIG.model == "Tacotron":
|
||||||
wav = ap.inv_spectrogram(postnet_output.T)
|
wav = ap.inv_spectrogram(postnet_output.T)
|
||||||
else:
|
else:
|
||||||
wav = ap.inv_mel_spectrogram(postnet_output.T)
|
wav = ap.inv_mel_spectrogram(postnet_output.T)
|
||||||
# wav = wav[:ap.find_endpoint(wav)]
|
# trim silence
|
||||||
|
if trim_silence:
|
||||||
|
wav = wav[:ap.find_endpoint(wav)]
|
||||||
return wav, alignment, decoder_output, postnet_output, stop_tokens
|
return wav, alignment, decoder_output, postnet_output, stop_tokens
|
|
@ -45,8 +45,10 @@ def text2phone(text, language):
|
||||||
return ph
|
return ph
|
||||||
|
|
||||||
|
|
||||||
def phoneme_to_sequence(text, cleaner_names, language):
|
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
|
||||||
# sequence = [_phonemes_to_id['^']]
|
if enable_eos_bos:
|
||||||
|
sequence = [_phonemes_to_id['^']]
|
||||||
|
else:
|
||||||
sequence = []
|
sequence = []
|
||||||
clean_text = _clean_text(text, cleaner_names)
|
clean_text = _clean_text(text, cleaner_names)
|
||||||
phonemes = text2phone(clean_text, language)
|
phonemes = text2phone(clean_text, language)
|
||||||
|
@ -56,7 +58,8 @@ def phoneme_to_sequence(text, cleaner_names, language):
|
||||||
for phoneme in filter(None, phonemes.split('|')):
|
for phoneme in filter(None, phonemes.split('|')):
|
||||||
sequence += _phoneme_to_sequence(phoneme)
|
sequence += _phoneme_to_sequence(phoneme)
|
||||||
# Append EOS char
|
# Append EOS char
|
||||||
# sequence.append(_phonemes_to_id['~'])
|
if enable_eos_bos:
|
||||||
|
sequence.append(_phonemes_to_id['~'])
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +87,6 @@ def text_to_sequence(text, cleaner_names):
|
||||||
List of integers corresponding to the symbols in the text
|
List of integers corresponding to the symbols in the text
|
||||||
'''
|
'''
|
||||||
sequence = []
|
sequence = []
|
||||||
# sequence = [_phonemes_to_id['^']]
|
|
||||||
# Check for curly braces and treat their contents as ARPAbet:
|
# Check for curly braces and treat their contents as ARPAbet:
|
||||||
while len(text):
|
while len(text):
|
||||||
m = _curly_re.match(text)
|
m = _curly_re.match(text)
|
||||||
|
@ -95,9 +97,6 @@ def text_to_sequence(text, cleaner_names):
|
||||||
_clean_text(m.group(1), cleaner_names))
|
_clean_text(m.group(1), cleaner_names))
|
||||||
sequence += _arpabet_to_sequence(m.group(2))
|
sequence += _arpabet_to_sequence(m.group(2))
|
||||||
text = m.group(3)
|
text = m.group(3)
|
||||||
|
|
||||||
# Append EOS token
|
|
||||||
# sequence.append(_symbol_to_id['~'])
|
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
|
||||||
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
|
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
|
||||||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||||
if CONFIG.use_phonemes:
|
if CONFIG.use_phonemes:
|
||||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language)
|
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars)
|
||||||
text = sequence_to_phoneme(seq)
|
text = sequence_to_phoneme(seq)
|
||||||
plt.yticks(range(len(text)), list(text))
|
plt.yticks(range(len(text)), list(text))
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
|
|
Loading…
Reference in New Issue