mirror of https://github.com/coqui-ai/TTS.git
Make eos bos chars optional
This commit is contained in:
parent
b44aa872c7
commit
9466505f27
|
@ -26,6 +26,7 @@ class MyDataset(Dataset):
|
|||
use_phonemes=True,
|
||||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us",
|
||||
enable_eos_bos=False,
|
||||
verbose=False):
|
||||
"""
|
||||
Args:
|
||||
|
@ -48,6 +49,7 @@ class MyDataset(Dataset):
|
|||
phoneme_cache_path (str): path to cache phoneme features.
|
||||
phoneme_language (str): one the languages from
|
||||
https://github.com/bootphon/phonemizer#languages
|
||||
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
self.root_path = root_path
|
||||
|
@ -63,6 +65,7 @@ class MyDataset(Dataset):
|
|||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
self.enable_eos_bos = enable_eos_bos
|
||||
self.verbose = verbose
|
||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||
|
@ -98,13 +101,13 @@ class MyDataset(Dataset):
|
|||
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
|
||||
text = np.asarray(
|
||||
phoneme_to_sequence(
|
||||
text, [self.cleaners], language=self.phoneme_language),
|
||||
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
|
||||
dtype=np.int32)
|
||||
np.save(tmp_path, text)
|
||||
else:
|
||||
text = np.asarray(
|
||||
phoneme_to_sequence(
|
||||
text, [self.cleaners], language=self.phoneme_language),
|
||||
text, [self.cleaners], language=self.phoneme_language, enable_eos_bos=self.enable_eos_bos),
|
||||
dtype=np.int32)
|
||||
np.save(tmp_path, text)
|
||||
return text
|
||||
|
|
|
@ -9,14 +9,14 @@ def test_phoneme_to_sequence():
|
|||
lang = "en-us"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
gt = "^ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
||||
gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
||||
assert text_hat == gt
|
||||
|
||||
# multiple punctuations
|
||||
text = "Be a voice, not an! echo?"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
|
@ -25,7 +25,7 @@ def test_phoneme_to_sequence():
|
|||
text = "Be a voice, not an! echo"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
|
@ -34,7 +34,7 @@ def test_phoneme_to_sequence():
|
|||
text = "Be a voice, not an echo!"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
|
@ -43,7 +43,16 @@ def test_phoneme_to_sequence():
|
|||
text = "Be a voice, not an! echo. "
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
|
||||
# extra space after the sentence
|
||||
text = "Be a voice, not an! echo. "
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
|
||||
text_hat = sequence_to_phoneme(sequence)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == gt
|
||||
|
|
1
train.py
1
train.py
|
@ -57,6 +57,7 @@ def setup_loader(is_val=False, verbose=False):
|
|||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
verbose=verbose)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
|
|
|
@ -8,7 +8,7 @@ from .visual import visualize
|
|||
from matplotlib import pylab as plt
|
||||
|
||||
|
||||
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
|
||||
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False, enable_eos_bos_chars=False, trim_silence=False):
|
||||
"""Synthesize voice for the given text.
|
||||
|
||||
Args:
|
||||
|
@ -20,30 +20,37 @@ def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
|
|||
model outputs.
|
||||
truncated (bool): keep model states after inference. It can be used
|
||||
for continuous inference at long texts.
|
||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||
trim_silence (bool): trim silence after synthesis.
|
||||
"""
|
||||
# preprocess the given text
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
if CONFIG.use_phonemes:
|
||||
seq = np.asarray(
|
||||
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language),
|
||||
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, enable_eos_bos_chars),
|
||||
dtype=np.int32)
|
||||
else:
|
||||
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
# synthesize voice
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
# chars_var = chars_var[:-1]
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
chars_var.long())
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
chars_var.long())
|
||||
# convert outputs to numpy
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
# plot results
|
||||
if CONFIG.model == "Tacotron":
|
||||
wav = ap.inv_spectrogram(postnet_output.T)
|
||||
else:
|
||||
wav = ap.inv_mel_spectrogram(postnet_output.T)
|
||||
# wav = wav[:ap.find_endpoint(wav)]
|
||||
# trim silence
|
||||
if trim_silence:
|
||||
wav = wav[:ap.find_endpoint(wav)]
|
||||
return wav, alignment, decoder_output, postnet_output, stop_tokens
|
|
@ -45,9 +45,11 @@ def text2phone(text, language):
|
|||
return ph
|
||||
|
||||
|
||||
def phoneme_to_sequence(text, cleaner_names, language):
|
||||
# sequence = [_phonemes_to_id['^']]
|
||||
sequence = []
|
||||
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
|
||||
if enable_eos_bos:
|
||||
sequence = [_phonemes_to_id['^']]
|
||||
else:
|
||||
sequence = []
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
phonemes = text2phone(clean_text, language)
|
||||
if phonemes is None:
|
||||
|
@ -56,7 +58,8 @@ def phoneme_to_sequence(text, cleaner_names, language):
|
|||
for phoneme in filter(None, phonemes.split('|')):
|
||||
sequence += _phoneme_to_sequence(phoneme)
|
||||
# Append EOS char
|
||||
# sequence.append(_phonemes_to_id['~'])
|
||||
if enable_eos_bos:
|
||||
sequence.append(_phonemes_to_id['~'])
|
||||
return sequence
|
||||
|
||||
|
||||
|
@ -84,7 +87,6 @@ def text_to_sequence(text, cleaner_names):
|
|||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = []
|
||||
# sequence = [_phonemes_to_id['^']]
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
|
@ -95,9 +97,6 @@ def text_to_sequence(text, cleaner_names):
|
|||
_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
# Append EOS token
|
||||
# sequence.append(_symbol_to_id['~'])
|
||||
return sequence
|
||||
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
|
|||
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
|
||||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||
if CONFIG.use_phonemes:
|
||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language)
|
||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars)
|
||||
text = sequence_to_phoneme(seq)
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
|
|
Loading…
Reference in New Issue