diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 5cd23ce4..9db2381e 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- +"""Train Glow TTS model.""" -import argparse -import glob import os import sys import time @@ -14,10 +12,12 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler + +from TTS.utils.arguments import parse_arguments, process_args from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss -from TTS.tts.utils.generic_utils import check_config_tts, setup_model +from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import parse_speakers @@ -25,18 +25,15 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.utils.console_logger import ConsoleLogger from TTS.utils.distribute import init_distributed, reduce_tensor from TTS.utils.generic_utils import (KeepAverage, count_parameters, - create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) -from TTS.utils.io import copy_model_files, load_config from TTS.utils.radam import RAdam -from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import NoamLR, setup_torch_training_env use_cuda, num_gpus = setup_torch_training_env(True, False) + def setup_loader(ap, r, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None @@ -468,7 +465,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): return keep_avg.avg_values -# FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping @@ -567,81 +563,9 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--continue_path', - type=str, - help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default='', - required='--config_path' not in sys.argv) - parser.add_argument( - '--restore_path', - type=str, - help='Model file to be restored. Use to finetune a model.', - default='') - parser.add_argument( - '--config_path', - type=str, - help='Path to config file for training.', - required='--continue_path' not in sys.argv - ) - parser.add_argument('--debug', - type=bool, - default=False, - help='Do not verify commit integrity to run training.') - - # DISTRUBUTED - parser.add_argument( - '--rank', - type=int, - default=0, - help='DISTRIBUTED: process rank for distributed training.') - parser.add_argument('--group_id', - type=str, - default="", - help='DISTRIBUTED: process group id.') - args = parser.parse_args() - - if args.continue_path != '': - args.output_path = args.continue_path - args.config_path = os.path.join(args.continue_path, 'config.json') - list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv - latest_model_file = max(list_of_files, key=os.path.getctime) - args.restore_path = latest_model_file - print(f" > Training continues for {args.restore_path}") - - # setup output paths and read configs - c = load_config(args.config_path) - # check_config(c) - check_config_tts(c) - _ = os.path.dirname(os.path.realpath(__file__)) - - if c.mixed_precision: - print(" > Mixed precision enabled.") - - OUT_PATH = args.continue_path - if args.continue_path == '': - OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) - - AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') - - c_logger = ConsoleLogger() - - if args.rank == 0: - os.makedirs(AUDIO_PATH, exist_ok=True) - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files(c, args.config_path, OUT_PATH, new_fields) - os.chmod(AUDIO_PATH, 0o775) - os.chmod(OUT_PATH, 0o775) - - LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS') - - # write model desc to tensorboard - tb_logger.tb_add_text('model-description', c['run_description'], 0) + args = parse_arguments(sys.argv) + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( + args, model_type='glow_tts') try: main(args) diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py index 667f5abd..a9a83bbf 100644 --- a/TTS/bin/train_speedy_speech.py +++ b/TTS/bin/train_speedy_speech.py @@ -11,6 +11,7 @@ import numpy as np from random import randrange import torch +from TTS.utils.arguments import parse_arguments, process_args # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader @@ -18,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import SpeedySpeechLoss -from TTS.tts.utils.generic_utils import check_config_tts, setup_model +from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import parse_speakers @@ -26,14 +27,10 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.utils.console_logger import ConsoleLogger from TTS.utils.distribute import init_distributed, reduce_tensor from TTS.utils.generic_utils import (KeepAverage, count_parameters, - create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) -from TTS.utils.io import copy_model_files, load_config from TTS.utils.radam import RAdam -from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import NoamLR, setup_torch_training_env use_cuda, num_gpus = setup_torch_training_env(True, False) @@ -524,86 +521,15 @@ def main(args): # pylint: disable=redefined-outer-name target_loss = train_avg_loss_dict['avg_loss'] if c.run_eval: target_loss = eval_avg_loss_dict['avg_loss'] - best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, + best_loss = save_best_model(target_loss, best_loss, model, optimizer, + global_step, epoch, c.r, OUT_PATH) if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--continue_path', - type=str, - help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default='', - required='--config_path' not in sys.argv) - parser.add_argument( - '--restore_path', - type=str, - help='Model file to be restored. Use to finetune a model.', - default='') - parser.add_argument( - '--config_path', - type=str, - help='Path to config file for training.', - required='--continue_path' not in sys.argv - ) - parser.add_argument('--debug', - type=bool, - default=False, - help='Do not verify commit integrity to run training.') - - # DISTRUBUTED - parser.add_argument( - '--rank', - type=int, - default=0, - help='DISTRIBUTED: process rank for distributed training.') - parser.add_argument('--group_id', - type=str, - default="", - help='DISTRIBUTED: process group id.') - args = parser.parse_args() - - if args.continue_path != '': - args.output_path = args.continue_path - args.config_path = os.path.join(args.continue_path, 'config.json') - list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv - latest_model_file = max(list_of_files, key=os.path.getctime) - args.restore_path = latest_model_file - print(f" > Training continues for {args.restore_path}") - - # setup output paths and read configs - c = load_config(args.config_path) - # check_config(c) - check_config_tts(c) - _ = os.path.dirname(os.path.realpath(__file__)) - - if c.mixed_precision: - print(" > Mixed precision enabled.") - - OUT_PATH = args.continue_path - if args.continue_path == '': - OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) - - AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') - - c_logger = ConsoleLogger() - - if args.rank == 0: - os.makedirs(AUDIO_PATH, exist_ok=True) - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files(c, args.config_path, OUT_PATH, new_fields) - os.chmod(AUDIO_PATH, 0o775) - os.chmod(OUT_PATH, 0o775) - - LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS') - - # write model desc to tensorboard - tb_logger.tb_add_text('model-description', c['run_description'], 0) + args = parse_arguments(sys.argv) + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( + args, model_type='tts') try: main(args) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 4640a3eb..0a53f2a1 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- +"""Trains Tacotron based TTS models.""" -import argparse -import glob import os import sys import time @@ -12,10 +10,11 @@ from random import randrange import numpy as np import torch from torch.utils.data import DataLoader +from TTS.utils.arguments import parse_arguments, process_args from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import TacotronLoss -from TTS.tts.utils.generic_utils import check_config_tts, setup_model +from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import parse_speakers @@ -23,15 +22,11 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.utils.console_logger import ConsoleLogger from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) from TTS.utils.generic_utils import (KeepAverage, count_parameters, - create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) -from TTS.utils.io import copy_model_files, load_config from TTS.utils.radam import RAdam -from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, gradual_training_scheduler, set_weight_decay, setup_torch_training_env) @@ -61,7 +56,11 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + speaker_mapping=(speaker_mapping if ( + c.use_speaker_embedding + and c.use_external_speaker_embedding_file + ) else None) + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -491,7 +490,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): return keep_avg.avg_values -# FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping @@ -636,84 +634,14 @@ def main(args): # pylint: disable=redefined-outer-name epoch, c.r, OUT_PATH, - scaler=scaler.state_dict() if c.mixed_precision else None) + scaler=scaler.state_dict() if c.mixed_precision else None + ) if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--continue_path', - type=str, - help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default='', - required='--config_path' not in sys.argv) - parser.add_argument( - '--restore_path', - type=str, - help='Model file to be restored. Use to finetune a model.', - default='') - parser.add_argument( - '--config_path', - type=str, - help='Path to config file for training.', - required='--continue_path' not in sys.argv - ) - parser.add_argument('--debug', - type=bool, - default=False, - help='Do not verify commit integrity to run training.') - - # DISTRUBUTED - parser.add_argument( - '--rank', - type=int, - default=0, - help='DISTRIBUTED: process rank for distributed training.') - parser.add_argument('--group_id', - type=str, - default="", - help='DISTRIBUTED: process group id.') - args = parser.parse_args() - - if args.continue_path != '': - print(f" > Training continues for {args.continue_path}") - args.output_path = args.continue_path - args.config_path = os.path.join(args.continue_path, 'config.json') - list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv - latest_model_file = max(list_of_files, key=os.path.getctime) - args.restore_path = latest_model_file - - # setup output paths and read configs - c = load_config(args.config_path) - check_config_tts(c) - _ = os.path.dirname(os.path.realpath(__file__)) - - if c.mixed_precision: - print(" > Mixed precision mode is ON") - - OUT_PATH = args.continue_path - if args.continue_path == '': - OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) - - AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') - - c_logger = ConsoleLogger() - - if args.rank == 0: - os.makedirs(AUDIO_PATH, exist_ok=True) - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files(c, args.config_path, OUT_PATH, new_fields) - os.chmod(AUDIO_PATH, 0o775) - os.chmod(OUT_PATH, 0o775) - - LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS') - - # write model desc to tensorboard - tb_logger.tb_add_text('model-description', c['run_description'], 0) + args = parse_arguments(sys.argv) + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( + args, model_type='tacotron') try: main(args) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index a1d1b322..1f2beb70 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -1,5 +1,6 @@ -import argparse -import glob +#!/usr/bin/env python3 +"""Trains GAN based vocoder model.""" + import os import sys import time @@ -8,14 +9,13 @@ from inspect import signature import torch from torch.utils.data import DataLoader +from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor -from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import (KeepAverage, count_parameters, - create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) -from TTS.utils.io import copy_model_files, load_config + from TTS.utils.radam import RAdam -from TTS.utils.tensorboard_logger import TensorboardLogger + from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data @@ -439,7 +439,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) return keep_avg.avg_values -# FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data @@ -506,7 +505,7 @@ def main(args): # pylint: disable=redefined-outer-name scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) scheduler_disc.optimizer = optimizer_disc except RuntimeError: - # retore only matching layers. + # restore only matching layers. print(" > Partial model initialization...") model_dict = model_gen.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) @@ -556,7 +555,8 @@ def main(args): # pylint: disable=redefined-outer-name model_disc, criterion_disc, optimizer_disc, scheduler_gen, scheduler_disc, ap, global_step, epoch) - eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, + eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, + criterion_disc, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] @@ -575,78 +575,9 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--continue_path', - type=str, - help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default='', - required='--config_path' not in sys.argv) - parser.add_argument( - '--restore_path', - type=str, - help='Model file to be restored. Use to finetune a model.', - default='') - parser.add_argument('--config_path', - type=str, - help='Path to config file for training.', - required='--continue_path' not in sys.argv) - parser.add_argument('--debug', - type=bool, - default=False, - help='Do not verify commit integrity to run training.') - - # DISTRUBUTED - parser.add_argument( - '--rank', - type=int, - default=0, - help='DISTRIBUTED: process rank for distributed training.') - parser.add_argument('--group_id', - type=str, - default="", - help='DISTRIBUTED: process group id.') - args = parser.parse_args() - - if args.continue_path != '': - args.output_path = args.continue_path - args.config_path = os.path.join(args.continue_path, 'config.json') - list_of_files = glob.glob( - args.continue_path + - "/*.pth.tar") # * means all if need specific format then *.csv - latest_model_file = max(list_of_files, key=os.path.getctime) - args.restore_path = latest_model_file - print(f" > Training continues for {args.restore_path}") - - # setup output paths and read configs - c = load_config(args.config_path) - # check_config(c) - _ = os.path.dirname(os.path.realpath(__file__)) - - OUT_PATH = args.continue_path - if args.continue_path == '': - OUT_PATH = create_experiment_folder(c.output_path, c.run_name, - args.debug) - - AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') - - c_logger = ConsoleLogger() - - if args.rank == 0: - os.makedirs(AUDIO_PATH, exist_ok=True) - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files(c, args.config_path, OUT_PATH, new_fields) - os.chmod(AUDIO_PATH, 0o775) - os.chmod(OUT_PATH, 0o775) - - LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER') - - # write model desc to tensorboard - tb_logger.tb_add_text('model-description', c['run_description'], 0) + args = parse_arguments(sys.argv) + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( + args, model_type='gan') try: main(args) diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index c53612c2..d8dc88e1 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -1,5 +1,6 @@ -import argparse -import glob +#!/usr/bin/env python3 +"""Trains WaveGrad vocoder models.""" + import os import sys import time @@ -12,14 +13,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor -from TTS.utils.console_logger import ConsoleLogger from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import (KeepAverage, count_parameters, - create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict) -from TTS.utils.io import copy_model_files, load_config -from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset @@ -54,6 +52,7 @@ def setup_loader(ap, is_val=False, verbose=False): if is_val else c.num_loader_workers, pin_memory=False) + return loader @@ -195,18 +194,19 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, - optimizer, - scheduler, - None, - None, - None, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict, - scaler=scaler.state_dict() - if c.mixed_precision else None) + save_checkpoint( + model, + optimizer, + scheduler, + None, + None, + None, + global_step, + epoch, + OUT_PATH, + model_losses=loss_dict, + scaler=scaler.state_dict() if c.mixed_precision else None + ) end_time = time.time() @@ -247,6 +247,7 @@ def evaluate(model, criterion, ap, global_step, epoch): else: noise, x_noisy, noise_scale = model.compute_y_n(x) + # forward pass noise_hat = model(x_noisy, m, noise_scale) @@ -254,6 +255,7 @@ def evaluate(model, criterion, ap, global_step, epoch): loss = criterion(noise, noise_hat) loss_wavegrad_dict = {'wavegrad_loss': loss} + loss_dict = dict() for key, value in loss_wavegrad_dict.items(): if isinstance(value, (int, float)): @@ -415,87 +417,14 @@ def main(args): # pylint: disable=redefined-outer-name epoch, OUT_PATH, model_losses=eval_avg_loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None) + scaler=scaler.state_dict() if c.mixed_precision else None + ) if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--continue_path', - type=str, - help= - 'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default='', - required='--config_path' not in sys.argv) - parser.add_argument( - '--restore_path', - type=str, - help='Model file to be restored. Use to finetune a model.', - default='') - parser.add_argument('--config_path', - type=str, - help='Path to config file for training.', - required='--continue_path' not in sys.argv) - parser.add_argument('--debug', - type=bool, - default=False, - help='Do not verify commit integrity to run training.') - - # DISTRUBUTED - parser.add_argument( - '--rank', - type=int, - default=0, - help='DISTRIBUTED: process rank for distributed training.') - parser.add_argument('--group_id', - type=str, - default="", - help='DISTRIBUTED: process group id.') - args = parser.parse_args() - - if args.continue_path != '': - args.output_path = args.continue_path - args.config_path = os.path.join(args.continue_path, 'config.json') - list_of_files = glob.glob( - args.continue_path + - "/*.pth.tar") # * means all if need specific format then *.csv - latest_model_file = max(list_of_files, key=os.path.getctime) - args.restore_path = latest_model_file - print(f" > Training continues for {args.restore_path}") - - # setup output paths and read configs - c = load_config(args.config_path) - # check_config(c) - _ = os.path.dirname(os.path.realpath(__file__)) - - # DISTRIBUTED - if c.mixed_precision: - print(" > Mixed precision is enabled") - - OUT_PATH = args.continue_path - if args.continue_path == '': - OUT_PATH = create_experiment_folder(c.output_path, c.run_name, - args.debug) - - AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') - - c_logger = ConsoleLogger() - - if args.rank == 0: - os.makedirs(AUDIO_PATH, exist_ok=True) - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files(c, args.config_path, OUT_PATH, new_fields) - os.chmod(AUDIO_PATH, 0o775) - os.chmod(OUT_PATH, 0o775) - - LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER') - - # write model desc to tensorboard - tb_logger.tb_add_text('model-description', c['run_description'], 0) + args = parse_arguments(sys.argv) + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( + args, model_type='wavegrad') try: main(args) diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py index 7056b74a..b4ffe143 100644 --- a/TTS/bin/train_vocoder_wavernn.py +++ b/TTS/bin/train_vocoder_wavernn.py @@ -1,9 +1,10 @@ -import argparse +#!/usr/bin/env python3 +"""Train WaveRNN vocoder model.""" + import os import sys import traceback import time -import glob import random import torch @@ -11,18 +12,14 @@ from torch.utils.data import DataLoader # from torch.utils.data.distributed import DistributedSampler +from TTS.utils.arguments import parse_arguments, process_args from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.radam import RAdam -from TTS.utils.io import copy_model_files, load_config from TTS.utils.training import setup_torch_training_env -from TTS.utils.console_logger import ConsoleLogger -from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.generic_utils import ( KeepAverage, count_parameters, - create_experiment_folder, - get_git_branch, remove_experiment_folder, set_init_dict, ) @@ -181,18 +178,19 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, - optimizer, - scheduler, - None, - None, - None, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None - ) + save_checkpoint( + model, + optimizer, + scheduler, + None, + None, + None, + global_step, + epoch, + OUT_PATH, + model_losses=loss_dict, + scaler=scaler.state_dict() if c.mixed_precision else None + ) # synthesize a full voice rand_idx = random.randrange(0, len(train_data)) @@ -448,87 +446,9 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--continue_path", - type=str, - help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default="", - required="--config_path" not in sys.argv, - ) - parser.add_argument( - "--restore_path", - type=str, - help="Model file to be restored. Use to finetune a model.", - default="", - ) - parser.add_argument( - "--config_path", - type=str, - help="Path to config file for training.", - required="--continue_path" not in sys.argv, - ) - parser.add_argument( - "--debug", - type=bool, - default=False, - help="Do not verify commit integrity to run training.", - ) - - # DISTRUBUTED - parser.add_argument( - "--rank", - type=int, - default=0, - help="DISTRIBUTED: process rank for distributed training.", - ) - parser.add_argument( - "--group_id", type=str, default="", help="DISTRIBUTED: process group id." - ) - args = parser.parse_args() - - if args.continue_path != "": - args.output_path = args.continue_path - args.config_path = os.path.join(args.continue_path, "config.json") - list_of_files = glob.glob( - args.continue_path + "/*.pth.tar" - ) # * means all if need specific format then *.csv - latest_model_file = max(list_of_files, key=os.path.getctime) - args.restore_path = latest_model_file - print(f" > Training continues for {args.restore_path}") - - # setup output paths and read configs - c = load_config(args.config_path) - # check_config(c) - _ = os.path.dirname(os.path.realpath(__file__)) - - OUT_PATH = args.continue_path - if args.continue_path == "": - OUT_PATH = create_experiment_folder( - c.output_path, c.run_name, args.debug - ) - - AUDIO_PATH = os.path.join(OUT_PATH, "test_audios") - - c_logger = ConsoleLogger() - - if args.rank == 0: - os.makedirs(AUDIO_PATH, exist_ok=True) - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - copy_model_files( - c, args.config_path, OUT_PATH, new_fields - ) - os.chmod(AUDIO_PATH, 0o775) - os.chmod(OUT_PATH, 0o775) - - LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") - - # write model desc to tensorboard - tb_logger.tb_add_text("model-description", c["run_description"], 0) + args = parse_arguments(sys.argv) + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( + args, model_type='wavernn') try: main(args) diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py new file mode 100644 index 00000000..948c90d3 --- /dev/null +++ b/TTS/utils/arguments.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Argument parser for training scripts.""" + +import argparse +import re +import glob +import os + +from TTS.utils.generic_utils import ( + create_experiment_folder, get_git_branch) +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.io import copy_model_files, load_config +from TTS.utils.tensorboard_logger import TensorboardLogger + +from TTS.tts.utils.generic_utils import check_config_tts + + +def parse_arguments(argv): + """Parse command line arguments of training scripts. + + Parameters + ---------- + argv : list + This is a list of input arguments as given by sys.argv + + Returns + ------- + argparse.Namespace + Parsed arguments. + + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--continue_path", + type=str, + help=("Training output folder to continue training. Used to continue " + "a training. If it is used, 'config_path' is ignored."), + default="", + required="--config_path" not in argv) + parser.add_argument( + "--restore_path", + type=str, + help="Model file to be restored. Use to finetune a model.", + default="") + parser.add_argument( + "--config_path", + type=str, + help="Path to config file for training.", + required="--continue_path" not in argv) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="Do not verify commit integrity to run training.") + parser.add_argument( + "--rank", + type=int, + default=0, + help="DISTRIBUTED: process rank for distributed training.") + parser.add_argument( + "--group_id", + type=str, + default="", + help="DISTRIBUTED: process group id.") + + return parser.parse_args() + + +def get_last_checkpoint(path): + """Get latest checkpoint from a list of filenames. + + It is based on globbing for `*.pth.tar` and the RegEx + `checkpoint_([0-9]+)`. + + Parameters + ---------- + path : list + Path to files to be compared. + + Raises + ------ + ValueError + If no checkpoint files are found. + + Returns + ------- + last_checkpoint : str + Last checkpoint filename. + + """ + last_checkpoint_num = 0 + last_checkpoint = None + filenames = glob.glob( + os.path.join(path, "/*.pth.tar")) + for filename in filenames: + try: + checkpoint_num = int( + re.search(r"checkpoint_([0-9]+)", filename).groups()[0]) + if checkpoint_num > last_checkpoint_num: + last_checkpoint_num = checkpoint_num + last_checkpoint = filename + except AttributeError: # if there's no match in the filename + pass + if last_checkpoint is None: + raise ValueError(f"No checkpoints in {path}!") + return last_checkpoint + + +def process_args(args, model_type): + """Process parsed comand line arguments. + + Parameters + ---------- + args : argparse.Namespace or dict like + Parsed input arguments. + model_type : str + Model type used to check config parameters and setup the TensorBoard + logger. One of: + - tacotron + - glow_tts + - speedy_speech + - gan + - wavegrad + - wavernn + + Raises + ------ + ValueError + If `model_type` is not one of implemented choices. + + Returns + ------- + c : TTS.utils.io.AttrDict + Config paramaters. + out_path : str + Path to save models and logging. + audio_path : str + Path to save generated test audios. + c_logger : TTS.utils.console_logger.ConsoleLogger + Class that does logging to the console. + tb_logger : TTS.utils.tensorboard.TensorboardLogger + Class that does the TensorBoard loggind. + + """ + if args.continue_path != "": + args.output_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + list_of_files = glob.glob( + os.path.join(args.continue_path, "*.pth.tar") + ) # * means all if need specific format then *.csv + args.restore_path = max(list_of_files, key=os.path.getctime) + # checkpoint number based continuing + # args.restore_path = get_last_checkpoint(args.continue_path) + print(f" > Training continues for {args.restore_path}") + + # setup output paths and read configs + c = load_config(args.config_path) + + if model_type in "tacotron glow_tts speedy_speech": + model_class = "TTS" + elif model_type in "gan wavegrad wavernn": + model_class = "VOCODER" + else: + raise ValueError("model type {model_type} not recognized!") + + if model_class == "TTS": + check_config_tts(c) + elif model_class == "VOCODER": + print("Vocoder config checker not implemented, " + "skipping ...") + else: + raise ValueError(f"model type {model_type} not recognized!") + + _ = os.path.dirname(os.path.realpath(__file__)) + + if model_type in "tacotron wavegrad wavernn" and c.mixed_precision: + print(" > Mixed precision mode is ON") + + out_path = args.continue_path + if args.continue_path == "": + out_path = create_experiment_folder(c.output_path, c.run_name, + args.debug) + + audio_path = os.path.join(out_path, "test_audios") + + c_logger = ConsoleLogger() + + if args.rank == 0: + os.makedirs(audio_path, exist_ok=True) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_model_files(c, args.config_path, + out_path, new_fields) + os.chmod(audio_path, 0o775) + os.chmod(out_path, 0o775) + + log_path = out_path + + tb_logger = TensorboardLogger(log_path, model_name=model_class) + + # write model desc to tensorboard + tb_logger.tb_add_text("model-description", c["run_description"], 0) + + return c, out_path, audio_path, c_logger, tb_logger diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 2f68c6e7..8c075d06 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -21,7 +21,7 @@ def test_phoneme_to_sequence(): text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" assert text_hat == text_hat_with_params == gt - + # multiple punctuations text = "Be a voice, not an! echo?" sequence = phoneme_to_sequence(text, text_cleaner, lang)