mirror of https://github.com/coqui-ai/TTS.git
changed train scripts
This commit is contained in:
parent
2daca15802
commit
6f06e31541
|
@ -1,8 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
"""Train Glow TTS model."""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -14,10 +12,12 @@ import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from TTS.utils.arguments import parse_arguments, process_args
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import parse_speakers
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
|
@ -25,18 +25,15 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
from TTS.utils.io import copy_model_files, load_config
|
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||||
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
|
@ -119,7 +116,7 @@ def format_data(data):
|
||||||
avg_text_length, avg_spec_length, attn_mask, item_idx
|
avg_text_length, avg_spec_length, attn_mask, item_idx
|
||||||
|
|
||||||
|
|
||||||
def data_depended_init(data_loader, model):
|
def data_depended_init(data_loader, model, ap):
|
||||||
"""Data depended initialization for activation normalization."""
|
"""Data depended initialization for activation normalization."""
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
for f in model.module.decoder.flows:
|
for f in model.module.decoder.flows:
|
||||||
|
@ -138,7 +135,7 @@ def data_depended_init(data_loader, model):
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
|
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
|
||||||
_, _, attn_mask, _ = format_data(data)
|
_, _, attn_mask, item_idx = format_data(data)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
_ = model.forward(
|
_ = model.forward(
|
||||||
|
@ -177,7 +174,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||||
avg_text_length, avg_spec_length, attn_mask, _ = format_data(data)
|
avg_text_length, avg_spec_length, attn_mask, item_idx = format_data(data)
|
||||||
|
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
|
@ -332,7 +329,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||||
_, _, attn_mask, _ = format_data(data)
|
_, _, attn_mask, item_idx = format_data(data)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||||
|
@ -468,7 +465,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
return keep_avg.avg_values
|
return keep_avg.avg_values
|
||||||
|
|
||||||
|
|
||||||
# FIXME: move args definition/parsing inside of main?
|
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||||
|
@ -550,14 +546,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||||
|
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
model = data_depended_init(train_loader, model)
|
model = data_depended_init(train_loader, model, ap)
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
||||||
scheduler, ap, global_step,
|
scheduler, ap, global_step,
|
||||||
epoch)
|
epoch)
|
||||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
|
||||||
global_step, epoch)
|
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_loss']
|
target_loss = train_avg_loss_dict['avg_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
|
@ -567,81 +562,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
args = parse_arguments(sys.argv)
|
||||||
parser.add_argument(
|
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||||
'--continue_path',
|
args, model_type='glow_tts')
|
||||||
type=str,
|
|
||||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
||||||
default='',
|
|
||||||
required='--config_path' not in sys.argv)
|
|
||||||
parser.add_argument(
|
|
||||||
'--restore_path',
|
|
||||||
type=str,
|
|
||||||
help='Model file to be restored. Use to finetune a model.',
|
|
||||||
default='')
|
|
||||||
parser.add_argument(
|
|
||||||
'--config_path',
|
|
||||||
type=str,
|
|
||||||
help='Path to config file for training.',
|
|
||||||
required='--continue_path' not in sys.argv
|
|
||||||
)
|
|
||||||
parser.add_argument('--debug',
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help='Do not verify commit integrity to run training.')
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
parser.add_argument(
|
|
||||||
'--rank',
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help='DISTRIBUTED: process rank for distributed training.')
|
|
||||||
parser.add_argument('--group_id',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='DISTRIBUTED: process group id.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.continue_path != '':
|
|
||||||
args.output_path = args.continue_path
|
|
||||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
|
||||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
|
||||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
|
||||||
args.restore_path = latest_model_file
|
|
||||||
print(f" > Training continues for {args.restore_path}")
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
# check_config(c)
|
|
||||||
check_config_tts(c)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
if c.mixed_precision:
|
|
||||||
print(" > Mixed precision enabled.")
|
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
|
||||||
if args.continue_path == '':
|
|
||||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
|
||||||
|
|
||||||
if args.rank == 0:
|
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
||||||
new_fields = {}
|
|
||||||
if args.restore_path:
|
|
||||||
new_fields["restore_path"] = args.restore_path
|
|
||||||
new_fields["github_branch"] = get_git_branch()
|
|
||||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
|
||||||
os.chmod(OUT_PATH, 0o775)
|
|
||||||
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
|
||||||
|
|
||||||
# write model desc to tensorboard
|
|
||||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -11,6 +11,7 @@ import numpy as np
|
||||||
from random import randrange
|
from random import randrange
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from TTS.utils.arguments import parse_arguments, process_args
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -18,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import SpeedySpeechLoss
|
from TTS.tts.layers.losses import SpeedySpeechLoss
|
||||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import parse_speakers
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
|
@ -26,14 +27,10 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
from TTS.utils.io import copy_model_files, load_config
|
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||||
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
@ -518,8 +515,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
||||||
scheduler, ap, global_step,
|
scheduler, ap, global_step,
|
||||||
epoch)
|
epoch)
|
||||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
|
||||||
global_step, epoch)
|
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_loss']
|
target_loss = train_avg_loss_dict['avg_loss']
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
|
@ -529,81 +525,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
args = parse_arguments(sys.argv)
|
||||||
parser.add_argument(
|
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||||
'--continue_path',
|
args, model_type='tts')
|
||||||
type=str,
|
|
||||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
||||||
default='',
|
|
||||||
required='--config_path' not in sys.argv)
|
|
||||||
parser.add_argument(
|
|
||||||
'--restore_path',
|
|
||||||
type=str,
|
|
||||||
help='Model file to be restored. Use to finetune a model.',
|
|
||||||
default='')
|
|
||||||
parser.add_argument(
|
|
||||||
'--config_path',
|
|
||||||
type=str,
|
|
||||||
help='Path to config file for training.',
|
|
||||||
required='--continue_path' not in sys.argv
|
|
||||||
)
|
|
||||||
parser.add_argument('--debug',
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help='Do not verify commit integrity to run training.')
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
parser.add_argument(
|
|
||||||
'--rank',
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help='DISTRIBUTED: process rank for distributed training.')
|
|
||||||
parser.add_argument('--group_id',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='DISTRIBUTED: process group id.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.continue_path != '':
|
|
||||||
args.output_path = args.continue_path
|
|
||||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
|
||||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
|
||||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
|
||||||
args.restore_path = latest_model_file
|
|
||||||
print(f" > Training continues for {args.restore_path}")
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
# check_config(c)
|
|
||||||
check_config_tts(c)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
if c.mixed_precision:
|
|
||||||
print(" > Mixed precision enabled.")
|
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
|
||||||
if args.continue_path == '':
|
|
||||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
|
||||||
|
|
||||||
if args.rank == 0:
|
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
||||||
new_fields = {}
|
|
||||||
if args.restore_path:
|
|
||||||
new_fields["restore_path"] = args.restore_path
|
|
||||||
new_fields["github_branch"] = get_git_branch()
|
|
||||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
|
||||||
os.chmod(OUT_PATH, 0o775)
|
|
||||||
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
|
||||||
|
|
||||||
# write model desc to tensorboard
|
|
||||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
"""Trains Tacotron based TTS models."""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -11,11 +9,12 @@ from random import randrange
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from TTS.utils.arguments import parse_arguments, process_args
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import TacotronLoss
|
from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import parse_speakers
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
|
@ -23,15 +22,11 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
init_distributed, reduce_tensor)
|
init_distributed, reduce_tensor)
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
from TTS.utils.io import copy_model_files, load_config
|
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
||||||
gradual_training_scheduler, set_weight_decay,
|
gradual_training_scheduler, set_weight_decay,
|
||||||
setup_torch_training_env)
|
setup_torch_training_env)
|
||||||
|
@ -61,7 +56,13 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
||||||
phoneme_language=c.phoneme_language,
|
phoneme_language=c.phoneme_language,
|
||||||
enable_eos_bos=c.enable_eos_bos_chars,
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
speaker_mapping=(
|
||||||
|
speaker_mapping if (
|
||||||
|
c.use_speaker_embedding and
|
||||||
|
c.use_external_speaker_embedding_file
|
||||||
|
) else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if c.use_phonemes and c.compute_input_seq_cache:
|
if c.use_phonemes and c.compute_input_seq_cache:
|
||||||
# precompute phonemes to have a better estimate of sequence lengths.
|
# precompute phonemes to have a better estimate of sequence lengths.
|
||||||
|
@ -180,8 +181,8 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||||
linear_input, stop_tokens, stop_targets,
|
linear_input, stop_tokens, stop_targets,
|
||||||
mel_lengths, decoder_backward_output,
|
mel_lengths, decoder_backward_output,
|
||||||
alignments, alignment_lengths,
|
alignments, alignment_lengths, alignments_backward,
|
||||||
alignments_backward, text_lengths)
|
text_lengths)
|
||||||
|
|
||||||
# check nan loss
|
# check nan loss
|
||||||
if torch.isnan(loss_dict['loss']).any():
|
if torch.isnan(loss_dict['loss']).any():
|
||||||
|
@ -491,7 +492,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||||
return keep_avg.avg_values
|
return keep_avg.avg_values
|
||||||
|
|
||||||
|
|
||||||
# FIXME: move args definition/parsing inside of main?
|
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||||
|
@ -534,7 +534,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
optimizer_st = None
|
optimizer_st = None
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4)
|
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
try:
|
try:
|
||||||
|
@ -640,80 +641,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
args = parse_arguments(sys.argv)
|
||||||
parser.add_argument(
|
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||||
'--continue_path',
|
args, model_type='tacotron')
|
||||||
type=str,
|
|
||||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
||||||
default='',
|
|
||||||
required='--config_path' not in sys.argv)
|
|
||||||
parser.add_argument(
|
|
||||||
'--restore_path',
|
|
||||||
type=str,
|
|
||||||
help='Model file to be restored. Use to finetune a model.',
|
|
||||||
default='')
|
|
||||||
parser.add_argument(
|
|
||||||
'--config_path',
|
|
||||||
type=str,
|
|
||||||
help='Path to config file for training.',
|
|
||||||
required='--continue_path' not in sys.argv
|
|
||||||
)
|
|
||||||
parser.add_argument('--debug',
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help='Do not verify commit integrity to run training.')
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
parser.add_argument(
|
|
||||||
'--rank',
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help='DISTRIBUTED: process rank for distributed training.')
|
|
||||||
parser.add_argument('--group_id',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='DISTRIBUTED: process group id.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.continue_path != '':
|
|
||||||
print(f" > Training continues for {args.continue_path}")
|
|
||||||
args.output_path = args.continue_path
|
|
||||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
|
||||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
|
||||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
|
||||||
args.restore_path = latest_model_file
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
check_config_tts(c)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
if c.mixed_precision:
|
|
||||||
print(" > Mixed precision mode is ON")
|
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
|
||||||
if args.continue_path == '':
|
|
||||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
|
||||||
|
|
||||||
if args.rank == 0:
|
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
||||||
new_fields = {}
|
|
||||||
if args.restore_path:
|
|
||||||
new_fields["restore_path"] = args.restore_path
|
|
||||||
new_fields["github_branch"] = get_git_branch()
|
|
||||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
|
||||||
os.chmod(OUT_PATH, 0o775)
|
|
||||||
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
|
||||||
|
|
||||||
# write model desc to tensorboard
|
|
||||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
#!/usr/bin/env python3
|
||||||
import glob
|
"""Trains GAN based vocoder model."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -7,15 +8,14 @@ import traceback
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from TTS.utils.arguments import parse_arguments, process_args
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
from TTS.utils.io import copy_model_files, load_config
|
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
|
@ -33,8 +33,9 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, is_val=False, verbose=False):
|
def setup_loader(ap, is_val=False, verbose=False):
|
||||||
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
if not is_val or c.run_eval:
|
else:
|
||||||
dataset = GANDataset(ap=ap,
|
dataset = GANDataset(ap=ap,
|
||||||
items=eval_data if is_val else train_data,
|
items=eval_data if is_val else train_data,
|
||||||
seq_len=c.seq_len,
|
seq_len=c.seq_len,
|
||||||
|
@ -113,7 +114,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
y_hat = model_G(c_G)
|
y_hat = model_G(c_G)
|
||||||
y_hat_sub = None
|
y_hat_sub = None
|
||||||
y_G_sub = None
|
y_G_sub = None
|
||||||
y_hat_vis = y_hat # for visualization
|
y_hat_vis = y_hat # for visualization # FIXME! .clone().detach()
|
||||||
|
|
||||||
# PQMF formatting
|
# PQMF formatting
|
||||||
if y_hat.shape[1] > 1:
|
if y_hat.shape[1] > 1:
|
||||||
|
@ -439,7 +440,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
||||||
return keep_avg.avg_values
|
return keep_avg.avg_values
|
||||||
|
|
||||||
|
|
||||||
# FIXME: move args definition/parsing inside of main?
|
|
||||||
def main(args): # pylint: disable=redefined-outer-name
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
# pylint: disable=global-variable-undefined
|
# pylint: disable=global-variable-undefined
|
||||||
global train_data, eval_data
|
global train_data, eval_data
|
||||||
|
@ -506,7 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
||||||
scheduler_disc.optimizer = optimizer_disc
|
scheduler_disc.optimizer = optimizer_disc
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# retore only matching layers.
|
# restore only matching layers.
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
model_dict = model_gen.state_dict()
|
model_dict = model_gen.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
|
@ -556,7 +556,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model_disc, criterion_disc, optimizer_disc,
|
model_disc, criterion_disc, optimizer_disc,
|
||||||
scheduler_gen, scheduler_disc, ap, global_step,
|
scheduler_gen, scheduler_disc, ap, global_step,
|
||||||
epoch)
|
epoch)
|
||||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap,
|
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
|
||||||
|
criterion_disc, ap,
|
||||||
global_step, epoch)
|
global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||||
|
@ -575,78 +576,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
args = parse_arguments(sys.argv)
|
||||||
parser.add_argument(
|
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||||
'--continue_path',
|
args, model_type='gan')
|
||||||
type=str,
|
|
||||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
||||||
default='',
|
|
||||||
required='--config_path' not in sys.argv)
|
|
||||||
parser.add_argument(
|
|
||||||
'--restore_path',
|
|
||||||
type=str,
|
|
||||||
help='Model file to be restored. Use to finetune a model.',
|
|
||||||
default='')
|
|
||||||
parser.add_argument('--config_path',
|
|
||||||
type=str,
|
|
||||||
help='Path to config file for training.',
|
|
||||||
required='--continue_path' not in sys.argv)
|
|
||||||
parser.add_argument('--debug',
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help='Do not verify commit integrity to run training.')
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
parser.add_argument(
|
|
||||||
'--rank',
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help='DISTRIBUTED: process rank for distributed training.')
|
|
||||||
parser.add_argument('--group_id',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='DISTRIBUTED: process group id.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.continue_path != '':
|
|
||||||
args.output_path = args.continue_path
|
|
||||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
|
||||||
list_of_files = glob.glob(
|
|
||||||
args.continue_path +
|
|
||||||
"/*.pth.tar") # * means all if need specific format then *.csv
|
|
||||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
|
||||||
args.restore_path = latest_model_file
|
|
||||||
print(f" > Training continues for {args.restore_path}")
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
# check_config(c)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
|
||||||
if args.continue_path == '':
|
|
||||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
|
||||||
args.debug)
|
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
|
||||||
|
|
||||||
if args.rank == 0:
|
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
||||||
new_fields = {}
|
|
||||||
if args.restore_path:
|
|
||||||
new_fields["restore_path"] = args.restore_path
|
|
||||||
new_fields["github_branch"] = get_git_branch()
|
|
||||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
|
||||||
os.chmod(OUT_PATH, 0o775)
|
|
||||||
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
|
||||||
|
|
||||||
# write model desc to tensorboard
|
|
||||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
#!/usr/bin/env python3
|
||||||
import glob
|
"""Trains WaveGrad vocoder models."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -7,19 +8,16 @@ import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from TTS.utils.arguments import parse_arguments, process_args
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.utils.distribute import init_distributed
|
from TTS.utils.distribute import init_distributed
|
||||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
|
||||||
remove_experiment_folder, set_init_dict)
|
remove_experiment_folder, set_init_dict)
|
||||||
from TTS.utils.io import copy_model_files, load_config
|
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||||
|
@ -54,6 +52,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
if is_val else c.num_loader_workers,
|
if is_val else c.num_loader_workers,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
|
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,8 +77,8 @@ def format_test_data(data):
|
||||||
return m, x
|
return m, x
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
def train(model, criterion, optimizer,
|
||||||
epoch):
|
scheduler, scaler, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
@ -93,8 +92,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
||||||
c_logger.print_train_start()
|
c_logger.print_train_start()
|
||||||
# setup noise schedule
|
# setup noise schedule
|
||||||
noise_schedule = c['train_noise_schedule']
|
noise_schedule = c['train_noise_schedule']
|
||||||
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'],
|
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
||||||
noise_schedule['num_steps'])
|
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model.module.compute_noise_level(betas)
|
model.module.compute_noise_level(betas)
|
||||||
else:
|
else:
|
||||||
|
@ -205,8 +203,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
model_losses=loss_dict,
|
model_losses=loss_dict,
|
||||||
scaler=scaler.state_dict()
|
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||||
if c.mixed_precision else None)
|
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
@ -247,6 +244,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
else:
|
else:
|
||||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||||
|
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
noise_hat = model(x_noisy, m, noise_scale)
|
noise_hat = model(x_noisy, m, noise_scale)
|
||||||
|
|
||||||
|
@ -254,6 +252,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
loss = criterion(noise, noise_hat)
|
loss = criterion(noise, noise_hat)
|
||||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||||
|
|
||||||
|
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for key, value in loss_wavegrad_dict.items():
|
for key, value in loss_wavegrad_dict.items():
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
|
@ -283,9 +282,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
# setup noise schedule and inference
|
# setup noise schedule and inference
|
||||||
noise_schedule = c['test_noise_schedule']
|
noise_schedule = c['test_noise_schedule']
|
||||||
betas = np.linspace(noise_schedule['min_val'],
|
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
||||||
noise_schedule['max_val'],
|
|
||||||
noise_schedule['num_steps'])
|
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model.module.compute_noise_level(betas)
|
model.module.compute_noise_level(betas)
|
||||||
# compute voice
|
# compute voice
|
||||||
|
@ -316,8 +313,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
print(f" > Loading wavs from: {c.data_path}")
|
print(f" > Loading wavs from: {c.data_path}")
|
||||||
if c.feature_path is not None:
|
if c.feature_path is not None:
|
||||||
print(f" > Loading features from: {c.feature_path}")
|
print(f" > Loading features from: {c.feature_path}")
|
||||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
|
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||||
c.eval_split_size)
|
|
||||||
else:
|
else:
|
||||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||||
|
|
||||||
|
@ -347,10 +343,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion = torch.nn.L1Loss().cuda()
|
criterion = torch.nn.L1Loss().cuda()
|
||||||
|
|
||||||
if use_cuda:
|
|
||||||
model.cuda()
|
|
||||||
criterion.cuda()
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
try:
|
try:
|
||||||
|
@ -384,6 +376,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
criterion.cuda()
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
model = DDP_th(model, device_ids=[args.rank])
|
model = DDP_th(model, device_ids=[args.rank])
|
||||||
|
@ -397,13 +393,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
_, global_step = train(model, criterion, optimizer, scheduler, scaler,
|
_, global_step = train(model, criterion, optimizer,
|
||||||
ap, global_step, epoch)
|
scheduler, scaler, ap, global_step,
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
epoch)
|
||||||
|
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
||||||
|
global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||||
best_loss = save_best_model(
|
best_loss = save_best_model(target_loss,
|
||||||
target_loss,
|
|
||||||
best_loss,
|
best_loss,
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -419,83 +416,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
args = parse_arguments(sys.argv)
|
||||||
parser.add_argument(
|
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||||
'--continue_path',
|
args, model_type='wavegrad')
|
||||||
type=str,
|
|
||||||
help=
|
|
||||||
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
||||||
default='',
|
|
||||||
required='--config_path' not in sys.argv)
|
|
||||||
parser.add_argument(
|
|
||||||
'--restore_path',
|
|
||||||
type=str,
|
|
||||||
help='Model file to be restored. Use to finetune a model.',
|
|
||||||
default='')
|
|
||||||
parser.add_argument('--config_path',
|
|
||||||
type=str,
|
|
||||||
help='Path to config file for training.',
|
|
||||||
required='--continue_path' not in sys.argv)
|
|
||||||
parser.add_argument('--debug',
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help='Do not verify commit integrity to run training.')
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
parser.add_argument(
|
|
||||||
'--rank',
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help='DISTRIBUTED: process rank for distributed training.')
|
|
||||||
parser.add_argument('--group_id',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='DISTRIBUTED: process group id.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.continue_path != '':
|
|
||||||
args.output_path = args.continue_path
|
|
||||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
|
||||||
list_of_files = glob.glob(
|
|
||||||
args.continue_path +
|
|
||||||
"/*.pth.tar") # * means all if need specific format then *.csv
|
|
||||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
|
||||||
args.restore_path = latest_model_file
|
|
||||||
print(f" > Training continues for {args.restore_path}")
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
# check_config(c)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
# DISTRIBUTED
|
|
||||||
if c.mixed_precision:
|
|
||||||
print(" > Mixed precision is enabled")
|
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
|
||||||
if args.continue_path == '':
|
|
||||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
|
||||||
args.debug)
|
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
|
||||||
|
|
||||||
if args.rank == 0:
|
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
||||||
new_fields = {}
|
|
||||||
if args.restore_path:
|
|
||||||
new_fields["restore_path"] = args.restore_path
|
|
||||||
new_fields["github_branch"] = get_git_branch()
|
|
||||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
|
||||||
os.chmod(OUT_PATH, 0o775)
|
|
||||||
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
|
||||||
|
|
||||||
# write model desc to tensorboard
|
|
||||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import argparse
|
#!/usr/bin/env python3
|
||||||
|
"""Train WaveRNN vocoder model."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
import glob
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -11,18 +12,14 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
# from torch.utils.data.distributed import DistributedSampler
|
# from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from TTS.utils.arguments import parse_arguments, process_args
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.io import copy_model_files, load_config
|
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (
|
||||||
KeepAverage,
|
KeepAverage,
|
||||||
count_parameters,
|
count_parameters,
|
||||||
create_experiment_folder,
|
|
||||||
get_git_branch,
|
|
||||||
remove_experiment_folder,
|
remove_experiment_folder,
|
||||||
set_init_dict,
|
set_init_dict,
|
||||||
)
|
)
|
||||||
|
@ -207,7 +204,14 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
||||||
c.batched,
|
c.batched,
|
||||||
c.target_samples,
|
c.target_samples,
|
||||||
c.overlap_samples,
|
c.overlap_samples,
|
||||||
|
# use_cuda
|
||||||
)
|
)
|
||||||
|
# sample_wav = model.generate(ground_mel,
|
||||||
|
# c.batched,
|
||||||
|
# c.target_samples,
|
||||||
|
# c.overlap_samples,
|
||||||
|
# use_cuda
|
||||||
|
# )
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
|
@ -296,6 +300,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
c.batched,
|
c.batched,
|
||||||
c.target_samples,
|
c.target_samples,
|
||||||
c.overlap_samples,
|
c.overlap_samples,
|
||||||
|
# use_cuda
|
||||||
)
|
)
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
|
@ -447,87 +452,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
args = parse_arguments(sys.argv)
|
||||||
parser.add_argument(
|
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||||
"--continue_path",
|
args, model_type='wavernn')
|
||||||
type=str,
|
|
||||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
||||||
default="",
|
|
||||||
required="--config_path" not in sys.argv,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--restore_path",
|
|
||||||
type=str,
|
|
||||||
help="Model file to be restored. Use to finetune a model.",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_path",
|
|
||||||
type=str,
|
|
||||||
help="Path to config file for training.",
|
|
||||||
required="--continue_path" not in sys.argv,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--debug",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="Do not verify commit integrity to run training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
parser.add_argument(
|
|
||||||
"--rank",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="DISTRIBUTED: process rank for distributed training.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--group_id", type=str, default="", help="DISTRIBUTED: process group id."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.continue_path != "":
|
|
||||||
args.output_path = args.continue_path
|
|
||||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
|
||||||
list_of_files = glob.glob(
|
|
||||||
args.continue_path + "/*.pth.tar"
|
|
||||||
) # * means all if need specific format then *.csv
|
|
||||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
|
||||||
args.restore_path = latest_model_file
|
|
||||||
print(f" > Training continues for {args.restore_path}")
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
# check_config(c)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
|
||||||
if args.continue_path == "":
|
|
||||||
OUT_PATH = create_experiment_folder(
|
|
||||||
c.output_path, c.run_name, args.debug
|
|
||||||
)
|
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
|
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
|
||||||
|
|
||||||
if args.rank == 0:
|
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
||||||
new_fields = {}
|
|
||||||
if args.restore_path:
|
|
||||||
new_fields["restore_path"] = args.restore_path
|
|
||||||
new_fields["github_branch"] = get_git_branch()
|
|
||||||
copy_model_files(
|
|
||||||
c, args.config_path, OUT_PATH, new_fields
|
|
||||||
)
|
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
|
||||||
os.chmod(OUT_PATH, 0o775)
|
|
||||||
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
|
||||||
|
|
||||||
# write model desc to tensorboard
|
|
||||||
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
Loading…
Reference in New Issue