mirror of https://github.com/coqui-ai/TTS.git
compute audio feat on dataload
This commit is contained in:
parent
7c72562fe7
commit
bef3f2020b
|
@ -29,8 +29,8 @@ from TTS.utils.generic_utils import (
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.datasets.preprocess import (
|
from TTS.vocoder.datasets.preprocess import (
|
||||||
find_feat_files,
|
find_feat_files,
|
||||||
load_wav_feat_data,
|
load_wav_data,
|
||||||
preprocess_wav_files,
|
load_wav_feat_data
|
||||||
)
|
)
|
||||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||||
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
||||||
|
@ -41,15 +41,16 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, is_val=False, verbose=False):
|
def setup_loader(ap, is_val=False, verbose=False):
|
||||||
if is_val and not CONFIG.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = WaveRNNDataset(ap=ap,
|
dataset = WaveRNNDataset(ap=ap,
|
||||||
items=eval_data if is_val else train_data,
|
items=eval_data if is_val else train_data,
|
||||||
seq_len=CONFIG.seq_len,
|
seq_len=c.seq_len,
|
||||||
hop_len=ap.hop_length,
|
hop_len=ap.hop_length,
|
||||||
pad=CONFIG.padding,
|
pad=c.padding,
|
||||||
mode=CONFIG.mode,
|
mode=c.mode,
|
||||||
|
mulaw=c.mulaw,
|
||||||
is_training=not is_val,
|
is_training=not is_val,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
@ -57,10 +58,10 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
loader = DataLoader(dataset,
|
loader = DataLoader(dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=dataset.collate,
|
collate_fn=dataset.collate,
|
||||||
batch_size=CONFIG.batch_size,
|
batch_size=c.batch_size,
|
||||||
num_workers=CONFIG.num_val_loader_workers
|
num_workers=c.num_val_loader_workers
|
||||||
if is_val
|
if is_val
|
||||||
else CONFIG.num_loader_workers,
|
else c.num_loader_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
@ -89,9 +90,9 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
batch_n_iter = int(len(data_loader.dataset) /
|
batch_n_iter = int(len(data_loader.dataset) /
|
||||||
(CONFIG.batch_size * num_gpus))
|
(c.batch_size * num_gpus))
|
||||||
else:
|
else:
|
||||||
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
c_logger.print_train_start()
|
c_logger.print_train_start()
|
||||||
# train loop
|
# train loop
|
||||||
|
@ -102,9 +103,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
##################
|
|
||||||
# MODEL TRAINING #
|
|
||||||
##################
|
|
||||||
y_hat = model(x_input, mels)
|
y_hat = model(x_input, mels)
|
||||||
|
|
||||||
if isinstance(model.mode, int):
|
if isinstance(model.mode, int):
|
||||||
|
@ -112,7 +110,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
else:
|
else:
|
||||||
y_coarse = y_coarse.float()
|
y_coarse = y_coarse.float()
|
||||||
y_coarse = y_coarse.unsqueeze(-1)
|
y_coarse = y_coarse.unsqueeze(-1)
|
||||||
# m_scaled, _ = model.upsample(m)
|
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss = criterion(y_hat, y_coarse)
|
loss = criterion(y_hat, y_coarse)
|
||||||
|
@ -120,11 +117,11 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if CONFIG.grad_clip > 0:
|
if c.grad_clip > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(
|
torch.nn.utils.clip_grad_norm_(
|
||||||
model.parameters(), CONFIG.grad_clip)
|
model.parameters(), c.grad_clip)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
|
@ -144,7 +141,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
keep_avg.update_values(update_train_values)
|
keep_avg.update_values(update_train_values)
|
||||||
|
|
||||||
# print training stats
|
# print training stats
|
||||||
if global_step % CONFIG.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
log_dict = {"step_time": [step_time, 2],
|
log_dict = {"step_time": [step_time, 2],
|
||||||
"loader_time": [loader_time, 4],
|
"loader_time": [loader_time, 4],
|
||||||
"current_lr": cur_lr,
|
"current_lr": cur_lr,
|
||||||
|
@ -164,8 +161,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if global_step % CONFIG.save_step == 0:
|
if global_step % c.save_step == 0:
|
||||||
if CONFIG.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model,
|
save_checkpoint(model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -180,28 +177,30 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
)
|
)
|
||||||
|
|
||||||
# synthesize a full voice
|
# synthesize a full voice
|
||||||
wav_path = train_data[random.randrange(0, len(train_data))][0]
|
rand_idx = random.randrange(0, len(train_data))
|
||||||
|
wav_path = train_data[rand_idx] if not isinstance(
|
||||||
|
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
|
||||||
wav = ap.load_wav(wav_path)
|
wav = ap.load_wav(wav_path)
|
||||||
ground_mel = ap.melspectrogram(wav)
|
ground_mel = ap.melspectrogram(wav)
|
||||||
sample_wav = model.generate(ground_mel,
|
sample_wav = model.generate(ground_mel,
|
||||||
CONFIG.batched,
|
c.batched,
|
||||||
CONFIG.target_samples,
|
c.target_samples,
|
||||||
CONFIG.overlap_samples,
|
c.overlap_samples,
|
||||||
|
use_cuda
|
||||||
)
|
)
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
|
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
"train/prediction": plot_spectrogram(predict_mel.T),
|
"train/prediction": plot_spectrogram(predict_mel.T)
|
||||||
}
|
}
|
||||||
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
tb_logger.tb_train_audios(
|
tb_logger.tb_train_audios(
|
||||||
global_step, {
|
global_step, {
|
||||||
"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
"train/audio": sample_wav}, c.audio["sample_rate"]
|
||||||
)
|
)
|
||||||
|
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
|
@ -259,34 +258,35 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
keep_avg.update_values(update_eval_values)
|
keep_avg.update_values(update_eval_values)
|
||||||
|
|
||||||
# print eval stats
|
# print eval stats
|
||||||
if CONFIG.print_eval:
|
if c.print_eval:
|
||||||
c_logger.print_eval_step(
|
c_logger.print_eval_step(
|
||||||
num_iter, loss_dict, keep_avg.avg_values)
|
num_iter, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
if epoch % CONFIG.test_every_epochs == 0 and epoch != 0:
|
if epoch % c.test_every_epochs == 0 and epoch != 0:
|
||||||
# synthesize a part of data
|
# synthesize a full voice
|
||||||
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
|
rand_idx = random.randrange(0, len(eval_data))
|
||||||
|
wav_path = eval_data[rand_idx] if not isinstance(
|
||||||
|
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
|
||||||
wav = ap.load_wav(wav_path)
|
wav = ap.load_wav(wav_path)
|
||||||
ground_mel = ap.melspectrogram(wav[:22000])
|
ground_mel = ap.melspectrogram(wav)
|
||||||
sample_wav = model.generate(ground_mel,
|
sample_wav = model.generate(ground_mel,
|
||||||
CONFIG.batched,
|
c.batched,
|
||||||
CONFIG.target_samples,
|
c.target_samples,
|
||||||
CONFIG.overlap_samples,
|
c.overlap_samples,
|
||||||
use_cuda
|
use_cuda
|
||||||
)
|
)
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
# compute spectrograms
|
|
||||||
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
|
||||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
tb_logger.tb_eval_audios(
|
tb_logger.tb_eval_audios(
|
||||||
global_step, {
|
global_step, {
|
||||||
"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
"eval/audio": sample_wav}, c.audio["sample_rate"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# compute spectrograms
|
||||||
|
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
|
"eval/prediction": plot_spectrogram(predict_mel.T)
|
||||||
|
}
|
||||||
tb_logger.tb_eval_figures(global_step, figures)
|
tb_logger.tb_eval_figures(global_step, figures)
|
||||||
|
|
||||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||||
|
@ -299,53 +299,62 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
global train_data, eval_data
|
global train_data, eval_data
|
||||||
|
|
||||||
# setup audio processor
|
# setup audio processor
|
||||||
ap = AudioProcessor(**CONFIG.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
print(f" > Loading wavs from: {CONFIG.data_path}")
|
# print(f" > Loading wavs from: {c.data_path}")
|
||||||
if CONFIG.feature_path is not None:
|
# if c.feature_path is not None:
|
||||||
print(f" > Loading features from: {CONFIG.feature_path}")
|
# print(f" > Loading features from: {c.feature_path}")
|
||||||
|
# eval_data, train_data = load_wav_feat_data(
|
||||||
|
# c.data_path, c.feature_path, c.eval_split_size
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# mel_feat_path = os.path.join(OUT_PATH, "mel")
|
||||||
|
# feat_data = find_feat_files(mel_feat_path)
|
||||||
|
# if feat_data:
|
||||||
|
# print(f" > Loading features from: {mel_feat_path}")
|
||||||
|
# eval_data, train_data = load_wav_feat_data(
|
||||||
|
# c.data_path, mel_feat_path, c.eval_split_size
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# print(" > No feature data found. Preprocessing...")
|
||||||
|
# # preprocessing feature data from given wav files
|
||||||
|
# preprocess_wav_files(OUT_PATH, CONFIG, ap)
|
||||||
|
# eval_data, train_data = load_wav_feat_data(
|
||||||
|
# c.data_path, mel_feat_path, c.eval_split_size
|
||||||
|
# )
|
||||||
|
|
||||||
|
print(f" > Loading wavs from: {c.data_path}")
|
||||||
|
if c.feature_path is not None:
|
||||||
|
print(f" > Loading features from: {c.feature_path}")
|
||||||
eval_data, train_data = load_wav_feat_data(
|
eval_data, train_data = load_wav_feat_data(
|
||||||
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
|
c.data_path, c.feature_path, c.eval_split_size)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
mel_feat_path = os.path.join(OUT_PATH, "mel")
|
eval_data, train_data = load_wav_data(
|
||||||
feat_data = find_feat_files(mel_feat_path)
|
c.data_path, c.eval_split_size)
|
||||||
if feat_data:
|
|
||||||
print(f" > Loading features from: {mel_feat_path}")
|
|
||||||
eval_data, train_data = load_wav_feat_data(
|
|
||||||
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(" > No feature data found. Preprocessing...")
|
|
||||||
# preprocessing feature data from given wav files
|
|
||||||
preprocess_wav_files(OUT_PATH, CONFIG, ap)
|
|
||||||
eval_data, train_data = load_wav_feat_data(
|
|
||||||
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
|
|
||||||
)
|
|
||||||
# setup model
|
# setup model
|
||||||
model_wavernn = setup_wavernn(CONFIG)
|
model_wavernn = setup_wavernn(c)
|
||||||
|
|
||||||
# define train functions
|
# define train functions
|
||||||
if CONFIG.mode == "mold":
|
if c.mode == "mold":
|
||||||
criterion = discretized_mix_logistic_loss
|
criterion = discretized_mix_logistic_loss
|
||||||
elif CONFIG.mode == "gauss":
|
elif c.mode == "gauss":
|
||||||
criterion = gaussian_loss
|
criterion = gaussian_loss
|
||||||
elif isinstance(CONFIG.mode, int):
|
elif isinstance(c.mode, int):
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model_wavernn.cuda()
|
model_wavernn.cuda()
|
||||||
if isinstance(CONFIG.mode, int):
|
if isinstance(c.mode, int):
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
|
|
||||||
optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
|
optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
||||||
scheduler = None
|
scheduler = None
|
||||||
if "lr_scheduler" in CONFIG:
|
if "lr_scheduler" in c:
|
||||||
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)
|
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
||||||
scheduler = scheduler(optimizer, **CONFIG.lr_scheduler_params)
|
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
||||||
# slow start for the first 5 epochs
|
# slow start for the first 5 epochs
|
||||||
# lr_lambda = lambda epoch: min(epoch / CONFIG.warmup_steps, 1)
|
# lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
|
||||||
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||||
|
|
||||||
# restore any checkpoint
|
# restore any checkpoint
|
||||||
|
@ -366,7 +375,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# retore only matching layers.
|
# retore only matching layers.
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
model_dict = model_wavernn.state_dict()
|
model_dict = model_wavernn.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG)
|
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||||
model_wavernn.load_state_dict(model_dict)
|
model_wavernn.load_state_dict(model_dict)
|
||||||
|
|
||||||
print(" > Model restored from step %d" %
|
print(" > Model restored from step %d" %
|
||||||
|
@ -386,11 +395,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
best_loss = float("inf")
|
best_loss = float("inf")
|
||||||
|
|
||||||
global_step = args.restore_step
|
global_step = args.restore_step
|
||||||
for epoch in range(0, CONFIG.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, CONFIG.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
_, global_step = train(
|
_, global_step = train(model_wavernn, optimizer,
|
||||||
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch
|
criterion, scheduler, ap, global_step, epoch)
|
||||||
)
|
|
||||||
eval_avg_loss_dict = evaluate(
|
eval_avg_loss_dict = evaluate(
|
||||||
model_wavernn, criterion, ap, global_step, epoch)
|
model_wavernn, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
|
@ -462,14 +470,14 @@ if __name__ == "__main__":
|
||||||
print(f" > Training continues for {args.restore_path}")
|
print(f" > Training continues for {args.restore_path}")
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
CONFIG = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
# check_config(c)
|
# check_config(c)
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
OUT_PATH = args.continue_path
|
||||||
if args.continue_path == "":
|
if args.continue_path == "":
|
||||||
OUT_PATH = create_experiment_folder(
|
OUT_PATH = create_experiment_folder(
|
||||||
CONFIG.output_path, CONFIG.run_name, args.debug
|
c.output_path, c.run_name, args.debug
|
||||||
)
|
)
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
|
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
|
||||||
|
@ -483,7 +491,7 @@ if __name__ == "__main__":
|
||||||
new_fields["restore_path"] = args.restore_path
|
new_fields["restore_path"] = args.restore_path
|
||||||
new_fields["github_branch"] = get_git_branch()
|
new_fields["github_branch"] = get_git_branch()
|
||||||
copy_config_file(
|
copy_config_file(
|
||||||
args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields
|
args.config_path, os.path.join(OUT_PATH, "c.json"), new_fields
|
||||||
)
|
)
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
os.chmod(AUDIO_PATH, 0o775)
|
||||||
os.chmod(OUT_PATH, 0o775)
|
os.chmod(OUT_PATH, 0o775)
|
||||||
|
@ -492,8 +500,7 @@ if __name__ == "__main__":
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
||||||
|
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-description",
|
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
||||||
CONFIG["run_description"], 0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -2,29 +2,25 @@
|
||||||
"run_name": "wavernn_test",
|
"run_name": "wavernn_test",
|
||||||
"run_description": "wavernn_test training",
|
"run_description": "wavernn_test training",
|
||||||
|
|
||||||
// AUDIO PARAMETERS
|
// AUDIO PARAMETERS
|
||||||
"audio":{
|
"audio": {
|
||||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||||
"win_length": 1024, // stft window length in ms.
|
"win_length": 1024, // stft window length in ms.
|
||||||
"hop_length": 256, // stft window hop-lengh in ms.
|
"hop_length": 256, // stft window hop-lengh in ms.
|
||||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||||
|
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||||
|
|
||||||
// Silence trimming
|
// Silence trimming
|
||||||
"do_trim_silence": false,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
"do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||||
|
|
||||||
// MelSpectrogram parameters
|
// MelSpectrogram parameters
|
||||||
"num_mels": 80, // size of the mel spec frame.
|
"num_mels": 80, // size of the mel spec frame.
|
||||||
"mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
"mel_fmin": 40.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||||
"spec_gain": 20.0, // scaler value appplied after log transform of spectrogram.
|
"spec_gain": 20.0, // scaler value appplied after log transform of spectrogram.
|
||||||
|
|
||||||
// Normalization parameters
|
// Normalization parameters
|
||||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||||
"min_level_db": -100, // lower bound for normalization
|
"min_level_db": -100, // lower bound for normalization
|
||||||
|
@ -34,40 +30,48 @@
|
||||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||||
},
|
},
|
||||||
|
|
||||||
// Generating / Synthesizing
|
// Generating / Synthesizing
|
||||||
"batched": true,
|
"batched": true,
|
||||||
"target_samples": 11000, // target number of samples to be generated in each batch entry
|
"target_samples": 11000, // target number of samples to be generated in each batch entry
|
||||||
"overlap_samples": 550, // number of samples for crossfading between batches
|
"overlap_samples": 550, // number of samples for crossfading between batches
|
||||||
|
|
||||||
// DISTRIBUTED TRAINING
|
// DISTRIBUTED TRAINING
|
||||||
// "distributed":{
|
// "distributed":{
|
||||||
// "backend": "nccl",
|
// "backend": "nccl",
|
||||||
// "url": "tcp:\/\/localhost:54321"
|
// "url": "tcp:\/\/localhost:54321"
|
||||||
// },
|
// },
|
||||||
|
|
||||||
// MODEL PARAMETERS
|
// MODEL MODE
|
||||||
|
"mode": 10, // mold [string], gauss [string], bits [int]
|
||||||
|
"mulaw": true, // apply mulaw if mode is bits
|
||||||
|
|
||||||
|
// MODEL PARAMETERS
|
||||||
|
"wavernn_model_params": {
|
||||||
|
"rnn_dims": 512,
|
||||||
|
"fc_dims": 512,
|
||||||
|
"compute_dims": 128,
|
||||||
|
"res_out_dims": 128,
|
||||||
|
"num_res_blocks": 10,
|
||||||
"use_aux_net": true,
|
"use_aux_net": true,
|
||||||
"use_upsample_net": true,
|
"use_upsample_net": true,
|
||||||
"upsample_factors": [4, 8, 8], // this needs to correctly factorise hop_length
|
"upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length
|
||||||
|
},
|
||||||
|
|
||||||
|
// DATASET
|
||||||
|
//"use_gta": true, // use computed gta features from the tts model
|
||||||
|
"data_path": "/media/alexander/LinuxFS/SpeechData/GothicSpeech/NPC_Speech", // path containing training wav files
|
||||||
|
"feature_path": null, // path containing computed features from wav files if null compute them
|
||||||
"seq_len": 1280, // has to be devideable by hop_length
|
"seq_len": 1280, // has to be devideable by hop_length
|
||||||
"mode": "mold", // mold [string], gauss [string], bits [int]
|
|
||||||
"mulaw": false, // apply mulaw if mode is bits
|
|
||||||
"padding": 2, // pad the input for resnet to see wider input length
|
"padding": 2, // pad the input for resnet to see wider input length
|
||||||
|
|
||||||
// DATASET
|
// TRAINING
|
||||||
//"use_gta": true, // use computed gta features from the tts model
|
"batch_size": 64, // Batch size for training.
|
||||||
"data_path": "path/to/wav/files", // path containing training wav files
|
|
||||||
"feature_path": null, // path containing computed features from wav files if null compute them
|
|
||||||
|
|
||||||
// TRAINING
|
|
||||||
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
|
||||||
"epochs": 10000, // total number of epochs to train.
|
"epochs": 10000, // total number of epochs to train.
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
"test_every_epochs": 10, // Test after set number of epochs (Test every 20 epochs for example)
|
"test_every_epochs": 10, // Test after set number of epochs (Test every 10 epochs for example)
|
||||||
|
|
||||||
// OPTIMIZER
|
// OPTIMIZER
|
||||||
"grad_clip": 4, // apply gradient clipping if > 0
|
"grad_clip": 4, // apply gradient clipping if > 0
|
||||||
"lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
"lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
||||||
"lr_scheduler_params": {
|
"lr_scheduler_params": {
|
||||||
|
@ -76,19 +80,18 @@
|
||||||
},
|
},
|
||||||
"lr": 1e-4, // initial learning rate
|
"lr": 1e-4, // initial learning rate
|
||||||
|
|
||||||
// TENSORBOARD and LOGGING
|
// TENSORBOARD and LOGGING
|
||||||
"print_step": 25, // Number of steps to log traning on console.
|
"print_step": 25, // Number of steps to log traning on console.
|
||||||
"print_eval": false, // If True, it prints loss values for each step in eval run.
|
"print_eval": false, // If True, it prints loss values for each step in eval run.
|
||||||
"save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
"save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
|
||||||
// DATA LOADING
|
// DATA LOADING
|
||||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||||
"eval_split_size": 50, // number of samples for testing
|
"eval_split_size": 50, // number of samples for testing
|
||||||
|
|
||||||
// PATHS
|
// PATHS
|
||||||
"output_path": "output/training/path"
|
"output_path": "output/training/path"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
|
|
||||||
class WaveRNNDataset(Dataset):
|
class WaveRNNDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
WaveRNN Dataset searchs for all the wav files under root path.
|
WaveRNN Dataset searchs for all the wav files under root path
|
||||||
|
and converts them to acoustic features on the fly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -15,16 +17,19 @@ class WaveRNNDataset(Dataset):
|
||||||
hop_len,
|
hop_len,
|
||||||
pad,
|
pad,
|
||||||
mode,
|
mode,
|
||||||
|
mulaw,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||||
self.item_list = items
|
self.item_list = items
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
self.hop_len = hop_len
|
self.hop_len = hop_len
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
self.mulaw = mulaw
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
|
@ -36,22 +41,47 @@ class WaveRNNDataset(Dataset):
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def load_item(self, index):
|
def load_item(self, index):
|
||||||
|
"""
|
||||||
|
load (audio, feat) couple if feature_path is set
|
||||||
|
else compute it on the fly
|
||||||
|
"""
|
||||||
|
if self.compute_feat:
|
||||||
|
|
||||||
|
wavpath = self.item_list[index]
|
||||||
|
audio = self.ap.load_wav(wavpath)
|
||||||
|
mel = self.ap.melspectrogram(audio)
|
||||||
|
|
||||||
|
if mel.shape[-1] < 5:
|
||||||
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||||
|
self.item_list[index] = self.item_list[index + 1]
|
||||||
|
audio = self.ap.load_wav(wavpath)
|
||||||
|
mel = self.ap.melspectrogram(audio)
|
||||||
|
if self.mode in ["gauss", "mold"]:
|
||||||
|
x_input = audio
|
||||||
|
elif isinstance(self.mode, int):
|
||||||
|
x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
|
||||||
|
if self.mulaw else self.ap.quantize(audio, bits=self.mode))
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
wavpath, feat_path = self.item_list[index]
|
wavpath, feat_path = self.item_list[index]
|
||||||
m = np.load(feat_path.replace("/quant/", "/mel/"))
|
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||||
# x = self.wav_cache[index]
|
|
||||||
if m.shape[-1] < 5:
|
if mel.shape[-1] < 5:
|
||||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||||
self.item_list[index] = self.item_list[index + 1]
|
self.item_list[index] = self.item_list[index + 1]
|
||||||
feat_path = self.item_list[index]
|
feat_path = self.item_list[index]
|
||||||
m = np.load(feat_path.replace("/quant/", "/mel/"))
|
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||||
if self.mode in ["gauss", "mold"]:
|
if self.mode in ["gauss", "mold"]:
|
||||||
# x = np.load(feat_path.replace("/mel/", "/quant/"))
|
x_input = self.ap.load_wav(wavpath)
|
||||||
x = self.ap.load_wav(wavpath)
|
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
x = np.load(feat_path.replace("/mel/", "/quant/"))
|
x_input = np.load(feat_path.replace("/mel/", "/quant/"))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||||
return m, x
|
|
||||||
|
return mel, x_input
|
||||||
|
|
||||||
def collate(self, batch):
|
def collate(self, batch):
|
||||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||||
|
@ -79,10 +109,8 @@ class WaveRNNDataset(Dataset):
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
coarse = np.stack(coarse).astype(np.int64)
|
coarse = np.stack(coarse).astype(np.int64)
|
||||||
coarse = torch.LongTensor(coarse)
|
coarse = torch.LongTensor(coarse)
|
||||||
x_input = (
|
x_input = (2 * coarse[:, : self.seq_len].float() /
|
||||||
2 * coarse[:, : self.seq_len].float() /
|
(2 ** self.mode - 1.0) - 1.0)
|
||||||
(2 ** self.mode - 1.0) - 1.0
|
|
||||||
)
|
|
||||||
y_coarse = coarse[:, 1:]
|
y_coarse = coarse[:, 1:]
|
||||||
mels = torch.FloatTensor(mels)
|
mels = torch.FloatTensor(mels)
|
||||||
return x_input, mels, y_coarse
|
return x_input, mels, y_coarse
|
||||||
|
|
|
@ -36,14 +36,14 @@ class ResBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MelResNet(nn.Module):
|
class MelResNet(nn.Module):
|
||||||
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
k_size = pad * 2 + 1
|
k_size = pad * 2 + 1
|
||||||
self.conv_in = nn.Conv1d(
|
self.conv_in = nn.Conv1d(
|
||||||
in_dims, compute_dims, kernel_size=k_size, bias=False)
|
in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||||
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
for _ in range(res_blocks):
|
for _ in range(num_res_blocks):
|
||||||
self.layers.append(ResBlock(compute_dims))
|
self.layers.append(ResBlock(compute_dims))
|
||||||
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ class UpsampleNetwork(nn.Module):
|
||||||
feat_dims,
|
feat_dims,
|
||||||
upsample_scales,
|
upsample_scales,
|
||||||
compute_dims,
|
compute_dims,
|
||||||
res_blocks,
|
num_res_blocks,
|
||||||
res_out_dims,
|
res_out_dims,
|
||||||
pad,
|
pad,
|
||||||
use_aux_net,
|
use_aux_net,
|
||||||
|
@ -87,7 +87,7 @@ class UpsampleNetwork(nn.Module):
|
||||||
self.use_aux_net = use_aux_net
|
self.use_aux_net = use_aux_net
|
||||||
if use_aux_net:
|
if use_aux_net:
|
||||||
self.resnet = MelResNet(
|
self.resnet = MelResNet(
|
||||||
res_blocks, feat_dims, compute_dims, res_out_dims, pad
|
num_res_blocks, feat_dims, compute_dims, res_out_dims, pad
|
||||||
)
|
)
|
||||||
self.resnet_stretch = Stretch2d(self.total_scale, 1)
|
self.resnet_stretch = Stretch2d(self.total_scale, 1)
|
||||||
self.up_layers = nn.ModuleList()
|
self.up_layers = nn.ModuleList()
|
||||||
|
@ -118,14 +118,14 @@ class UpsampleNetwork(nn.Module):
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, scale, pad, res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net
|
self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
self.indent = pad * scale
|
self.indent = pad * scale
|
||||||
self.use_aux_net = use_aux_net
|
self.use_aux_net = use_aux_net
|
||||||
self.resnet = MelResNet(res_blocks, feat_dims,
|
self.resnet = MelResNet(num_res_blocks, feat_dims,
|
||||||
compute_dims, res_out_dims, pad)
|
compute_dims, res_out_dims, pad)
|
||||||
|
|
||||||
def forward(self, m):
|
def forward(self, m):
|
||||||
|
@ -147,8 +147,7 @@ class Upsample(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class WaveRNN(nn.Module):
|
class WaveRNN(nn.Module):
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
rnn_dims,
|
rnn_dims,
|
||||||
fc_dims,
|
fc_dims,
|
||||||
mode,
|
mode,
|
||||||
|
@ -160,7 +159,7 @@ class WaveRNN(nn.Module):
|
||||||
feat_dims,
|
feat_dims,
|
||||||
compute_dims,
|
compute_dims,
|
||||||
res_out_dims,
|
res_out_dims,
|
||||||
res_blocks,
|
num_res_blocks,
|
||||||
hop_length,
|
hop_length,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
):
|
):
|
||||||
|
@ -177,7 +176,7 @@ class WaveRNN(nn.Module):
|
||||||
elif self.mode == "gauss":
|
elif self.mode == "gauss":
|
||||||
self.n_classes = 2
|
self.n_classes = 2
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(" > Unknown training mode")
|
raise RuntimeError("Unknown model mode value - ", self.mode)
|
||||||
|
|
||||||
self.rnn_dims = rnn_dims
|
self.rnn_dims = rnn_dims
|
||||||
self.aux_dims = res_out_dims // 4
|
self.aux_dims = res_out_dims // 4
|
||||||
|
@ -192,7 +191,7 @@ class WaveRNN(nn.Module):
|
||||||
feat_dims,
|
feat_dims,
|
||||||
upsample_factors,
|
upsample_factors,
|
||||||
compute_dims,
|
compute_dims,
|
||||||
res_blocks,
|
num_res_blocks,
|
||||||
res_out_dims,
|
res_out_dims,
|
||||||
pad,
|
pad,
|
||||||
use_aux_net,
|
use_aux_net,
|
||||||
|
@ -201,7 +200,7 @@ class WaveRNN(nn.Module):
|
||||||
self.upsample = Upsample(
|
self.upsample = Upsample(
|
||||||
hop_length,
|
hop_length,
|
||||||
pad,
|
pad,
|
||||||
res_blocks,
|
num_res_blocks,
|
||||||
feat_dims,
|
feat_dims,
|
||||||
compute_dims,
|
compute_dims,
|
||||||
res_out_dims,
|
res_out_dims,
|
||||||
|
@ -260,7 +259,7 @@ class WaveRNN(nn.Module):
|
||||||
x = F.relu(self.fc2(x))
|
x = F.relu(self.fc2(x))
|
||||||
return self.fc3(x)
|
return self.fc3(x)
|
||||||
|
|
||||||
def generate(self, mels, batched, target, overlap, use_cuda):
|
def generate(self, mels, batched, target, overlap, use_cuda=False):
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
device = 'cuda' if use_cuda else 'cpu'
|
device = 'cuda' if use_cuda else 'cpu'
|
||||||
|
@ -360,6 +359,8 @@ class WaveRNN(nn.Module):
|
||||||
# Fade-out at the end to avoid signal cutting out suddenly
|
# Fade-out at the end to avoid signal cutting out suddenly
|
||||||
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
||||||
output = output[:wave_len]
|
output = output[:wave_len]
|
||||||
|
|
||||||
|
if wave_len > len(fade_out):
|
||||||
output[-20 * self.hop_length:] *= fade_out
|
output[-20 * self.hop_length:] *= fade_out
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
@ -405,7 +406,8 @@ class WaveRNN(nn.Module):
|
||||||
padding = target + 2 * overlap - remaining
|
padding = target + 2 * overlap - remaining
|
||||||
x = self.pad_tensor(x, padding, side="after")
|
x = self.pad_tensor(x, padding, side="after")
|
||||||
|
|
||||||
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
|
folded = torch.zeros(num_folds, target + 2 *
|
||||||
|
overlap, features).to(x.device)
|
||||||
|
|
||||||
# Get the values for the folded tensor
|
# Get the values for the folded tensor
|
||||||
for i in range(num_folds):
|
for i in range(num_folds):
|
||||||
|
|
Loading…
Reference in New Issue