mirror of https://github.com/coqui-ai/TTS.git
some minor changes to wavernn
This commit is contained in:
parent
9c3c7ce2f8
commit
e495e03ea1
|
@ -13,17 +13,13 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.generic_utils import (
|
||||
KeepAverage,
|
||||
count_parameters,
|
||||
|
@ -32,6 +28,10 @@ from TTS.utils.generic_utils import (
|
|||
remove_experiment_folder,
|
||||
set_init_dict,
|
||||
)
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
|
||||
|
@ -105,9 +105,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
# MODEL TRAINING #
|
||||
##################
|
||||
y_hat = model(x, m)
|
||||
y_hat_vis = y_hat # for visualization
|
||||
|
||||
# y_hat = y_hat.transpose(1, 2)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
|
@ -200,8 +198,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
)
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False),
|
||||
"prediction": plot_spectrogram(predict_mel.T),
|
||||
"ground_truth": plot_spectrogram(ground_mel.T),
|
||||
}
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
end_time = time.time()
|
||||
|
@ -237,6 +235,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
global_step += 1
|
||||
|
||||
y_hat = model(x, m)
|
||||
y_hat_viz = y_hat # for vizualization
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
|
@ -266,7 +265,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
|
||||
if epoch > CONFIG.test_delay_epochs:
|
||||
# synthesize a full voice
|
||||
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
|
||||
wav_path = train_data[random.randrange(0, len(train_data))][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
sample_wav = model.generate(
|
||||
|
@ -283,8 +282,8 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
)
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(predict_mel.T, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(ground_mel.T, ap, output_fig=False),
|
||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
}
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
|
@ -303,7 +302,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
eval_data, train_data = load_wav_feat_data(
|
||||
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
|
||||
)
|
||||
eval_data, train_data = eval_data, train_data
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(CONFIG.data_path, CONFIG.eval_split_size)
|
||||
|
||||
|
@ -326,7 +324,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if isinstance(CONFIG.mode, int):
|
||||
criterion.cuda()
|
||||
|
||||
optimizer = optim.Adam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
|
||||
optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
|
||||
|
||||
scheduler = None
|
||||
if "lr_scheduler" in CONFIG:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
{
|
||||
"model": "wavernn",
|
||||
"run_name": "wavernn_test",
|
||||
"run_description": "wavernn_test training",
|
||||
|
||||
|
@ -56,11 +55,12 @@
|
|||
"padding": 2, // pad the input for resnet to see wider input length
|
||||
|
||||
// DATASET
|
||||
"use_gta": true, // use computed gta features from the tts model
|
||||
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech/", // path containing training wav files
|
||||
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing extracted features .npy (mels / quant)
|
||||
"feature_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech_Computed/mel/", // path containing computed features .npy (mels / quant)
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"warmup_steps": 10,
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ from torch.utils.data import Dataset
|
|||
|
||||
class WaveRNNDataset(Dataset):
|
||||
"""
|
||||
WaveRNN Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly.
|
||||
WaveRNN Dataset searchs for all the wav files under root path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -20,8 +19,6 @@ class WaveRNNDataset(Dataset):
|
|||
pad,
|
||||
mode,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_cache=False,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
|
@ -32,14 +29,8 @@ class WaveRNNDataset(Dataset):
|
|||
self.pad = pad
|
||||
self.mode = mode
|
||||
self.is_training = is_training
|
||||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
self.verbose = verbose
|
||||
|
||||
# wav_files = [f"{self.path}wavs/{file}.wav" for file in self.metadata]
|
||||
# with Pool(4) as pool:
|
||||
# self.wav_cache = pool.map(self.ap.load_wav, wav_files)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
|
|
|
@ -39,11 +39,12 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
|||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_wavernn(c):
|
||||
print(" > Model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.wavernn')
|
||||
print(" > Model: WaveRNN")
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
|
||||
MyModel = getattr(MyModel, "WaveRNN")
|
||||
model = MyModel(
|
||||
rnn_dims=512,
|
||||
|
@ -58,98 +59,109 @@ def setup_wavernn(c):
|
|||
compute_dims=128,
|
||||
res_out_dims=128,
|
||||
res_blocks=10,
|
||||
hop_length=c.audio['hop_length'],
|
||||
sample_rate=c.audio['sample_rate'])
|
||||
hop_length=c.audio["hop_length"],
|
||||
sample_rate=c.audio["sample_rate"],
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.generator_model.lower())
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
if c.generator_model in 'melgan_generator':
|
||||
if c.generator_model in "melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'melgan_fb_generator':
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
if c.generator_model in "melgan_fb_generator":
|
||||
pass
|
||||
if c.generator_model in 'multiband_melgan_generator':
|
||||
if c.generator_model in "multiband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'fullband_melgan_generator':
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
if c.generator_model in "fullband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'parallel_wavegan_generator':
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
if c.generator_model in "parallel_wavegan_generator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'],
|
||||
stacks=c.generator_model_params['stacks'],
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
stacks=c.generator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=c.audio['num_mels'],
|
||||
aux_channels=c.audio["num_mels"],
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'])
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if 'parallel_wavegan' in c.discriminator_model:
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module(
|
||||
'TTS.vocoder.models.parallel_wavegan_discriminator')
|
||||
"TTS.vocoder.models.parallel_wavegan_discriminator"
|
||||
)
|
||||
else:
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.discriminator_model.lower())
|
||||
MyModel = importlib.import_module(
|
||||
"TTS.vocoder.models." + c.discriminator_model.lower()
|
||||
)
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in 'random_window_discriminator':
|
||||
if c.discriminator_model in "random_window_discriminator":
|
||||
model = MyModel(
|
||||
cond_channels=c.audio['num_mels'],
|
||||
hop_length=c.audio['hop_length'],
|
||||
uncond_disc_donwsample_factors=c.
|
||||
discriminator_model_params['uncond_disc_donwsample_factors'],
|
||||
cond_disc_downsample_factors=c.
|
||||
discriminator_model_params['cond_disc_downsample_factors'],
|
||||
cond_disc_out_channels=c.
|
||||
discriminator_model_params['cond_disc_out_channels'],
|
||||
window_sizes=c.discriminator_model_params['window_sizes'])
|
||||
if c.discriminator_model in 'melgan_multiscale_discriminator':
|
||||
cond_channels=c.audio["num_mels"],
|
||||
hop_length=c.audio["hop_length"],
|
||||
uncond_disc_donwsample_factors=c.discriminator_model_params[
|
||||
"uncond_disc_donwsample_factors"
|
||||
],
|
||||
cond_disc_downsample_factors=c.discriminator_model_params[
|
||||
"cond_disc_downsample_factors"
|
||||
],
|
||||
cond_disc_out_channels=c.discriminator_model_params[
|
||||
"cond_disc_out_channels"
|
||||
],
|
||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
||||
)
|
||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params['base_channels'],
|
||||
max_channels=c.discriminator_model_params['max_channels'],
|
||||
downsample_factors=c.
|
||||
discriminator_model_params['downsample_factors'])
|
||||
if c.discriminator_model == 'residual_parallel_wavegan_discriminator':
|
||||
base_channels=c.discriminator_model_params["base_channels"],
|
||||
max_channels=c.discriminator_model_params["max_channels"],
|
||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
||||
)
|
||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params['num_layers'],
|
||||
stacks=c.discriminator_model_params['stacks'],
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
stacks=c.discriminator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
|
@ -158,17 +170,17 @@ def setup_discriminator(c):
|
|||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
)
|
||||
if c.discriminator_model == 'parallel_wavegan_discriminator':
|
||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params['num_layers'],
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True
|
||||
bias=True,
|
||||
)
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue