Merge branch 'save_characters' into dev

This commit is contained in:
Eren Gölge 2021-02-15 12:07:28 +00:00
commit 7f58fa365b
15 changed files with 117 additions and 87 deletions

View File

@ -25,7 +25,7 @@ jobs:
- checkout
- run: |
sudo apt update
sudo apt install espeak git
sudo apt install espeak-ng git
- run: sudo pip install --upgrade pip
- run: sudo pip install -e .
- run: |

View File

@ -7,16 +7,15 @@ from TTS.tts.datasets.preprocess import get_preprocessor_by_name
def main():
# pylint: disable=bad-continuation
parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n'''
'''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\
'''
Example runs:
python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv
''',
formatter_class=RawTextHelpFormatter)
''', formatter_class=RawTextHelpFormatter)
parser.add_argument(
'--dataset',
@ -36,13 +35,13 @@ def main():
preprocessor = get_preprocessor_by_name(args.dataset)
items = preprocessor(os.path.dirname(args.meta_file), os.path.basename(args.meta_file))
texts = " ".join([item[0] for item in items])
texts = "".join(item[0] for item in items)
chars = set(texts)
lower_chars = set(texts.lower())
lower_chars = filter(lambda c: c.islower(), chars)
print(f" > Number of unique characters: {len(chars)}")
print(f" > Unique characters: {''.join(sorted(chars))}")
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
if __name__ == "__main__":
main()
main()

View File

@ -268,7 +268,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
model_loss=loss_dict['loss'])
# wait all kernels to be completed
@ -467,7 +467,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
@ -477,7 +477,10 @@ def main(args): # pylint: disable=redefined-outer-name
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
@ -559,7 +562,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH)
OUT_PATH, model_characters)
if __name__ == '__main__':

View File

@ -247,7 +247,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
model_loss=loss_dict['loss'])
# wait all kernels to be completed
@ -431,7 +431,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
@ -441,7 +441,10 @@ def main(args): # pylint: disable=redefined-outer-name
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
@ -523,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
global_step, epoch, c.r,
OUT_PATH)
OUT_PATH, model_characters)
if __name__ == '__main__':

View File

@ -284,6 +284,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict['postnet_loss'],
characters=model_characters,
scaler=scaler.state_dict() if c.mixed_precision else None)
# Diagnostic visualizations
@ -492,9 +493,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters
# Audio processor
ap = AudioProcessor(**c.audio)
# setup custom characters if set in config file.
if 'characters' in c.keys():
symbols, phonemes = make_symbols(**c.characters)
@ -503,6 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model_characters = phonemes if c.use_phonemes else symbols
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
@ -634,6 +638,7 @@ def main(args): # pylint: disable=redefined-outer-name
epoch,
c.r,
OUT_PATH,
model_characters,
scaler=scaler.state_dict() if c.mixed_precision else None
)

View File

@ -39,6 +39,7 @@ class MyDataset(Dataset):
compute_linear_spec (bool): compute linear spectrogram if True.
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
tp (dict): dict of custom text characters used for converting texts to sequences.
batch_group_size (int): (0) range of batch randomization after sorting
sequences by length.
min_seq_len (int): (0) minimum sequence length to be processed

View File

@ -38,7 +38,15 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False
return model, state
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
def save_model(model,
optimizer,
current_step,
epoch,
r,
output_path,
characters,
amp_state_dict=None,
**kwargs):
"""Save ```TTS.tts.models``` states with extra fields.
Args:
@ -48,6 +56,7 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
"""
if hasattr(model, 'module'):
@ -60,7 +69,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
'step': current_step,
'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r
'r': r,
'characters': characters
}
if amp_state_dict:
state['amp'] = amp_state_dict
@ -68,7 +78,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
torch.save(state, output_path)
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
characters, **kwargs):
"""Save model checkpoint, intended for saving checkpoints at training.
Args:
@ -78,14 +89,16 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
"""
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, **kwargs)
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
def save_best_model(target_loss, best_loss, model, optimizer, current_step,
epoch, r, output_folder, characters, **kwargs):
"""Save model checkpoint, intended for saving the best model after each epoch.
It compares the current model loss with the best loss so far and saves the
model if the current loss is better.
@ -99,6 +112,7 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
Returns:
float: updated current best loss.
@ -107,6 +121,6 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
file_name = 'best_model.pth.tar'
checkpoint_path = os.path.join(output_folder, file_name)
print(" >> BEST MODEL : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, model_loss=target_loss, **kwargs)
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs)
best_loss = target_loss
return best_loss

View File

@ -6,7 +6,7 @@ import phonemizer
from packaging import version
from phonemizer.phonemize import phonemize
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.symbols import (_bos, _eos, _phoneme_punctuations,
from TTS.tts.utils.text.symbols import (_bos, _eos, _punctuations,
make_symbols, phonemes, symbols)
@ -24,7 +24,7 @@ _phonemes = phonemes
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
# Regular expression matching punctuations, ignoring empty space
PHONEME_PUNCTUATION_PATTERN = r'['+_phoneme_punctuations+']+'
PHONEME_PUNCTUATION_PATTERN = r'['+_punctuations.replace(' ', '')+']+'
def text2phone(text, language):

View File

@ -5,6 +5,8 @@ Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
'''
def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name
''' Function to create symbols and phonemes '''
_symbols = [pad, eos, bos] + list(characters)
@ -18,15 +20,13 @@ def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_'
_symbols += _arpabet
return _symbols, _phonemes
_pad = '_'
_eos = '~'
_bos = '^'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
_punctuations = '!\'(),-.:;? '
_phoneme_punctuations = '.!;:,?'
# Phonemes definition
# Phonemes definition (All IPA characters)
_vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
_pulmonic_consonants = 'pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
@ -41,6 +41,16 @@ symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _e
# from random import shuffle
# shuffle(phonemes)
def parse_symbols():
return {'pad': _pad,
'eos': _eos,
'bos': _bos,
'characters': _characters,
'punctuations': _punctuations,
'phonemes': _phonemes}
if __name__ == '__main__':
print(" > TTS symbols {}".format(len(symbols)))
print(symbols)

View File

@ -3,17 +3,16 @@
"""Argument parser for training scripts."""
import argparse
import re
import glob
import os
from TTS.utils.generic_utils import (
create_experiment_folder, get_git_branch)
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
import re
from TTS.tts.utils.generic_utils import check_config_tts
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
def parse_arguments(argv):
@ -110,38 +109,27 @@ def get_last_checkpoint(path):
def process_args(args, model_type):
"""Process parsed comand line arguments.
Parameters
----------
args : argparse.Namespace or dict like
Parsed input arguments.
model_type : str
Model type used to check config parameters and setup the TensorBoard
logger. One of:
- tacotron
- glow_tts
- speedy_speech
- gan
- wavegrad
- wavernn
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
model_type (str): Model type used to check config parameters and setup the TensorBoard
logger. One of:
- tacotron
- glow_tts
- speedy_speech
- gan
- wavegrad
- wavernn
Raises
------
ValueError
If `model_type` is not one of implemented choices.
Returns
-------
c : TTS.utils.io.AttrDict
Config paramaters.
out_path : str
Path to save models and logging.
audio_path : str
Path to save generated test audios.
c_logger : TTS.utils.console_logger.ConsoleLogger
Class that does logging to the console.
tb_logger : TTS.utils.tensorboard.TensorboardLogger
Class that does the TensorBoard loggind.
Raises:
ValueError
If `model_type` is not one of implemented choices.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard loggind.
"""
if args.continue_path != "":
args.output_path = args.continue_path
@ -156,7 +144,6 @@ def process_args(args, model_type):
# setup output paths and read configs
c = load_config(args.config_path)
if model_type in "tacotron glow_tts speedy_speech":
model_class = "TTS"
elif model_type in "gan wavegrad wavernn":
@ -192,6 +179,12 @@ def process_args(args, model_type):
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if model_class == 'TTS' and not 'characters' in c:
used_characters = parse_symbols()
new_fields['characters'] = used_characters
copy_model_files(c, args.config_path,
out_path, new_fields)
os.chmod(audio_path, 0o775)

View File

@ -67,7 +67,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
if isinstance(value, str):
new_line = '"{}":"{}",\n'.format(key, value)
else:
new_line = '"{}":{},\n'.format(key, value)
new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False))
config_lines.insert(1, new_line)
config_out_file = open(copy_config_path, "w")
config_out_file.writelines(config_lines)

View File

@ -128,7 +128,8 @@ class ModelManager(object):
"""Download files from GDrive using their file ids"""
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
def _download_zip_file(self, file_url, output):
@staticmethod
def _download_zip_file(file_url, output):
"""Download the target zip file and extract the files
to a folder with the same name as the zip file."""
r = requests.get(file_url)

View File

@ -1,11 +1,11 @@
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak-ng
import torch
from TTS.utils.synthesizer import Synthesizer
from TTS.utils.manage import ModelManager
def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder_models/en/ljspeech/mulitband-melgan', use_cuda=False):
def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name=None, use_cuda=False):
"""TTS entry point for PyTorch Hub that provides a Synthesizer object to synthesize speech from a give text.
Example:
@ -15,7 +15,7 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder
Args:
model_name (str, optional): One of the model names from .model.json. Defaults to 'tts_models/en/ljspeech/tacotron2-DCA'.
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/mulitband-melgan'.
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/multiband-melgan'.
pretrained (bool, optional): [description]. Defaults to True.
Returns:
@ -23,8 +23,9 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder
"""
manager = ModelManager()
model_path, config_path = manager.download_model(model_name)
vocoder_path, vocoder_config_path = manager.download_model(vocoder_name)
model_path, config_path, model_item = manager.download_model(model_name)
vocoder_name = model_item['default_vocoder'] if vocoder_name is None else vocoder_name
vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
# create synthesizer
synt = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, use_cuda)

View File

@ -21,7 +21,7 @@ class DemoServerTest(unittest.TestCase):
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
model = setup_model(num_chars, 0, config)
output_path = os.path.join(get_tests_output_path())
save_checkpoint(model, None, 10, 10, 1, output_path)
save_checkpoint(model, None, 10, 10, 1, output_path, None)
def test_in_out(self):
self._create_random_model()

View File

@ -19,7 +19,7 @@ def test_phoneme_to_sequence():
text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
gt = 'ɹiːsənt ɹᵻːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!'
assert text_hat == text_hat_with_params == gt
# multiple punctuations
@ -28,7 +28,7 @@ def test_phoneme_to_sequence():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -39,7 +39,7 @@ def test_phoneme_to_sequence():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -61,7 +61,7 @@ def test_phoneme_to_sequence():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ."
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -72,7 +72,7 @@ def test_phoneme_to_sequence():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
gt = "^biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.~"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -83,7 +83,7 @@ def test_phoneme_to_sequence():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -97,7 +97,7 @@ def test_phoneme_to_sequence_with_blank_token():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
gt = "ɹiːsənt ɹːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡleɪʃən ænd lɜːnɪŋ!"
assert text_hat == text_hat_with_params == gt
# multiple punctuations
@ -106,7 +106,7 @@ def test_phoneme_to_sequence_with_blank_token():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?'
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -117,7 +117,7 @@ def test_phoneme_to_sequence_with_blank_token():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ'
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -128,7 +128,7 @@ def test_phoneme_to_sequence_with_blank_token():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
gt = 'biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!'
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -139,7 +139,7 @@ def test_phoneme_to_sequence_with_blank_token():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.'
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -150,7 +150,7 @@ def test_phoneme_to_sequence_with_blank_token():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
gt = "^biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.~"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
@ -161,14 +161,14 @@ def test_phoneme_to_sequence_with_blank_token():
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
def test_text2phone():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
gt = 'ɹ|iː|s|ə|n|t| |ɹ|ᵻ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|ŋ|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ᵻ|s|p|ɑː|n|s|ᵻ|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|ʊ|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!'
lang = "en-us"
ph = text2phone(text, lang)
assert gt == ph