fix formatting + pylint

This commit is contained in:
sanjaesc 2020-10-19 16:20:15 +02:00
parent 64adfbf4a5
commit 24d18d20e3
5 changed files with 252 additions and 307 deletions

View File

@ -15,23 +15,17 @@ from TTS.utils.audio import AudioProcessor
def main(): def main():
"""Run preprocessing process.""" """Run preprocessing process."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Compute mean and variance of spectrogtram features." description="Compute mean and variance of spectrogtram features.")
) parser.add_argument("--config_path", type=str, required=True,
parser.add_argument( help="TTS config file path to define audio processin parameters.")
"--config_path", parser.add_argument("--out_path", default=None, type=str,
type=str, help="directory to save the output file.")
required=True,
help="TTS config file path to define audio processin parameters.",
)
parser.add_argument(
"--out_path", default=None, type=str, help="directory to save the output file."
)
args = parser.parse_args() args = parser.parse_args()
# load config # load config
CONFIG = load_config(args.config_path) CONFIG = load_config(args.config_path)
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization CONFIG.audio['signal_norm'] = False # do not apply earlier normalization
CONFIG.audio["stats_path"] = None # discard pre-defined stats CONFIG.audio['stats_path'] = None # discard pre-defined stats
# load audio processor # load audio processor
ap = AudioProcessor(**CONFIG.audio) ap = AudioProcessor(**CONFIG.audio)
@ -65,27 +59,27 @@ def main():
output_file_path = os.path.join(args.out_path, "scale_stats.npy") output_file_path = os.path.join(args.out_path, "scale_stats.npy")
stats = {} stats = {}
stats["mel_mean"] = mel_mean stats['mel_mean'] = mel_mean
stats["mel_std"] = mel_scale stats['mel_std'] = mel_scale
stats["linear_mean"] = linear_mean stats['linear_mean'] = linear_mean
stats["linear_std"] = linear_scale stats['linear_std'] = linear_scale
print(f" > Avg mel spec mean: {mel_mean.mean()}") print(f' > Avg mel spec mean: {mel_mean.mean()}')
print(f" > Avg mel spec scale: {mel_scale.mean()}") print(f' > Avg mel spec scale: {mel_scale.mean()}')
print(f" > Avg linear spec mean: {linear_mean.mean()}") print(f' > Avg linear spec mean: {linear_mean.mean()}')
print(f" > Avg lienar spec scale: {linear_scale.mean()}") print(f' > Avg lienar spec scale: {linear_scale.mean()}')
# set default config values for mean-var scaling # set default config values for mean-var scaling
CONFIG.audio["stats_path"] = output_file_path CONFIG.audio['stats_path'] = output_file_path
CONFIG.audio["signal_norm"] = True CONFIG.audio['signal_norm'] = True
# remove redundant values # remove redundant values
del CONFIG.audio["max_norm"] del CONFIG.audio['max_norm']
del CONFIG.audio["min_level_db"] del CONFIG.audio['min_level_db']
del CONFIG.audio["symmetric_norm"] del CONFIG.audio['symmetric_norm']
del CONFIG.audio["clip_norm"] del CONFIG.audio['clip_norm']
stats["audio_config"] = CONFIG.audio stats['audio_config'] = CONFIG.audio
np.save(output_file_path, stats, allow_pickle=True) np.save(output_file_path, stats, allow_pickle=True)
print(f" > scale_stats.npy is saved to {output_file_path}") print(f' > scale_stats.npy is saved to {output_file_path}')
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -10,29 +10,20 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import ( from TTS.utils.generic_utils import (KeepAverage, count_parameters,
KeepAverage, create_experiment_folder, get_git_branch,
count_parameters, remove_experiment_folder, set_init_dict)
create_experiment_folder,
get_git_branch,
remove_experiment_folder,
set_init_dict,
)
from TTS.utils.io import copy_config_file, load_config from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
# from distribute import (DistributedSampler, apply_gradient_allreduce, # from distribute import (DistributedSampler, apply_gradient_allreduce,
# init_distributed, reduce_tensor) # init_distributed, reduce_tensor)
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.generic_utils import ( from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
plot_results, setup_generator)
setup_discriminator,
setup_generator,
)
from TTS.vocoder.utils.io import save_best_model, save_checkpoint from TTS.vocoder.utils.io import save_best_model, save_checkpoint
use_cuda, num_gpus = setup_torch_training_env(True, True) use_cuda, num_gpus = setup_torch_training_env(True, True)
@ -42,30 +33,27 @@ def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval: if is_val and not c.run_eval:
loader = None loader = None
else: else:
dataset = GANDataset( dataset = GANDataset(ap=ap,
ap=ap, items=eval_data if is_val else train_data,
items=eval_data if is_val else train_data, seq_len=c.seq_len,
seq_len=c.seq_len, hop_len=ap.hop_length,
hop_len=ap.hop_length, pad_short=c.pad_short,
pad_short=c.pad_short, conv_pad=c.conv_pad,
conv_pad=c.conv_pad, is_training=not is_val,
is_training=not is_val, return_segments=not is_val,
return_segments=not is_val, use_noise_augment=c.use_noise_augment,
use_noise_augment=c.use_noise_augment, use_cache=c.use_cache,
use_cache=c.use_cache, verbose=verbose)
verbose=verbose,
)
dataset.shuffle_mapping() dataset.shuffle_mapping()
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(dataset,
dataset, batch_size=1 if is_val else c.batch_size,
batch_size=1 if is_val else c.batch_size, shuffle=True,
shuffle=True, drop_last=False,
drop_last=False, sampler=None,
sampler=None, num_workers=c.num_val_loader_workers
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, if is_val else c.num_loader_workers,
pin_memory=False, pin_memory=False)
)
return loader return loader
@ -92,26 +80,16 @@ def format_data(data):
return co, x, None, None return co, x, None, None
def train( def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
model_G, scheduler_G, scheduler_D, ap, global_step, epoch):
criterion_G,
optimizer_G,
model_D,
criterion_D,
optimizer_D,
scheduler_G,
scheduler_D,
ap,
global_step,
epoch,
):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model_G.train() model_G.train()
model_D.train() model_D.train()
epoch_time = 0 epoch_time = 0
keep_avg = KeepAverage() keep_avg = KeepAverage()
if use_cuda: if use_cuda:
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
else: else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size) batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time() end_time = time.time()
@ -167,16 +145,16 @@ def train(
scores_fake = D_out_fake scores_fake = D_out_fake
# compute losses # compute losses
loss_G_dict = criterion_G( loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub feats_real, y_hat_sub, y_G_sub)
) loss_G = loss_G_dict['G_loss']
loss_G = loss_G_dict["G_loss"]
# optimizer generator # optimizer generator
optimizer_G.zero_grad() optimizer_G.zero_grad()
loss_G.backward() loss_G.backward()
if c.gen_clip_grad > 0: if c.gen_clip_grad > 0:
torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) torch.nn.utils.clip_grad_norm_(model_G.parameters(),
c.gen_clip_grad)
optimizer_G.step() optimizer_G.step()
if scheduler_G is not None: if scheduler_G is not None:
scheduler_G.step() scheduler_G.step()
@ -221,13 +199,14 @@ def train(
# compute losses # compute losses
loss_D_dict = criterion_D(scores_fake, scores_real) loss_D_dict = criterion_D(scores_fake, scores_real)
loss_D = loss_D_dict["D_loss"] loss_D = loss_D_dict['D_loss']
# optimizer discriminator # optimizer discriminator
optimizer_D.zero_grad() optimizer_D.zero_grad()
loss_D.backward() loss_D.backward()
if c.disc_clip_grad > 0: if c.disc_clip_grad > 0:
torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) torch.nn.utils.clip_grad_norm_(model_D.parameters(),
c.disc_clip_grad)
optimizer_D.step() optimizer_D.step()
if scheduler_D is not None: if scheduler_D is not None:
scheduler_D.step() scheduler_D.step()
@ -242,40 +221,34 @@ def train(
epoch_time += step_time epoch_time += step_time
# get current learning rates # get current learning rates
current_lr_G = list(optimizer_G.param_groups)[0]["lr"] current_lr_G = list(optimizer_G.param_groups)[0]['lr']
current_lr_D = list(optimizer_D.param_groups)[0]["lr"] current_lr_D = list(optimizer_D.param_groups)[0]['lr']
# update avg stats # update avg stats
update_train_values = dict() update_train_values = dict()
for key, value in loss_dict.items(): for key, value in loss_dict.items():
update_train_values["avg_" + key] = value update_train_values['avg_' + key] = value
update_train_values["avg_loader_time"] = loader_time update_train_values['avg_loader_time'] = loader_time
update_train_values["avg_step_time"] = step_time update_train_values['avg_step_time'] = step_time
keep_avg.update_values(update_train_values) keep_avg.update_values(update_train_values)
# print training stats # print training stats
if global_step % c.print_step == 0: if global_step % c.print_step == 0:
log_dict = { log_dict = {
"step_time": [step_time, 2], 'step_time': [step_time, 2],
"loader_time": [loader_time, 4], 'loader_time': [loader_time, 4],
"current_lr_G": current_lr_G, "current_lr_G": current_lr_G,
"current_lr_D": current_lr_D, "current_lr_D": current_lr_D
} }
c_logger.print_train_step( c_logger.print_train_step(batch_n_iter, num_iter, global_step,
batch_n_iter, log_dict, loss_dict, keep_avg.avg_values)
num_iter,
global_step,
log_dict,
loss_dict,
keep_avg.avg_values,
)
# plot step stats # plot step stats
if global_step % 10 == 0: if global_step % 10 == 0:
iter_stats = { iter_stats = {
"lr_G": current_lr_G, "lr_G": current_lr_G,
"lr_D": current_lr_D, "lr_D": current_lr_D,
"step_time": step_time, "step_time": step_time
} }
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats) tb_logger.tb_train_iter_stats(global_step, iter_stats)
@ -284,28 +257,27 @@ def train(
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint( save_checkpoint(model_G,
model_G, optimizer_G,
optimizer_G, scheduler_G,
scheduler_G, model_D,
model_D, optimizer_D,
optimizer_D, scheduler_D,
scheduler_D, global_step,
global_step, epoch,
epoch, OUT_PATH,
OUT_PATH, model_losses=loss_dict)
model_losses=loss_dict,
)
# compute spectrograms # compute spectrograms
figures = plot_results(y_hat_vis, y_G, ap, global_step, "train") figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train')
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
# Sample audio # Sample audio
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios( tb_logger.tb_train_audios(global_step,
global_step, {"train/audio": sample_voice}, c.audio["sample_rate"] {'train/audio': sample_voice},
) c.audio["sample_rate"])
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats
@ -379,9 +351,8 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
feats_fake, feats_real = None, None feats_fake, feats_real = None, None
# compute losses # compute losses
loss_G_dict = criterion_G( loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub feats_real, y_hat_sub, y_G_sub)
)
loss_dict = dict() loss_dict = dict()
for key, value in loss_G_dict.items(): for key, value in loss_G_dict.items():
@ -437,9 +408,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
# update avg stats # update avg stats
update_eval_values = dict() update_eval_values = dict()
for key, value in loss_dict.items(): for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value update_eval_values['avg_' + key] = value
update_eval_values["avg_loader_time"] = loader_time update_eval_values['avg_loader_time'] = loader_time
update_eval_values["avg_step_time"] = step_time update_eval_values['avg_step_time'] = step_time
keep_avg.update_values(update_eval_values) keep_avg.update_values(update_eval_values)
# print eval stats # print eval stats
@ -447,14 +418,13 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
# compute spectrograms # compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step, "eval") figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures) tb_logger.tb_eval_figures(global_step, figures)
# Sample audio # Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios( tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"] c.audio["sample_rate"])
)
# synthesize a full voice # synthesize a full voice
data_loader.return_segments = False data_loader.return_segments = False
@ -472,8 +442,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.feature_path is not None: if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}") print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data( eval_data, train_data = load_wav_feat_data(
c.data_path, c.feature_path, c.eval_split_size c.data_path, c.feature_path, c.eval_split_size)
)
else: else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
@ -491,63 +460,68 @@ def main(args): # pylint: disable=redefined-outer-name
# setup optimizers # setup optimizers
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0) optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
optimizer_disc = RAdam(model_disc.parameters(), lr=c.lr_disc, weight_decay=0) optimizer_disc = RAdam(model_disc.parameters(),
lr=c.lr_disc,
weight_decay=0)
# schedulers # schedulers
scheduler_gen = None scheduler_gen = None
scheduler_disc = None scheduler_disc = None
if "lr_scheduler_gen" in c: if 'lr_scheduler_gen' in c:
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) scheduler_gen = scheduler_gen(
if "lr_scheduler_disc" in c: optimizer_gen, **c.lr_scheduler_gen_params)
if 'lr_scheduler_disc' in c:
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) scheduler_disc = scheduler_disc(
optimizer_disc, **c.lr_scheduler_disc_params)
# setup criterion # setup criterion
criterion_gen = GeneratorLoss(c) criterion_gen = GeneratorLoss(c)
criterion_disc = DiscriminatorLoss(c) criterion_disc = DiscriminatorLoss(c)
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location="cpu") checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
print(" > Restoring Generator Model...") print(" > Restoring Generator Model...")
model_gen.load_state_dict(checkpoint["model"]) model_gen.load_state_dict(checkpoint['model'])
print(" > Restoring Generator Optimizer...") print(" > Restoring Generator Optimizer...")
optimizer_gen.load_state_dict(checkpoint["optimizer"]) optimizer_gen.load_state_dict(checkpoint['optimizer'])
print(" > Restoring Discriminator Model...") print(" > Restoring Discriminator Model...")
model_disc.load_state_dict(checkpoint["model_disc"]) model_disc.load_state_dict(checkpoint['model_disc'])
print(" > Restoring Discriminator Optimizer...") print(" > Restoring Discriminator Optimizer...")
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
if "scheduler" in checkpoint: if 'scheduler' in checkpoint:
print(" > Restoring Generator LR Scheduler...") print(" > Restoring Generator LR Scheduler...")
scheduler_gen.load_state_dict(checkpoint["scheduler"]) scheduler_gen.load_state_dict(checkpoint['scheduler'])
# NOTE: Not sure if necessary # NOTE: Not sure if necessary
scheduler_gen.optimizer = optimizer_gen scheduler_gen.optimizer = optimizer_gen
if "scheduler_disc" in checkpoint: if 'scheduler_disc' in checkpoint:
print(" > Restoring Discriminator LR Scheduler...") print(" > Restoring Discriminator LR Scheduler...")
scheduler_disc.load_state_dict(checkpoint["scheduler_disc"]) scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
scheduler_disc.optimizer = optimizer_disc scheduler_disc.optimizer = optimizer_disc
except RuntimeError: except RuntimeError:
# retore only matching layers. # retore only matching layers.
print(" > Partial model initialization...") print(" > Partial model initialization...")
model_dict = model_gen.state_dict() model_dict = model_gen.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_gen.load_state_dict(model_dict) model_gen.load_state_dict(model_dict)
model_dict = model_disc.state_dict() model_dict = model_disc.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c) model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
model_disc.load_state_dict(model_dict) model_disc.load_state_dict(model_dict)
del model_dict del model_dict
# reset lr if not countinuining training. # reset lr if not countinuining training.
for group in optimizer_gen.param_groups: for group in optimizer_gen.param_groups:
group["lr"] = c.lr_gen group['lr'] = c.lr_gen
for group in optimizer_disc.param_groups: for group in optimizer_disc.param_groups:
group["lr"] = c.lr_disc group['lr'] = c.lr_disc
print(" > Model restored from step %d" % checkpoint["step"], flush=True) print(" > Model restored from step %d" % checkpoint['step'],
args.restore_step = checkpoint["step"] flush=True)
args.restore_step = checkpoint['step']
else: else:
args.restore_step = 0 args.restore_step = 0
@ -566,92 +540,74 @@ def main(args): # pylint: disable=redefined-outer-name
num_params = count_parameters(model_disc) num_params = count_parameters(model_disc)
print(" > Discriminator has {} parameters".format(num_params), flush=True) print(" > Discriminator has {} parameters".format(num_params), flush=True)
if "best_loss" not in locals(): if 'best_loss' not in locals():
best_loss = float("inf") best_loss = float('inf')
global_step = args.restore_step global_step = args.restore_step
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train( _, global_step = train(model_gen, criterion_gen, optimizer_gen,
model_gen, model_disc, criterion_disc, optimizer_disc,
criterion_gen, scheduler_gen, scheduler_disc, ap, global_step,
optimizer_gen, epoch)
model_disc, eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap,
criterion_disc, global_step, epoch)
optimizer_disc,
scheduler_gen,
scheduler_disc,
ap,
global_step,
epoch,
)
eval_avg_loss_dict = evaluate(
model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch
)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss] target_loss = eval_avg_loss_dict[c.target_loss]
best_loss = save_best_model( best_loss = save_best_model(target_loss,
target_loss, best_loss,
best_loss, model_gen,
model_gen, optimizer_gen,
optimizer_gen, scheduler_gen,
scheduler_gen, model_disc,
model_disc, optimizer_disc,
optimizer_disc, scheduler_disc,
scheduler_disc, global_step,
global_step, epoch,
epoch, OUT_PATH,
OUT_PATH, model_losses=eval_avg_loss_dict)
model_losses=eval_avg_loss_dict,
)
if __name__ == "__main__": if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--continue_path", '--continue_path',
type=str, type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default="", default='',
required="--config_path" not in sys.argv, required='--config_path' not in sys.argv)
)
parser.add_argument( parser.add_argument(
"--restore_path", '--restore_path',
type=str, type=str,
help="Model file to be restored. Use to finetune a model.", help='Model file to be restored. Use to finetune a model.',
default="", default='')
) parser.add_argument('--config_path',
parser.add_argument( type=str,
"--config_path", help='Path to config file for training.',
type=str, required='--continue_path' not in sys.argv)
help="Path to config file for training.", parser.add_argument('--debug',
required="--continue_path" not in sys.argv, type=bool,
) default=False,
parser.add_argument( help='Do not verify commit integrity to run training.')
"--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.",
)
# DISTRUBUTED # DISTRUBUTED
parser.add_argument( parser.add_argument(
"--rank", '--rank',
type=int, type=int,
default=0, default=0,
help="DISTRIBUTED: process rank for distributed training.", help='DISTRIBUTED: process rank for distributed training.')
) parser.add_argument('--group_id',
parser.add_argument( type=str,
"--group_id", type=str, default="", help="DISTRIBUTED: process group id." default="",
) help='DISTRIBUTED: process group id.')
args = parser.parse_args() args = parser.parse_args()
if args.continue_path != "": if args.continue_path != '':
args.output_path = args.continue_path args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json") args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob( list_of_files = glob.glob(
args.continue_path + "/*.pth.tar" args.continue_path +
) # * means all if need specific format then *.csv "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime) latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}") print(f" > Training continues for {args.restore_path}")
@ -662,10 +618,11 @@ if __name__ == "__main__":
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path OUT_PATH = args.continue_path
if args.continue_path == "": if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios") AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
@ -675,17 +632,16 @@ if __name__ == "__main__":
if args.restore_path: if args.restore_path:
new_fields["restore_path"] = args.restore_path new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch() new_fields["github_branch"] = get_git_branch()
copy_config_file( copy_config_file(args.config_path,
args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields os.path.join(OUT_PATH, 'config.json'), new_fields)
)
os.chmod(AUDIO_PATH, 0o775) os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775) os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-description", c["run_description"], 0) tb_logger.tb_add_text('model-description', c['run_description'], 0)
try: try:
main(args) main(args)
@ -698,4 +654,4 @@ if __name__ == "__main__":
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH) remove_experiment_folder(OUT_PATH)
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)

View File

@ -365,28 +365,6 @@ class WaveRNN(nn.Module):
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio), (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
) )
@staticmethod
def get_gru_cell(gru):
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
gru_cell.weight_hh.data = gru.weight_hh_l0.data
gru_cell.weight_ih.data = gru.weight_ih_l0.data
gru_cell.bias_hh.data = gru.bias_hh_l0.data
gru_cell.bias_ih.data = gru.bias_ih_l0.data
return gru_cell
@staticmethod
def pad_tensor(x, pad, side="both"):
# NB - this is just a quick method i need right now
# i.e., it won't generalise to other shapes/dims
b, t, c = x.size()
total = t + 2 * pad if side == "both" else t + pad
padded = torch.zeros(b, total, c).cuda()
if side in ("before", "both"):
padded[:, pad : pad + t, :] = x
elif side == "after":
padded[:, :t, :] = x
return padded
def fold_with_overlap(self, x, target, overlap): def fold_with_overlap(self, x, target, overlap):
"""Fold the tensor with overlap for quick batched inference. """Fold the tensor with overlap for quick batched inference.
@ -430,7 +408,30 @@ class WaveRNN(nn.Module):
return folded return folded
def xfade_and_unfold(self, y, target, overlap): @staticmethod
def get_gru_cell(gru):
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
gru_cell.weight_hh.data = gru.weight_hh_l0.data
gru_cell.weight_ih.data = gru.weight_ih_l0.data
gru_cell.bias_hh.data = gru.bias_hh_l0.data
gru_cell.bias_ih.data = gru.bias_ih_l0.data
return gru_cell
@staticmethod
def pad_tensor(x, pad, side="both"):
# NB - this is just a quick method i need right now
# i.e., it won't generalise to other shapes/dims
b, t, c = x.size()
total = t + 2 * pad if side == "both" else t + pad
padded = torch.zeros(b, total, c).cuda()
if side in ("before", "both"):
padded[:, pad : pad + t, :] = x
elif side == "after":
padded[:, :t, :] = x
return padded
@staticmethod
def xfade_and_unfold(y, target, overlap):
"""Applies a crossfade and unfolds into a 1d array. """Applies a crossfade and unfolds into a 1d array.
Args: Args:

View File

@ -28,7 +28,8 @@ def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
torch.exp(log_std), torch.exp(log_std),
) )
sample = dist.sample() sample = dist.sample()
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) sample = torch.clamp(torch.clamp(
sample, min=-scale_factor), max=scale_factor)
del dist del dist
return sample return sample
@ -58,8 +59,9 @@ def discretized_mix_logistic_loss(
# unpack parameters. (B, T, num_mixtures) x 3 # unpack parameters. (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix] logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix : 2 * nr_mix] means = y_hat[:, :, nr_mix: 2 * nr_mix]
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) log_scales = torch.clamp(
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min)
# B x T x 1 -> B x T x num_mixtures # B x T x 1 -> B x T x num_mixtures
y = y.expand_as(means) y = y.expand_as(means)
@ -104,7 +106,8 @@ def discretized_mix_logistic_loss(
) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
inner_cond = (y > 0.999).float() inner_cond = (y > 0.999).float()
inner_out = ( inner_out = (
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out inner_cond * log_one_minus_cdf_min +
(1.0 - inner_cond) * inner_inner_out
) )
cond = (y < -0.999).float() cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
@ -142,9 +145,9 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
# (B, T) -> (B, T, nr_mix) # (B, T) -> (B, T, nr_mix)
one_hot = to_one_hot(argmax, nr_mix) one_hot = to_one_hot(argmax, nr_mix)
# select logistic parameters # select logistic parameters
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) means = torch.sum(y[:, :, nr_mix: 2 * nr_mix] * one_hot, dim=-1)
log_scales = torch.clamp( log_scales = torch.clamp(
torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min torch.sum(y[:, :, 2 * nr_mix: 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min
) )
# sample from logistic & clip to interval # sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling # we don't actually round to the nearest 8bit value when sampling

View File

@ -39,7 +39,7 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
def setup_wavernn(c): def setup_wavernn(c):
@ -67,101 +67,92 @@ def setup_wavernn(c):
def setup_generator(c): def setup_generator(c):
print(" > Generator Model: {}".format(c.generator_model)) print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = importlib.import_module('TTS.vocoder.models.' +
c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model)) MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model in "melgan_generator": if c.generator_model in 'melgan_generator':
model = MyModel( model = MyModel(
in_channels=c.audio["num_mels"], in_channels=c.audio['num_mels'],
out_channels=1, out_channels=1,
proj_kernel=7, proj_kernel=7,
base_channels=512, base_channels=512,
upsample_factors=c.generator_model_params["upsample_factors"], upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"], num_res_blocks=c.generator_model_params['num_res_blocks'])
) if c.generator_model in 'melgan_fb_generator':
if c.generator_model in "melgan_fb_generator":
pass pass
if c.generator_model in "multiband_melgan_generator": if c.generator_model in 'multiband_melgan_generator':
model = MyModel( model = MyModel(
in_channels=c.audio["num_mels"], in_channels=c.audio['num_mels'],
out_channels=4, out_channels=4,
proj_kernel=7, proj_kernel=7,
base_channels=384, base_channels=384,
upsample_factors=c.generator_model_params["upsample_factors"], upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"], num_res_blocks=c.generator_model_params['num_res_blocks'])
) if c.generator_model in 'fullband_melgan_generator':
if c.generator_model in "fullband_melgan_generator":
model = MyModel( model = MyModel(
in_channels=c.audio["num_mels"], in_channels=c.audio['num_mels'],
out_channels=1, out_channels=1,
proj_kernel=7, proj_kernel=7,
base_channels=512, base_channels=512,
upsample_factors=c.generator_model_params["upsample_factors"], upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params["num_res_blocks"], num_res_blocks=c.generator_model_params['num_res_blocks'])
) if c.generator_model in 'parallel_wavegan_generator':
if c.generator_model in "parallel_wavegan_generator":
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_size=3, kernel_size=3,
num_res_blocks=c.generator_model_params["num_res_blocks"], num_res_blocks=c.generator_model_params['num_res_blocks'],
stacks=c.generator_model_params["stacks"], stacks=c.generator_model_params['stacks'],
res_channels=64, res_channels=64,
gate_channels=128, gate_channels=128,
skip_channels=64, skip_channels=64,
aux_channels=c.audio["num_mels"], aux_channels=c.audio['num_mels'],
dropout=0.0, dropout=0.0,
bias=True, bias=True,
use_weight_norm=True, use_weight_norm=True,
upsample_factors=c.generator_model_params["upsample_factors"], upsample_factors=c.generator_model_params['upsample_factors'])
)
return model return model
def setup_discriminator(c): def setup_discriminator(c):
print(" > Discriminator Model: {}".format(c.discriminator_model)) print(" > Discriminator Model: {}".format(c.discriminator_model))
if "parallel_wavegan" in c.discriminator_model: if 'parallel_wavegan' in c.discriminator_model:
MyModel = importlib.import_module( MyModel = importlib.import_module(
"TTS.vocoder.models.parallel_wavegan_discriminator" 'TTS.vocoder.models.parallel_wavegan_discriminator')
)
else: else:
MyModel = importlib.import_module( MyModel = importlib.import_module('TTS.vocoder.models.' +
"TTS.vocoder.models." + c.discriminator_model.lower() c.discriminator_model.lower())
)
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in "random_window_discriminator": if c.discriminator_model in 'random_window_discriminator':
model = MyModel( model = MyModel(
cond_channels=c.audio["num_mels"], cond_channels=c.audio['num_mels'],
hop_length=c.audio["hop_length"], hop_length=c.audio['hop_length'],
uncond_disc_donwsample_factors=c.discriminator_model_params[ uncond_disc_donwsample_factors=c.
"uncond_disc_donwsample_factors" discriminator_model_params['uncond_disc_donwsample_factors'],
], cond_disc_downsample_factors=c.
cond_disc_downsample_factors=c.discriminator_model_params[ discriminator_model_params['cond_disc_downsample_factors'],
"cond_disc_downsample_factors" cond_disc_out_channels=c.
], discriminator_model_params['cond_disc_out_channels'],
cond_disc_out_channels=c.discriminator_model_params[ window_sizes=c.discriminator_model_params['window_sizes'])
"cond_disc_out_channels" if c.discriminator_model in 'melgan_multiscale_discriminator':
],
window_sizes=c.discriminator_model_params["window_sizes"],
)
if c.discriminator_model in "melgan_multiscale_discriminator":
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_sizes=(5, 3), kernel_sizes=(5, 3),
base_channels=c.discriminator_model_params["base_channels"], base_channels=c.discriminator_model_params['base_channels'],
max_channels=c.discriminator_model_params["max_channels"], max_channels=c.discriminator_model_params['max_channels'],
downsample_factors=c.discriminator_model_params["downsample_factors"], downsample_factors=c.
) discriminator_model_params['downsample_factors'])
if c.discriminator_model == "residual_parallel_wavegan_discriminator": if c.discriminator_model == 'residual_parallel_wavegan_discriminator':
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_size=3, kernel_size=3,
num_layers=c.discriminator_model_params["num_layers"], num_layers=c.discriminator_model_params['num_layers'],
stacks=c.discriminator_model_params["stacks"], stacks=c.discriminator_model_params['stacks'],
res_channels=64, res_channels=64,
gate_channels=128, gate_channels=128,
skip_channels=64, skip_channels=64,
@ -170,17 +161,17 @@ def setup_discriminator(c):
nonlinear_activation="LeakyReLU", nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2}, nonlinear_activation_params={"negative_slope": 0.2},
) )
if c.discriminator_model == "parallel_wavegan_discriminator": if c.discriminator_model == 'parallel_wavegan_discriminator':
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_size=3, kernel_size=3,
num_layers=c.discriminator_model_params["num_layers"], num_layers=c.discriminator_model_params['num_layers'],
conv_channels=64, conv_channels=64,
dilation_factor=1, dilation_factor=1,
nonlinear_activation="LeakyReLU", nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2}, nonlinear_activation_params={"negative_slope": 0.2},
bias=True, bias=True
) )
return model return model