stowed aligntts commit and small refactoring with feed_forward layers

This commit is contained in:
Eren Gölge 2021-03-23 03:21:27 +01:00
parent d542a50818
commit 7a382a5c2b
28 changed files with 884 additions and 396 deletions

View File

@ -73,6 +73,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
- Tacotron2: [paper](https://arxiv.org/abs/1712.05884) - Tacotron2: [paper](https://arxiv.org/abs/1712.05884)
- Glow-TTS: [paper](https://arxiv.org/abs/2005.11129) - Glow-TTS: [paper](https://arxiv.org/abs/2005.11129)
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802) - Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
### Attention Methods ### Attention Methods
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969) - Guided Attention: [paper](https://arxiv.org/abs/1710.08969)

View File

@ -1,18 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse
import glob
import os import os
import sys import sys
import time import time
import traceback import traceback
import numpy as np
from random import randrange from random import randrange
import numpy as np
import torch import torch
from TTS.utils.arguments import parse_arguments, process_args
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@ -26,6 +22,7 @@ from TTS.tts.utils.speakers import parse_speakers
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed, reduce_tensor from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import (KeepAverage, count_parameters, from TTS.utils.generic_utils import (KeepAverage, count_parameters,
@ -33,7 +30,6 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, setup_torch_training_env from TTS.utils.training import NoamLR, setup_torch_training_env
if __name__ == '__main__': if __name__ == '__main__':
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
# torch.autograd.set_detect_anomaly(True) # torch.autograd.set_detect_anomaly(True)
@ -60,7 +56,8 @@ if __name__ == '__main__':
enable_eos_bos=c.enable_eos_bos_chars, enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=not is_val, use_noise_augment=not is_val,
verbose=verbose, verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) speaker_mapping=speaker_mapping if c.use_speaker_embedding
and c.use_external_speaker_embedding_file else None)
if c.use_phonemes and c.compute_input_seq_cache: if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths. # precompute phonemes to have a better estimate of sequence lengths.
@ -80,7 +77,6 @@ if __name__ == '__main__':
pin_memory=False) pin_memory=False)
return loader return loader
def format_data(data): def format_data(data):
# setup input data # setup input data
text_input = data[0] text_input = data[0]
@ -89,7 +85,6 @@ if __name__ == '__main__':
mel_input = data[4].permute(0, 2, 1) # B x D x T mel_input = data[4].permute(0, 2, 1) # B x D x T
mel_lengths = data[5] mel_lengths = data[5]
item_idx = data[7] item_idx = data[7]
attn_mask = data[9]
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
@ -100,7 +95,8 @@ if __name__ == '__main__':
else: else:
# return speaker_id to be used by an embedding layer # return speaker_id to be used by an embedding layer
speaker_c = [ speaker_c = [
speaker_mapping[speaker_name] for speaker_name in speaker_names speaker_mapping[speaker_name]
for speaker_name in speaker_names
] ]
speaker_c = torch.LongTensor(speaker_c) speaker_c = torch.LongTensor(speaker_c)
else: else:
@ -116,9 +112,8 @@ if __name__ == '__main__':
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, item_idx avg_text_length, avg_spec_length, item_idx
def train(data_loader, model, criterion, optimizer, scheduler, ap,
def train(data_loader, model, criterion, optimizer, scheduler, global_step, epoch, training_phase):
ap, global_step, epoch):
model.train() model.train()
epoch_time = 0 epoch_time = 0
@ -145,24 +140,39 @@ if __name__ == '__main__':
# forward pass model # forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision): with torch.cuda.amp.autocast(enabled=c.mixed_precision):
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp_max_path = model.forward( decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp = model.forward(
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c) text_input,
text_lengths,
mel_targets,
mel_lengths,
g=speaker_c,
phase=training_phase)
# compute loss # compute loss
loss_dict = criterion(mu, log_sigma, logp_max_path, decoder_output, mel_targets, mel_lengths, dur_output, dur_mas_output, text_lengths, global_step) loss_dict = criterion(mu,
log_sigma,
logp,
decoder_output,
mel_targets,
mel_lengths,
dur_output,
dur_mas_output,
text_lengths,
global_step,
phase=training_phase)
# backward pass with loss scaling # backward pass with loss scaling
if c.mixed_precision: if c.mixed_precision:
scaler.scale(loss_dict['loss']).backward() scaler.scale(loss_dict['loss']).backward()
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(
c.grad_clip) model.parameters(), c.grad_clip)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else: else:
loss_dict['loss'].backward() loss_dict['loss'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm = torch.nn.utils.clip_grad_norm_(
c.grad_clip) model.parameters(), c.grad_clip)
optimizer.step() optimizer.step()
# setup lr # setup lr
@ -181,10 +191,14 @@ if __name__ == '__main__':
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data,
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) loss_dict['loss_ssim'] = reduce_tensor(
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) loss_dict['loss_ssim'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(
loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
num_gpus)
# detach loss values # detach loss values
loss_dict_new = dict() loss_dict_new = dict()
@ -206,15 +220,16 @@ if __name__ == '__main__':
# print training progress # print training progress
if global_step % c.print_step == 0: if global_step % c.print_step == 0:
log_dict = { log_dict = {
"avg_spec_length": [avg_spec_length,
"avg_spec_length": [avg_spec_length, 1], # value, precision 1], # value, precision
"avg_text_length": [avg_text_length, 1], "avg_text_length": [avg_text_length, 1],
"step_time": [step_time, 4], "step_time": [step_time, 4],
"loader_time": [loader_time, 2], "loader_time": [loader_time, 2],
"current_lr": current_lr, "current_lr": current_lr,
} }
c_logger.print_train_step(batch_n_iter, num_iter, global_step, c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values) log_dict, loss_dict,
keep_avg.avg_values)
if args.rank == 0: if args.rank == 0:
# Plot Training Iter Stats # Plot Training Iter Stats
@ -231,35 +246,44 @@ if __name__ == '__main__':
if global_step % c.save_step == 0: if global_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters, save_checkpoint(model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
model_loss=loss_dict['loss']) model_loss=loss_dict['loss'])
# wait all kernels to be completed # wait all kernels to be completed
torch.cuda.synchronize() torch.cuda.synchronize()
# Diagnostic visualizations # Diagnostic visualizations
idx = np.random.randint(mel_targets.shape[0]) if decoder_output is not None:
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T idx = np.random.randint(mel_targets.shape[0])
gt_spec = mel_targets[idx].data.cpu().numpy().T pred_spec = decoder_output[idx].detach().data.cpu(
align_img = alignments[idx].data.cpu() ).numpy().T
gt_spec = mel_targets[idx].data.cpu().numpy().T
align_img = alignments[idx].data.cpu()
figures = { figures = {
"prediction": plot_spectrogram(pred_spec, ap), "prediction": plot_spectrogram(pred_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img), "alignment": plot_alignment(align_img),
} }
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
# Sample audio # Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T) train_audio = ap.inv_melspectrogram(pred_spec.T)
tb_logger.tb_train_audios(global_step, tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio}, {'TrainAudio': train_audio},
c.audio["sample_rate"]) c.audio["sample_rate"])
end_time = time.time() end_time = time.time()
# print epoch stats # print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) c_logger.print_train_epoch_end(global_step, epoch, epoch_time,
keep_avg)
# Plot Epoch Stats # Plot Epoch Stats
if args.rank == 0: if args.rank == 0:
@ -270,9 +294,9 @@ if __name__ == '__main__':
tb_logger.tb_model_weights(model, global_step) tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step return keep_avg.avg_values, global_step
@torch.no_grad() @torch.no_grad()
def evaluate(data_loader, model, criterion, ap, global_step, epoch): def evaluate(data_loader, model, criterion, ap, global_step, epoch,
training_phase):
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
keep_avg = KeepAverage() keep_avg = KeepAverage()
@ -283,30 +307,43 @@ if __name__ == '__main__':
# format data # format data
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, _ = format_data(data) _, _, _ = format_data(data)
# forward pass model # forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision): with torch.cuda.amp.autocast(enabled=c.mixed_precision):
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp_max_path = model.forward( decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp_max_path = model.forward(
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c) text_input,
text_lengths,
mel_targets,
mel_lengths,
g=speaker_c)
# compute loss # compute loss
loss_dict = criterion(mu, log_sigma, logp_max_path, decoder_output, mel_targets, mel_lengths, dur_output, dur_mas_output, text_lengths, global_step) loss_dict = criterion(mu, log_sigma, logp_max_path,
decoder_output, mel_targets,
mel_lengths, dur_output,
dur_mas_output, text_lengths,
global_step, training_phase)
# step time # step time
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
# compute alignment score # compute alignment score
align_error = 1 - alignment_diagonal_score(alignments, binary=True) align_error = 1 - alignment_diagonal_score(alignments,
binary=True)
loss_dict['align_error'] = align_error loss_dict['align_error'] = align_error
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) loss_dict['loss_l1'] = reduce_tensor(
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) loss_dict['loss_l1'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) loss_dict['loss_ssim'] = reduce_tensor(
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) loss_dict['loss_ssim'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(
loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
num_gpus)
# detach loss values # detach loss values
loss_dict_new = dict() loss_dict_new = dict()
@ -324,7 +361,8 @@ if __name__ == '__main__':
keep_avg.update_values(update_train_values) keep_avg.update_values(update_train_values)
if c.print_eval: if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) c_logger.print_eval_step(num_iter, loss_dict,
keep_avg.avg_values)
if args.rank == 0: if args.rank == 0:
# Diagnostic visualizations # Diagnostic visualizations
@ -334,15 +372,19 @@ if __name__ == '__main__':
align_img = alignments[idx].data.cpu() align_img = alignments[idx].data.cpu()
eval_figures = { eval_figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "prediction": plot_spectrogram(pred_spec,
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), ap,
output_fig=False),
"ground_truth": plot_spectrogram(gt_spec,
ap,
output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False) "alignment": plot_alignment(align_img, output_fig=False)
} }
# Sample audio # Sample audio
eval_audio = ap.inv_melspectrogram(pred_spec.T) eval_audio = ap.inv_melspectrogram(pred_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"]) c.audio["sample_rate"])
# Plot Validation Stats # Plot Validation Stats
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
@ -367,7 +409,9 @@ if __name__ == '__main__':
print(" | > Synthesizing test sentences") print(" | > Synthesizing test sentences")
if c.use_speaker_embedding: if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file: if c.use_external_speaker_embedding_file:
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] speaker_embedding = speaker_mapping[list(
speaker_mapping.keys())[randrange(
len(speaker_mapping) - 1)]]['embedding']
speaker_id = None speaker_id = None
else: else:
speaker_id = 0 speaker_id = 0
@ -389,30 +433,28 @@ if __name__ == '__main__':
speaker_embedding=speaker_embedding, speaker_embedding=speaker_embedding,
style_wav=style_wav, style_wav=style_wav,
truncated=False, truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False) do_trim_silence=False)
file_path = os.path.join(AUDIO_PATH, str(global_step)) file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path, file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx)) "TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path) ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram( test_figures['{}-prediction'.format(
postnet_output, ap) idx)] = plot_spectrogram(postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment( test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment) alignment)
except: #pylint: disable=bare-except except: #pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
traceback.print_exc() traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios, tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate']) c.audio['sample_rate'])
tb_logger.tb_test_figures(global_step, test_figures) tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined # pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
@ -424,31 +466,43 @@ if __name__ == '__main__':
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id, init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"]) c.distributed["backend"], c.distributed["url"])
# set model characters # set model characters
model_characters = phonemes if c.use_phonemes else symbols model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters) num_chars = len(model_characters)
# load data instances # load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True) meta_data_train, meta_data_eval = load_meta_data(c.datasets,
eval_split=True)
# set the portion of the data used for training if set in config.json # set the portion of the data used for training if set in config.json
if 'train_portion' in c.keys(): if 'train_portion' in c.keys():
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] meta_data_train = meta_data_train[:int(
len(meta_data_train) * c.train_portion)]
if 'eval_portion' in c.keys(): if 'eval_portion' in c.keys():
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] meta_data_eval = meta_data_eval[:int(
len(meta_data_eval) * c.eval_portion)]
# parse speakers # parse speakers
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(
c, args, meta_data_train, OUT_PATH)
# setup model # setup model
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) model = setup_model(num_chars,
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) num_speakers,
c,
speaker_embedding_dim=speaker_embedding_dim)
optimizer = RAdam(model.parameters(),
lr=c.lr,
weight_decay=0,
betas=(0.9, 0.98),
eps=1e-9)
criterion = AlignTTSLoss(c) criterion = AlignTTSLoss(c)
if args.restore_path: if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)} ...") print(
f" > Restoring from {os.path.basename(args.restore_path)} ...")
checkpoint = torch.load(args.restore_path, map_location='cpu') checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
# TODO: fix optimizer init, model.cuda() needs to be called before # TODO: fix optimizer init, model.cuda() needs to be called before
@ -457,7 +511,7 @@ if __name__ == '__main__':
if c.reinit_layers: if c.reinit_layers:
raise RuntimeError raise RuntimeError
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
except: #pylint: disable=bare-except except: #pylint: disable=bare-except
print(" > Partial model initialization.") print(" > Partial model initialization.")
model_dict = model.state_dict() model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_dict = set_init_dict(model_dict, checkpoint['model'], c)
@ -467,7 +521,7 @@ if __name__ == '__main__':
for group in optimizer.param_groups: for group in optimizer.param_groups:
group['initial_lr'] = c.lr group['initial_lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'], print(" > Model restored from step %d" % checkpoint['step'],
flush=True) flush=True)
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
else: else:
args.restore_step = 0 args.restore_step = 0
@ -482,8 +536,8 @@ if __name__ == '__main__':
if c.noam_schedule: if c.noam_schedule:
scheduler = NoamLR(optimizer, scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps, warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1) last_epoch=args.restore_step - 1)
else: else:
scheduler = None scheduler = None
@ -495,9 +549,9 @@ if __name__ == '__main__':
print(" > Starting with inf best loss.") print(" > Starting with inf best loss.")
else: else:
print(" > Restoring best loss from " print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...") f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss'] map_location='cpu')['model_loss']
print(f" > Starting with loaded last best loss {best_loss}.") print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get('keep_all_best', False) keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False keep_after = c.get('keep_after', 10000) # void if keep_all_best False
@ -507,25 +561,51 @@ if __name__ == '__main__':
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
global_step = args.restore_step global_step = args.restore_step
def set_phase():
"""Set AlignTTS training phase"""
if isinstance(c.phase_start_steps, list):
vals = [i < global_step for i in c.phase_start_steps]
if not True in vals:
phase = 0
else:
phase = len(c.phase_start_steps) - [
i < global_step for i in c.phase_start_steps
][::-1].index(True) - 1
else:
phase = None
return phase
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
cur_phase = set_phase()
print(f"\n > Current AlignTTS phase: {cur_phase}")
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, train_avg_loss_dict, global_step = train(train_loader, model,
scheduler, ap, global_step, criterion, optimizer,
epoch) scheduler, ap,
global_step, epoch,
cur_phase)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch) global_step, epoch, cur_phase)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss'] target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval: if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss'] target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, best_loss = save_best_model(target_loss,
global_step, epoch, c.r, OUT_PATH, model_characters, best_loss,
keep_all_best=keep_all_best, keep_after=keep_after) model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
keep_all_best=keep_all_best,
keep_after=keep_after)
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_type='tts') args, model_class='tts')
try: try:
main(args) main(args)

View File

@ -580,7 +580,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_type='glow_tts') args, model_class='tts')
try: try:
main(args) main(args)

View File

@ -540,7 +540,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_type='tts') args, model_class='tts')
try: try:
main(args) main(args)

View File

@ -658,7 +658,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_type='tacotron') args, model_class='tts')
try: try:
main(args) main(args)

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# TODO: mixed precision training
"""Trains GAN based vocoder model.""" """Trains GAN based vocoder model."""
import os import os
@ -590,7 +591,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_type='gan') args, model_class='vocoder')
try: try:
main(args) main(args)

View File

@ -436,7 +436,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_type='wavegrad') args, model_class='vocoder')
try: try:
main(args) main(args)

View File

@ -460,7 +460,7 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__": if __name__ == "__main__":
args = parse_arguments(sys.argv) args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_type='wavernn') args, model_class='vocoder')
try: try:
main(args) main(args)

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
# adapted from https://github.com/cvqluu/GE2E-Loss # adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module): class GE2ELoss(nn.Module):

View File

View File

@ -0,0 +1,20 @@
from torch import nn
from TTS.tts.layers.generic.transformer import FFTransformerBlock
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
class DurationPredictor(nn.Module):
def __init__(self, num_chars, hidden_channels, hidden_channels_ffn, num_heads):
super().__init__()
self.embed = nn.Embedding(num_chars, hidden_channels)
self.pos_enc = PositionalEncoding(hidden_channels, dropout_p=0.1)
self.FFT = FFTransformerBlock(hidden_channels, num_heads, hidden_channels_ffn, 2, 0.1)
self.out_layer = nn.Conv1d(hidden_channels, 1, 1)
def forward(self, text, text_lengths):
# B, L -> B, L
emb = self.embed(text)
emb = self.pos_enc(emb.transpose(1, 2))
x = self.FFT(emb, text_lengths)
x = self.out_layer(x).squeeze(-1)
return x

View File

@ -1,6 +1,4 @@
import torch
from torch import nn from torch import nn
from ..generic.normalization import LayerNorm
class MDNBlock(nn.Module): class MDNBlock(nn.Module):
@ -10,14 +8,20 @@ class MDNBlock(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
self.mdn = nn.Sequential(nn.Conv1d(in_channels, in_channels, 1), self.conv1 = nn.Conv1d(in_channels, in_channels, 1)
LayerNorm(in_channels), self.norm = nn.LayerNorm(in_channels)
nn.ReLU(), self.relu = nn.ReLU()
nn.Dropout(0.1), self.dropout = nn.Dropout(0.1)
nn.Conv1d(in_channels, out_channels, 1)) self.conv2 = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x): def forward(self, x):
mu_sigma = self.mdn(x) o = self.conv1(x)
o = o.transpose(1, 2)
o = self.norm(o)
o = o.transpose(1, 2)
o = self.relu(o)
o = self.dropout(o)
mu_sigma = self.conv2(o)
# TODO: check this sigmoid # TODO: check this sigmoid
# mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :])
mu = mu_sigma[:, :self.out_channels//2, :] mu = mu_sigma[:, :self.out_channels//2, :]

View File

@ -3,7 +3,7 @@ from torch import nn
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
from TTS.tts.layers.generic.wavenet import WNBlocks from TTS.tts.layers.generic.wavenet import WNBlocks
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.transformer import FFTransformersBlock from TTS.tts.layers.generic.transformer import FFTransformerBlock
class WaveNetDecoder(nn.Module): class WaveNetDecoder(nn.Module):
@ -93,8 +93,7 @@ class RelativePositionTransformerDecoder(nn.Module):
class FFTransformerDecoder(nn.Module): class FFTransformerDecoder(nn.Module):
"""Decoder with FeedForwardTransformer. """Decoder with FeedForwardTransformer.
Note: Default params
Default params
params={ params={
'hidden_channels_ffn': 1024, 'hidden_channels_ffn': 1024,
'num_heads': 2, 'num_heads': 2,
@ -111,15 +110,17 @@ class FFTransformerDecoder(nn.Module):
def __init__(self, in_channels, out_channels, params): def __init__(self, in_channels, out_channels, params):
super().__init__() super().__init__()
self.transformer_block = FFTransformersBlock(in_channels, **params) self.transformer_block = FFTransformerBlock(in_channels, **params)
self.postnet = nn.Conv1d(in_channels, out_channels, 1) self.postnet = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
# TODO: handle multi-speaker # TODO: handle multi-speaker
x_mask = 1 if x_mask is None else x_mask
o = self.transformer_block(x) * x_mask o = self.transformer_block(x) * x_mask
o = self.postnet(o)* x_mask o = self.postnet(o)* x_mask
return o return o
class ResidualConv1dBNDecoder(nn.Module): class ResidualConv1dBNDecoder(nn.Module):
"""Residual Convolutional Decoder as in the original Speedy Speech paper """Residual Convolutional Decoder as in the original Speedy Speech paper
@ -208,7 +209,7 @@ class Decoder(nn.Module):
hidden_channels=in_hidden_channels, hidden_channels=in_hidden_channels,
c_in_channels=c_in_channels, c_in_channels=c_in_channels,
params=decoder_params) params=decoder_params)
elif decoder_type.lower() == 'transformer': elif decoder_type.lower() == 'fftransformer':
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
else: else:
raise ValueError(f'[!] Unknown decoder type - {decoder_type}') raise ValueError(f'[!] Unknown decoder type - {decoder_type}')

View File

@ -1,10 +1,8 @@
import math
import torch
from torch import nn from torch import nn
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
from TTS.tts.layers.generic.transformer import FFTransformersBlock from TTS.tts.layers.generic.transformer import FFTransformerBlock
class RelativePositionTransformerEncoder(nn.Module): class RelativePositionTransformerEncoder(nn.Module):
@ -88,32 +86,34 @@ class Encoder(nn.Module):
Note: Note:
Default encoder_params to be set in config.json... Default encoder_params to be set in config.json...
for 'relative_position_transformer' ```python
encoder_params={ # for 'relative_position_transformer'
'hidden_channels_ffn': 128, encoder_params={
'num_heads': 2, 'hidden_channels_ffn': 128,
"kernel_size": 3, 'num_heads': 2,
"dropout_p": 0.1, "kernel_size": 3,
"num_layers": 6, "dropout_p": 0.1,
"rel_attn_window_size": 4, "num_layers": 6,
"input_length": None "rel_attn_window_size": 4,
}, "input_length": None
},
for 'residual_conv_bn' # for 'residual_conv_bn'
encoder_params = { encoder_params = {
"kernel_size": 4, "kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1], "dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2, "num_conv_blocks": 2,
"num_res_blocks": 13 "num_res_blocks": 13
} }
for 'transformer_decoder' # for 'fftransformer'
encoder_params = { encoder_params = {
hidden_channels_ffn: 1024 , "hidden_channels_ffn": 1024 ,
num_heads: 2, "num_heads": 2,
num_layers: 6, "num_layers": 6,
dropout_p: 0.1 "dropout_p": 0.1
} }
```
""" """
def __init__( def __init__(
self, self,
@ -145,8 +145,10 @@ class Encoder(nn.Module):
out_channels, out_channels,
in_hidden_channels, in_hidden_channels,
encoder_params) encoder_params)
elif encoder_type.lower() == 'transformer': elif encoder_type.lower() == 'fftransformer':
self.encoder = FFTransformersBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg assert in_hidden_channels == out_channels, \
"[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
else: else:
raise NotImplementedError(' [!] unknown encoder type.') raise NotImplementedError(' [!] unknown encoder type.')

View File

@ -0,0 +1,56 @@
import torch
import math
from torch import nn
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding for non-recurrent neural networks.
Implementation based on "Attention Is All You Need"
Args:
channels (int): embedding size
dropout (float): dropout parameter
"""
def __init__(self, channels, dropout_p=0.0, max_len=5000):
super().__init__()
if channels % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd channels (got channels={:d})".format(channels))
pe = torch.zeros(max_len, channels)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.pow(10000,
torch.arange(0, channels, 2).float() / channels)
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0).transpose(1, 2)
self.register_buffer('pe', pe)
if dropout_p > 0:
self.dropout = nn.Dropout(p=dropout_p)
self.channels = channels
def forward(self, x, mask=None, first_idx=None, last_idx=None):
"""
Shapes:
x: [B, C, T]
mask: [B, 1, T]
first_idx: int
last_idx: int
"""
x = x * math.sqrt(self.channels)
if first_idx is None:
if self.pe.size(2) < x.size(2):
raise RuntimeError(
f"Sequence is {x.size(2)} but PositionalEncoding is"
f" limited to {self.pe.size(2)}. See max_len argument.")
if mask is not None:
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
else:
pos_enc = self.pe[:, :, :x.size(2)]
x = x + pos_enc
else:
x = x + self.pe[:, :, first_idx:last_idx]
if hasattr(self, 'dropout'):
x = self.dropout(x)
return x

View File

@ -0,0 +1,74 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FFTransformer(nn.Module):
def __init__(self,
in_out_channels,
num_heads,
hidden_channels_ffn=1024,
kernel_size_fft=3,
dropout_p=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(in_out_channels,
num_heads,
dropout=dropout_p)
padding = (kernel_size_fft - 1) // 2
self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding)
self.conv2 = nn.Conv1d(hidden_channels_ffn, in_out_channels, kernel_size=kernel_size_fft, padding=padding)
self.norm1 = nn.LayerNorm(in_out_channels)
self.norm2 = nn.LayerNorm(in_out_channels)
self.dropout = nn.Dropout(dropout_p)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""😦 ugly looking with all the transposing """
src = src.permute(2, 0, 1)
src2, enc_align = self.self_attn(src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)
src = self.norm1(src + src2)
# T x B x D -> B x D x T
src = src.permute(1, 2, 0)
src2 = self.conv2(F.relu(self.conv1(src)))
src2 = self.dropout(src2)
src = src + src2
src = src.transpose(1, 2)
src = self.norm2(src)
src = src.transpose(1, 2)
return src, enc_align
class FFTransformerBlock(nn.Module):
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn,
num_layers, dropout_p):
super().__init__()
self.fft_layers = nn.ModuleList([
FFTransformer(in_out_channels=in_out_channels,
num_heads=num_heads,
hidden_channels_ffn=hidden_channels_ffn,
dropout_p=dropout_p) for _ in range(num_layers)
])
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
"""
TODO: handle multi-speaker
Shapes:
x: [B, C, T]
mask: [B, 1, T] or [B, T]
"""
if mask is not None and mask.ndim == 3:
mask = mask.squeeze(1)
# mask is negated, torch uses 1s and 0s reversely.
mask = ~mask.bool()
alignments = []
for layer in self.fft_layers:
x, align = layer(x, src_key_padding_mask=mask)
alignments.append(align.unsqueeze(1))
alignments = torch.cat(alignments, 1)
return x

View File

@ -444,38 +444,36 @@ class SpeedySpeechLoss(nn.Module):
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss} return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}
def mse_loss_custom(input, target): def mse_loss_custom(x, y):
"""MSE loss using the torch back-end without reduction. """MSE loss using the torch back-end without reduction.
It uses less VRAM than the raw code""" It uses less VRAM than the raw code"""
expanded_input, expanded_target = torch.broadcast_tensors(input, target) expanded_x, expanded_y = torch.broadcast_tensors(x, y)
return torch._C._nn.mse_loss(expanded_input, expanded_target, 0) return torch._C._nn.mse_loss(expanded_x, expanded_y, 0) # pylint: disable=protected-access, c-extension-no-member
class MDNLoss(nn.Module): class MDNLoss(nn.Module):
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf. """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
""" """
def __init__(self):
super().__init__()
def forward(self, mu, log_sigma, logp_max_path, melspec, text_lengths, mel_lengths): def forward(self, mu, log_sigma, logp, melspec, text_lengths, mel_lengths): # pylint: disable=no-self-use
''' '''
Shapes: Shapes:
mu: [B, D, T] mu: [B, D, T]
log_sigma: [B, D, T] log_sigma: [B, D, T]
mel_spec: [B, D, T] mel_spec: [B, D, T]
''' '''
B, D, L = mu.size() B, _, L = mu.size()
T = melspec.size(2) T = melspec.size(2)
x = melspec.transpose(1,2).unsqueeze(1) # [B, 1, T1, D] # x = melspec.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] # mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] # log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T # exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T
log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi)) # log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)
log_alpha = mu.new_ones(B, L, T)*(-1e4) log_alpha = logp.new_ones(B, L, T)*(-1e4)
log_alpha[:, 0, 0] = log_prob_matrix[:, 0, 0] log_alpha[:, 0, 0] = logp[:, 0, 0]
for t in range(1, T): for t in range(1, T):
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1) prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1)
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + log_prob_matrix[:, :, t] log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1] alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
mdn_loss = -alpha_last.mean() / L mdn_loss = -alpha_last.mean() / L
return mdn_loss#, log_prob_matrix return mdn_loss#, log_prob_matrix
@ -506,8 +504,11 @@ class AlignTTSLoss(nn.Module):
self.spec_loss_alpha = c.spec_loss_alpha self.spec_loss_alpha = c.spec_loss_alpha
self.mdn_alpha = c.mdn_alpha self.mdn_alpha = c.mdn_alpha
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase): def forward(self, mu, log_sigma, logp, decoder_output, decoder_target,
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) decoder_output_lens, dur_output, dur_target, input_lens, step,
phase):
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
step)
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
if phase == 0: if phase == 0:
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens) mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
@ -528,7 +529,8 @@ class AlignTTSLoss(nn.Module):
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss} return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss}
def _set_alpha(self, step, alpha_settings): @staticmethod
def _set_alpha(step, alpha_settings):
'''Set the loss alpha wrt number of steps. '''Set the loss alpha wrt number of steps.
Return the corresponding value if no schedule is set. Return the corresponding value if no schedule is set.
@ -546,7 +548,7 @@ class AlignTTSLoss(nn.Module):
for key, alpha in alpha_settings: for key, alpha in alpha_settings:
if key < step: if key < step:
return_alpha = alpha return_alpha = alpha
elif isinstance(alpha_settings, float) or isinstance(alpha_settings, int): elif isinstance(alpha_settings, (float, int)):
return_alpha = alpha_settings return_alpha = alpha_settings
return return_alpha return return_alpha

View File

@ -1,86 +1,106 @@
import torch
import math import math
from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder import torch
from TTS.tts.layers.align_tts.duration_predictor import DurationPredictor import torch.nn as nn
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.feed_forward.decoder import Decoder
class AlignTTS(nn.Module): class AlignTTS(nn.Module):
"""Speedy Speech model with Monotonic Alignment Search """AlignTTS with modified duration predictor.
https://arxiv.org/abs/2008.03802 https://arxiv.org/pdf/2003.01950.pdf
https://arxiv.org/pdf/2005.11129.pdf
Encoder -> DurationPredictor -> Decoder Encoder -> DurationPredictor -> Decoder
This model is able to achieve a reasonable performance with only AlignTTS's Abstract - Targeting at both high efficiency and performance, we propose AlignTTS to predict the
~3M model parameters and convolutional layers. mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a
sequence of characters, and the duration of each character is determined by a duration predictor.Instead of
adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented
to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s
how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean
option score (MOS), but also a high efficiency which is more than 50 times faster than real-time.
This model requires precomputed phoneme durations to train a duration predictor. At inference Note:
it only uses the duration predictor to compute durations and expand encoder outputs respectively. Original model uses a separate character embedding layer for duration predictor. However, it causes the
duration predictor to overfit and prevents learning higher level interactions among characters. Therefore,
we predict durations based on encoder outputs which has higher level information about input characters. This
enables training without phases as in the original paper.
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
Args: Args:
num_chars (int): number of unique input to characters num_chars (int):
out_channels (int): number of output tensor channels. It is equal to the expected spectrogram size. number of unique input to characters
hidden_channels (int): number of channels in all the model layers. out_channels (int):
positional_encoding (bool, optional): enable/disable Positional encoding on encoder outputs. Defaults to True. number of output tensor channels. It is equal to the expected spectrogram size.
length_scale (int, optional): coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1. hidden_channels (int):
encoder_type (str, optional): set the encoder type. Defaults to 'residual_conv_bn'. number of channels in all the model layers.
encoder_params (dict, optional): set encoder parameters depending on 'encoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 }. hidden_channels_ffn (int):
decoder_type (str, optional): decoder type. Defaults to 'residual_conv_bn'. number of channels in transformer's conv layers.
decoder_params (dict, optional): set decoder parameters depending on 'decoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }. hidden_channels_dp (int):
num_speakers (int, optional): number of speakers for multi-speaker training. Defaults to 0. number of channels in duration predictor network.
external_c (bool, optional): enable external speaker embeddings. Defaults to False. num_heads (int):
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0. number of attention heads in transformer networks.
num_transformer_layers (int):
number of layers in encoder and decoder transformer blocks.
dropout_p (int):
dropout rate in transformer layers.
length_scale (int, optional):
coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1.
num_speakers (int, optional):
number of speakers for multi-speaker training. Defaults to 0.
external_c (bool, optional):
enable external speaker embeddings. Defaults to False.
c_in_channels (int, optional):
number of channels in speaker embedding vectors. Defaults to 0.
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__( def __init__(
self, self,
num_chars, num_chars,
out_channels, out_channels,
hidden_channels, hidden_channels=256,
positional_encoding=True, hidden_channels_dp=256,
length_scale=1, encoder_type='fftransformer',
encoder_type='residual_conv_bn', encoder_params={
encoder_params={ 'hidden_channels_ffn': 1024,
"kernel_size": 4, 'num_heads': 2,
"dilations": 4 * [1, 2, 4] + [1], 'num_layers': 6,
"num_conv_blocks": 2, 'dropout_p': 0.1
"num_res_blocks": 13 },
}, decoder_type='fftransformer',
decoder_type='residual_conv_bn', decoder_params={
decoder_params={ 'hidden_channels_ffn': 1024,
"kernel_size": 4, 'num_heads': 2,
"dilations": 4 * [1, 2, 4, 8] + [1], 'num_layers': 6,
"num_conv_blocks": 2, 'dropout_p': 0.1
"num_res_blocks": 17 },
}, length_scale=1,
num_speakers=0, num_speakers=0,
external_c=False, external_c=False,
c_in_channels=0): c_in_channels=0):
super().__init__() super().__init__()
self.length_scale = float(length_scale) if isinstance( self.length_scale = float(length_scale) if isinstance(
length_scale, int) else length_scale length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels) self.emb = nn.Embedding(num_chars, hidden_channels)
self.pos_encoder = PositionalEncoding(hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
encoder_params, c_in_channels) encoder_params, c_in_channels)
if positional_encoding:
self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
decoder_params) decoder_params)
self.duration_predictor = DurationPredictor(num_chars, hidden_channels, hidden_channels_ffn=1024, num_heads=2) self.duration_predictor = DurationPredictor(hidden_channels_dp)
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1) self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
self.mdn_block = MDNBlock(hidden_channels, 2*out_channels) self.mdn_block = MDNBlock(hidden_channels, 2 * out_channels)
if num_speakers > 1 and not external_c: if num_speakers > 1 and not external_c:
# speaker embedding layer # speaker embedding layer
@ -90,35 +110,39 @@ class AlignTTS(nn.Module):
if c_in_channels > 0 and c_in_channels != hidden_channels: if c_in_channels > 0 and c_in_channels != hidden_channels:
self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1) self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1)
def compute_mas_path(self, mu, log_sigma, y, x_mask, y_mask): @staticmethod
def compute_log_probs(mu, log_sigma, y):
'''Faster way to compute log probability'''
scale = torch.exp(-2 * log_sigma)
# [B, T_en, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_sigma,
[1]).unsqueeze(-1)
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
logp2 = torch.matmul(scale.transpose(1, 2), -0.5 * (y**2))
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
logp3 = torch.matmul((mu * scale).transpose(1, 2), y)
# [B, T_en, 1]
logp4 = torch.sum(-0.5 * (mu**2) * scale, [1]).unsqueeze(-1)
# [B, T_en, T_dec]
logp = logp1 + logp2 + logp3 + logp4
return logp
def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):
# find the max alignment path # find the max alignment path
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad(): log_p = self.compute_log_probs(mu, log_sigma, y)
scale = torch.exp(-2 * log_sigma) # [B, T_en, T_dec]
# [B, T_en, 1] attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_sigma,
[1]).unsqueeze(-1)
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
logp2 = torch.matmul(scale.transpose(1, 2), -0.5 * (y**2))
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
logp3 = torch.matmul((mu * scale).transpose(1, 2), y)
# [B, T_en, 1]
logp4 = torch.sum(-0.5 * (mu**2) * scale,
[1]).unsqueeze(-1)
# [B, T_en, T_dec]
logp = logp1 + logp2 + logp3 + logp4
# import pdb; pdb.set_trace()
# [B, T_en, T_dec]
attn = maximum_path(logp,
attn_mask.squeeze(1)).unsqueeze(1).detach()
# logp_max_path = logp.new_ones(logp.shape) * -1e4
# logp_max_path += logp * attn.squeeze(1)
logp_max_path = None
dr_mas = torch.sum(attn, -1) dr_mas = torch.sum(attn, -1)
return dr_mas.squeeze(1), logp_max_path return dr_mas.squeeze(1), log_p
@staticmethod @staticmethod
def expand_encoder_outputs(en, dr, x_mask, y_mask): def convert_dr_to_align(dr, x_mask, y_mask):
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and """Generate attention alignment map from durations and
expand encoder outputs expand encoder outputs
@ -132,8 +156,7 @@ class AlignTTS(nn.Module):
[0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]] [1, 0, 0, 0, 0, 0, 0]]
""" """
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = self.convert_dr_to_align(dr, x_mask, y_mask)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul( o_en_ex = torch.matmul(
attn.squeeze(1).transpose(1, 2), en.transpose(1, attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2) 2)).transpose(1, 2)
@ -198,23 +221,16 @@ class AlignTTS(nn.Module):
o_de = self.decoder(o_en_ex, y_mask, g=g) o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de, attn.transpose(1, 2) return o_de, attn.transpose(1, 2)
# def _forward_mas(self, o_en, y, y_lengths, x_mask):
# # MAS potentials and alignment
# o_en_mean = self.mod_layer(o_en)
# y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
# 1).to(o_en.dtype)
# z = self.wn_spec_encoder(y)
# dr_mas, y_mean, y_scale = self.compute_mas_path(o_en_mean, z, x_mask, y_mask)
# return dr_mas, z, y_mean, y_scale
def _forward_mdn(self, o_en, y, y_lengths, x_mask): def _forward_mdn(self, o_en, y, y_lengths, x_mask):
# MAS potentials and alignment # MAS potentials and alignment
mu, log_sigma = self.mdn_block(o_en) mu, log_sigma = self.mdn_block(o_en)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
dr_mas, logp_max_path = self.compute_mas_path(mu, log_sigma, y, x_mask, y_mask) 1).to(o_en.dtype)
return dr_mas, mu, log_sigma, logp_max_path dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
y_mask)
return dr_mas, mu, log_sigma, logp
def forward(self, x, x_lengths, y, y_lengths, g=None): # pylint: disable=unused-argument def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: [B, T_max] x: [B, T_max]
@ -223,14 +239,65 @@ class AlignTTS(nn.Module):
dr: [B, T_max] dr: [B, T_max]
g: [B, C] g: [B, C]
""" """
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
o_dr_log = self.duration_predictor(x, x_mask) if phase == 0:
dr_mas, mu, log_sigma, logp_max_path = self._forward_mdn(o_en, y, y_lengths, x_mask) # train encoder and MDN
# TODO: compute attn once o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) dr_mas, mu, log_sigma, logp = self._forward_mdn(
dr_mas_log = torch.log(1 + dr_mas).squeeze(1) o_en, y, y_lengths, x_mask)
return o_de, o_dr_log.squeeze(1), dr_mas_log, attn, mu, log_sigma, logp_max_path y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
elif phase == 1:
# train decoder
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en.detach(),
o_en_dp.detach(),
dr_mas.detach(),
x_mask,
y_lengths,
g=g)
elif phase == 2:
# train the whole except duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
elif phase == 3:
# train duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(x, x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
o_dr_log = o_dr_log.squeeze(1)
else:
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
o_dr_log = o_dr_log.squeeze(1)
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp
@torch.no_grad()
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
@ -239,13 +306,19 @@ class AlignTTS(nn.Module):
g: [B, C] g: [B, C]
""" """
# pad input to prevent dropping the last word # pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(x, x_mask) # o_dr_log = self.duration_predictor(x, x_mask)
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
# duration predictor pass # duration predictor pass
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1) y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) o_de, attn = self._forward_decoder(o_en,
o_en_dp,
o_dr,
x_mask,
y_lengths,
g=g)
return o_de, attn return o_de, attn
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin

View File

@ -2,7 +2,8 @@ import torch
from torch import nn from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.layers.feed_forward.encoder import Encoder, PositionalEncoding from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import generate_path from TTS.tts.layers.glow_tts.monotonic_align import generate_path

View File

@ -138,7 +138,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'], out_channels=c.audio['num_mels'],
hidden_channels=c['hidden_channels'], hidden_channels=c['hidden_channels'],
positional_encoding=c['positional_encoding'], hidden_channels_dp=c['hidden_channels_dp'],
encoder_type=c['encoder_type'], encoder_type=c['encoder_type'],
encoder_params=c['encoder_params'], encoder_params=c['encoder_params'],
decoder_type=c['decoder_type'], decoder_type=c['decoder_type'],

View File

@ -7,7 +7,6 @@ import glob
import os import os
import re import re
from TTS.tts.utils.generic_utils import check_config_tts
from TTS.tts.utils.text.symbols import parse_symbols from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
@ -125,19 +124,13 @@ def get_last_checkpoint(path):
return last_models['checkpoint'], last_models['best_model'] return last_models['checkpoint'], last_models['best_model']
def process_args(args, model_type): def process_args(args, model_class):
"""Process parsed comand line arguments. """Process parsed comand line arguments based on model class (tts or vocoder).
Args: Args:
args (argparse.Namespace or dict like): Parsed input arguments. args (argparse.Namespace or dict like): Parsed input arguments.
model_type (str): Model type used to check config parameters and setup model_type (str): Model type used to check config parameters and setup
the TensorBoard logger. One of: the TensorBoard logger. One of ['tts', 'vocoder'].
- tacotron
- glow_tts
- speedy_speech
- gan
- wavegrad
- wavernn
Raises: Raises:
ValueError: If `model_type` is not one of implemented choices. ValueError: If `model_type` is not one of implemented choices.
@ -160,20 +153,6 @@ def process_args(args, model_type):
# setup output paths and read configs # setup output paths and read configs
c = load_config(args.config_path) c = load_config(args.config_path)
if model_type in "tacotron glow_tts speedy_speech":
model_class = "TTS"
elif model_type in "gan wavegrad wavernn":
model_class = "VOCODER"
else:
raise ValueError("model type {model_type} not recognized!")
if model_class == "TTS":
check_config_tts(c)
elif model_class == "VOCODER":
print("Vocoder config checker not implemented, skipping ...")
else:
raise ValueError(f"model type {model_type} not recognized!")
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
if 'mixed_precision' in c and c.mixed_precision: if 'mixed_precision' in c and c.mixed_precision:
@ -198,7 +177,7 @@ def process_args(args, model_type):
# if model characters are not set in the config file # if model characters are not set in the config file
# save the default set to the config file for future # save the default set to the config file for future
# compatibility. # compatibility.
if model_class == 'TTS' and 'characters' not in c: if model_class == 'tts' and 'characters' not in c:
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields['characters'] = used_characters new_fields['characters'] = used_characters
copy_model_files(c, args.config_path, copy_model_files(c, args.config_path,
@ -208,7 +187,7 @@ def process_args(args, model_type):
log_path = out_path log_path = out_path
tb_logger = TensorboardLogger(log_path, model_name=model_class) tb_logger = TensorboardLogger(log_path, model_name=model_class.upper())
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-description", c["run_description"], 0) tb_logger.tb_add_text("model-description", c["run_description"], 0)

View File

@ -12,7 +12,6 @@ from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder
# pylint: disable=unused-wildcard-import # pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, trim_silence from TTS.tts.utils.synthesis import synthesis, trim_silence
from TTS.tts.utils.text import make_symbols, phonemes, symbols from TTS.tts.utils.text import make_symbols, phonemes, symbols

View File

@ -1,7 +1,7 @@
set -e set -e
TF_CPP_MIN_LOG_LEVEL=3 TF_CPP_MIN_LOG_LEVEL=3
# tests # # tests
nosetests tests -x &&\ nosetests tests -x &&\
# runtime tests # runtime tests
@ -13,7 +13,7 @@ nosetests tests -x &&\
./tests/test_vocoder_wavernn_train.sh && \ ./tests/test_vocoder_wavernn_train.sh && \
./tests/test_vocoder_wavegrad_train.sh && \ ./tests/test_vocoder_wavegrad_train.sh && \
./tests/test_speedy_speech_train.sh && \ ./tests/test_speedy_speech_train.sh && \
./tests/test_align_tts_train.sh && \ ./tests/test_aligntts_train.sh && \
./tests/test_compute_statistics.sh && \ ./tests/test_compute_statistics.sh && \
# linter check # linter check

View File

@ -0,0 +1,157 @@
{
"model": "align_tts",
"run_name": "test_sample_dataset_run",
"run_description": "sample dataset test run",
// AUDIO PARAMETERS
"audio":{
// stft parameters
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
"win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
// Audio processing parameters
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate.
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
// Silence trimming
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (true), TWEB (false), Nancy (true)
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
// Griffin-Lim
"power": 1.5, // value to sharpen wav signals after GL algorithm.
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
// MelSpectrogram parameters
"num_mels": 80, // size of the mel spec frame.
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
"spec_gain": 1,
// Normalization parameters
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
"min_level_db": -100, // lower bound for normalization
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
},
// VOCABULARY PARAMETERS
// if custom character set is not defined,
// default set in symbols.py is used
// "characters":{
// "pad": "_",
// "eos": "&",
// "bos": "*",
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZÇÃÀÁÂÊÉÍÓÔÕÚÛabcdefghijklmnopqrstuvwxyzçãàáâêéíóôõúû!(),-.:;? ",
// "punctuations":"!'(),-.:;? ",
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ'̃' "
// },
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
// DISTRIBUTED TRAINING
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54321"
},
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
// MODEL PARAMETERS
"positional_encoding": true,
"hidden_channels": 256,
"encoder_type": "fftransformer",
"encoder_params":{
"hidden_channels_ffn": 1024 ,
"num_heads": 2,
"num_layers": 6,
"dropout_p": 0.1
},
"decoder_type": "fftransformer",
"decoder_params":{
"hidden_channels_ffn": 1024 ,
"num_heads": 2,
"num_layers": 6,
"dropout_p": 0.1
},
// TRAINING
"batch_size":2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":1,
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"phase_start_steps": [0, 40000, 80000, 160000, 170000],
// LOSS PARAMETERS
"ssim_alpha": 1,
"spec_loss_alpha": 1,
"dur_loss_alpha": 1,
"mdn_alpha": 1,
// VALIDATION
"run_eval": true,
"test_delay_epochs": -1, //Until attention is aligned, testing only wastes computation time.
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
// OPTIMIZER
"noam_schedule": true, // use noam warmup and lr schedule.
"grad_clip": 1.0, // upper limit for gradients for clipping.
"epochs": 1, // total number of epochs to train.
"lr": 0.002, // Initial learning rate. If Noam decay is active, maximum learning rate.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
// TENSORBOARD and LOGGING
"print_step": 1, // Number of steps to log training on console.
"tb_plot_step": 100, // Number of steps to plot TB training figures.
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
"save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"keep_all_best": true, // If true, keeps all best_models after keep_after steps
"keep_after": 10000, // Global step after which to keep best models if keep_all_best is true
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.:set n
"mixed_precision": false,
// DATA LOADING
"text_cleaner": "english_cleaners",
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"num_loader_workers": 0, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"num_val_loader_workers": 0, // number of evaluation data loader processes.
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
"min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 300, // DATASET-RELATED: maximum text length
"compute_f0": false, // compute f0 values in data-loader
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
// PATHS
"output_path": "tests/train_outputs/",
// PHONEMES
"phoneme_cache_path": "tests/train_outputs/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronoun[ciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
// MULTI-SPEAKER and GST
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
"use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
"external_speaker_embedding_file": "/home/erogol/Data/libritts/speakers.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
// DATASETS
"datasets": // List of datasets. They all merged and they get different speaker_ids.
[
{
"name": "ljspeech",
"path": "tests/data/ljspeech/",
"meta_file_train": "metadata.csv",
"meta_file_val": "metadata.csv",
"meta_file_attn_mask": null
}
]
}

View File

@ -0,0 +1,13 @@
#!/usr/bin/env bash
set -xe
BASEDIR=$(dirname "$0")
echo "$BASEDIR"
# run training
CUDA_VISIBLE_DEVICES="" python TTS/bin/train_align_tts.py --config_path $BASEDIR/inputs/test_align_tts.json
# find the training folder
LATEST_FOLDER=$(ls $BASEDIR/train_outputs/| sort | tail -1)
echo $LATEST_FOLDER
# continue the previous training
CUDA_VISIBLE_DEVICES="" python TTS/bin/train_align_tts.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER
# remove all the outputs
rm -rf $BASEDIR/train_outputs/

View File

@ -0,0 +1,106 @@
import torch
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.utils.generic_utils import sequence_mask
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def test_encoder():
input_dummy = torch.rand(8, 14, 37).to(device)
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
input_lengths[-1] = 37
input_mask = torch.unsqueeze(
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
# relative positional transformer encoder
layer = Encoder(out_channels=11,
in_hidden_channels=14,
encoder_type='relative_position_transformer',
encoder_params={
'hidden_channels_ffn': 768,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 6,
"rel_attn_window_size": 4,
"input_length": None
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# residual conv bn encoder
layer = Encoder(out_channels=11,
in_hidden_channels=14,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# FFTransformer encoder
layer = Encoder(out_channels=14,
in_hidden_channels=14,
encoder_type='fftransformer',
encoder_params={
"hidden_channels_ffn": 31,
"num_heads": 2,
"num_layers": 2,
"dropout_p": 0.1
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 14, 37]
def test_decoder():
input_dummy = torch.rand(8, 128, 37).to(device)
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
input_lengths[-1] = 37
input_mask = torch.unsqueeze(
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
# residual bn conv decoder
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# transformer decoder
layer = Decoder(out_channels=11,
in_hidden_channels=128,
decoder_type='relative_position_transformer',
decoder_params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 8,
"rel_attn_window_size": 4,
"input_length": None
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# wavenet decoder
layer = Decoder(out_channels=11,
in_hidden_channels=128,
decoder_type='wavenet',
decoder_params={
"num_blocks": 12,
"hidden_channels": 192,
"kernel_size": 5,
"dilation_rate": 1,
"num_layers": 4,
"dropout_p": 0.05
}).to(device)
output = layer(input_dummy, input_mask)
# FFTransformer decoder
layer = Decoder(out_channels=11,
in_hidden_channels=128,
decoder_type='fftransformer',
decoder_params={
'hidden_channels_ffn': 31,
'num_heads': 2,
"dropout_p": 0.1,
"num_layers": 2,
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]

View File

@ -1,20 +1,20 @@
#!/usr/bin/env python3` # #!/usr/bin/env python3`
import os # import os
import shutil # import shutil
import glob # import glob
from tests import get_tests_output_path # from tests import get_tests_output_path
from TTS.utils.manage import ModelManager # from TTS.utils.manage import ModelManager
def test_if_all_models_available(): # def test_if_all_models_available():
"""Check if all the models are downloadable.""" # """Check if all the models are downloadable."""
print(" > Checking the availability of all the models under the ModelManager.") # print(" > Checking the availability of all the models under the ModelManager.")
manager = ModelManager(output_prefix=get_tests_output_path()) # manager = ModelManager(output_prefix=get_tests_output_path())
model_names = manager.list_models() # model_names = manager.list_models()
for model_name in model_names: # for model_name in model_names:
manager.download_model(model_name) # manager.download_model(model_name)
print(f" | > OK: {model_name}") # print(f" | > OK: {model_name}")
folders = glob.glob(os.path.join(manager.output_prefix, '*')) # folders = glob.glob(os.path.join(manager.output_prefix, '*'))
assert len(folders) == len(model_names) # assert len(folders) == len(model_names)
shutil.rmtree(manager.output_prefix) # shutil.rmtree(manager.output_prefix)

View File

@ -1,7 +1,4 @@
import torch import torch
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.models.speedy_speech import SpeedySpeech from TTS.tts.models.speedy_speech import SpeedySpeech
@ -11,84 +8,6 @@ use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def test_encoder():
input_dummy = torch.rand(8, 14, 37).to(device)
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
input_lengths[-1] = 37
input_mask = torch.unsqueeze(
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
# residual bn conv encoder
layer = Encoder(out_channels=11,
in_hidden_channels=14,
encoder_type='residual_conv_bn').to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# transformer encoder
layer = Encoder(out_channels=11,
in_hidden_channels=14,
encoder_type='transformer',
encoder_params={
'hidden_channels_ffn': 768,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 6,
"rel_attn_window_size": 4,
"input_length": None
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
def test_decoder():
input_dummy = torch.rand(8, 128, 37).to(device)
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
input_lengths[-1] = 37
input_mask = torch.unsqueeze(
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
# residual bn conv decoder
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# transformer decoder
layer = Decoder(out_channels=11,
in_hidden_channels=128,
decoder_type='transformer',
decoder_params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
"dropout_p": 0.1,
"num_layers": 8,
"rel_attn_window_size": 4,
"input_length": None
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
# wavenet decoder
layer = Decoder(out_channels=11,
in_hidden_channels=128,
decoder_type='wavenet',
decoder_params={
"num_blocks": 12,
"hidden_channels": 192,
"kernel_size": 5,
"dilation_rate": 1,
"num_layers": 4,
"dropout_p": 0.05
}).to(device)
output = layer(input_dummy, input_mask)
assert list(output.shape) == [8, 11, 37]
def test_duration_predictor(): def test_duration_predictor():
input_dummy = torch.rand(8, 128, 27).to(device) input_dummy = torch.rand(8, 128, 27).to(device)
input_lengths = torch.randint(20, 27, (8, )).long().to(device) input_lengths = torch.randint(20, 27, (8, )).long().to(device)