changed train scripts

This commit is contained in:
gerazov 2021-02-06 22:29:30 +01:00 committed by Eren Gölge
parent 2daca15802
commit 6f06e31541
6 changed files with 131 additions and 572 deletions

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- """Train Glow TTS model."""
import argparse
import glob
import os import os
import sys import sys
import time import time
@ -14,10 +12,12 @@ import torch
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.utils.arguments import parse_arguments, process_args
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.generic_utils import check_config_tts, setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.speakers import parse_speakers
@ -25,18 +25,15 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import init_distributed, reduce_tensor from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import NoamLR, setup_torch_training_env from TTS.utils.training import NoamLR, setup_torch_training_env
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False): def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not c.run_eval: if is_val and not c.run_eval:
loader = None loader = None
@ -119,7 +116,7 @@ def format_data(data):
avg_text_length, avg_spec_length, attn_mask, item_idx avg_text_length, avg_spec_length, attn_mask, item_idx
def data_depended_init(data_loader, model): def data_depended_init(data_loader, model, ap):
"""Data depended initialization for activation normalization.""" """Data depended initialization for activation normalization."""
if hasattr(model, 'module'): if hasattr(model, 'module'):
for f in model.module.decoder.flows: for f in model.module.decoder.flows:
@ -138,7 +135,7 @@ def data_depended_init(data_loader, model):
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\ text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
_, _, attn_mask, _ = format_data(data) _, _, attn_mask, item_idx = format_data(data)
# forward pass model # forward pass model
_ = model.forward( _ = model.forward(
@ -177,7 +174,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, attn_mask, _ = format_data(data) avg_text_length, avg_spec_length, attn_mask, item_idx = format_data(data)
loader_time = time.time() - end_time loader_time = time.time() - end_time
@ -332,7 +329,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
_, _, attn_mask, _ = format_data(data) _, _, attn_mask, item_idx = format_data(data)
# forward pass model # forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
@ -468,7 +465,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
return keep_avg.avg_values return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
@ -550,14 +546,13 @@ def main(args): # pylint: disable=redefined-outer-name
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
global_step = args.restore_step global_step = args.restore_step
model = data_depended_init(train_loader, model) model = data_depended_init(train_loader, model, ap)
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
scheduler, ap, global_step, scheduler, ap, global_step,
epoch) epoch)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss'] target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval: if c.run_eval:
@ -567,81 +562,9 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() args = parse_arguments(sys.argv)
parser.add_argument( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
'--continue_path', args, model_type='glow_tts')
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.mixed_precision:
print(" > Mixed precision enabled.")
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try: try:
main(args) main(args)

View File

@ -11,6 +11,7 @@ import numpy as np
from random import randrange from random import randrange
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
# DISTRIBUTED # DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -18,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import SpeedySpeechLoss from TTS.tts.layers.losses import SpeedySpeechLoss
from TTS.tts.utils.generic_utils import check_config_tts, setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.speakers import parse_speakers
@ -26,14 +27,10 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import init_distributed, reduce_tensor from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import NoamLR, setup_torch_training_env from TTS.utils.training import NoamLR, setup_torch_training_env
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
@ -518,8 +515,7 @@ def main(args): # pylint: disable=redefined-outer-name
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
scheduler, ap, global_step, scheduler, ap, global_step,
epoch) epoch)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss'] target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval: if c.run_eval:
@ -529,81 +525,9 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() args = parse_arguments(sys.argv)
parser.add_argument( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
'--continue_path', args, model_type='tts')
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.mixed_precision:
print(" > Mixed precision enabled.")
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try: try:
main(args) main(args)

View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- """Trains Tacotron based TTS models."""
import argparse
import glob
import os import os
import sys import sys
import time import time
@ -11,11 +9,12 @@ from random import randrange
import numpy as np import numpy as np
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.utils.generic_utils import check_config_tts, setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.speakers import parse_speakers
@ -23,15 +22,11 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor) init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
gradual_training_scheduler, set_weight_decay, gradual_training_scheduler, set_weight_decay,
setup_torch_training_env) setup_torch_training_env)
@ -61,7 +56,13 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
phoneme_language=c.phoneme_language, phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars, enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose, verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) speaker_mapping=(
speaker_mapping if (
c.use_speaker_embedding and
c.use_external_speaker_embedding_file
) else None
)
)
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
@ -180,8 +181,8 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
loss_dict = criterion(postnet_output, decoder_output, mel_input, loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets, linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output, mel_lengths, decoder_backward_output,
alignments, alignment_lengths, alignments, alignment_lengths, alignments_backward,
alignments_backward, text_lengths) text_lengths)
# check nan loss # check nan loss
if torch.isnan(loss_dict['loss']).any(): if torch.isnan(loss_dict['loss']).any():
@ -491,7 +492,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
return keep_avg.avg_values return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
@ -534,7 +534,8 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_st = None optimizer_st = None
# setup criterion # setup criterion
criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4) criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu') checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
@ -640,80 +641,9 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() args = parse_arguments(sys.argv)
parser.add_argument( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
'--continue_path', args, model_type='tacotron')
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
print(f" > Training continues for {args.continue_path}")
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
# setup output paths and read configs
c = load_config(args.config_path)
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.mixed_precision:
print(" > Mixed precision mode is ON")
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try: try:
main(args) main(args)

View File

@ -1,5 +1,6 @@
import argparse #!/usr/bin/env python3
import glob """Trains GAN based vocoder model."""
import os import os
import sys import sys
import time import time
@ -7,15 +8,14 @@ import traceback
from inspect import signature from inspect import signature
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
@ -33,8 +33,9 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False): def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None loader = None
if not is_val or c.run_eval: else:
dataset = GANDataset(ap=ap, dataset = GANDataset(ap=ap,
items=eval_data if is_val else train_data, items=eval_data if is_val else train_data,
seq_len=c.seq_len, seq_len=c.seq_len,
@ -113,7 +114,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
y_hat = model_G(c_G) y_hat = model_G(c_G)
y_hat_sub = None y_hat_sub = None
y_G_sub = None y_G_sub = None
y_hat_vis = y_hat # for visualization y_hat_vis = y_hat # for visualization # FIXME! .clone().detach()
# PQMF formatting # PQMF formatting
if y_hat.shape[1] > 1: if y_hat.shape[1] > 1:
@ -439,7 +440,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
return keep_avg.avg_values return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global train_data, eval_data global train_data, eval_data
@ -506,7 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
scheduler_disc.optimizer = optimizer_disc scheduler_disc.optimizer = optimizer_disc
except RuntimeError: except RuntimeError:
# retore only matching layers. # restore only matching layers.
print(" > Partial model initialization...") print(" > Partial model initialization...")
model_dict = model_gen.state_dict() model_dict = model_gen.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_dict = set_init_dict(model_dict, checkpoint['model'], c)
@ -556,7 +556,8 @@ def main(args): # pylint: disable=redefined-outer-name
model_disc, criterion_disc, optimizer_disc, model_disc, criterion_disc, optimizer_disc,
scheduler_gen, scheduler_disc, ap, global_step, scheduler_gen, scheduler_disc, ap, global_step,
epoch) epoch)
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
criterion_disc, ap,
global_step, epoch) global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss] target_loss = eval_avg_loss_dict[c.target_loss]
@ -575,78 +576,9 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() args = parse_arguments(sys.argv)
parser.add_argument( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
'--continue_path', args, model_type='gan')
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument('--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(
args.continue_path +
"/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try: try:
main(args) main(args)

View File

@ -1,5 +1,6 @@
import argparse #!/usr/bin/env python3
import glob """Trains WaveGrad vocoder models."""
import os import os
import sys import sys
import time import time
@ -7,19 +8,16 @@ import traceback
import numpy as np import numpy as np
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
# DISTRIBUTED # DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import init_distributed from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict) remove_experiment_folder, set_init_dict)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
@ -54,6 +52,7 @@ def setup_loader(ap, is_val=False, verbose=False):
if is_val else c.num_loader_workers, if is_val else c.num_loader_workers,
pin_memory=False) pin_memory=False)
return loader return loader
@ -78,8 +77,8 @@ def format_test_data(data):
return m, x return m, x
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, def train(model, criterion, optimizer,
epoch): scheduler, scaler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train() model.train()
epoch_time = 0 epoch_time = 0
@ -93,8 +92,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
c_logger.print_train_start() c_logger.print_train_start()
# setup noise schedule # setup noise schedule
noise_schedule = c['train_noise_schedule'] noise_schedule = c['train_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
noise_schedule['num_steps'])
if hasattr(model, 'module'): if hasattr(model, 'module'):
model.module.compute_noise_level(betas) model.module.compute_noise_level(betas)
else: else:
@ -205,8 +203,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
epoch, epoch,
OUT_PATH, OUT_PATH,
model_losses=loss_dict, model_losses=loss_dict,
scaler=scaler.state_dict() scaler=scaler.state_dict() if c.mixed_precision else None)
if c.mixed_precision else None)
end_time = time.time() end_time = time.time()
@ -247,6 +244,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
else: else:
noise, x_noisy, noise_scale = model.compute_y_n(x) noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass # forward pass
noise_hat = model(x_noisy, m, noise_scale) noise_hat = model(x_noisy, m, noise_scale)
@ -254,6 +252,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
loss = criterion(noise, noise_hat) loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss} loss_wavegrad_dict = {'wavegrad_loss':loss}
loss_dict = dict() loss_dict = dict()
for key, value in loss_wavegrad_dict.items(): for key, value in loss_wavegrad_dict.items():
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
@ -283,9 +282,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
# setup noise schedule and inference # setup noise schedule and inference
noise_schedule = c['test_noise_schedule'] noise_schedule = c['test_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
noise_schedule['max_val'],
noise_schedule['num_steps'])
if hasattr(model, 'module'): if hasattr(model, 'module'):
model.module.compute_noise_level(betas) model.module.compute_noise_level(betas)
# compute voice # compute voice
@ -316,8 +313,7 @@ def main(args): # pylint: disable=redefined-outer-name
print(f" > Loading wavs from: {c.data_path}") print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None: if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}") print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
c.eval_split_size)
else: else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
@ -347,10 +343,6 @@ def main(args): # pylint: disable=redefined-outer-name
# setup criterion # setup criterion
criterion = torch.nn.L1Loss().cuda() criterion = torch.nn.L1Loss().cuda()
if use_cuda:
model.cuda()
criterion.cuda()
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu') checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
@ -384,6 +376,10 @@ def main(args): # pylint: disable=redefined-outer-name
else: else:
args.restore_step = 0 args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
model = DDP_th(model, device_ids=[args.rank]) model = DDP_th(model, device_ids=[args.rank])
@ -397,13 +393,14 @@ def main(args): # pylint: disable=redefined-outer-name
global_step = args.restore_step global_step = args.restore_step
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer, scheduler, scaler, _, global_step = train(model, criterion, optimizer,
ap, global_step, epoch) scheduler, scaler, ap, global_step,
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss] target_loss = eval_avg_loss_dict[c.target_loss]
best_loss = save_best_model( best_loss = save_best_model(target_loss,
target_loss,
best_loss, best_loss,
model, model,
optimizer, optimizer,
@ -419,83 +416,9 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() args = parse_arguments(sys.argv)
parser.add_argument( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
'--continue_path', args, model_type='wavegrad')
type=str,
help=
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument('--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(
args.continue_path +
"/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
# DISTRIBUTED
if c.mixed_precision:
print(" > Mixed precision is enabled")
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try: try:
main(args) main(args)

View File

@ -1,9 +1,10 @@
import argparse #!/usr/bin/env python3
"""Train WaveRNN vocoder model."""
import os import os
import sys import sys
import traceback import traceback
import time import time
import glob
import random import random
import torch import torch
@ -11,18 +12,14 @@ from torch.utils.data import DataLoader
# from torch.utils.data.distributed import DistributedSampler # from torch.utils.data.distributed import DistributedSampler
from TTS.utils.arguments import parse_arguments, process_args
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.training import setup_torch_training_env from TTS.utils.training import setup_torch_training_env
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.generic_utils import ( from TTS.utils.generic_utils import (
KeepAverage, KeepAverage,
count_parameters, count_parameters,
create_experiment_folder,
get_git_branch,
remove_experiment_folder, remove_experiment_folder,
set_init_dict, set_init_dict,
) )
@ -207,7 +204,14 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
c.batched, c.batched,
c.target_samples, c.target_samples,
c.overlap_samples, c.overlap_samples,
# use_cuda
) )
# sample_wav = model.generate(ground_mel,
# c.batched,
# c.target_samples,
# c.overlap_samples,
# use_cuda
# )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms # compute spectrograms
@ -296,6 +300,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
c.batched, c.batched,
c.target_samples, c.target_samples,
c.overlap_samples, c.overlap_samples,
# use_cuda
) )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
@ -447,87 +452,9 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() args = parse_arguments(sys.argv)
parser.add_argument( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
"--continue_path", args, model_type='wavernn')
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default="",
required="--config_path" not in sys.argv,
)
parser.add_argument(
"--restore_path",
type=str,
help="Model file to be restored. Use to finetune a model.",
default="",
)
parser.add_argument(
"--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in sys.argv,
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.",
)
# DISTRUBUTED
parser.add_argument(
"--rank",
type=int,
default=0,
help="DISTRIBUTED: process rank for distributed training.",
)
parser.add_argument(
"--group_id", type=str, default="", help="DISTRIBUTED: process group id."
)
args = parser.parse_args()
if args.continue_path != "":
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
list_of_files = glob.glob(
args.continue_path + "/*.pth.tar"
) # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path
if args.continue_path == "":
OUT_PATH = create_experiment_folder(
c.output_path, c.run_name, args.debug
)
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_model_files(
c, args.config_path, OUT_PATH, new_fields
)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
# write model desc to tensorboard
tb_logger.tb_add_text("model-description", c["run_description"], 0)
try: try:
main(args) main(args)