mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'save_characters' into dev
This commit is contained in:
commit
7f58fa365b
|
@ -25,7 +25,7 @@ jobs:
|
|||
- checkout
|
||||
- run: |
|
||||
sudo apt update
|
||||
sudo apt install espeak git
|
||||
sudo apt install espeak-ng git
|
||||
- run: sudo pip install --upgrade pip
|
||||
- run: sudo pip install -e .
|
||||
- run: |
|
||||
|
|
|
@ -7,16 +7,15 @@ from TTS.tts.datasets.preprocess import get_preprocessor_by_name
|
|||
|
||||
|
||||
def main():
|
||||
# pylint: disable=bad-continuation
|
||||
parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n'''
|
||||
|
||||
'''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\
|
||||
|
||||
'''
|
||||
Example runs:
|
||||
|
||||
python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv
|
||||
''',
|
||||
formatter_class=RawTextHelpFormatter)
|
||||
''', formatter_class=RawTextHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
|
@ -36,13 +35,13 @@ def main():
|
|||
|
||||
preprocessor = get_preprocessor_by_name(args.dataset)
|
||||
items = preprocessor(os.path.dirname(args.meta_file), os.path.basename(args.meta_file))
|
||||
texts = " ".join([item[0] for item in items])
|
||||
texts = "".join(item[0] for item in items)
|
||||
chars = set(texts)
|
||||
lower_chars = set(texts.lower())
|
||||
lower_chars = filter(lambda c: c.islower(), chars)
|
||||
print(f" > Number of unique characters: {len(chars)}")
|
||||
print(f" > Unique characters: {''.join(sorted(chars))}")
|
||||
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
|
|
@ -268,7 +268,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# wait all kernels to be completed
|
||||
|
@ -467,7 +467,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
|
@ -477,7 +477,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
@ -559,7 +562,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
OUT_PATH, model_characters)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -247,7 +247,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# wait all kernels to be completed
|
||||
|
@ -431,7 +431,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
|
@ -441,7 +441,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
||||
|
@ -523,7 +526,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
||||
global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
OUT_PATH, model_characters)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -284,6 +284,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict['postnet_loss'],
|
||||
characters=model_characters,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
|
||||
# Diagnostic visualizations
|
||||
|
@ -492,9 +493,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# setup custom characters if set in config file.
|
||||
if 'characters' in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
|
@ -503,6 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
@ -634,6 +638,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
epoch,
|
||||
c.r,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ class MyDataset(Dataset):
|
|||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
tp (dict): dict of custom text characters used for converting texts to sequences.
|
||||
batch_group_size (int): (0) range of batch randomization after sorting
|
||||
sequences by length.
|
||||
min_seq_len (int): (0) minimum sequence length to be processed
|
||||
|
|
|
@ -38,7 +38,15 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False
|
|||
return model, state
|
||||
|
||||
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
current_step,
|
||||
epoch,
|
||||
r,
|
||||
output_path,
|
||||
characters,
|
||||
amp_state_dict=None,
|
||||
**kwargs):
|
||||
"""Save ```TTS.tts.models``` states with extra fields.
|
||||
|
||||
Args:
|
||||
|
@ -48,6 +56,7 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
|
|||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
characters (list): list of characters used in the model.
|
||||
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
|
@ -60,7 +69,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
|
|||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': r
|
||||
'r': r,
|
||||
'characters': characters
|
||||
}
|
||||
if amp_state_dict:
|
||||
state['amp'] = amp_state_dict
|
||||
|
@ -68,7 +78,8 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
|
|||
torch.save(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
|
||||
characters, **kwargs):
|
||||
"""Save model checkpoint, intended for saving checkpoints at training.
|
||||
|
||||
Args:
|
||||
|
@ -78,14 +89,16 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
|
|||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
characters (list): list of characters used in the model.
|
||||
"""
|
||||
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, **kwargs)
|
||||
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
|
||||
|
||||
|
||||
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||
def save_best_model(target_loss, best_loss, model, optimizer, current_step,
|
||||
epoch, r, output_folder, characters, **kwargs):
|
||||
"""Save model checkpoint, intended for saving the best model after each epoch.
|
||||
It compares the current model loss with the best loss so far and saves the
|
||||
model if the current loss is better.
|
||||
|
@ -99,6 +112,7 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
|
|||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
characters (list): list of characters used in the model.
|
||||
|
||||
Returns:
|
||||
float: updated current best loss.
|
||||
|
@ -107,6 +121,6 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
|
|||
file_name = 'best_model.pth.tar'
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" >> BEST MODEL : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, model_loss=target_loss, **kwargs)
|
||||
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs)
|
||||
best_loss = target_loss
|
||||
return best_loss
|
||||
|
|
|
@ -6,7 +6,7 @@ import phonemizer
|
|||
from packaging import version
|
||||
from phonemizer.phonemize import phonemize
|
||||
from TTS.tts.utils.text import cleaners
|
||||
from TTS.tts.utils.text.symbols import (_bos, _eos, _phoneme_punctuations,
|
||||
from TTS.tts.utils.text.symbols import (_bos, _eos, _punctuations,
|
||||
make_symbols, phonemes, symbols)
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ _phonemes = phonemes
|
|||
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
||||
# Regular expression matching punctuations, ignoring empty space
|
||||
PHONEME_PUNCTUATION_PATTERN = r'['+_phoneme_punctuations+']+'
|
||||
PHONEME_PUNCTUATION_PATTERN = r'['+_punctuations.replace(' ', '')+']+'
|
||||
|
||||
|
||||
def text2phone(text, language):
|
||||
|
|
|
@ -5,6 +5,8 @@ Defines the set of symbols used in text input to the model.
|
|||
The default is a set of ASCII characters that works well for English or text that has been run
|
||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||
'''
|
||||
|
||||
|
||||
def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name
|
||||
''' Function to create symbols and phonemes '''
|
||||
_symbols = [pad, eos, bos] + list(characters)
|
||||
|
@ -18,15 +20,13 @@ def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_'
|
|||
_symbols += _arpabet
|
||||
return _symbols, _phonemes
|
||||
|
||||
|
||||
_pad = '_'
|
||||
_eos = '~'
|
||||
_bos = '^'
|
||||
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
|
||||
_punctuations = '!\'(),-.:;? '
|
||||
_phoneme_punctuations = '.!;:,?'
|
||||
|
||||
# Phonemes definition
|
||||
# Phonemes definition (All IPA characters)
|
||||
_vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
|
||||
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
|
||||
_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
|
||||
|
@ -41,6 +41,16 @@ symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _e
|
|||
# from random import shuffle
|
||||
# shuffle(phonemes)
|
||||
|
||||
|
||||
def parse_symbols():
|
||||
return {'pad': _pad,
|
||||
'eos': _eos,
|
||||
'bos': _bos,
|
||||
'characters': _characters,
|
||||
'punctuations': _punctuations,
|
||||
'phonemes': _phonemes}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(" > TTS symbols {}".format(len(symbols)))
|
||||
print(symbols)
|
||||
|
|
|
@ -3,17 +3,16 @@
|
|||
"""Argument parser for training scripts."""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import glob
|
||||
import os
|
||||
|
||||
from TTS.utils.generic_utils import (
|
||||
create_experiment_folder, get_git_branch)
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
import re
|
||||
|
||||
from TTS.tts.utils.generic_utils import check_config_tts
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
||||
|
||||
def parse_arguments(argv):
|
||||
|
@ -110,38 +109,27 @@ def get_last_checkpoint(path):
|
|||
def process_args(args, model_type):
|
||||
"""Process parsed comand line arguments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : argparse.Namespace or dict like
|
||||
Parsed input arguments.
|
||||
model_type : str
|
||||
Model type used to check config parameters and setup the TensorBoard
|
||||
logger. One of:
|
||||
- tacotron
|
||||
- glow_tts
|
||||
- speedy_speech
|
||||
- gan
|
||||
- wavegrad
|
||||
- wavernn
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
model_type (str): Model type used to check config parameters and setup the TensorBoard
|
||||
logger. One of:
|
||||
- tacotron
|
||||
- glow_tts
|
||||
- speedy_speech
|
||||
- gan
|
||||
- wavegrad
|
||||
- wavernn
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `model_type` is not one of implemented choices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
c : TTS.utils.io.AttrDict
|
||||
Config paramaters.
|
||||
out_path : str
|
||||
Path to save models and logging.
|
||||
audio_path : str
|
||||
Path to save generated test audios.
|
||||
c_logger : TTS.utils.console_logger.ConsoleLogger
|
||||
Class that does logging to the console.
|
||||
tb_logger : TTS.utils.tensorboard.TensorboardLogger
|
||||
Class that does the TensorBoard loggind.
|
||||
Raises:
|
||||
ValueError
|
||||
If `model_type` is not one of implemented choices.
|
||||
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console.
|
||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard loggind.
|
||||
"""
|
||||
if args.continue_path != "":
|
||||
args.output_path = args.continue_path
|
||||
|
@ -156,7 +144,6 @@ def process_args(args, model_type):
|
|||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
|
||||
if model_type in "tacotron glow_tts speedy_speech":
|
||||
model_class = "TTS"
|
||||
elif model_type in "gan wavegrad wavernn":
|
||||
|
@ -192,6 +179,12 @@ def process_args(args, model_type):
|
|||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if model_class == 'TTS' and not 'characters' in c:
|
||||
used_characters = parse_symbols()
|
||||
new_fields['characters'] = used_characters
|
||||
copy_model_files(c, args.config_path,
|
||||
out_path, new_fields)
|
||||
os.chmod(audio_path, 0o775)
|
||||
|
|
|
@ -67,7 +67,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
|
|||
if isinstance(value, str):
|
||||
new_line = '"{}":"{}",\n'.format(key, value)
|
||||
else:
|
||||
new_line = '"{}":{},\n'.format(key, value)
|
||||
new_line = '"{}":{},\n'.format(key, json.dumps(value, ensure_ascii=False))
|
||||
config_lines.insert(1, new_line)
|
||||
config_out_file = open(copy_config_path, "w")
|
||||
config_out_file.writelines(config_lines)
|
||||
|
|
|
@ -128,7 +128,8 @@ class ModelManager(object):
|
|||
"""Download files from GDrive using their file ids"""
|
||||
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||
|
||||
def _download_zip_file(self, file_url, output):
|
||||
@staticmethod
|
||||
def _download_zip_file(file_url, output):
|
||||
"""Download the target zip file and extract the files
|
||||
to a folder with the same name as the zip file."""
|
||||
r = requests.get(file_url)
|
||||
|
|
11
hubconf.py
11
hubconf.py
|
@ -1,11 +1,11 @@
|
|||
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak
|
||||
dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak-ng
|
||||
import torch
|
||||
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder_models/en/ljspeech/mulitband-melgan', use_cuda=False):
|
||||
def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name=None, use_cuda=False):
|
||||
"""TTS entry point for PyTorch Hub that provides a Synthesizer object to synthesize speech from a give text.
|
||||
|
||||
Example:
|
||||
|
@ -15,7 +15,7 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder
|
|||
|
||||
Args:
|
||||
model_name (str, optional): One of the model names from .model.json. Defaults to 'tts_models/en/ljspeech/tacotron2-DCA'.
|
||||
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/mulitband-melgan'.
|
||||
vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/multiband-melgan'.
|
||||
pretrained (bool, optional): [description]. Defaults to True.
|
||||
|
||||
Returns:
|
||||
|
@ -23,8 +23,9 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder
|
|||
"""
|
||||
manager = ModelManager()
|
||||
|
||||
model_path, config_path = manager.download_model(model_name)
|
||||
vocoder_path, vocoder_config_path = manager.download_model(vocoder_name)
|
||||
model_path, config_path, model_item = manager.download_model(model_name)
|
||||
vocoder_name = model_item['default_vocoder'] if vocoder_name is None else vocoder_name
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
|
||||
|
||||
# create synthesizer
|
||||
synt = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, use_cuda)
|
||||
|
|
|
@ -21,7 +21,7 @@ class DemoServerTest(unittest.TestCase):
|
|||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, 0, config)
|
||||
output_path = os.path.join(get_tests_output_path())
|
||||
save_checkpoint(model, None, 10, 10, 1, output_path)
|
||||
save_checkpoint(model, None, 10, 10, 1, output_path, None)
|
||||
|
||||
def test_in_out(self):
|
||||
self._create_random_model()
|
||||
|
|
|
@ -19,7 +19,7 @@ def test_phoneme_to_sequence():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
||||
gt = 'ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!'
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# multiple punctuations
|
||||
|
@ -28,7 +28,7 @@ def test_phoneme_to_sequence():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
|
||||
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -39,7 +39,7 @@ def test_phoneme_to_sequence():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -61,7 +61,7 @@ def test_phoneme_to_sequence():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
|
||||
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ."
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -72,7 +72,7 @@ def test_phoneme_to_sequence():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
|
||||
gt = "^biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.~"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -83,7 +83,7 @@ def test_phoneme_to_sequence():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -97,7 +97,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||
gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
||||
gt = "ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!"
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
# multiple punctuations
|
||||
|
@ -106,7 +106,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
|
||||
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?'
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -117,7 +117,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ'
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -128,7 +128,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
|
||||
gt = 'biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!'
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -139,7 +139,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
|
||||
gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.'
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -150,7 +150,7 @@ def test_phoneme_to_sequence_with_blank_token():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
|
||||
gt = "^biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.~"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
@ -161,14 +161,14 @@ def test_phoneme_to_sequence_with_blank_token():
|
|||
text_hat = sequence_to_phoneme(sequence)
|
||||
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
|
||||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
|
||||
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
|
||||
gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ"
|
||||
print(text_hat)
|
||||
print(len(sequence))
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
def test_text2phone():
|
||||
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
||||
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
||||
gt = 'ɹ|iː|s|ə|n|t| |ɹ|ᵻ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|ŋ|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ᵻ|s|p|ɑː|n|s|ᵻ|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|ʊ|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!'
|
||||
lang = "en-us"
|
||||
ph = text2phone(text, lang)
|
||||
assert gt == ph
|
||||
|
|
Loading…
Reference in New Issue