mirror of https://github.com/coqui-ai/TTS.git
fix formatting + pylint
This commit is contained in:
parent
64adfbf4a5
commit
24d18d20e3
|
@ -15,23 +15,17 @@ from TTS.utils.audio import AudioProcessor
|
||||||
def main():
|
def main():
|
||||||
"""Run preprocessing process."""
|
"""Run preprocessing process."""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Compute mean and variance of spectrogtram features."
|
description="Compute mean and variance of spectrogtram features.")
|
||||||
)
|
parser.add_argument("--config_path", type=str, required=True,
|
||||||
parser.add_argument(
|
help="TTS config file path to define audio processin parameters.")
|
||||||
"--config_path",
|
parser.add_argument("--out_path", default=None, type=str,
|
||||||
type=str,
|
help="directory to save the output file.")
|
||||||
required=True,
|
|
||||||
help="TTS config file path to define audio processin parameters.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--out_path", default=None, type=str, help="directory to save the output file."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
CONFIG = load_config(args.config_path)
|
CONFIG = load_config(args.config_path)
|
||||||
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization
|
CONFIG.audio['signal_norm'] = False # do not apply earlier normalization
|
||||||
CONFIG.audio["stats_path"] = None # discard pre-defined stats
|
CONFIG.audio['stats_path'] = None # discard pre-defined stats
|
||||||
|
|
||||||
# load audio processor
|
# load audio processor
|
||||||
ap = AudioProcessor(**CONFIG.audio)
|
ap = AudioProcessor(**CONFIG.audio)
|
||||||
|
@ -65,27 +59,27 @@ def main():
|
||||||
|
|
||||||
output_file_path = os.path.join(args.out_path, "scale_stats.npy")
|
output_file_path = os.path.join(args.out_path, "scale_stats.npy")
|
||||||
stats = {}
|
stats = {}
|
||||||
stats["mel_mean"] = mel_mean
|
stats['mel_mean'] = mel_mean
|
||||||
stats["mel_std"] = mel_scale
|
stats['mel_std'] = mel_scale
|
||||||
stats["linear_mean"] = linear_mean
|
stats['linear_mean'] = linear_mean
|
||||||
stats["linear_std"] = linear_scale
|
stats['linear_std'] = linear_scale
|
||||||
|
|
||||||
print(f" > Avg mel spec mean: {mel_mean.mean()}")
|
print(f' > Avg mel spec mean: {mel_mean.mean()}')
|
||||||
print(f" > Avg mel spec scale: {mel_scale.mean()}")
|
print(f' > Avg mel spec scale: {mel_scale.mean()}')
|
||||||
print(f" > Avg linear spec mean: {linear_mean.mean()}")
|
print(f' > Avg linear spec mean: {linear_mean.mean()}')
|
||||||
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
|
print(f' > Avg lienar spec scale: {linear_scale.mean()}')
|
||||||
|
|
||||||
# set default config values for mean-var scaling
|
# set default config values for mean-var scaling
|
||||||
CONFIG.audio["stats_path"] = output_file_path
|
CONFIG.audio['stats_path'] = output_file_path
|
||||||
CONFIG.audio["signal_norm"] = True
|
CONFIG.audio['signal_norm'] = True
|
||||||
# remove redundant values
|
# remove redundant values
|
||||||
del CONFIG.audio["max_norm"]
|
del CONFIG.audio['max_norm']
|
||||||
del CONFIG.audio["min_level_db"]
|
del CONFIG.audio['min_level_db']
|
||||||
del CONFIG.audio["symmetric_norm"]
|
del CONFIG.audio['symmetric_norm']
|
||||||
del CONFIG.audio["clip_norm"]
|
del CONFIG.audio['clip_norm']
|
||||||
stats["audio_config"] = CONFIG.audio
|
stats['audio_config'] = CONFIG.audio
|
||||||
np.save(output_file_path, stats, allow_pickle=True)
|
np.save(output_file_path, stats, allow_pickle=True)
|
||||||
print(f" > scale_stats.npy is saved to {output_file_path}")
|
print(f' > scale_stats.npy is saved to {output_file_path}')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -10,29 +10,20 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
KeepAverage,
|
create_experiment_folder, get_git_branch,
|
||||||
count_parameters,
|
remove_experiment_folder, set_init_dict)
|
||||||
create_experiment_folder,
|
|
||||||
get_git_branch,
|
|
||||||
remove_experiment_folder,
|
|
||||||
set_init_dict,
|
|
||||||
)
|
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
|
|
||||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
# init_distributed, reduce_tensor)
|
# init_distributed, reduce_tensor)
|
||||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||||
from TTS.vocoder.utils.generic_utils import (
|
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
|
||||||
plot_results,
|
setup_generator)
|
||||||
setup_discriminator,
|
|
||||||
setup_generator,
|
|
||||||
)
|
|
||||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||||
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||||
|
@ -42,30 +33,27 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = GANDataset(
|
dataset = GANDataset(ap=ap,
|
||||||
ap=ap,
|
items=eval_data if is_val else train_data,
|
||||||
items=eval_data if is_val else train_data,
|
seq_len=c.seq_len,
|
||||||
seq_len=c.seq_len,
|
hop_len=ap.hop_length,
|
||||||
hop_len=ap.hop_length,
|
pad_short=c.pad_short,
|
||||||
pad_short=c.pad_short,
|
conv_pad=c.conv_pad,
|
||||||
conv_pad=c.conv_pad,
|
is_training=not is_val,
|
||||||
is_training=not is_val,
|
return_segments=not is_val,
|
||||||
return_segments=not is_val,
|
use_noise_augment=c.use_noise_augment,
|
||||||
use_noise_augment=c.use_noise_augment,
|
use_cache=c.use_cache,
|
||||||
use_cache=c.use_cache,
|
verbose=verbose)
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
dataset.shuffle_mapping()
|
dataset.shuffle_mapping()
|
||||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(dataset,
|
||||||
dataset,
|
batch_size=1 if is_val else c.batch_size,
|
||||||
batch_size=1 if is_val else c.batch_size,
|
shuffle=True,
|
||||||
shuffle=True,
|
drop_last=False,
|
||||||
drop_last=False,
|
sampler=None,
|
||||||
sampler=None,
|
num_workers=c.num_val_loader_workers
|
||||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
if is_val else c.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False)
|
||||||
)
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,26 +80,16 @@ def format_data(data):
|
||||||
return co, x, None, None
|
return co, x, None, None
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
model_G,
|
scheduler_G, scheduler_D, ap, global_step, epoch):
|
||||||
criterion_G,
|
|
||||||
optimizer_G,
|
|
||||||
model_D,
|
|
||||||
criterion_D,
|
|
||||||
optimizer_D,
|
|
||||||
scheduler_G,
|
|
||||||
scheduler_D,
|
|
||||||
ap,
|
|
||||||
global_step,
|
|
||||||
epoch,
|
|
||||||
):
|
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
model_G.train()
|
model_G.train()
|
||||||
model_D.train()
|
model_D.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
batch_n_iter = int(
|
||||||
|
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||||
else:
|
else:
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -167,16 +145,16 @@ def train(
|
||||||
scores_fake = D_out_fake
|
scores_fake = D_out_fake
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss_G_dict = criterion_G(
|
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||||
y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub
|
feats_real, y_hat_sub, y_G_sub)
|
||||||
)
|
loss_G = loss_G_dict['G_loss']
|
||||||
loss_G = loss_G_dict["G_loss"]
|
|
||||||
|
|
||||||
# optimizer generator
|
# optimizer generator
|
||||||
optimizer_G.zero_grad()
|
optimizer_G.zero_grad()
|
||||||
loss_G.backward()
|
loss_G.backward()
|
||||||
if c.gen_clip_grad > 0:
|
if c.gen_clip_grad > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad)
|
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||||
|
c.gen_clip_grad)
|
||||||
optimizer_G.step()
|
optimizer_G.step()
|
||||||
if scheduler_G is not None:
|
if scheduler_G is not None:
|
||||||
scheduler_G.step()
|
scheduler_G.step()
|
||||||
|
@ -221,13 +199,14 @@ def train(
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||||
loss_D = loss_D_dict["D_loss"]
|
loss_D = loss_D_dict['D_loss']
|
||||||
|
|
||||||
# optimizer discriminator
|
# optimizer discriminator
|
||||||
optimizer_D.zero_grad()
|
optimizer_D.zero_grad()
|
||||||
loss_D.backward()
|
loss_D.backward()
|
||||||
if c.disc_clip_grad > 0:
|
if c.disc_clip_grad > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad)
|
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||||
|
c.disc_clip_grad)
|
||||||
optimizer_D.step()
|
optimizer_D.step()
|
||||||
if scheduler_D is not None:
|
if scheduler_D is not None:
|
||||||
scheduler_D.step()
|
scheduler_D.step()
|
||||||
|
@ -242,40 +221,34 @@ def train(
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
# get current learning rates
|
# get current learning rates
|
||||||
current_lr_G = list(optimizer_G.param_groups)[0]["lr"]
|
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
||||||
current_lr_D = list(optimizer_D.param_groups)[0]["lr"]
|
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
||||||
|
|
||||||
# update avg stats
|
# update avg stats
|
||||||
update_train_values = dict()
|
update_train_values = dict()
|
||||||
for key, value in loss_dict.items():
|
for key, value in loss_dict.items():
|
||||||
update_train_values["avg_" + key] = value
|
update_train_values['avg_' + key] = value
|
||||||
update_train_values["avg_loader_time"] = loader_time
|
update_train_values['avg_loader_time'] = loader_time
|
||||||
update_train_values["avg_step_time"] = step_time
|
update_train_values['avg_step_time'] = step_time
|
||||||
keep_avg.update_values(update_train_values)
|
keep_avg.update_values(update_train_values)
|
||||||
|
|
||||||
# print training stats
|
# print training stats
|
||||||
if global_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
log_dict = {
|
log_dict = {
|
||||||
"step_time": [step_time, 2],
|
'step_time': [step_time, 2],
|
||||||
"loader_time": [loader_time, 4],
|
'loader_time': [loader_time, 4],
|
||||||
"current_lr_G": current_lr_G,
|
"current_lr_G": current_lr_G,
|
||||||
"current_lr_D": current_lr_D,
|
"current_lr_D": current_lr_D
|
||||||
}
|
}
|
||||||
c_logger.print_train_step(
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||||
batch_n_iter,
|
log_dict, loss_dict, keep_avg.avg_values)
|
||||||
num_iter,
|
|
||||||
global_step,
|
|
||||||
log_dict,
|
|
||||||
loss_dict,
|
|
||||||
keep_avg.avg_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
# plot step stats
|
# plot step stats
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
iter_stats = {
|
iter_stats = {
|
||||||
"lr_G": current_lr_G,
|
"lr_G": current_lr_G,
|
||||||
"lr_D": current_lr_D,
|
"lr_D": current_lr_D,
|
||||||
"step_time": step_time,
|
"step_time": step_time
|
||||||
}
|
}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||||
|
@ -284,28 +257,27 @@ def train(
|
||||||
if global_step % c.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(
|
save_checkpoint(model_G,
|
||||||
model_G,
|
optimizer_G,
|
||||||
optimizer_G,
|
scheduler_G,
|
||||||
scheduler_G,
|
model_D,
|
||||||
model_D,
|
optimizer_D,
|
||||||
optimizer_D,
|
scheduler_D,
|
||||||
scheduler_D,
|
global_step,
|
||||||
global_step,
|
epoch,
|
||||||
epoch,
|
OUT_PATH,
|
||||||
OUT_PATH,
|
model_losses=loss_dict)
|
||||||
model_losses=loss_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = plot_results(y_hat_vis, y_G, ap, global_step, "train")
|
figures = plot_results(y_hat_vis, y_G, ap, global_step,
|
||||||
|
'train')
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
||||||
tb_logger.tb_train_audios(
|
tb_logger.tb_train_audios(global_step,
|
||||||
global_step, {"train/audio": sample_voice}, c.audio["sample_rate"]
|
{'train/audio': sample_voice},
|
||||||
)
|
c.audio["sample_rate"])
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
|
@ -379,9 +351,8 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
||||||
feats_fake, feats_real = None, None
|
feats_fake, feats_real = None, None
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss_G_dict = criterion_G(
|
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||||
y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub
|
feats_real, y_hat_sub, y_G_sub)
|
||||||
)
|
|
||||||
|
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for key, value in loss_G_dict.items():
|
for key, value in loss_G_dict.items():
|
||||||
|
@ -437,9 +408,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
||||||
# update avg stats
|
# update avg stats
|
||||||
update_eval_values = dict()
|
update_eval_values = dict()
|
||||||
for key, value in loss_dict.items():
|
for key, value in loss_dict.items():
|
||||||
update_eval_values["avg_" + key] = value
|
update_eval_values['avg_' + key] = value
|
||||||
update_eval_values["avg_loader_time"] = loader_time
|
update_eval_values['avg_loader_time'] = loader_time
|
||||||
update_eval_values["avg_step_time"] = step_time
|
update_eval_values['avg_step_time'] = step_time
|
||||||
keep_avg.update_values(update_eval_values)
|
keep_avg.update_values(update_eval_values)
|
||||||
|
|
||||||
# print eval stats
|
# print eval stats
|
||||||
|
@ -447,14 +418,13 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
||||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = plot_results(y_hat, y_G, ap, global_step, "eval")
|
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
|
||||||
tb_logger.tb_eval_figures(global_step, figures)
|
tb_logger.tb_eval_figures(global_step, figures)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
tb_logger.tb_eval_audios(
|
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||||
global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"]
|
c.audio["sample_rate"])
|
||||||
)
|
|
||||||
|
|
||||||
# synthesize a full voice
|
# synthesize a full voice
|
||||||
data_loader.return_segments = False
|
data_loader.return_segments = False
|
||||||
|
@ -472,8 +442,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if c.feature_path is not None:
|
if c.feature_path is not None:
|
||||||
print(f" > Loading features from: {c.feature_path}")
|
print(f" > Loading features from: {c.feature_path}")
|
||||||
eval_data, train_data = load_wav_feat_data(
|
eval_data, train_data = load_wav_feat_data(
|
||||||
c.data_path, c.feature_path, c.eval_split_size
|
c.data_path, c.feature_path, c.eval_split_size)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||||
|
|
||||||
|
@ -491,63 +460,68 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
# setup optimizers
|
# setup optimizers
|
||||||
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
|
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
|
||||||
optimizer_disc = RAdam(model_disc.parameters(), lr=c.lr_disc, weight_decay=0)
|
optimizer_disc = RAdam(model_disc.parameters(),
|
||||||
|
lr=c.lr_disc,
|
||||||
|
weight_decay=0)
|
||||||
|
|
||||||
# schedulers
|
# schedulers
|
||||||
scheduler_gen = None
|
scheduler_gen = None
|
||||||
scheduler_disc = None
|
scheduler_disc = None
|
||||||
if "lr_scheduler_gen" in c:
|
if 'lr_scheduler_gen' in c:
|
||||||
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
||||||
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
|
scheduler_gen = scheduler_gen(
|
||||||
if "lr_scheduler_disc" in c:
|
optimizer_gen, **c.lr_scheduler_gen_params)
|
||||||
|
if 'lr_scheduler_disc' in c:
|
||||||
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
||||||
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
|
scheduler_disc = scheduler_disc(
|
||||||
|
optimizer_disc, **c.lr_scheduler_disc_params)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion_gen = GeneratorLoss(c)
|
criterion_gen = GeneratorLoss(c)
|
||||||
criterion_disc = DiscriminatorLoss(c)
|
criterion_disc = DiscriminatorLoss(c)
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
try:
|
try:
|
||||||
print(" > Restoring Generator Model...")
|
print(" > Restoring Generator Model...")
|
||||||
model_gen.load_state_dict(checkpoint["model"])
|
model_gen.load_state_dict(checkpoint['model'])
|
||||||
print(" > Restoring Generator Optimizer...")
|
print(" > Restoring Generator Optimizer...")
|
||||||
optimizer_gen.load_state_dict(checkpoint["optimizer"])
|
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
||||||
print(" > Restoring Discriminator Model...")
|
print(" > Restoring Discriminator Model...")
|
||||||
model_disc.load_state_dict(checkpoint["model_disc"])
|
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||||
print(" > Restoring Discriminator Optimizer...")
|
print(" > Restoring Discriminator Optimizer...")
|
||||||
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
|
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||||
if "scheduler" in checkpoint:
|
if 'scheduler' in checkpoint:
|
||||||
print(" > Restoring Generator LR Scheduler...")
|
print(" > Restoring Generator LR Scheduler...")
|
||||||
scheduler_gen.load_state_dict(checkpoint["scheduler"])
|
scheduler_gen.load_state_dict(checkpoint['scheduler'])
|
||||||
# NOTE: Not sure if necessary
|
# NOTE: Not sure if necessary
|
||||||
scheduler_gen.optimizer = optimizer_gen
|
scheduler_gen.optimizer = optimizer_gen
|
||||||
if "scheduler_disc" in checkpoint:
|
if 'scheduler_disc' in checkpoint:
|
||||||
print(" > Restoring Discriminator LR Scheduler...")
|
print(" > Restoring Discriminator LR Scheduler...")
|
||||||
scheduler_disc.load_state_dict(checkpoint["scheduler_disc"])
|
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
||||||
scheduler_disc.optimizer = optimizer_disc
|
scheduler_disc.optimizer = optimizer_disc
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# retore only matching layers.
|
# retore only matching layers.
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
model_dict = model_gen.state_dict()
|
model_dict = model_gen.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
model_gen.load_state_dict(model_dict)
|
model_gen.load_state_dict(model_dict)
|
||||||
|
|
||||||
model_dict = model_disc.state_dict()
|
model_dict = model_disc.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
|
||||||
model_disc.load_state_dict(model_dict)
|
model_disc.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
# reset lr if not countinuining training.
|
# reset lr if not countinuining training.
|
||||||
for group in optimizer_gen.param_groups:
|
for group in optimizer_gen.param_groups:
|
||||||
group["lr"] = c.lr_gen
|
group['lr'] = c.lr_gen
|
||||||
|
|
||||||
for group in optimizer_disc.param_groups:
|
for group in optimizer_disc.param_groups:
|
||||||
group["lr"] = c.lr_disc
|
group['lr'] = c.lr_disc
|
||||||
|
|
||||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
print(" > Model restored from step %d" % checkpoint['step'],
|
||||||
args.restore_step = checkpoint["step"]
|
flush=True)
|
||||||
|
args.restore_step = checkpoint['step']
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
||||||
|
@ -566,92 +540,74 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
num_params = count_parameters(model_disc)
|
num_params = count_parameters(model_disc)
|
||||||
print(" > Discriminator has {} parameters".format(num_params), flush=True)
|
print(" > Discriminator has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
if "best_loss" not in locals():
|
if 'best_loss' not in locals():
|
||||||
best_loss = float("inf")
|
best_loss = float('inf')
|
||||||
|
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
_, global_step = train(
|
_, global_step = train(model_gen, criterion_gen, optimizer_gen,
|
||||||
model_gen,
|
model_disc, criterion_disc, optimizer_disc,
|
||||||
criterion_gen,
|
scheduler_gen, scheduler_disc, ap, global_step,
|
||||||
optimizer_gen,
|
epoch)
|
||||||
model_disc,
|
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap,
|
||||||
criterion_disc,
|
global_step, epoch)
|
||||||
optimizer_disc,
|
|
||||||
scheduler_gen,
|
|
||||||
scheduler_disc,
|
|
||||||
ap,
|
|
||||||
global_step,
|
|
||||||
epoch,
|
|
||||||
)
|
|
||||||
eval_avg_loss_dict = evaluate(
|
|
||||||
model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch
|
|
||||||
)
|
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||||
best_loss = save_best_model(
|
best_loss = save_best_model(target_loss,
|
||||||
target_loss,
|
best_loss,
|
||||||
best_loss,
|
model_gen,
|
||||||
model_gen,
|
optimizer_gen,
|
||||||
optimizer_gen,
|
scheduler_gen,
|
||||||
scheduler_gen,
|
model_disc,
|
||||||
model_disc,
|
optimizer_disc,
|
||||||
optimizer_disc,
|
scheduler_disc,
|
||||||
scheduler_disc,
|
global_step,
|
||||||
global_step,
|
epoch,
|
||||||
epoch,
|
OUT_PATH,
|
||||||
OUT_PATH,
|
model_losses=eval_avg_loss_dict)
|
||||||
model_losses=eval_avg_loss_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--continue_path",
|
'--continue_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||||
default="",
|
default='',
|
||||||
required="--config_path" not in sys.argv,
|
required='--config_path' not in sys.argv)
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--restore_path",
|
'--restore_path',
|
||||||
type=str,
|
type=str,
|
||||||
help="Model file to be restored. Use to finetune a model.",
|
help='Model file to be restored. Use to finetune a model.',
|
||||||
default="",
|
default='')
|
||||||
)
|
parser.add_argument('--config_path',
|
||||||
parser.add_argument(
|
type=str,
|
||||||
"--config_path",
|
help='Path to config file for training.',
|
||||||
type=str,
|
required='--continue_path' not in sys.argv)
|
||||||
help="Path to config file for training.",
|
parser.add_argument('--debug',
|
||||||
required="--continue_path" not in sys.argv,
|
type=bool,
|
||||||
)
|
default=False,
|
||||||
parser.add_argument(
|
help='Do not verify commit integrity to run training.')
|
||||||
"--debug",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="Do not verify commit integrity to run training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rank",
|
'--rank',
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="DISTRIBUTED: process rank for distributed training.",
|
help='DISTRIBUTED: process rank for distributed training.')
|
||||||
)
|
parser.add_argument('--group_id',
|
||||||
parser.add_argument(
|
type=str,
|
||||||
"--group_id", type=str, default="", help="DISTRIBUTED: process group id."
|
default="",
|
||||||
)
|
help='DISTRIBUTED: process group id.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.continue_path != "":
|
if args.continue_path != '':
|
||||||
args.output_path = args.continue_path
|
args.output_path = args.continue_path
|
||||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||||
list_of_files = glob.glob(
|
list_of_files = glob.glob(
|
||||||
args.continue_path + "/*.pth.tar"
|
args.continue_path +
|
||||||
) # * means all if need specific format then *.csv
|
"/*.pth.tar") # * means all if need specific format then *.csv
|
||||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||||
args.restore_path = latest_model_file
|
args.restore_path = latest_model_file
|
||||||
print(f" > Training continues for {args.restore_path}")
|
print(f" > Training continues for {args.restore_path}")
|
||||||
|
@ -662,10 +618,11 @@ if __name__ == "__main__":
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
OUT_PATH = args.continue_path
|
||||||
if args.continue_path == "":
|
if args.continue_path == '':
|
||||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
||||||
|
args.debug)
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
|
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||||
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
|
|
||||||
|
@ -675,17 +632,16 @@ if __name__ == "__main__":
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
new_fields["restore_path"] = args.restore_path
|
new_fields["restore_path"] = args.restore_path
|
||||||
new_fields["github_branch"] = get_git_branch()
|
new_fields["github_branch"] = get_git_branch()
|
||||||
copy_config_file(
|
copy_config_file(args.config_path,
|
||||||
args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields
|
os.path.join(OUT_PATH, 'config.json'), new_fields)
|
||||||
)
|
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
os.chmod(AUDIO_PATH, 0o775)
|
||||||
os.chmod(OUT_PATH, 0o775)
|
os.chmod(OUT_PATH, 0o775)
|
||||||
|
|
||||||
LOG_DIR = OUT_PATH
|
LOG_DIR = OUT_PATH
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
||||||
|
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -698,4 +654,4 @@ if __name__ == "__main__":
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
remove_experiment_folder(OUT_PATH)
|
remove_experiment_folder(OUT_PATH)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
@ -365,28 +365,6 @@ class WaveRNN(nn.Module):
|
||||||
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_gru_cell(gru):
|
|
||||||
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
|
||||||
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
|
||||||
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
|
||||||
gru_cell.bias_hh.data = gru.bias_hh_l0.data
|
|
||||||
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
|
||||||
return gru_cell
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def pad_tensor(x, pad, side="both"):
|
|
||||||
# NB - this is just a quick method i need right now
|
|
||||||
# i.e., it won't generalise to other shapes/dims
|
|
||||||
b, t, c = x.size()
|
|
||||||
total = t + 2 * pad if side == "both" else t + pad
|
|
||||||
padded = torch.zeros(b, total, c).cuda()
|
|
||||||
if side in ("before", "both"):
|
|
||||||
padded[:, pad : pad + t, :] = x
|
|
||||||
elif side == "after":
|
|
||||||
padded[:, :t, :] = x
|
|
||||||
return padded
|
|
||||||
|
|
||||||
def fold_with_overlap(self, x, target, overlap):
|
def fold_with_overlap(self, x, target, overlap):
|
||||||
|
|
||||||
"""Fold the tensor with overlap for quick batched inference.
|
"""Fold the tensor with overlap for quick batched inference.
|
||||||
|
@ -430,7 +408,30 @@ class WaveRNN(nn.Module):
|
||||||
|
|
||||||
return folded
|
return folded
|
||||||
|
|
||||||
def xfade_and_unfold(self, y, target, overlap):
|
@staticmethod
|
||||||
|
def get_gru_cell(gru):
|
||||||
|
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
||||||
|
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
||||||
|
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
||||||
|
gru_cell.bias_hh.data = gru.bias_hh_l0.data
|
||||||
|
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
||||||
|
return gru_cell
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pad_tensor(x, pad, side="both"):
|
||||||
|
# NB - this is just a quick method i need right now
|
||||||
|
# i.e., it won't generalise to other shapes/dims
|
||||||
|
b, t, c = x.size()
|
||||||
|
total = t + 2 * pad if side == "both" else t + pad
|
||||||
|
padded = torch.zeros(b, total, c).cuda()
|
||||||
|
if side in ("before", "both"):
|
||||||
|
padded[:, pad : pad + t, :] = x
|
||||||
|
elif side == "after":
|
||||||
|
padded[:, :t, :] = x
|
||||||
|
return padded
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def xfade_and_unfold(y, target, overlap):
|
||||||
|
|
||||||
"""Applies a crossfade and unfolds into a 1d array.
|
"""Applies a crossfade and unfolds into a 1d array.
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -28,7 +28,8 @@ def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
|
||||||
torch.exp(log_std),
|
torch.exp(log_std),
|
||||||
)
|
)
|
||||||
sample = dist.sample()
|
sample = dist.sample()
|
||||||
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
|
sample = torch.clamp(torch.clamp(
|
||||||
|
sample, min=-scale_factor), max=scale_factor)
|
||||||
del dist
|
del dist
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
@ -58,8 +59,9 @@ def discretized_mix_logistic_loss(
|
||||||
|
|
||||||
# unpack parameters. (B, T, num_mixtures) x 3
|
# unpack parameters. (B, T, num_mixtures) x 3
|
||||||
logit_probs = y_hat[:, :, :nr_mix]
|
logit_probs = y_hat[:, :, :nr_mix]
|
||||||
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
means = y_hat[:, :, nr_mix: 2 * nr_mix]
|
||||||
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
|
log_scales = torch.clamp(
|
||||||
|
y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min)
|
||||||
|
|
||||||
# B x T x 1 -> B x T x num_mixtures
|
# B x T x 1 -> B x T x num_mixtures
|
||||||
y = y.expand_as(means)
|
y = y.expand_as(means)
|
||||||
|
@ -104,7 +106,8 @@ def discretized_mix_logistic_loss(
|
||||||
) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
||||||
inner_cond = (y > 0.999).float()
|
inner_cond = (y > 0.999).float()
|
||||||
inner_out = (
|
inner_out = (
|
||||||
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
|
inner_cond * log_one_minus_cdf_min +
|
||||||
|
(1.0 - inner_cond) * inner_inner_out
|
||||||
)
|
)
|
||||||
cond = (y < -0.999).float()
|
cond = (y < -0.999).float()
|
||||||
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
||||||
|
@ -142,9 +145,9 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
||||||
# (B, T) -> (B, T, nr_mix)
|
# (B, T) -> (B, T, nr_mix)
|
||||||
one_hot = to_one_hot(argmax, nr_mix)
|
one_hot = to_one_hot(argmax, nr_mix)
|
||||||
# select logistic parameters
|
# select logistic parameters
|
||||||
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
means = torch.sum(y[:, :, nr_mix: 2 * nr_mix] * one_hot, dim=-1)
|
||||||
log_scales = torch.clamp(
|
log_scales = torch.clamp(
|
||||||
torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min
|
torch.sum(y[:, :, 2 * nr_mix: 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min
|
||||||
)
|
)
|
||||||
# sample from logistic & clip to interval
|
# sample from logistic & clip to interval
|
||||||
# we don't actually round to the nearest 8bit value when sampling
|
# we don't actually round to the nearest 8bit value when sampling
|
||||||
|
|
|
@ -39,7 +39,7 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
||||||
|
|
||||||
def to_camel(text):
|
def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
def setup_wavernn(c):
|
def setup_wavernn(c):
|
||||||
|
@ -67,101 +67,92 @@ def setup_wavernn(c):
|
||||||
|
|
||||||
def setup_generator(c):
|
def setup_generator(c):
|
||||||
print(" > Generator Model: {}".format(c.generator_model))
|
print(" > Generator Model: {}".format(c.generator_model))
|
||||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||||
|
c.generator_model.lower())
|
||||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||||
if c.generator_model in "melgan_generator":
|
if c.generator_model in 'melgan_generator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio["num_mels"],
|
in_channels=c.audio['num_mels'],
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
proj_kernel=7,
|
proj_kernel=7,
|
||||||
base_channels=512,
|
base_channels=512,
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
)
|
if c.generator_model in 'melgan_fb_generator':
|
||||||
if c.generator_model in "melgan_fb_generator":
|
|
||||||
pass
|
pass
|
||||||
if c.generator_model in "multiband_melgan_generator":
|
if c.generator_model in 'multiband_melgan_generator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio["num_mels"],
|
in_channels=c.audio['num_mels'],
|
||||||
out_channels=4,
|
out_channels=4,
|
||||||
proj_kernel=7,
|
proj_kernel=7,
|
||||||
base_channels=384,
|
base_channels=384,
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
)
|
if c.generator_model in 'fullband_melgan_generator':
|
||||||
if c.generator_model in "fullband_melgan_generator":
|
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio["num_mels"],
|
in_channels=c.audio['num_mels'],
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
proj_kernel=7,
|
proj_kernel=7,
|
||||||
base_channels=512,
|
base_channels=512,
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
)
|
if c.generator_model in 'parallel_wavegan_generator':
|
||||||
if c.generator_model in "parallel_wavegan_generator":
|
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
num_res_blocks=c.generator_model_params['num_res_blocks'],
|
||||||
stacks=c.generator_model_params["stacks"],
|
stacks=c.generator_model_params['stacks'],
|
||||||
res_channels=64,
|
res_channels=64,
|
||||||
gate_channels=128,
|
gate_channels=128,
|
||||||
skip_channels=64,
|
skip_channels=64,
|
||||||
aux_channels=c.audio["num_mels"],
|
aux_channels=c.audio['num_mels'],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
bias=True,
|
bias=True,
|
||||||
use_weight_norm=True,
|
use_weight_norm=True,
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
upsample_factors=c.generator_model_params['upsample_factors'])
|
||||||
)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def setup_discriminator(c):
|
def setup_discriminator(c):
|
||||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||||
if "parallel_wavegan" in c.discriminator_model:
|
if 'parallel_wavegan' in c.discriminator_model:
|
||||||
MyModel = importlib.import_module(
|
MyModel = importlib.import_module(
|
||||||
"TTS.vocoder.models.parallel_wavegan_discriminator"
|
'TTS.vocoder.models.parallel_wavegan_discriminator')
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
MyModel = importlib.import_module(
|
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||||
"TTS.vocoder.models." + c.discriminator_model.lower()
|
c.discriminator_model.lower())
|
||||||
)
|
|
||||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||||
if c.discriminator_model in "random_window_discriminator":
|
if c.discriminator_model in 'random_window_discriminator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
cond_channels=c.audio["num_mels"],
|
cond_channels=c.audio['num_mels'],
|
||||||
hop_length=c.audio["hop_length"],
|
hop_length=c.audio['hop_length'],
|
||||||
uncond_disc_donwsample_factors=c.discriminator_model_params[
|
uncond_disc_donwsample_factors=c.
|
||||||
"uncond_disc_donwsample_factors"
|
discriminator_model_params['uncond_disc_donwsample_factors'],
|
||||||
],
|
cond_disc_downsample_factors=c.
|
||||||
cond_disc_downsample_factors=c.discriminator_model_params[
|
discriminator_model_params['cond_disc_downsample_factors'],
|
||||||
"cond_disc_downsample_factors"
|
cond_disc_out_channels=c.
|
||||||
],
|
discriminator_model_params['cond_disc_out_channels'],
|
||||||
cond_disc_out_channels=c.discriminator_model_params[
|
window_sizes=c.discriminator_model_params['window_sizes'])
|
||||||
"cond_disc_out_channels"
|
if c.discriminator_model in 'melgan_multiscale_discriminator':
|
||||||
],
|
|
||||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
|
||||||
)
|
|
||||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_sizes=(5, 3),
|
kernel_sizes=(5, 3),
|
||||||
base_channels=c.discriminator_model_params["base_channels"],
|
base_channels=c.discriminator_model_params['base_channels'],
|
||||||
max_channels=c.discriminator_model_params["max_channels"],
|
max_channels=c.discriminator_model_params['max_channels'],
|
||||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
downsample_factors=c.
|
||||||
)
|
discriminator_model_params['downsample_factors'])
|
||||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
if c.discriminator_model == 'residual_parallel_wavegan_discriminator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
num_layers=c.discriminator_model_params["num_layers"],
|
num_layers=c.discriminator_model_params['num_layers'],
|
||||||
stacks=c.discriminator_model_params["stacks"],
|
stacks=c.discriminator_model_params['stacks'],
|
||||||
res_channels=64,
|
res_channels=64,
|
||||||
gate_channels=128,
|
gate_channels=128,
|
||||||
skip_channels=64,
|
skip_channels=64,
|
||||||
|
@ -170,17 +161,17 @@ def setup_discriminator(c):
|
||||||
nonlinear_activation="LeakyReLU",
|
nonlinear_activation="LeakyReLU",
|
||||||
nonlinear_activation_params={"negative_slope": 0.2},
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
)
|
)
|
||||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
if c.discriminator_model == 'parallel_wavegan_discriminator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
num_layers=c.discriminator_model_params["num_layers"],
|
num_layers=c.discriminator_model_params['num_layers'],
|
||||||
conv_channels=64,
|
conv_channels=64,
|
||||||
dilation_factor=1,
|
dilation_factor=1,
|
||||||
nonlinear_activation="LeakyReLU",
|
nonlinear_activation="LeakyReLU",
|
||||||
nonlinear_activation_params={"negative_slope": 0.2},
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
bias=True,
|
bias=True
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue