some minor changes to wavernn

This commit is contained in:
sanjaesc 2020-10-16 21:19:51 +02:00
parent da60f35c14
commit 9a120f28ed
4 changed files with 82 additions and 80 deletions

View File

@ -13,17 +13,13 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.radam import RAdam
from TTS.utils.io import copy_config_file, load_config from TTS.utils.io import copy_config_file, load_config
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
from TTS.utils.training import setup_torch_training_env from TTS.utils.training import setup_torch_training_env
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.generic_utils import ( from TTS.utils.generic_utils import (
KeepAverage, KeepAverage,
count_parameters, count_parameters,
@ -32,6 +28,10 @@ from TTS.utils.generic_utils import (
remove_experiment_folder, remove_experiment_folder,
set_init_dict, set_init_dict,
) )
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
from TTS.vocoder.utils.io import save_best_model, save_checkpoint from TTS.vocoder.utils.io import save_best_model, save_checkpoint
@ -105,9 +105,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
# MODEL TRAINING # # MODEL TRAINING #
################## ##################
y_hat = model(x, m) y_hat = model(x, m)
y_hat_vis = y_hat # for visualization
# y_hat = y_hat.transpose(1, 2)
if isinstance(model.mode, int): if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1) y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else: else:
@ -200,8 +198,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
) )
# compute spectrograms # compute spectrograms
figures = { figures = {
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False), "prediction": plot_spectrogram(predict_mel.T),
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False), "ground_truth": plot_spectrogram(ground_mel.T),
} }
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
end_time = time.time() end_time = time.time()
@ -237,6 +235,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
global_step += 1 global_step += 1
y_hat = model(x, m) y_hat = model(x, m)
y_hat_viz = y_hat # for vizualization
if isinstance(model.mode, int): if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1) y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else: else:
@ -266,7 +265,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
if epoch > CONFIG.test_delay_epochs: if epoch > CONFIG.test_delay_epochs:
# synthesize a full voice # synthesize a full voice
wav_path = eval_data[random.randrange(0, len(eval_data))][0] wav_path = train_data[random.randrange(0, len(train_data))][0]
wav = ap.load_wav(wav_path) wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav) ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate( sample_wav = model.generate(
@ -283,8 +282,8 @@ def evaluate(model, criterion, ap, global_step, epoch):
) )
# compute spectrograms # compute spectrograms
figures = { figures = {
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False), "eval/prediction": plot_spectrogram(predict_mel.T),
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False), "eval/ground_truth": plot_spectrogram(ground_mel.T),
} }
tb_logger.tb_eval_figures(global_step, figures) tb_logger.tb_eval_figures(global_step, figures)
@ -303,7 +302,6 @@ def main(args): # pylint: disable=redefined-outer-name
eval_data, train_data = load_wav_feat_data( eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
) )
eval_data, train_data = eval_data, train_data
else: else:
eval_data, train_data = load_wav_data(CONFIG.data_path, CONFIG.eval_split_size) eval_data, train_data = load_wav_data(CONFIG.data_path, CONFIG.eval_split_size)
@ -326,7 +324,8 @@ def main(args): # pylint: disable=redefined-outer-name
if isinstance(CONFIG.mode, int): if isinstance(CONFIG.mode, int):
criterion.cuda() criterion.cuda()
optimizer = optim.Adam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0) optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
scheduler = None scheduler = None
if "lr_scheduler" in CONFIG: if "lr_scheduler" in CONFIG:
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler) scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)

View File

@ -1,5 +1,4 @@
{ {
"model": "wavernn",
"run_name": "wavernn_test", "run_name": "wavernn_test",
"run_description": "wavernn_test training", "run_description": "wavernn_test training",
@ -56,11 +55,12 @@
"padding": 2, // pad the input for resnet to see wider input length "padding": 2, // pad the input for resnet to see wider input length
// DATASET // DATASET
"use_gta": true, // use computed gta features from the tts model
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech/", // path containing training wav files "data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech/", // path containing training wav files
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing extracted features .npy (mels / quant) "feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing computed features .npy (mels / quant)
// TRAINING // TRAINING
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"epochs": 10000, // total number of epochs to train. "epochs": 10000, // total number of epochs to train.
"warmup_steps": 10, "warmup_steps": 10,

View File

@ -7,8 +7,7 @@ from torch.utils.data import Dataset
class WaveRNNDataset(Dataset): class WaveRNNDataset(Dataset):
""" """
WaveRNN Dataset searchs for all the wav files under root path WaveRNN Dataset searchs for all the wav files under root path.
and converts them to acoustic features on the fly.
""" """
def __init__( def __init__(
@ -20,8 +19,6 @@ class WaveRNNDataset(Dataset):
pad, pad,
mode, mode,
is_training=True, is_training=True,
return_segments=True,
use_cache=False,
verbose=False, verbose=False,
): ):
@ -32,14 +29,8 @@ class WaveRNNDataset(Dataset):
self.pad = pad self.pad = pad
self.mode = mode self.mode = mode
self.is_training = is_training self.is_training = is_training
self.return_segments = return_segments
self.use_cache = use_cache
self.verbose = verbose self.verbose = verbose
# wav_files = [f"{self.path}wavs/{file}.wav" for file in self.metadata]
# with Pool(4) as pool:
# self.wav_cache = pool.map(self.ap.load_wav, wav_files)
def __len__(self): def __len__(self):
return len(self.item_list) return len(self.item_list)

View File

@ -39,11 +39,12 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_wavernn(c): def setup_wavernn(c):
print(" > Model: {}".format(c.model)) print(" > Model: WaveRNN")
MyModel = importlib.import_module('TTS.vocoder.models.wavernn') MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
MyModel = getattr(MyModel, "WaveRNN") MyModel = getattr(MyModel, "WaveRNN")
model = MyModel( model = MyModel(
rnn_dims=512, rnn_dims=512,
@ -58,98 +59,109 @@ def setup_wavernn(c):
compute_dims=128, compute_dims=128,
res_out_dims=128, res_out_dims=128,
res_blocks=10, res_blocks=10,
hop_length=c.audio['hop_length'], hop_length=c.audio["hop_length"],
sample_rate=c.audio['sample_rate']) sample_rate=c.audio["sample_rate"],
)
return model return model
def setup_generator(c): def setup_generator(c):
print(" > Generator Model: {}".format(c.generator_model)) print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module('TTS.vocoder.models.' + MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model)) MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model in 'melgan_generator': if c.generator_model in "melgan_generator":
model = MyModel( model = MyModel(
in_channels=c.audio['num_mels'], in_channels=c.audio["num_mels"],
out_channels=1, out_channels=1,
proj_kernel=7, proj_kernel=7,
base_channels=512, base_channels=512,
upsample_factors=c.generator_model_params['upsample_factors'], upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks']) num_res_blocks=c.generator_model_params["num_res_blocks"],
if c.generator_model in 'melgan_fb_generator': )
if c.generator_model in "melgan_fb_generator":
pass pass
if c.generator_model in 'multiband_melgan_generator': if c.generator_model in "multiband_melgan_generator":
model = MyModel( model = MyModel(
in_channels=c.audio['num_mels'], in_channels=c.audio["num_mels"],
out_channels=4, out_channels=4,
proj_kernel=7, proj_kernel=7,
base_channels=384, base_channels=384,
upsample_factors=c.generator_model_params['upsample_factors'], upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks']) num_res_blocks=c.generator_model_params["num_res_blocks"],
if c.generator_model in 'fullband_melgan_generator': )
if c.generator_model in "fullband_melgan_generator":
model = MyModel( model = MyModel(
in_channels=c.audio['num_mels'], in_channels=c.audio["num_mels"],
out_channels=1, out_channels=1,
proj_kernel=7, proj_kernel=7,
base_channels=512, base_channels=512,
upsample_factors=c.generator_model_params['upsample_factors'], upsample_factors=c.generator_model_params["upsample_factors"],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks']) num_res_blocks=c.generator_model_params["num_res_blocks"],
if c.generator_model in 'parallel_wavegan_generator': )
if c.generator_model in "parallel_wavegan_generator":
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_size=3, kernel_size=3,
num_res_blocks=c.generator_model_params['num_res_blocks'], num_res_blocks=c.generator_model_params["num_res_blocks"],
stacks=c.generator_model_params['stacks'], stacks=c.generator_model_params["stacks"],
res_channels=64, res_channels=64,
gate_channels=128, gate_channels=128,
skip_channels=64, skip_channels=64,
aux_channels=c.audio['num_mels'], aux_channels=c.audio["num_mels"],
dropout=0.0, dropout=0.0,
bias=True, bias=True,
use_weight_norm=True, use_weight_norm=True,
upsample_factors=c.generator_model_params['upsample_factors']) upsample_factors=c.generator_model_params["upsample_factors"],
)
return model return model
def setup_discriminator(c): def setup_discriminator(c):
print(" > Discriminator Model: {}".format(c.discriminator_model)) print(" > Discriminator Model: {}".format(c.discriminator_model))
if 'parallel_wavegan' in c.discriminator_model: if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module( MyModel = importlib.import_module(
'TTS.vocoder.models.parallel_wavegan_discriminator') "TTS.vocoder.models.parallel_wavegan_discriminator"
)
else: else:
MyModel = importlib.import_module('TTS.vocoder.models.' + MyModel = importlib.import_module(
c.discriminator_model.lower()) "TTS.vocoder.models." + c.discriminator_model.lower()
)
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in 'random_window_discriminator': if c.discriminator_model in "random_window_discriminator":
model = MyModel( model = MyModel(
cond_channels=c.audio['num_mels'], cond_channels=c.audio["num_mels"],
hop_length=c.audio['hop_length'], hop_length=c.audio["hop_length"],
uncond_disc_donwsample_factors=c. uncond_disc_donwsample_factors=c.discriminator_model_params[
discriminator_model_params['uncond_disc_donwsample_factors'], "uncond_disc_donwsample_factors"
cond_disc_downsample_factors=c. ],
discriminator_model_params['cond_disc_downsample_factors'], cond_disc_downsample_factors=c.discriminator_model_params[
cond_disc_out_channels=c. "cond_disc_downsample_factors"
discriminator_model_params['cond_disc_out_channels'], ],
window_sizes=c.discriminator_model_params['window_sizes']) cond_disc_out_channels=c.discriminator_model_params[
if c.discriminator_model in 'melgan_multiscale_discriminator': "cond_disc_out_channels"
],
window_sizes=c.discriminator_model_params["window_sizes"],
)
if c.discriminator_model in "melgan_multiscale_discriminator":
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_sizes=(5, 3), kernel_sizes=(5, 3),
base_channels=c.discriminator_model_params['base_channels'], base_channels=c.discriminator_model_params["base_channels"],
max_channels=c.discriminator_model_params['max_channels'], max_channels=c.discriminator_model_params["max_channels"],
downsample_factors=c. downsample_factors=c.discriminator_model_params["downsample_factors"],
discriminator_model_params['downsample_factors']) )
if c.discriminator_model == 'residual_parallel_wavegan_discriminator': if c.discriminator_model == "residual_parallel_wavegan_discriminator":
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_size=3, kernel_size=3,
num_layers=c.discriminator_model_params['num_layers'], num_layers=c.discriminator_model_params["num_layers"],
stacks=c.discriminator_model_params['stacks'], stacks=c.discriminator_model_params["stacks"],
res_channels=64, res_channels=64,
gate_channels=128, gate_channels=128,
skip_channels=64, skip_channels=64,
@ -158,17 +170,17 @@ def setup_discriminator(c):
nonlinear_activation="LeakyReLU", nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2}, nonlinear_activation_params={"negative_slope": 0.2},
) )
if c.discriminator_model == 'parallel_wavegan_discriminator': if c.discriminator_model == "parallel_wavegan_discriminator":
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
kernel_size=3, kernel_size=3,
num_layers=c.discriminator_model_params['num_layers'], num_layers=c.discriminator_model_params["num_layers"],
conv_channels=64, conv_channels=64,
dilation_factor=1, dilation_factor=1,
nonlinear_activation="LeakyReLU", nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2}, nonlinear_activation_params={"negative_slope": 0.2},
bias=True bias=True,
) )
return model return model