some minor changes to wavernn

This commit is contained in:
sanjaesc 2020-10-16 21:19:51 +02:00
parent da60f35c14
commit 9a120f28ed
4 changed files with 82 additions and 80 deletions

View File

@ -13,17 +13,13 @@ import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.radam import RAdam
from TTS.utils.io import copy_config_file, load_config
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
from TTS.utils.training import setup_torch_training_env
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.generic_utils import (
KeepAverage,
count_parameters,
@ -32,6 +28,10 @@ from TTS.utils.generic_utils import (
remove_experiment_folder,
set_init_dict,
)
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
@ -105,9 +105,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
# MODEL TRAINING #
##################
y_hat = model(x, m)
y_hat_vis = y_hat # for visualization
# y_hat = y_hat.transpose(1, 2)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
@ -200,8 +198,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
)
# compute spectrograms
figures = {
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False),
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False),
"prediction": plot_spectrogram(predict_mel.T),
"ground_truth": plot_spectrogram(ground_mel.T),
}
tb_logger.tb_train_figures(global_step, figures)
end_time = time.time()
@ -237,6 +235,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
global_step += 1
y_hat = model(x, m)
y_hat_viz = y_hat # for vizualization
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
@ -266,7 +265,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
if epoch > CONFIG.test_delay_epochs:
# synthesize a full voice
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
wav_path = train_data[random.randrange(0, len(train_data))][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate(
@ -283,8 +282,8 @@ def evaluate(model, criterion, ap, global_step, epoch):
)
# compute spectrograms
figures = {
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False),
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False),
"eval/prediction": plot_spectrogram(predict_mel.T),
"eval/ground_truth": plot_spectrogram(ground_mel.T),
}
tb_logger.tb_eval_figures(global_step, figures)
@ -303,7 +302,6 @@ def main(args): # pylint: disable=redefined-outer-name
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
)
eval_data, train_data = eval_data, train_data
else:
eval_data, train_data = load_wav_data(CONFIG.data_path, CONFIG.eval_split_size)
@ -326,7 +324,8 @@ def main(args): # pylint: disable=redefined-outer-name
if isinstance(CONFIG.mode, int):
criterion.cuda()
optimizer = optim.Adam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
scheduler = None
if "lr_scheduler" in CONFIG:
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)

View File

@ -1,5 +1,4 @@
{
"model": "wavernn",
"run_name": "wavernn_test",
"run_description": "wavernn_test training",
@ -54,13 +53,14 @@
"mode": "mold", // mold [string], gauss [string], bits [int]
"mulaw": false, // apply mulaw if mode is bits
"padding": 2, // pad the input for resnet to see wider input length
// DATASET
"use_gta": true, // use computed gta features from the tts model
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech/", // path containing training wav files
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing extracted features .npy (mels / quant)
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing computed features .npy (mels / quant)
// TRAINING
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"epochs": 10000, // total number of epochs to train.
"warmup_steps": 10,

View File

@ -7,8 +7,7 @@ from torch.utils.data import Dataset
class WaveRNNDataset(Dataset):
"""
WaveRNN Dataset searchs for all the wav files under root path
and converts them to acoustic features on the fly.
WaveRNN Dataset searchs for all the wav files under root path.
"""
def __init__(
@ -20,8 +19,6 @@ class WaveRNNDataset(Dataset):
pad,
mode,
is_training=True,
return_segments=True,
use_cache=False,
verbose=False,
):
@ -32,14 +29,8 @@ class WaveRNNDataset(Dataset):
self.pad = pad
self.mode = mode
self.is_training = is_training
self.return_segments = return_segments
self.use_cache = use_cache
self.verbose = verbose
# wav_files = [f"{self.path}wavs/{file}.wav" for file in self.metadata]
# with Pool(4) as pool:
# self.wav_cache = pool.map(self.ap.load_wav, wav_files)
def __len__(self):
return len(self.item_list)

View File

@ -39,11 +39,12 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
def to_camel(text):
text = text.capitalize()
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_wavernn(c):
print(" > Model: {}".format(c.model))
MyModel = importlib.import_module('TTS.vocoder.models.wavernn')
print(" > Model: WaveRNN")
MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
MyModel = getattr(MyModel, "WaveRNN")
model = MyModel(
rnn_dims=512,
@ -58,98 +59,109 @@ def setup_wavernn(c):
compute_dims=128,
res_out_dims=128,
res_blocks=10,
hop_length=c.audio['hop_length'],
sample_rate=c.audio['sample_rate'])
hop_length=c.audio["hop_length"],
sample_rate=c.audio["sample_rate"],
)
return model
def setup_generator(c):
print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.generator_model.lower())
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model in 'melgan_generator':
if c.generator_model in "melgan_generator":
model = MyModel(
in_channels=c.audio['num_mels'],
in_channels=c.audio["num_mels"],
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=c.generator_model_params['upsample_factors'],
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'melgan_fb_generator':
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
if c.generator_model in "melgan_fb_generator":
pass
if c.generator_model in 'multiband_melgan_generator':
if c.generator_model in "multiband_melgan_generator":
model = MyModel(
in_channels=c.audio['num_mels'],
in_channels=c.audio["num_mels"],
out_channels=4,
proj_kernel=7,
base_channels=384,
upsample_factors=c.generator_model_params['upsample_factors'],
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'fullband_melgan_generator':
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
if c.generator_model in "fullband_melgan_generator":
model = MyModel(
in_channels=c.audio['num_mels'],
in_channels=c.audio["num_mels"],
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=c.generator_model_params['upsample_factors'],
upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'parallel_wavegan_generator':
num_res_blocks=c.generator_model_params["num_res_blocks"],
)
if c.generator_model in "parallel_wavegan_generator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=c.generator_model_params['num_res_blocks'],
stacks=c.generator_model_params['stacks'],
num_res_blocks=c.generator_model_params["num_res_blocks"],
stacks=c.generator_model_params["stacks"],
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=c.audio['num_mels'],
aux_channels=c.audio["num_mels"],
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=c.generator_model_params['upsample_factors'])
upsample_factors=c.generator_model_params["upsample_factors"],
)
return model
def setup_discriminator(c):
print(" > Discriminator Model: {}".format(c.discriminator_model))
if 'parallel_wavegan' in c.discriminator_model:
if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module(
'TTS.vocoder.models.parallel_wavegan_discriminator')
"TTS.vocoder.models.parallel_wavegan_discriminator"
)
else:
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.discriminator_model.lower())
MyModel = importlib.import_module(
"TTS.vocoder.models." + c.discriminator_model.lower()
)
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in 'random_window_discriminator':
if c.discriminator_model in "random_window_discriminator":
model = MyModel(
cond_channels=c.audio['num_mels'],
hop_length=c.audio['hop_length'],
uncond_disc_donwsample_factors=c.
discriminator_model_params['uncond_disc_donwsample_factors'],
cond_disc_downsample_factors=c.
discriminator_model_params['cond_disc_downsample_factors'],
cond_disc_out_channels=c.
discriminator_model_params['cond_disc_out_channels'],
window_sizes=c.discriminator_model_params['window_sizes'])
if c.discriminator_model in 'melgan_multiscale_discriminator':
cond_channels=c.audio["num_mels"],
hop_length=c.audio["hop_length"],
uncond_disc_donwsample_factors=c.discriminator_model_params[
"uncond_disc_donwsample_factors"
],
cond_disc_downsample_factors=c.discriminator_model_params[
"cond_disc_downsample_factors"
],
cond_disc_out_channels=c.discriminator_model_params[
"cond_disc_out_channels"
],
window_sizes=c.discriminator_model_params["window_sizes"],
)
if c.discriminator_model in "melgan_multiscale_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=c.discriminator_model_params['base_channels'],
max_channels=c.discriminator_model_params['max_channels'],
downsample_factors=c.
discriminator_model_params['downsample_factors'])
if c.discriminator_model == 'residual_parallel_wavegan_discriminator':
base_channels=c.discriminator_model_params["base_channels"],
max_channels=c.discriminator_model_params["max_channels"],
downsample_factors=c.discriminator_model_params["downsample_factors"],
)
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=c.discriminator_model_params['num_layers'],
stacks=c.discriminator_model_params['stacks'],
num_layers=c.discriminator_model_params["num_layers"],
stacks=c.discriminator_model_params["stacks"],
res_channels=64,
gate_channels=128,
skip_channels=64,
@ -158,17 +170,17 @@ def setup_discriminator(c):
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
)
if c.discriminator_model == 'parallel_wavegan_discriminator':
if c.discriminator_model == "parallel_wavegan_discriminator":
model = MyModel(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=c.discriminator_model_params['num_layers'],
num_layers=c.discriminator_model_params["num_layers"],
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True
bias=True,
)
return model