mirror of https://github.com/coqui-ai/TTS.git
stowed aligntts commit and small refactoring with feed_forward layers
This commit is contained in:
parent
d542a50818
commit
7a382a5c2b
|
@ -73,6 +73,7 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
|
|||
- Tacotron2: [paper](https://arxiv.org/abs/1712.05884)
|
||||
- Glow-TTS: [paper](https://arxiv.org/abs/2005.11129)
|
||||
- Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
|
||||
- Align-TTS: [paper](https://arxiv.org/abs/2003.01950)
|
||||
|
||||
### Attention Methods
|
||||
- Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
|
||||
|
|
|
@ -1,18 +1,14 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import numpy as np
|
||||
from random import randrange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
@ -26,6 +22,7 @@ from TTS.tts.utils.speakers import parse_speakers
|
|||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
|
@ -33,7 +30,6 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
|||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
@ -60,7 +56,8 @@ if __name__ == '__main__':
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding
|
||||
and c.use_external_speaker_embedding_file else None)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
|
@ -80,7 +77,6 @@ if __name__ == '__main__':
|
|||
pin_memory=False)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
|
@ -89,7 +85,6 @@ if __name__ == '__main__':
|
|||
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
attn_mask = data[9]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
|
@ -100,7 +95,8 @@ if __name__ == '__main__':
|
|||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
speaker_mapping[speaker_name]
|
||||
for speaker_name in speaker_names
|
||||
]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
|
@ -116,9 +112,8 @@ if __name__ == '__main__':
|
|||
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, item_idx
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch):
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap,
|
||||
global_step, epoch, training_phase):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
|
@ -145,24 +140,39 @@ if __name__ == '__main__':
|
|||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp_max_path = model.forward(
|
||||
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c)
|
||||
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp = model.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
g=speaker_c,
|
||||
phase=training_phase)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(mu, log_sigma, logp_max_path, decoder_output, mel_targets, mel_lengths, dur_output, dur_mas_output, text_lengths, global_step)
|
||||
loss_dict = criterion(mu,
|
||||
log_sigma,
|
||||
logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
dur_output,
|
||||
dur_mas_output,
|
||||
text_lengths,
|
||||
global_step,
|
||||
phase=training_phase)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
|
@ -181,10 +191,14 @@ if __name__ == '__main__':
|
|||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data,
|
||||
num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(
|
||||
loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(
|
||||
loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
|
||||
num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -206,15 +220,16 @@ if __name__ == '__main__':
|
|||
# print training progress
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_spec_length": [avg_spec_length,
|
||||
1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
log_dict, loss_dict,
|
||||
keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
|
@ -231,35 +246,44 @@ if __name__ == '__main__':
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||
save_checkpoint(model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
if decoder_output is not None:
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu(
|
||||
).numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time,
|
||||
keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
|
@ -270,9 +294,9 @@ if __name__ == '__main__':
|
|||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch,
|
||||
training_phase):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -283,30 +307,43 @@ if __name__ == '__main__':
|
|||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, _ = format_data(data)
|
||||
_, _, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, dur_mas_output, alignments, mu, log_sigma, logp_max_path = model.forward(
|
||||
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c)
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
g=speaker_c)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(mu, log_sigma, logp_max_path, decoder_output, mel_targets, mel_lengths, dur_output, dur_mas_output, text_lengths, global_step)
|
||||
loss_dict = criterion(mu, log_sigma, logp_max_path,
|
||||
decoder_output, mel_targets,
|
||||
mel_lengths, dur_output,
|
||||
dur_mas_output, text_lengths,
|
||||
global_step, training_phase)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
align_error = 1 - alignment_diagonal_score(alignments,
|
||||
binary=True)
|
||||
loss_dict['align_error'] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
loss_dict['loss_l1'] = reduce_tensor(
|
||||
loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(
|
||||
loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(
|
||||
loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
|
||||
num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -324,7 +361,8 @@ if __name__ == '__main__':
|
|||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_eval_step(num_iter, loss_dict,
|
||||
keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
|
@ -334,15 +372,19 @@ if __name__ == '__main__':
|
|||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"prediction": plot_spectrogram(pred_spec,
|
||||
ap,
|
||||
output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec,
|
||||
ap,
|
||||
output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False)
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
@ -367,7 +409,9 @@ if __name__ == '__main__':
|
|||
print(" | > Synthesizing test sentences")
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
|
||||
speaker_embedding = speaker_mapping[list(
|
||||
speaker_mapping.keys())[randrange(
|
||||
len(speaker_mapping) - 1)]]['embedding']
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
|
@ -389,30 +433,28 @@ if __name__ == '__main__':
|
|||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||
postnet_output, ap)
|
||||
test_figures['{}-prediction'.format(
|
||||
idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment)
|
||||
except: #pylint: disable=bare-except
|
||||
except: #pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios,
|
||||
c.audio['sample_rate'])
|
||||
c.audio['sample_rate'])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
|
@ -424,31 +466,43 @@ if __name__ == '__main__':
|
|||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets,
|
||||
eval_split=True)
|
||||
|
||||
# set the portion of the data used for training if set in config.json
|
||||
if 'train_portion' in c.keys():
|
||||
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
||||
meta_data_train = meta_data_train[:int(
|
||||
len(meta_data_train) * c.train_portion)]
|
||||
if 'eval_portion' in c.keys():
|
||||
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
||||
meta_data_eval = meta_data_eval[:int(
|
||||
len(meta_data_eval) * c.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(
|
||||
c, args, meta_data_train, OUT_PATH)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
model = setup_model(num_chars,
|
||||
num_speakers,
|
||||
c,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(),
|
||||
lr=c.lr,
|
||||
weight_decay=0,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-9)
|
||||
criterion = AlignTTSLoss(c)
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
print(
|
||||
f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
|
@ -457,7 +511,7 @@ if __name__ == '__main__':
|
|||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
except: #pylint: disable=bare-except
|
||||
except: #pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
|
@ -467,7 +521,7 @@ if __name__ == '__main__':
|
|||
for group in optimizer.param_groups:
|
||||
group['initial_lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
@ -482,8 +536,8 @@ if __name__ == '__main__':
|
|||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
|
@ -495,9 +549,9 @@ if __name__ == '__main__':
|
|||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
map_location='cpu')['model_loss']
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
|
@ -507,25 +561,51 @@ if __name__ == '__main__':
|
|||
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
|
||||
def set_phase():
|
||||
"""Set AlignTTS training phase"""
|
||||
if isinstance(c.phase_start_steps, list):
|
||||
vals = [i < global_step for i in c.phase_start_steps]
|
||||
if not True in vals:
|
||||
phase = 0
|
||||
else:
|
||||
phase = len(c.phase_start_steps) - [
|
||||
i < global_step for i in c.phase_start_steps
|
||||
][::-1].index(True) - 1
|
||||
else:
|
||||
phase = None
|
||||
return phase
|
||||
|
||||
for epoch in range(0, c.epochs):
|
||||
cur_phase = set_phase()
|
||||
print(f"\n > Current AlignTTS phase: {cur_phase}")
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch)
|
||||
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||
criterion, optimizer,
|
||||
scheduler, ap,
|
||||
global_step, epoch,
|
||||
cur_phase)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||
global_step, epoch)
|
||||
global_step, epoch, cur_phase)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
||||
global_step, epoch, c.r, OUT_PATH, model_characters,
|
||||
keep_all_best=keep_all_best, keep_after=keep_after)
|
||||
|
||||
best_loss = save_best_model(target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after)
|
||||
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='tts')
|
||||
args, model_class='tts')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
@ -538,4 +618,4 @@ if __name__ == '__main__':
|
|||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
sys.exit(1)
|
||||
|
|
|
@ -580,7 +580,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if __name__ == '__main__':
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='glow_tts')
|
||||
args, model_class='tts')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -540,7 +540,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if __name__ == '__main__':
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='tts')
|
||||
args, model_class='tts')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -658,7 +658,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if __name__ == '__main__':
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='tacotron')
|
||||
args, model_class='tts')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
# TODO: mixed precision training
|
||||
"""Trains GAN based vocoder model."""
|
||||
|
||||
import os
|
||||
|
@ -590,7 +591,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if __name__ == '__main__':
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='gan')
|
||||
args, model_class='vocoder')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -436,7 +436,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if __name__ == '__main__':
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='wavegrad')
|
||||
args, model_class='vocoder')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -460,7 +460,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='wavernn')
|
||||
args, model_class='vocoder')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||
class GE2ELoss(nn.Module):
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from torch import nn
|
||||
from TTS.tts.layers.generic.transformer import FFTransformerBlock
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, num_chars, hidden_channels, hidden_channels_ffn, num_heads):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(num_chars, hidden_channels)
|
||||
self.pos_enc = PositionalEncoding(hidden_channels, dropout_p=0.1)
|
||||
self.FFT = FFTransformerBlock(hidden_channels, num_heads, hidden_channels_ffn, 2, 0.1)
|
||||
self.out_layer = nn.Conv1d(hidden_channels, 1, 1)
|
||||
|
||||
def forward(self, text, text_lengths):
|
||||
# B, L -> B, L
|
||||
emb = self.embed(text)
|
||||
emb = self.pos_enc(emb.transpose(1, 2))
|
||||
x = self.FFT(emb, text_lengths)
|
||||
x = self.out_layer(x).squeeze(-1)
|
||||
return x
|
|
@ -1,6 +1,4 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from ..generic.normalization import LayerNorm
|
||||
|
||||
|
||||
class MDNBlock(nn.Module):
|
||||
|
@ -10,14 +8,20 @@ class MDNBlock(nn.Module):
|
|||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.mdn = nn.Sequential(nn.Conv1d(in_channels, in_channels, 1),
|
||||
LayerNorm(in_channels),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Conv1d(in_channels, out_channels, 1))
|
||||
self.conv1 = nn.Conv1d(in_channels, in_channels, 1)
|
||||
self.norm = nn.LayerNorm(in_channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.conv2 = nn.Conv1d(in_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
mu_sigma = self.mdn(x)
|
||||
o = self.conv1(x)
|
||||
o = o.transpose(1, 2)
|
||||
o = self.norm(o)
|
||||
o = o.transpose(1, 2)
|
||||
o = self.relu(o)
|
||||
o = self.dropout(o)
|
||||
mu_sigma = self.conv2(o)
|
||||
# TODO: check this sigmoid
|
||||
# mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :])
|
||||
mu = mu_sigma[:, :self.out_channels//2, :]
|
||||
|
|
|
@ -3,7 +3,7 @@ from torch import nn
|
|||
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN
|
||||
from TTS.tts.layers.generic.wavenet import WNBlocks
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.generic.transformer import FFTransformersBlock
|
||||
from TTS.tts.layers.generic.transformer import FFTransformerBlock
|
||||
|
||||
|
||||
class WaveNetDecoder(nn.Module):
|
||||
|
@ -93,8 +93,7 @@ class RelativePositionTransformerDecoder(nn.Module):
|
|||
class FFTransformerDecoder(nn.Module):
|
||||
"""Decoder with FeedForwardTransformer.
|
||||
|
||||
Note:
|
||||
Default params
|
||||
Default params
|
||||
params={
|
||||
'hidden_channels_ffn': 1024,
|
||||
'num_heads': 2,
|
||||
|
@ -111,15 +110,17 @@ class FFTransformerDecoder(nn.Module):
|
|||
def __init__(self, in_channels, out_channels, params):
|
||||
|
||||
super().__init__()
|
||||
self.transformer_block = FFTransformersBlock(in_channels, **params)
|
||||
self.transformer_block = FFTransformerBlock(in_channels, **params)
|
||||
self.postnet = nn.Conv1d(in_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
# TODO: handle multi-speaker
|
||||
x_mask = 1 if x_mask is None else x_mask
|
||||
o = self.transformer_block(x) * x_mask
|
||||
o = self.postnet(o)* x_mask
|
||||
return o
|
||||
|
||||
|
||||
class ResidualConv1dBNDecoder(nn.Module):
|
||||
"""Residual Convolutional Decoder as in the original Speedy Speech paper
|
||||
|
||||
|
@ -208,7 +209,7 @@ class Decoder(nn.Module):
|
|||
hidden_channels=in_hidden_channels,
|
||||
c_in_channels=c_in_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type.lower() == 'transformer':
|
||||
elif decoder_type.lower() == 'fftransformer':
|
||||
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
|
||||
else:
|
||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||
from TTS.tts.layers.generic.transformer import FFTransformersBlock
|
||||
from TTS.tts.layers.generic.transformer import FFTransformerBlock
|
||||
|
||||
|
||||
class RelativePositionTransformerEncoder(nn.Module):
|
||||
|
@ -88,32 +86,34 @@ class Encoder(nn.Module):
|
|||
Note:
|
||||
Default encoder_params to be set in config.json...
|
||||
|
||||
for 'relative_position_transformer'
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
},
|
||||
```python
|
||||
# for 'relative_position_transformer'
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
},
|
||||
|
||||
for 'residual_conv_bn'
|
||||
encoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
}
|
||||
# for 'residual_conv_bn'
|
||||
encoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
}
|
||||
|
||||
for 'transformer_decoder'
|
||||
encoder_params = {
|
||||
hidden_channels_ffn: 1024 ,
|
||||
num_heads: 2,
|
||||
num_layers: 6,
|
||||
dropout_p: 0.1
|
||||
}
|
||||
# for 'fftransformer'
|
||||
encoder_params = {
|
||||
"hidden_channels_ffn": 1024 ,
|
||||
"num_heads": 2,
|
||||
"num_layers": 6,
|
||||
"dropout_p": 0.1
|
||||
}
|
||||
```
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -145,8 +145,10 @@ class Encoder(nn.Module):
|
|||
out_channels,
|
||||
in_hidden_channels,
|
||||
encoder_params)
|
||||
elif encoder_type.lower() == 'transformer':
|
||||
self.encoder = FFTransformersBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
elif encoder_type.lower() == 'fftransformer':
|
||||
assert in_hidden_channels == out_channels, \
|
||||
"[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
|
||||
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
else:
|
||||
raise NotImplementedError(' [!] unknown encoder type.')
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
import math
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
||||
Implementation based on "Attention Is All You Need"
|
||||
Args:
|
||||
channels (int): embedding size
|
||||
dropout (float): dropout parameter
|
||||
"""
|
||||
def __init__(self, channels, dropout_p=0.0, max_len=5000):
|
||||
super().__init__()
|
||||
if channels % 2 != 0:
|
||||
raise ValueError(
|
||||
"Cannot use sin/cos positional encoding with "
|
||||
"odd channels (got channels={:d})".format(channels))
|
||||
pe = torch.zeros(max_len, channels)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.pow(10000,
|
||||
torch.arange(0, channels, 2).float() / channels)
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(1, 2)
|
||||
self.register_buffer('pe', pe)
|
||||
if dropout_p > 0:
|
||||
self.dropout = nn.Dropout(p=dropout_p)
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x, mask=None, first_idx=None, last_idx=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
mask: [B, 1, T]
|
||||
first_idx: int
|
||||
last_idx: int
|
||||
"""
|
||||
|
||||
x = x * math.sqrt(self.channels)
|
||||
if first_idx is None:
|
||||
if self.pe.size(2) < x.size(2):
|
||||
raise RuntimeError(
|
||||
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
||||
f" limited to {self.pe.size(2)}. See max_len argument.")
|
||||
if mask is not None:
|
||||
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
|
||||
else:
|
||||
pos_enc = self.pe[:, :, :x.size(2)]
|
||||
x = x + pos_enc
|
||||
else:
|
||||
x = x + self.pe[:, :, first_idx:last_idx]
|
||||
if hasattr(self, 'dropout'):
|
||||
x = self.dropout(x)
|
||||
return x
|
|
@ -0,0 +1,74 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FFTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
in_out_channels,
|
||||
num_heads,
|
||||
hidden_channels_ffn=1024,
|
||||
kernel_size_fft=3,
|
||||
dropout_p=0.1):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(in_out_channels,
|
||||
num_heads,
|
||||
dropout=dropout_p)
|
||||
|
||||
padding = (kernel_size_fft - 1) // 2
|
||||
self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding)
|
||||
self.conv2 = nn.Conv1d(hidden_channels_ffn, in_out_channels, kernel_size=kernel_size_fft, padding=padding)
|
||||
|
||||
self.norm1 = nn.LayerNorm(in_out_channels)
|
||||
self.norm2 = nn.LayerNorm(in_out_channels)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||
"""😦 ugly looking with all the transposing """
|
||||
src = src.permute(2, 0, 1)
|
||||
src2, enc_align = self.self_attn(src,
|
||||
src,
|
||||
src,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
src = self.norm1(src + src2)
|
||||
# T x B x D -> B x D x T
|
||||
src = src.permute(1, 2, 0)
|
||||
src2 = self.conv2(F.relu(self.conv1(src)))
|
||||
src2 = self.dropout(src2)
|
||||
src = src + src2
|
||||
src = src.transpose(1, 2)
|
||||
src = self.norm2(src)
|
||||
src = src.transpose(1, 2)
|
||||
return src, enc_align
|
||||
|
||||
|
||||
class FFTransformerBlock(nn.Module):
|
||||
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn,
|
||||
num_layers, dropout_p):
|
||||
super().__init__()
|
||||
self.fft_layers = nn.ModuleList([
|
||||
FFTransformer(in_out_channels=in_out_channels,
|
||||
num_heads=num_heads,
|
||||
hidden_channels_ffn=hidden_channels_ffn,
|
||||
dropout_p=dropout_p) for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
TODO: handle multi-speaker
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
mask: [B, 1, T] or [B, T]
|
||||
"""
|
||||
if mask is not None and mask.ndim == 3:
|
||||
mask = mask.squeeze(1)
|
||||
# mask is negated, torch uses 1s and 0s reversely.
|
||||
mask = ~mask.bool()
|
||||
alignments = []
|
||||
for layer in self.fft_layers:
|
||||
x, align = layer(x, src_key_padding_mask=mask)
|
||||
alignments.append(align.unsqueeze(1))
|
||||
alignments = torch.cat(alignments, 1)
|
||||
return x
|
|
@ -444,38 +444,36 @@ class SpeedySpeechLoss(nn.Module):
|
|||
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}
|
||||
|
||||
|
||||
def mse_loss_custom(input, target):
|
||||
def mse_loss_custom(x, y):
|
||||
"""MSE loss using the torch back-end without reduction.
|
||||
It uses less VRAM than the raw code"""
|
||||
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
|
||||
return torch._C._nn.mse_loss(expanded_input, expanded_target, 0)
|
||||
expanded_x, expanded_y = torch.broadcast_tensors(x, y)
|
||||
return torch._C._nn.mse_loss(expanded_x, expanded_y, 0) # pylint: disable=protected-access, c-extension-no-member
|
||||
|
||||
|
||||
class MDNLoss(nn.Module):
|
||||
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, mu, log_sigma, logp_max_path, melspec, text_lengths, mel_lengths):
|
||||
def forward(self, mu, log_sigma, logp, melspec, text_lengths, mel_lengths): # pylint: disable=no-self-use
|
||||
'''
|
||||
Shapes:
|
||||
mu: [B, D, T]
|
||||
log_sigma: [B, D, T]
|
||||
mel_spec: [B, D, T]
|
||||
'''
|
||||
B, D, L = mu.size()
|
||||
B, _, L = mu.size()
|
||||
T = melspec.size(2)
|
||||
x = melspec.transpose(1,2).unsqueeze(1) # [B, 1, T1, D]
|
||||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T
|
||||
log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)# -(hp.n_mel_channels/2)*torch.log(torch.tensor(2*math.pi))
|
||||
log_alpha = mu.new_ones(B, L, T)*(-1e4)
|
||||
log_alpha[:, 0, 0] = log_prob_matrix[:, 0, 0]
|
||||
# x = melspec.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
|
||||
# mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
# log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
# exponential = -0.5*torch.mean(mse_loss_custom(x, mu)/torch.pow(log_sigma.exp(), 2), dim=-1) # B, L, T
|
||||
# log_prob_matrix = exponential -0.5 * log_sigma.mean(dim=-1)
|
||||
log_alpha = logp.new_ones(B, L, T)*(-1e4)
|
||||
log_alpha[:, 0, 0] = logp[:, 0, 0]
|
||||
for t in range(1, T):
|
||||
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], (0, 0, 1, -1), value=-1e4)], dim=-1)
|
||||
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + log_prob_matrix[:, :, t]
|
||||
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
|
||||
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
|
||||
mdn_loss = -alpha_last.mean() / L
|
||||
return mdn_loss#, log_prob_matrix
|
||||
|
@ -506,8 +504,11 @@ class AlignTTSLoss(nn.Module):
|
|||
self.spec_loss_alpha = c.spec_loss_alpha
|
||||
self.mdn_alpha = c.mdn_alpha
|
||||
|
||||
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase):
|
||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
|
||||
def forward(self, mu, log_sigma, logp, decoder_output, decoder_target,
|
||||
decoder_output_lens, dur_output, dur_target, input_lens, step,
|
||||
phase):
|
||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
|
||||
step)
|
||||
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
|
||||
if phase == 0:
|
||||
mdn_loss = self.mdn_loss(mu, log_sigma, logp, decoder_target, input_lens, decoder_output_lens)
|
||||
|
@ -528,7 +529,8 @@ class AlignTTSLoss(nn.Module):
|
|||
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
|
||||
return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss}
|
||||
|
||||
def _set_alpha(self, step, alpha_settings):
|
||||
@staticmethod
|
||||
def _set_alpha(step, alpha_settings):
|
||||
'''Set the loss alpha wrt number of steps.
|
||||
Return the corresponding value if no schedule is set.
|
||||
|
||||
|
@ -546,7 +548,7 @@ class AlignTTSLoss(nn.Module):
|
|||
for key, alpha in alpha_settings:
|
||||
if key < step:
|
||||
return_alpha = alpha
|
||||
elif isinstance(alpha_settings, float) or isinstance(alpha_settings, int):
|
||||
elif isinstance(alpha_settings, (float, int)):
|
||||
return_alpha = alpha_settings
|
||||
return return_alpha
|
||||
|
||||
|
|
|
@ -1,86 +1,106 @@
|
|||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.align_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
|
||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||
|
||||
|
||||
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
|
||||
|
||||
class AlignTTS(nn.Module):
|
||||
"""Speedy Speech model with Monotonic Alignment Search
|
||||
https://arxiv.org/abs/2008.03802
|
||||
https://arxiv.org/pdf/2005.11129.pdf
|
||||
"""AlignTTS with modified duration predictor.
|
||||
https://arxiv.org/pdf/2003.01950.pdf
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
This model is able to achieve a reasonable performance with only
|
||||
~3M model parameters and convolutional layers.
|
||||
AlignTTS's Abstract - Targeting at both high efficiency and performance, we propose AlignTTS to predict the
|
||||
mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a
|
||||
sequence of characters, and the duration of each character is determined by a duration predictor.Instead of
|
||||
adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented
|
||||
to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s
|
||||
how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean
|
||||
option score (MOS), but also a high efficiency which is more than 50 times faster than real-time.
|
||||
|
||||
This model requires precomputed phoneme durations to train a duration predictor. At inference
|
||||
it only uses the duration predictor to compute durations and expand encoder outputs respectively.
|
||||
Note:
|
||||
Original model uses a separate character embedding layer for duration predictor. However, it causes the
|
||||
duration predictor to overfit and prevents learning higher level interactions among characters. Therefore,
|
||||
we predict durations based on encoder outputs which has higher level information about input characters. This
|
||||
enables training without phases as in the original paper.
|
||||
|
||||
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
|
||||
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of unique input to characters
|
||||
out_channels (int): number of output tensor channels. It is equal to the expected spectrogram size.
|
||||
hidden_channels (int): number of channels in all the model layers.
|
||||
positional_encoding (bool, optional): enable/disable Positional encoding on encoder outputs. Defaults to True.
|
||||
length_scale (int, optional): coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1.
|
||||
encoder_type (str, optional): set the encoder type. Defaults to 'residual_conv_bn'.
|
||||
encoder_params (dict, optional): set encoder parameters depending on 'encoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 }.
|
||||
decoder_type (str, optional): decoder type. Defaults to 'residual_conv_bn'.
|
||||
decoder_params (dict, optional): set decoder parameters depending on 'decoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }.
|
||||
num_speakers (int, optional): number of speakers for multi-speaker training. Defaults to 0.
|
||||
external_c (bool, optional): enable external speaker embeddings. Defaults to False.
|
||||
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
|
||||
num_chars (int):
|
||||
number of unique input to characters
|
||||
out_channels (int):
|
||||
number of output tensor channels. It is equal to the expected spectrogram size.
|
||||
hidden_channels (int):
|
||||
number of channels in all the model layers.
|
||||
hidden_channels_ffn (int):
|
||||
number of channels in transformer's conv layers.
|
||||
hidden_channels_dp (int):
|
||||
number of channels in duration predictor network.
|
||||
num_heads (int):
|
||||
number of attention heads in transformer networks.
|
||||
num_transformer_layers (int):
|
||||
number of layers in encoder and decoder transformer blocks.
|
||||
dropout_p (int):
|
||||
dropout rate in transformer layers.
|
||||
length_scale (int, optional):
|
||||
coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1.
|
||||
num_speakers (int, optional):
|
||||
number of speakers for multi-speaker training. Defaults to 0.
|
||||
external_c (bool, optional):
|
||||
enable external speaker embeddings. Defaults to False.
|
||||
c_in_channels (int, optional):
|
||||
number of channels in speaker embedding vectors. Defaults to 0.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
positional_encoding=True,
|
||||
length_scale=1,
|
||||
encoder_type='residual_conv_bn',
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
},
|
||||
decoder_type='residual_conv_bn',
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
},
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0):
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels=256,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type='fftransformer',
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 1024,
|
||||
'num_heads': 2,
|
||||
'num_layers': 6,
|
||||
'dropout_p': 0.1
|
||||
},
|
||||
decoder_type='fftransformer',
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 1024,
|
||||
'num_heads': 2,
|
||||
'num_layers': 6,
|
||||
'dropout_p': 0.1
|
||||
},
|
||||
length_scale=1,
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0):
|
||||
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(
|
||||
length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
||||
encoder_params, c_in_channels)
|
||||
if positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
||||
decoder_params)
|
||||
self.duration_predictor = DurationPredictor(num_chars, hidden_channels, hidden_channels_ffn=1024, num_heads=2)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
||||
|
||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.mdn_block = MDNBlock(hidden_channels, 2*out_channels)
|
||||
self.mdn_block = MDNBlock(hidden_channels, 2 * out_channels)
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
# speaker embedding layer
|
||||
|
@ -90,35 +110,39 @@ class AlignTTS(nn.Module):
|
|||
if c_in_channels > 0 and c_in_channels != hidden_channels:
|
||||
self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1)
|
||||
|
||||
def compute_mas_path(self, mu, log_sigma, y, x_mask, y_mask):
|
||||
@staticmethod
|
||||
def compute_log_probs(mu, log_sigma, y):
|
||||
'''Faster way to compute log probability'''
|
||||
scale = torch.exp(-2 * log_sigma)
|
||||
# [B, T_en, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_sigma,
|
||||
[1]).unsqueeze(-1)
|
||||
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
|
||||
logp2 = torch.matmul(scale.transpose(1, 2), -0.5 * (y**2))
|
||||
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
|
||||
logp3 = torch.matmul((mu * scale).transpose(1, 2), y)
|
||||
# [B, T_en, 1]
|
||||
logp4 = torch.sum(-0.5 * (mu**2) * scale, [1]).unsqueeze(-1)
|
||||
# [B, T_en, T_dec]
|
||||
logp = logp1 + logp2 + logp3 + logp4
|
||||
return logp
|
||||
|
||||
def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask):
|
||||
# find the max alignment path
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
with torch.no_grad():
|
||||
scale = torch.exp(-2 * log_sigma)
|
||||
# [B, T_en, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - log_sigma,
|
||||
[1]).unsqueeze(-1)
|
||||
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
|
||||
logp2 = torch.matmul(scale.transpose(1, 2), -0.5 * (y**2))
|
||||
# [B, T_en, D] x [B, D, T_dec] = [B, T_en, T_dec]
|
||||
logp3 = torch.matmul((mu * scale).transpose(1, 2), y)
|
||||
# [B, T_en, 1]
|
||||
logp4 = torch.sum(-0.5 * (mu**2) * scale,
|
||||
[1]).unsqueeze(-1)
|
||||
# [B, T_en, T_dec]
|
||||
logp = logp1 + logp2 + logp3 + logp4
|
||||
# import pdb; pdb.set_trace()
|
||||
# [B, T_en, T_dec]
|
||||
attn = maximum_path(logp,
|
||||
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
# logp_max_path = logp.new_ones(logp.shape) * -1e4
|
||||
# logp_max_path += logp * attn.squeeze(1)
|
||||
logp_max_path = None
|
||||
log_p = self.compute_log_probs(mu, log_sigma, y)
|
||||
# [B, T_en, T_dec]
|
||||
attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1)
|
||||
dr_mas = torch.sum(attn, -1)
|
||||
return dr_mas.squeeze(1), logp_max_path
|
||||
return dr_mas.squeeze(1), log_p
|
||||
|
||||
@staticmethod
|
||||
def expand_encoder_outputs(en, dr, x_mask, y_mask):
|
||||
def convert_dr_to_align(dr, x_mask, y_mask):
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||
return attn
|
||||
|
||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
|
||||
|
@ -132,8 +156,7 @@ class AlignTTS(nn.Module):
|
|||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
||||
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
||||
2)).transpose(1, 2)
|
||||
|
@ -198,23 +221,16 @@ class AlignTTS(nn.Module):
|
|||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||
return o_de, attn.transpose(1, 2)
|
||||
|
||||
# def _forward_mas(self, o_en, y, y_lengths, x_mask):
|
||||
# # MAS potentials and alignment
|
||||
# o_en_mean = self.mod_layer(o_en)
|
||||
# y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
# 1).to(o_en.dtype)
|
||||
# z = self.wn_spec_encoder(y)
|
||||
# dr_mas, y_mean, y_scale = self.compute_mas_path(o_en_mean, z, x_mask, y_mask)
|
||||
# return dr_mas, z, y_mean, y_scale
|
||||
|
||||
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
|
||||
# MAS potentials and alignment
|
||||
# MAS potentials and alignment
|
||||
mu, log_sigma = self.mdn_block(o_en)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
dr_mas, logp_max_path = self.compute_mas_path(mu, log_sigma, y, x_mask, y_mask)
|
||||
return dr_mas, mu, log_sigma, logp_max_path
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en.dtype)
|
||||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
|
||||
y_mask)
|
||||
return dr_mas, mu, log_sigma, logp
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, g=None): # pylint: disable=unused-argument
|
||||
def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
|
@ -223,14 +239,65 @@ class AlignTTS(nn.Module):
|
|||
dr: [B, T_max]
|
||||
g: [B, C]
|
||||
"""
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(x, x_mask)
|
||||
dr_mas, mu, log_sigma, logp_max_path = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
# TODO: compute attn once
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
dr_mas_log = torch.log(1 + dr_mas).squeeze(1)
|
||||
return o_de, o_dr_log.squeeze(1), dr_mas_log, attn, mu, log_sigma, logp_max_path
|
||||
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
||||
if phase == 0:
|
||||
# train encoder and MDN
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en_dp.dtype)
|
||||
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
|
||||
elif phase == 1:
|
||||
# train decoder
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en.detach(),
|
||||
o_en_dp.detach(),
|
||||
dr_mas.detach(),
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
elif phase == 2:
|
||||
# train the whole except duration predictor
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
elif phase == 3:
|
||||
# train duration predictor
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(x, x_mask)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
o_dr_log = o_dr_log.squeeze(1)
|
||||
else:
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
o_dr_log = o_dr_log.squeeze(1)
|
||||
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
||||
return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -239,13 +306,19 @@ class AlignTTS(nn.Module):
|
|||
g: [B, C]
|
||||
"""
|
||||
# pad input to prevent dropping the last word
|
||||
x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(x, x_mask)
|
||||
# o_dr_log = self.duration_predictor(x, x_mask)
|
||||
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||
# duration predictor pass
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
o_dr,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
return o_de, attn
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
|
@ -253,4 +326,4 @@ class AlignTTS(nn.Module):
|
|||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
assert not self.training
|
||||
|
|
|
@ -2,7 +2,8 @@ import torch
|
|||
from torch import nn
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder, PositionalEncoding
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio['num_mels'],
|
||||
hidden_channels=c['hidden_channels'],
|
||||
positional_encoding=c['positional_encoding'],
|
||||
hidden_channels_dp=c['hidden_channels_dp'],
|
||||
encoder_type=c['encoder_type'],
|
||||
encoder_params=c['encoder_params'],
|
||||
decoder_type=c['decoder_type'],
|
||||
|
@ -301,4 +301,4 @@ def check_config_tts(c):
|
|||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
|
|
|
@ -7,7 +7,6 @@ import glob
|
|||
import os
|
||||
import re
|
||||
|
||||
from TTS.tts.utils.generic_utils import check_config_tts
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||
|
@ -125,19 +124,13 @@ def get_last_checkpoint(path):
|
|||
return last_models['checkpoint'], last_models['best_model']
|
||||
|
||||
|
||||
def process_args(args, model_type):
|
||||
"""Process parsed comand line arguments.
|
||||
def process_args(args, model_class):
|
||||
"""Process parsed comand line arguments based on model class (tts or vocoder).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
model_type (str): Model type used to check config parameters and setup
|
||||
the TensorBoard logger. One of:
|
||||
- tacotron
|
||||
- glow_tts
|
||||
- speedy_speech
|
||||
- gan
|
||||
- wavegrad
|
||||
- wavernn
|
||||
the TensorBoard logger. One of ['tts', 'vocoder'].
|
||||
|
||||
Raises:
|
||||
ValueError: If `model_type` is not one of implemented choices.
|
||||
|
@ -160,20 +153,6 @@ def process_args(args, model_type):
|
|||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
if model_type in "tacotron glow_tts speedy_speech":
|
||||
model_class = "TTS"
|
||||
elif model_type in "gan wavegrad wavernn":
|
||||
model_class = "VOCODER"
|
||||
else:
|
||||
raise ValueError("model type {model_type} not recognized!")
|
||||
|
||||
if model_class == "TTS":
|
||||
check_config_tts(c)
|
||||
elif model_class == "VOCODER":
|
||||
print("Vocoder config checker not implemented, skipping ...")
|
||||
else:
|
||||
raise ValueError(f"model type {model_type} not recognized!")
|
||||
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if 'mixed_precision' in c and c.mixed_precision:
|
||||
|
@ -198,7 +177,7 @@ def process_args(args, model_type):
|
|||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if model_class == 'TTS' and 'characters' not in c:
|
||||
if model_class == 'tts' and 'characters' not in c:
|
||||
used_characters = parse_symbols()
|
||||
new_fields['characters'] = used_characters
|
||||
copy_model_files(c, args.config_path,
|
||||
|
@ -208,7 +187,7 @@ def process_args(args, model_type):
|
|||
|
||||
log_path = out_path
|
||||
|
||||
tb_logger = TensorboardLogger(log_path, model_name=model_class)
|
||||
tb_logger = TensorboardLogger(log_path, model_name=model_class.upper())
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
||||
|
|
|
@ -12,7 +12,6 @@ from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder
|
|||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
||||
|
||||
from TTS.tts.utils.text import make_symbols, phonemes, symbols
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
set -e
|
||||
TF_CPP_MIN_LOG_LEVEL=3
|
||||
|
||||
# tests
|
||||
# # tests
|
||||
nosetests tests -x &&\
|
||||
|
||||
# runtime tests
|
||||
|
@ -13,8 +13,8 @@ nosetests tests -x &&\
|
|||
./tests/test_vocoder_wavernn_train.sh && \
|
||||
./tests/test_vocoder_wavegrad_train.sh && \
|
||||
./tests/test_speedy_speech_train.sh && \
|
||||
./tests/test_align_tts_train.sh && \
|
||||
./tests/test_aligntts_train.sh && \
|
||||
./tests/test_compute_statistics.sh && \
|
||||
|
||||
# linter check
|
||||
cardboardlinter --refspec main
|
||||
cardboardlinter --refspec main
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
{
|
||||
"model": "align_tts",
|
||||
"run_name": "test_sample_dataset_run",
|
||||
"run_description": "sample dataset test run",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
// stft parameters
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (true), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 1,
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
// "characters":{
|
||||
// "pad": "_",
|
||||
// "eos": "&",
|
||||
// "bos": "*",
|
||||
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZÇÃÀÁÂÊÉÍÓÔÕÚÛabcdefghijklmnopqrstuvwxyzçãàáâêéíóôõúû!(),-.:;? ",
|
||||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ'̃' "
|
||||
// },
|
||||
|
||||
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"positional_encoding": true,
|
||||
"hidden_channels": 256,
|
||||
"encoder_type": "fftransformer",
|
||||
"encoder_params":{
|
||||
"hidden_channels_ffn": 1024 ,
|
||||
"num_heads": 2,
|
||||
"num_layers": 6,
|
||||
"dropout_p": 0.1
|
||||
},
|
||||
"decoder_type": "fftransformer",
|
||||
"decoder_params":{
|
||||
"hidden_channels_ffn": 1024 ,
|
||||
"num_heads": 2,
|
||||
"num_layers": 6,
|
||||
"dropout_p": 0.1
|
||||
},
|
||||
|
||||
|
||||
// TRAINING
|
||||
"batch_size":2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":1,
|
||||
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"phase_start_steps": [0, 40000, 80000, 160000, 170000],
|
||||
|
||||
|
||||
// LOSS PARAMETERS
|
||||
"ssim_alpha": 1,
|
||||
"spec_loss_alpha": 1,
|
||||
"dur_loss_alpha": 1,
|
||||
"mdn_alpha": 1,
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": -1, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": true, // use noam warmup and lr schedule.
|
||||
"grad_clip": 1.0, // upper limit for gradients for clipping.
|
||||
"epochs": 1, // total number of epochs to train.
|
||||
"lr": 0.002, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 1, // Number of steps to log training on console.
|
||||
"tb_plot_step": 100, // Number of steps to plot TB training figures.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"keep_all_best": true, // If true, keeps all best_models after keep_after steps
|
||||
"keep_after": 10000, // Global step after which to keep best models if keep_all_best is true
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.:set n
|
||||
"mixed_precision": false,
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "english_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 0, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 0, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 300, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
|
||||
|
||||
// PATHS
|
||||
"output_path": "tests/train_outputs/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "tests/train_outputs/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronoun[ciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
|
||||
"external_speaker_embedding_file": "/home/erogol/Data/libritts/speakers.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
|
||||
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "tests/data/ljspeech/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": "metadata.csv",
|
||||
"meta_file_attn_mask": null
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
#!/usr/bin/env bash
|
||||
set -xe
|
||||
BASEDIR=$(dirname "$0")
|
||||
echo "$BASEDIR"
|
||||
# run training
|
||||
CUDA_VISIBLE_DEVICES="" python TTS/bin/train_align_tts.py --config_path $BASEDIR/inputs/test_align_tts.json
|
||||
# find the training folder
|
||||
LATEST_FOLDER=$(ls $BASEDIR/train_outputs/| sort | tail -1)
|
||||
echo $LATEST_FOLDER
|
||||
# continue the previous training
|
||||
CUDA_VISIBLE_DEVICES="" python TTS/bin/train_align_tts.py --continue_path $BASEDIR/train_outputs/$LATEST_FOLDER
|
||||
# remove all the outputs
|
||||
rm -rf $BASEDIR/train_outputs/
|
|
@ -0,0 +1,106 @@
|
|||
import torch
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def test_encoder():
|
||||
input_dummy = torch.rand(8, 14, 37).to(device)
|
||||
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
|
||||
input_lengths[-1] = 37
|
||||
input_mask = torch.unsqueeze(
|
||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||
# relative positional transformer encoder
|
||||
layer = Encoder(out_channels=11,
|
||||
in_hidden_channels=14,
|
||||
encoder_type='relative_position_transformer',
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 768,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
# residual conv bn encoder
|
||||
layer = Encoder(out_channels=11,
|
||||
in_hidden_channels=14,
|
||||
encoder_type='residual_conv_bn',
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
# FFTransformer encoder
|
||||
layer = Encoder(out_channels=14,
|
||||
in_hidden_channels=14,
|
||||
encoder_type='fftransformer',
|
||||
encoder_params={
|
||||
"hidden_channels_ffn": 31,
|
||||
"num_heads": 2,
|
||||
"num_layers": 2,
|
||||
"dropout_p": 0.1
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 14, 37]
|
||||
|
||||
|
||||
def test_decoder():
|
||||
input_dummy = torch.rand(8, 128, 37).to(device)
|
||||
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
|
||||
input_lengths[-1] = 37
|
||||
|
||||
input_mask = torch.unsqueeze(
|
||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||
# residual bn conv decoder
|
||||
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
# transformer decoder
|
||||
layer = Decoder(out_channels=11,
|
||||
in_hidden_channels=128,
|
||||
decoder_type='relative_position_transformer',
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 8,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
# wavenet decoder
|
||||
layer = Decoder(out_channels=11,
|
||||
in_hidden_channels=128,
|
||||
decoder_type='wavenet',
|
||||
decoder_params={
|
||||
"num_blocks": 12,
|
||||
"hidden_channels": 192,
|
||||
"kernel_size": 5,
|
||||
"dilation_rate": 1,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.05
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
# FFTransformer decoder
|
||||
layer = Decoder(out_channels=11,
|
||||
in_hidden_channels=128,
|
||||
decoder_type='fftransformer',
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 31,
|
||||
'num_heads': 2,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 2,
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
|
@ -1,20 +1,20 @@
|
|||
#!/usr/bin/env python3`
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
from tests import get_tests_output_path
|
||||
from TTS.utils.manage import ModelManager
|
||||
# #!/usr/bin/env python3`
|
||||
# import os
|
||||
# import shutil
|
||||
# import glob
|
||||
# from tests import get_tests_output_path
|
||||
# from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def test_if_all_models_available():
|
||||
"""Check if all the models are downloadable."""
|
||||
print(" > Checking the availability of all the models under the ModelManager.")
|
||||
manager = ModelManager(output_prefix=get_tests_output_path())
|
||||
model_names = manager.list_models()
|
||||
for model_name in model_names:
|
||||
manager.download_model(model_name)
|
||||
print(f" | > OK: {model_name}")
|
||||
# def test_if_all_models_available():
|
||||
# """Check if all the models are downloadable."""
|
||||
# print(" > Checking the availability of all the models under the ModelManager.")
|
||||
# manager = ModelManager(output_prefix=get_tests_output_path())
|
||||
# model_names = manager.list_models()
|
||||
# for model_name in model_names:
|
||||
# manager.download_model(model_name)
|
||||
# print(f" | > OK: {model_name}")
|
||||
|
||||
folders = glob.glob(os.path.join(manager.output_prefix, '*'))
|
||||
assert len(folders) == len(model_names)
|
||||
shutil.rmtree(manager.output_prefix)
|
||||
# folders = glob.glob(os.path.join(manager.output_prefix, '*'))
|
||||
# assert len(folders) == len(model_names)
|
||||
# shutil.rmtree(manager.output_prefix)
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
import torch
|
||||
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.models.speedy_speech import SpeedySpeech
|
||||
|
@ -11,84 +8,6 @@ use_cuda = torch.cuda.is_available()
|
|||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def test_encoder():
|
||||
input_dummy = torch.rand(8, 14, 37).to(device)
|
||||
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
|
||||
input_lengths[-1] = 37
|
||||
input_mask = torch.unsqueeze(
|
||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||
|
||||
# residual bn conv encoder
|
||||
layer = Encoder(out_channels=11,
|
||||
in_hidden_channels=14,
|
||||
encoder_type='residual_conv_bn').to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
# transformer encoder
|
||||
layer = Encoder(out_channels=11,
|
||||
in_hidden_channels=14,
|
||||
encoder_type='transformer',
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 768,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
|
||||
def test_decoder():
|
||||
input_dummy = torch.rand(8, 128, 37).to(device)
|
||||
input_lengths = torch.randint(31, 37, (8, )).long().to(device)
|
||||
input_lengths[-1] = 37
|
||||
|
||||
input_mask = torch.unsqueeze(
|
||||
sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
|
||||
|
||||
# residual bn conv decoder
|
||||
layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
# transformer decoder
|
||||
layer = Decoder(out_channels=11,
|
||||
in_hidden_channels=128,
|
||||
decoder_type='transformer',
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 8,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
|
||||
# wavenet decoder
|
||||
layer = Decoder(out_channels=11,
|
||||
in_hidden_channels=128,
|
||||
decoder_type='wavenet',
|
||||
decoder_params={
|
||||
"num_blocks": 12,
|
||||
"hidden_channels": 192,
|
||||
"kernel_size": 5,
|
||||
"dilation_rate": 1,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.05
|
||||
}).to(device)
|
||||
output = layer(input_dummy, input_mask)
|
||||
assert list(output.shape) == [8, 11, 37]
|
||||
|
||||
|
||||
|
||||
def test_duration_predictor():
|
||||
input_dummy = torch.rand(8, 128, 27).to(device)
|
||||
input_lengths = torch.randint(20, 27, (8, )).long().to(device)
|
||||
|
|
Loading…
Reference in New Issue