mirror of https://github.com/coqui-ai/TTS.git
some minor changes to wavernn
This commit is contained in:
parent
da60f35c14
commit
9a120f28ed
|
@ -13,17 +13,13 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
|
||||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
|
||||||
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
|
||||||
from TTS.utils.training import setup_torch_training_env
|
from TTS.utils.training import setup_torch_training_env
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (
|
||||||
KeepAverage,
|
KeepAverage,
|
||||||
count_parameters,
|
count_parameters,
|
||||||
|
@ -32,6 +28,10 @@ from TTS.utils.generic_utils import (
|
||||||
remove_experiment_folder,
|
remove_experiment_folder,
|
||||||
set_init_dict,
|
set_init_dict,
|
||||||
)
|
)
|
||||||
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
|
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||||
|
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
||||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,9 +105,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
# MODEL TRAINING #
|
# MODEL TRAINING #
|
||||||
##################
|
##################
|
||||||
y_hat = model(x, m)
|
y_hat = model(x, m)
|
||||||
y_hat_vis = y_hat # for visualization
|
|
||||||
|
|
||||||
# y_hat = y_hat.transpose(1, 2)
|
|
||||||
if isinstance(model.mode, int):
|
if isinstance(model.mode, int):
|
||||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
|
@ -200,8 +198,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
)
|
)
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = {
|
figures = {
|
||||||
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False),
|
"prediction": plot_spectrogram(predict_mel.T),
|
||||||
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False),
|
"ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
}
|
}
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -237,6 +235,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
y_hat = model(x, m)
|
y_hat = model(x, m)
|
||||||
|
y_hat_viz = y_hat # for vizualization
|
||||||
if isinstance(model.mode, int):
|
if isinstance(model.mode, int):
|
||||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
|
@ -266,7 +265,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
if epoch > CONFIG.test_delay_epochs:
|
if epoch > CONFIG.test_delay_epochs:
|
||||||
# synthesize a full voice
|
# synthesize a full voice
|
||||||
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
|
wav_path = train_data[random.randrange(0, len(train_data))][0]
|
||||||
wav = ap.load_wav(wav_path)
|
wav = ap.load_wav(wav_path)
|
||||||
ground_mel = ap.melspectrogram(wav)
|
ground_mel = ap.melspectrogram(wav)
|
||||||
sample_wav = model.generate(
|
sample_wav = model.generate(
|
||||||
|
@ -283,8 +282,8 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
)
|
)
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = {
|
figures = {
|
||||||
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False),
|
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||||
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False),
|
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
}
|
}
|
||||||
tb_logger.tb_eval_figures(global_step, figures)
|
tb_logger.tb_eval_figures(global_step, figures)
|
||||||
|
|
||||||
|
@ -303,7 +302,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
eval_data, train_data = load_wav_feat_data(
|
eval_data, train_data = load_wav_feat_data(
|
||||||
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
|
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
|
||||||
)
|
)
|
||||||
eval_data, train_data = eval_data, train_data
|
|
||||||
else:
|
else:
|
||||||
eval_data, train_data = load_wav_data(CONFIG.data_path, CONFIG.eval_split_size)
|
eval_data, train_data = load_wav_data(CONFIG.data_path, CONFIG.eval_split_size)
|
||||||
|
|
||||||
|
@ -326,7 +324,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if isinstance(CONFIG.mode, int):
|
if isinstance(CONFIG.mode, int):
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
|
|
||||||
optimizer = optim.Adam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
|
optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
|
||||||
|
|
||||||
scheduler = None
|
scheduler = None
|
||||||
if "lr_scheduler" in CONFIG:
|
if "lr_scheduler" in CONFIG:
|
||||||
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)
|
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
{
|
{
|
||||||
"model": "wavernn",
|
|
||||||
"run_name": "wavernn_test",
|
"run_name": "wavernn_test",
|
||||||
"run_description": "wavernn_test training",
|
"run_description": "wavernn_test training",
|
||||||
|
|
||||||
|
@ -56,11 +55,12 @@
|
||||||
"padding": 2, // pad the input for resnet to see wider input length
|
"padding": 2, // pad the input for resnet to see wider input length
|
||||||
|
|
||||||
// DATASET
|
// DATASET
|
||||||
|
"use_gta": true, // use computed gta features from the tts model
|
||||||
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech/", // path containing training wav files
|
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech/", // path containing training wav files
|
||||||
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing extracted features .npy (mels / quant)
|
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing computed features .npy (mels / quant)
|
||||||
|
|
||||||
// TRAINING
|
// TRAINING
|
||||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||||
"epochs": 10000, // total number of epochs to train.
|
"epochs": 10000, // total number of epochs to train.
|
||||||
"warmup_steps": 10,
|
"warmup_steps": 10,
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,7 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
class WaveRNNDataset(Dataset):
|
class WaveRNNDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
WaveRNN Dataset searchs for all the wav files under root path
|
WaveRNN Dataset searchs for all the wav files under root path.
|
||||||
and converts them to acoustic features on the fly.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -20,8 +19,6 @@ class WaveRNNDataset(Dataset):
|
||||||
pad,
|
pad,
|
||||||
mode,
|
mode,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
return_segments=True,
|
|
||||||
use_cache=False,
|
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -32,14 +29,8 @@ class WaveRNNDataset(Dataset):
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.return_segments = return_segments
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# wav_files = [f"{self.path}wavs/{file}.wav" for file in self.metadata]
|
|
||||||
# with Pool(4) as pool:
|
|
||||||
# self.wav_cache = pool.map(self.ap.load_wav, wav_files)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.item_list)
|
return len(self.item_list)
|
||||||
|
|
||||||
|
|
|
@ -39,11 +39,12 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
||||||
|
|
||||||
def to_camel(text):
|
def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
def setup_wavernn(c):
|
def setup_wavernn(c):
|
||||||
print(" > Model: {}".format(c.model))
|
print(" > Model: WaveRNN")
|
||||||
MyModel = importlib.import_module('TTS.vocoder.models.wavernn')
|
MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
|
||||||
MyModel = getattr(MyModel, "WaveRNN")
|
MyModel = getattr(MyModel, "WaveRNN")
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
rnn_dims=512,
|
rnn_dims=512,
|
||||||
|
@ -58,98 +59,109 @@ def setup_wavernn(c):
|
||||||
compute_dims=128,
|
compute_dims=128,
|
||||||
res_out_dims=128,
|
res_out_dims=128,
|
||||||
res_blocks=10,
|
res_blocks=10,
|
||||||
hop_length=c.audio['hop_length'],
|
hop_length=c.audio["hop_length"],
|
||||||
sample_rate=c.audio['sample_rate'])
|
sample_rate=c.audio["sample_rate"],
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def setup_generator(c):
|
def setup_generator(c):
|
||||||
print(" > Generator Model: {}".format(c.generator_model))
|
print(" > Generator Model: {}".format(c.generator_model))
|
||||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||||
c.generator_model.lower())
|
|
||||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||||
if c.generator_model in 'melgan_generator':
|
if c.generator_model in "melgan_generator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio['num_mels'],
|
in_channels=c.audio["num_mels"],
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
proj_kernel=7,
|
proj_kernel=7,
|
||||||
base_channels=512,
|
base_channels=512,
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||||
if c.generator_model in 'melgan_fb_generator':
|
)
|
||||||
|
if c.generator_model in "melgan_fb_generator":
|
||||||
pass
|
pass
|
||||||
if c.generator_model in 'multiband_melgan_generator':
|
if c.generator_model in "multiband_melgan_generator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio['num_mels'],
|
in_channels=c.audio["num_mels"],
|
||||||
out_channels=4,
|
out_channels=4,
|
||||||
proj_kernel=7,
|
proj_kernel=7,
|
||||||
base_channels=384,
|
base_channels=384,
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||||
if c.generator_model in 'fullband_melgan_generator':
|
)
|
||||||
|
if c.generator_model in "fullband_melgan_generator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio['num_mels'],
|
in_channels=c.audio["num_mels"],
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
proj_kernel=7,
|
proj_kernel=7,
|
||||||
base_channels=512,
|
base_channels=512,
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||||
res_kernel=3,
|
res_kernel=3,
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||||
if c.generator_model in 'parallel_wavegan_generator':
|
)
|
||||||
|
if c.generator_model in "parallel_wavegan_generator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
num_res_blocks=c.generator_model_params['num_res_blocks'],
|
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||||
stacks=c.generator_model_params['stacks'],
|
stacks=c.generator_model_params["stacks"],
|
||||||
res_channels=64,
|
res_channels=64,
|
||||||
gate_channels=128,
|
gate_channels=128,
|
||||||
skip_channels=64,
|
skip_channels=64,
|
||||||
aux_channels=c.audio['num_mels'],
|
aux_channels=c.audio["num_mels"],
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
bias=True,
|
bias=True,
|
||||||
use_weight_norm=True,
|
use_weight_norm=True,
|
||||||
upsample_factors=c.generator_model_params['upsample_factors'])
|
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def setup_discriminator(c):
|
def setup_discriminator(c):
|
||||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||||
if 'parallel_wavegan' in c.discriminator_model:
|
if "parallel_wavegan" in c.discriminator_model:
|
||||||
MyModel = importlib.import_module(
|
MyModel = importlib.import_module(
|
||||||
'TTS.vocoder.models.parallel_wavegan_discriminator')
|
"TTS.vocoder.models.parallel_wavegan_discriminator"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
MyModel = importlib.import_module(
|
||||||
c.discriminator_model.lower())
|
"TTS.vocoder.models." + c.discriminator_model.lower()
|
||||||
|
)
|
||||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||||
if c.discriminator_model in 'random_window_discriminator':
|
if c.discriminator_model in "random_window_discriminator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
cond_channels=c.audio['num_mels'],
|
cond_channels=c.audio["num_mels"],
|
||||||
hop_length=c.audio['hop_length'],
|
hop_length=c.audio["hop_length"],
|
||||||
uncond_disc_donwsample_factors=c.
|
uncond_disc_donwsample_factors=c.discriminator_model_params[
|
||||||
discriminator_model_params['uncond_disc_donwsample_factors'],
|
"uncond_disc_donwsample_factors"
|
||||||
cond_disc_downsample_factors=c.
|
],
|
||||||
discriminator_model_params['cond_disc_downsample_factors'],
|
cond_disc_downsample_factors=c.discriminator_model_params[
|
||||||
cond_disc_out_channels=c.
|
"cond_disc_downsample_factors"
|
||||||
discriminator_model_params['cond_disc_out_channels'],
|
],
|
||||||
window_sizes=c.discriminator_model_params['window_sizes'])
|
cond_disc_out_channels=c.discriminator_model_params[
|
||||||
if c.discriminator_model in 'melgan_multiscale_discriminator':
|
"cond_disc_out_channels"
|
||||||
|
],
|
||||||
|
window_sizes=c.discriminator_model_params["window_sizes"],
|
||||||
|
)
|
||||||
|
if c.discriminator_model in "melgan_multiscale_discriminator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_sizes=(5, 3),
|
kernel_sizes=(5, 3),
|
||||||
base_channels=c.discriminator_model_params['base_channels'],
|
base_channels=c.discriminator_model_params["base_channels"],
|
||||||
max_channels=c.discriminator_model_params['max_channels'],
|
max_channels=c.discriminator_model_params["max_channels"],
|
||||||
downsample_factors=c.
|
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
||||||
discriminator_model_params['downsample_factors'])
|
)
|
||||||
if c.discriminator_model == 'residual_parallel_wavegan_discriminator':
|
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
num_layers=c.discriminator_model_params['num_layers'],
|
num_layers=c.discriminator_model_params["num_layers"],
|
||||||
stacks=c.discriminator_model_params['stacks'],
|
stacks=c.discriminator_model_params["stacks"],
|
||||||
res_channels=64,
|
res_channels=64,
|
||||||
gate_channels=128,
|
gate_channels=128,
|
||||||
skip_channels=64,
|
skip_channels=64,
|
||||||
|
@ -158,17 +170,17 @@ def setup_discriminator(c):
|
||||||
nonlinear_activation="LeakyReLU",
|
nonlinear_activation="LeakyReLU",
|
||||||
nonlinear_activation_params={"negative_slope": 0.2},
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
)
|
)
|
||||||
if c.discriminator_model == 'parallel_wavegan_discriminator':
|
if c.discriminator_model == "parallel_wavegan_discriminator":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
num_layers=c.discriminator_model_params['num_layers'],
|
num_layers=c.discriminator_model_params["num_layers"],
|
||||||
conv_channels=64,
|
conv_channels=64,
|
||||||
dilation_factor=1,
|
dilation_factor=1,
|
||||||
nonlinear_activation="LeakyReLU",
|
nonlinear_activation="LeakyReLU",
|
||||||
nonlinear_activation_params={"negative_slope": 0.2},
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
bias=True
|
bias=True,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue