mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'glow-tts-amp-time_depth_conv' into dev
This commit is contained in:
commit
a6df617eb1
|
@ -129,3 +129,4 @@ TODO.txt
|
|||
.vscode/*
|
||||
data/*
|
||||
notebooks/data/*
|
||||
TTS/tts/layers/glow_tts/monotonic_align/core.c
|
||||
|
|
|
@ -0,0 +1,649 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
||||
reduce_tensor)
|
||||
from TTS.tts.utils.generic_utils import check_config, setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
||||
save_speaker_mapping)
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
||||
gradual_training_scheduler, set_weight_decay,
|
||||
setup_torch_training_env)
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=True if c.model.lower() == 'tacotron' else False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
batch_group_size=0 if is_val else c.batch_group_size *
|
||||
c.batch_size,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
verbose=verbose)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
if c.use_speaker_embedding:
|
||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
||||
mel_lengths = data[5]
|
||||
attn_mask = data[8]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
speaker_ids = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
else:
|
||||
speaker_ids = None
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
||||
avg_text_length, avg_spec_length, attn_mask
|
||||
|
||||
|
||||
def data_depended_init(model, ap):
|
||||
"""Data depended initialization for activation normalization."""
|
||||
if hasattr(model, 'module'):
|
||||
for f in model.module.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(True)
|
||||
else:
|
||||
for f in model.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(True)
|
||||
|
||||
data_loader = setup_loader(ap, 1, is_val=False)
|
||||
model.train()
|
||||
print(" > Data depended initialization ... ")
|
||||
with torch.no_grad():
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
_ = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
|
||||
break
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
for f in model.module.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(False)
|
||||
else:
|
||||
for f in model.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(False)
|
||||
return model
|
||||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch, amp):
|
||||
data_loader = setup_loader(ap, 1, is_val=False,
|
||||
verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if c.noam_schedule:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# backward pass
|
||||
if amp is not None:
|
||||
with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
|
||||
if amp:
|
||||
amp_opt_params = amp.master_params(optimizer)
|
||||
else:
|
||||
amp_opt_params = None
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
|
||||
optimizer.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict['align_error'] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % c.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
model_loss=loss_dict['loss'],
|
||||
amp_state_dict=amp.state_dict() if amp else None)
|
||||
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
const_spec = spec_pred[0].data.cpu().numpy()
|
||||
gt_spec = gt_spec[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
if c.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, 1, is_val=True)
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
c_logger.print_eval_start()
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict['align_error'] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
if hasattr(model, 'module'):
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1])
|
||||
else:
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
|
||||
const_spec = spec_pred[0].data.cpu().numpy()
|
||||
gt_spec = gt_spec[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img)
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch >= c.test_delay_epochs:
|
||||
if c.test_sentences_file is None:
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
]
|
||||
else:
|
||||
with open(c.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
|
||||
# test sentences
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
speaker_id = 0 if c.use_speaker_embedding else None
|
||||
style_wav = c.get("style_wav_for_test")
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis(
|
||||
model,
|
||||
test_sentence,
|
||||
c,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||
postnet_output, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment)
|
||||
except:
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios,
|
||||
c.audio['sample_rate'])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
||||
# set the portion of the data used for training
|
||||
if 'train_portion' in c.keys():
|
||||
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
||||
if 'eval_portion' in c.keys():
|
||||
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
if c.use_speaker_embedding:
|
||||
speakers = get_speakers(meta_data_train)
|
||||
if args.restore_path:
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
assert all([speaker in speaker_mapping
|
||||
for speaker in speakers]), "As of now you, you cannot " \
|
||||
"introduce new speakers to " \
|
||||
"a previously trained model."
|
||||
else:
|
||||
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
num_speakers = len(speaker_mapping)
|
||||
print("Training with {} speakers: {}".format(num_speakers,
|
||||
", ".join(speakers)))
|
||||
else:
|
||||
num_speakers = 0
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, c)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = GlowTTSLoss()
|
||||
|
||||
if c.apex_amp_level:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
model.cuda()
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||
else:
|
||||
amp = None
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
except:
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
if amp and 'amp' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp'])
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['initial_lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = DDP(model)
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if 'best_loss' not in locals():
|
||||
best_loss = float('inf')
|
||||
|
||||
global_step = args.restore_step
|
||||
model = data_depended_init(model, ap)
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch, amp)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv
|
||||
)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.apex_amp_level:
|
||||
print(" > apex AMP level: ", c.apex_amp_level)
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_config_file(args.config_path,
|
||||
os.path.join(OUT_PATH, 'config.json'), new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -0,0 +1,132 @@
|
|||
{
|
||||
"model": "glow_tts",
|
||||
"run_name": "glow-tts-gatedconv",
|
||||
"run_description": "glow-tts model training with gated conv.",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.1, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
// "characters":{
|
||||
// "pad": "_",
|
||||
// "eos": "~",
|
||||
// "bos": "^",
|
||||
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
// },
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": true, // use noam warmup and lr schedule.
|
||||
"grad_clip": 5.0, // upper limit for gradients for clipping.
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||
|
||||
"encoder_type": "gatedconv",
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log training on console.
|
||||
"tb_plot_step": 100, // Number of steps to plot TB training figures.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"apex_amp_level": null,
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 500, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference.
|
||||
"use_gst": false, // TACOTRON ONLY: use global style tokens
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null
|
||||
// "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
{
|
||||
"model": "glow_tts",
|
||||
"run_name": "glow-tts-tdsep-conv",
|
||||
"run_description": "glow-tts model training with time-depth separable conv encoder.",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.1, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
// "characters":{
|
||||
// "pad": "_",
|
||||
// "eos": "~",
|
||||
// "bos": "^",
|
||||
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
|
||||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
|
||||
// },
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 0, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": true, // use noam warmup and lr schedule.
|
||||
"grad_clip": 5.0, // upper limit for gradients for clipping.
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"lr": 1e-3, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths.
|
||||
|
||||
"encoder_type": "time-depth-separable",
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log training on console.
|
||||
"tb_plot_step": 100, // Number of steps to plot TB training figures.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"apex_amp_level": null,
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 3, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 500, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "/home/erogol/Models/phoneme_cache/", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference.
|
||||
"use_gst": false, // TACOTRON ONLY: use global style tokens
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null
|
||||
// "path_for_attn": "/home/erogol/Data/LJSpeech-1.1/alignments/"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
|
@ -113,7 +113,14 @@ class MyDataset(Dataset):
|
|||
return phonemes
|
||||
|
||||
def load_data(self, idx):
|
||||
text, wav_file, speaker_name = self.items[idx]
|
||||
item = self.items[idx]
|
||||
|
||||
if len(item) == 4:
|
||||
text, wav_file, speaker_name, attn_file = item
|
||||
else:
|
||||
text, wav_file, speaker_name = item
|
||||
attn = None
|
||||
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
|
||||
if self.use_phonemes:
|
||||
|
@ -125,9 +132,13 @@ class MyDataset(Dataset):
|
|||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
|
||||
if "attn_file" in locals():
|
||||
attn = np.load(attn_file)
|
||||
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'attn': attn,
|
||||
'item_idx': self.items[idx][1],
|
||||
'speaker_name': speaker_name,
|
||||
'wav_file_name': os.path.basename(wav_file)
|
||||
|
@ -245,8 +256,21 @@ class MyDataset(Dataset):
|
|||
linear = torch.FloatTensor(linear).contiguous()
|
||||
else:
|
||||
linear = None
|
||||
|
||||
# collate attention alignments
|
||||
if batch[0]['attn'] is not None:
|
||||
attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing]
|
||||
for idx, attn in enumerate(attns):
|
||||
pad2 = mel.shape[1] - attn.shape[1]
|
||||
pad1 = text.shape[1] - attn.shape[0]
|
||||
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
||||
attns[idx] = attn
|
||||
attns = prepare_tensor(attns, self.outputs_per_step)
|
||||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||
else:
|
||||
attns = None
|
||||
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||
stop_targets, item_idxs, speaker_embedding
|
||||
stop_targets, item_idxs, speaker_embedding, attns
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock
|
||||
from TTS.tts.layers.glow_tts.normalization import ActNorm
|
||||
|
||||
|
||||
def squeeze(x, x_mask=None, num_sqz=2):
|
||||
b, c, t = x.size()
|
||||
|
||||
t = (t // num_sqz) * num_sqz
|
||||
x = x[:, :, :t]
|
||||
x_sqz = x.view(b, c, t // num_sqz, num_sqz)
|
||||
x_sqz = x_sqz.permute(0, 3, 1,
|
||||
2).contiguous().view(b, c * num_sqz, t // num_sqz)
|
||||
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask[:, :, num_sqz - 1::num_sqz]
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
return x_sqz * x_mask, x_mask
|
||||
|
||||
|
||||
def unsqueeze(x, x_mask=None, num_sqz=2):
|
||||
b, c, t = x.size()
|
||||
|
||||
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
|
||||
x_unsqz = x_unsqz.permute(0, 2, 3,
|
||||
1).contiguous().view(b, c // num_sqz,
|
||||
t * num_sqz)
|
||||
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1,
|
||||
num_sqz).view(b, 1, t * num_sqz)
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
return x_unsqz * x_mask, x_mask
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Stack of Glow Modules"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_flow_blocks,
|
||||
num_coupling_layers,
|
||||
dropout_p=0.,
|
||||
num_splits=4,
|
||||
num_sqz=2,
|
||||
sigmoid_scale=False,
|
||||
c_in_channels=0,
|
||||
feat_channels=None):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_flow_blocks = num_flow_blocks
|
||||
self.num_coupling_layers = num_coupling_layers
|
||||
self.dropout_p = dropout_p
|
||||
self.num_splits = num_splits
|
||||
self.num_sqz = num_sqz
|
||||
self.sigmoid_scale = sigmoid_scale
|
||||
self.c_in_channels = c_in_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for _ in range(num_flow_blocks):
|
||||
self.flows.append(ActNorm(channels=in_channels * num_sqz))
|
||||
self.flows.append(
|
||||
InvConvNear(channels=in_channels * num_sqz,
|
||||
num_splits=num_splits))
|
||||
self.flows.append(
|
||||
CouplingBlock(in_channels * num_sqz,
|
||||
hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation_rate=dilation_rate,
|
||||
num_layers=num_coupling_layers,
|
||||
c_in_channels=c_in_channels,
|
||||
dropout_p=dropout_p,
|
||||
sigmoid_scale=sigmoid_scale))
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
logdet_tot = 0
|
||||
else:
|
||||
flows = reversed(self.flows)
|
||||
logdet_tot = None
|
||||
|
||||
if self.num_sqz > 1:
|
||||
x, x_mask = squeeze(x, x_mask, self.num_sqz)
|
||||
for f in flows:
|
||||
if not reverse:
|
||||
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
||||
logdet_tot += logdet
|
||||
else:
|
||||
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
||||
if self.num_sqz > 1:
|
||||
x, x_mask = unsqueeze(x, x_mask, self.num_sqz)
|
||||
return x, logdet_tot
|
||||
|
||||
def store_inverse(self):
|
||||
for f in self.flows:
|
||||
f.store_inverse()
|
|
@ -0,0 +1,40 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, dropout_p):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_p = dropout_p
|
||||
# layers
|
||||
self.drop = nn.Dropout(dropout_p)
|
||||
self.conv_1 = nn.Conv1d(in_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(filter_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
# output layer
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
|
@ -0,0 +1,145 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Glow-TTS encoder module. It uses Transformer with Relative Pos.Encoding
|
||||
as in the original paper or GatedConvBlock as a faster alternative.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): encoder's embedding size.
|
||||
filter_channels (int): transformer's feed-forward channels.
|
||||
num_head (int): number of attention heads in transformer.
|
||||
num_layers (int): number of transformer encoder stack.
|
||||
kernel_size (int): kernel size for conv layers and duration predictor.
|
||||
dropout_p (float): dropout rate for any dropout layer.
|
||||
mean_only (bool): if True, output only mean values and use constant std.
|
||||
use_prenet (bool): if True, use pre-convolutional layers before transformer layers.
|
||||
c_in_channels (int): number of channels in conditional input.
|
||||
|
||||
Shapes:
|
||||
- input: (B, T, C)
|
||||
"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
filter_channels_dp,
|
||||
encoder_type,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size,
|
||||
dropout_p,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
use_prenet=True,
|
||||
c_in_channels=0):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.num_chars = num_chars
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.filter_channels_dp = filter_channels_dp
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_p = dropout_p
|
||||
self.mean_only = mean_only
|
||||
self.use_prenet = use_prenet
|
||||
self.c_in_channels = c_in_channels
|
||||
self.encoder_type = encoder_type
|
||||
# embedding layer
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
# init encoder
|
||||
if encoder_type.lower() == "transformer":
|
||||
# optional convolutional prenet
|
||||
if use_prenet:
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
# text encoder
|
||||
self.encoder = Transformer(
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
input_length=input_length)
|
||||
elif encoder_type.lower() == 'gatedconv':
|
||||
self.encoder = GatedConvBlock(hidden_channels,
|
||||
kernel_size=5,
|
||||
dropout_p=dropout_p,
|
||||
num_layers=3 + num_layers)
|
||||
elif encoder_type.lower() == 'time-depth-separable':
|
||||
# optional convolutional prenet
|
||||
if use_prenet:
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3 + num_layers)
|
||||
|
||||
# final projection layers
|
||||
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
if not mean_only:
|
||||
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
# duration predictor
|
||||
self.duration_predictor = DurationPredictor(
|
||||
hidden_channels + c_in_channels, filter_channels_dp, kernel_size,
|
||||
dropout_p)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
# embedding layer
|
||||
# [B ,T, D]
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels)
|
||||
# [B, D, T]
|
||||
x = torch.transpose(x, 1, -1)
|
||||
# compute input sequence mask
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
|
||||
1).to(x.dtype)
|
||||
# pre-conv layers
|
||||
if self.encoder_type in ['transformer', 'time-depth-separable']:
|
||||
if self.use_prenet:
|
||||
x = self.pre(x, x_mask)
|
||||
# encoder
|
||||
x = self.encoder(x, x_mask)
|
||||
# set duration predictor input
|
||||
if g is not None:
|
||||
g_exp = g.expand(-1, -1, x.size(-1))
|
||||
x_dp = torch.cat([torch.detach(x), g_exp], 1)
|
||||
else:
|
||||
x_dp = torch.detach(x)
|
||||
# final projection layer
|
||||
x_m = self.proj_m(x) * x_mask
|
||||
if not self.mean_only:
|
||||
x_logs = self.proj_s(x) * x_mask
|
||||
else:
|
||||
x_logs = torch.zeros_like(x_m)
|
||||
# duration predictor
|
||||
logw = self.duration_predictor(x_dp, x_mask)
|
||||
return x_m, x_logs, logw, x_mask
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class GatedConvBlock(nn.Module):
|
||||
"""Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf
|
||||
Args:
|
||||
in_out_channels (int): number of input/output channels.
|
||||
kernel_size (int): convolution kernel size.
|
||||
dropout_p (float): dropout rate.
|
||||
"""
|
||||
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.dropout_p = dropout_p
|
||||
self.num_layers = num_layers
|
||||
# define layers
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.conv_layers += [
|
||||
nn.Conv1d(in_out_channels,
|
||||
2 * in_out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
]
|
||||
self.norm_layers += [LayerNorm(2 * in_out_channels)]
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
o = x
|
||||
res = x
|
||||
for idx in range(self.num_layers):
|
||||
o = nn.functional.dropout(o,
|
||||
p=self.dropout_p,
|
||||
training=self.training)
|
||||
o = self.conv_layers[idx](o * x_mask)
|
||||
o = self.norm_layers[idx](o)
|
||||
o = nn.functional.glu(o, dim=1)
|
||||
o = res + o
|
||||
res = o
|
||||
return o
|
|
@ -0,0 +1,334 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class ConvLayerNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
||||
num_layers, dropout_p):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout_p = dropout_p
|
||||
assert num_layers > 1, " [!] number of layers should be > 0."
|
||||
assert kernel_size % 2 == 1, " [!] kernel size should be odd number."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
|
||||
for _ in range(num_layers - 1):
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_res = x
|
||||
for i in range(self.num_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x * x_mask)
|
||||
x = F.dropout(F.relu(x), self.dropout_p, training=self.training)
|
||||
x = x_res + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0):
|
||||
super(WN, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
assert hidden_channels % 2 == 0
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_layers = num_layers
|
||||
self.c_in_channels = c_in_channels
|
||||
self.dropout_p = dropout_p
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
if c_in_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(c_in_channels,
|
||||
2 * hidden_channels * num_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
|
||||
name='weight')
|
||||
|
||||
for i in range(num_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(hidden_channels,
|
||||
2 * hidden_channels,
|
||||
kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
if i < num_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels,
|
||||
res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
|
||||
name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
x_in = self.dropout(x_in)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:,
|
||||
cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
|
||||
n_channels_tensor)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.num_layers - 1:
|
||||
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.c_in_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
|
||||
class ActNorm(nn.Module):
|
||||
"""Activation Normalization bijector as an alternative to Batch Norm. It computes
|
||||
mean and std from a sample data in advance and it uses these values
|
||||
for normalization at training.
|
||||
|
||||
Args:
|
||||
channels (int): input channels.
|
||||
ddi (False): data depended initialization flag.
|
||||
|
||||
Shapes:
|
||||
- inputs: (B, C, T)
|
||||
- outputs: (B, C, T)
|
||||
"""
|
||||
|
||||
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.initialized = not ddi
|
||||
|
||||
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
if not self.initialized:
|
||||
self.initialize(x, x_mask)
|
||||
self.initialized = True
|
||||
|
||||
if reverse:
|
||||
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
||||
logdet = None
|
||||
else:
|
||||
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
||||
logdet = torch.sum(self.logs) * x_len # [b]
|
||||
|
||||
return z, logdet
|
||||
|
||||
def store_inverse(self):
|
||||
pass
|
||||
|
||||
def set_ddi(self, ddi):
|
||||
self.initialized = not ddi
|
||||
|
||||
def initialize(self, x, x_mask):
|
||||
with torch.no_grad():
|
||||
denom = torch.sum(x_mask, [0, 2])
|
||||
m = torch.sum(x * x_mask, [0, 2]) / denom
|
||||
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
||||
v = m_sq - (m**2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(
|
||||
dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(
|
||||
dtype=self.logs.dtype)
|
||||
|
||||
self.bias.data.copy_(bias_init)
|
||||
self.logs.data.copy_(logs_init)
|
||||
|
||||
|
||||
class InvConvNear(nn.Module):
|
||||
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
assert num_splits % 2 == 0
|
||||
self.channels = channels
|
||||
self.num_splits = num_splits
|
||||
self.no_jacobian = no_jacobian
|
||||
self.weight_inv = None
|
||||
|
||||
w_init = torch.qr(
|
||||
torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
self.weight = nn.Parameter(w_init)
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
"""Split the input into groups of size self.num_splits and
|
||||
perform 1x1 convolution separately. Cast 1x1 conv operation
|
||||
to 2d by reshaping the input for efficienty.
|
||||
"""
|
||||
|
||||
b, c, t = x.size()
|
||||
assert c % self.num_splits == 0
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t
|
||||
else:
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
|
||||
x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t)
|
||||
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits,
|
||||
c // self.num_splits, t)
|
||||
|
||||
if reverse:
|
||||
if self.weight_inv is not None:
|
||||
weight = self.weight_inv
|
||||
else:
|
||||
weight = torch.inverse(
|
||||
self.weight.float()).to(dtype=self.weight.dtype)
|
||||
logdet = None
|
||||
else:
|
||||
weight = self.weight
|
||||
if self.no_jacobian:
|
||||
logdet = 0
|
||||
else:
|
||||
logdet = torch.logdet(
|
||||
self.weight) * (c / self.num_splits) * x_len # [b]
|
||||
|
||||
weight = weight.view(self.num_splits, self.num_splits, 1, 1)
|
||||
z = F.conv2d(x, weight)
|
||||
|
||||
z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t)
|
||||
z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
|
||||
return z, logdet
|
||||
|
||||
def store_inverse(self):
|
||||
self.weight_inv = torch.inverse(
|
||||
self.weight.float()).to(dtype=self.weight.dtype)
|
||||
|
||||
|
||||
class CouplingBlock(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
sigmoid_scale=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_layers = num_layers
|
||||
self.c_in_channels = c_in_channels
|
||||
self.dropout_p = dropout_p
|
||||
self.sigmoid_scale = sigmoid_scale
|
||||
|
||||
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
||||
start = torch.nn.utils.weight_norm(start)
|
||||
self.start = start
|
||||
# Initializing last layer to 0 makes the affine coupling layers
|
||||
# do nothing at first. This helps with training stability
|
||||
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
||||
end.weight.data.zero_()
|
||||
end.bias.data.zero_()
|
||||
self.end = end
|
||||
|
||||
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate,
|
||||
num_layers, c_in_channels, dropout_p)
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
||||
|
||||
x = self.start(x_0) * x_mask
|
||||
x = self.wn(x, x_mask, g)
|
||||
out = self.end(x)
|
||||
|
||||
z_0 = x_0
|
||||
m = out[:, :self.in_channels // 2, :]
|
||||
logs = out[:, self.in_channels // 2:, :]
|
||||
if self.sigmoid_scale:
|
||||
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
||||
|
||||
if reverse:
|
||||
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
||||
logdet = None
|
||||
else:
|
||||
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
||||
logdet = torch.sum(logs * x_mask, [1, 2])
|
||||
|
||||
z = torch.cat([z_0, z_1], 1)
|
||||
return z, logdet
|
||||
|
||||
def store_inverse(self):
|
||||
self.wn.remove_weight_norm()
|
|
@ -0,0 +1,49 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, t_x]
|
||||
mask: [b, t_x, t_y]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, t_x, t_y = mask.shape
|
||||
cum_duration = torch.cumsum(duration, 1)
|
||||
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]
|
||||
]))[:, :-1]
|
||||
path = path * mask
|
||||
return path
|
||||
|
||||
|
||||
def maximum_path(value, mask):
|
||||
""" Cython optimised version.
|
||||
value: [b, t_x, t_y]
|
||||
mask: [b, t_x, t_y]
|
||||
"""
|
||||
value = value * mask
|
||||
device = value.device
|
||||
dtype = value.dtype
|
||||
value = value.data.cpu().numpy().astype(np.float32)
|
||||
path = np.zeros_like(value).astype(np.int32)
|
||||
mask = mask.data.cpu().numpy()
|
||||
|
||||
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
||||
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
||||
maximum_path_c(path, value, t_x_max, t_y_max)
|
||||
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
|
@ -0,0 +1,45 @@
|
|||
import numpy as np
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
from cython.parallel import prange
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
|
||||
cdef int x
|
||||
cdef int y
|
||||
cdef float v_prev
|
||||
cdef float v_cur
|
||||
cdef float tmp
|
||||
cdef int index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[x, y-1]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[x-1, y-1]
|
||||
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[index, y] = 1
|
||||
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
|
||||
cdef int b = values.shape[0]
|
||||
|
||||
cdef int i
|
||||
for i in prange(b, nogil=True):
|
||||
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
|
@ -0,0 +1,9 @@
|
|||
from distutils.core import setup
|
||||
from Cython.Build import cythonize
|
||||
import numpy
|
||||
|
||||
setup(
|
||||
name = 'monotonic_align',
|
||||
ext_modules = cythonize("core.pyx"),
|
||||
include_dirs=[numpy.get_include()]
|
||||
)
|
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-4):
|
||||
"""Layer norm for the 2nd dimension of the input.
|
||||
Args:
|
||||
channels (int): number of channels (2nd dimension) of the input.
|
||||
eps (float): to prevent 0 division
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
- output: (B, C, T)
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1)
|
||||
self.beta = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
x = x * self.gamma + self.beta
|
||||
return x
|
||||
|
||||
|
||||
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||||
"""Normalize each channel separately over time and batch.
|
||||
"""
|
||||
def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1):
|
||||
super(TemporalBatchNorm1d, self).__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.transpose(2,1)).transpose(2,1)
|
||||
|
||||
|
||||
class ActNorm(nn.Module):
|
||||
"""Activation Normalization bijector as an alternative to Batch Norm. It computes
|
||||
mean and std from a sample data in advance and it uses these values
|
||||
for normalization at training.
|
||||
|
||||
Args:
|
||||
channels (int): input channels.
|
||||
ddi (False): data depended initialization flag.
|
||||
|
||||
Shapes:
|
||||
- inputs: (B, C, T)
|
||||
- outputs: (B, C, T)
|
||||
"""
|
||||
|
||||
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.initialized = not ddi
|
||||
|
||||
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
if not self.initialized:
|
||||
self.initialize(x, x_mask)
|
||||
self.initialized = True
|
||||
|
||||
if reverse:
|
||||
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
||||
logdet = None
|
||||
else:
|
||||
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
||||
logdet = torch.sum(self.logs) * x_len # [b]
|
||||
|
||||
return z, logdet
|
||||
|
||||
def store_inverse(self):
|
||||
pass
|
||||
|
||||
def set_ddi(self, ddi):
|
||||
self.initialized = not ddi
|
||||
|
||||
def initialize(self, x, x_mask):
|
||||
with torch.no_grad():
|
||||
denom = torch.sum(x_mask, [0, 2])
|
||||
m = torch.sum(x * x_mask, [0, 2]) / denom
|
||||
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
||||
v = m_sq - (m**2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(
|
||||
dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(
|
||||
dtype=self.logs.dtype)
|
||||
|
||||
self.bias.data.copy_(bias_init)
|
||||
self.logs.data.copy_(logs_init)
|
|
@ -0,0 +1,94 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class TimeDepthSeparableConv(nn.Module):
|
||||
"""Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf
|
||||
It shows competative results with less computation and memory footprint."""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=True):
|
||||
super(TimeDepthSeparableConv, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hid_channels = hid_channels
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.time_conv = nn.Conv1d(
|
||||
in_channels,
|
||||
2 * hid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm1 = nn.BatchNorm1d(2 * hid_channels)
|
||||
self.depth_conv = nn.Conv1d(
|
||||
hid_channels,
|
||||
hid_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=hid_channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm2 = nn.BatchNorm1d(hid_channels)
|
||||
self.time_conv2 = nn.Conv1d(
|
||||
hid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm3 = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x_res = x
|
||||
x = self.time_conv(x)
|
||||
x = self.norm1(x)
|
||||
x = nn.functional.glu(x, dim=1)
|
||||
x = self.depth_conv(x)
|
||||
x = self.norm2(x)
|
||||
x = x * torch.sigmoid(x)
|
||||
x = self.time_conv2(x)
|
||||
x = self.norm3(x)
|
||||
x = x_res + x
|
||||
return x
|
||||
|
||||
|
||||
class TimeDepthSeparableConvBlock(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
num_layers,
|
||||
kernel_size,
|
||||
bias=True):
|
||||
super(TimeDepthSeparableConvBlock, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
assert num_layers > 1
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
layer = TimeDepthSeparableConv(
|
||||
in_channels, hid_channels,
|
||||
out_channels if num_layers == 1 else hid_channels, kernel_size,
|
||||
bias)
|
||||
self.layers.append(layer)
|
||||
for idx in range(num_layers - 1):
|
||||
layer = TimeDepthSeparableConv(
|
||||
hid_channels, hid_channels, out_channels if
|
||||
(idx + 1) == (num_layers - 1) else hid_channels, kernel_size,
|
||||
bias)
|
||||
self.layers.append(layer)
|
||||
|
||||
def forward(self, x, mask):
|
||||
for layer in self.layers:
|
||||
x = layer(x * mask)
|
||||
return x
|
|
@ -0,0 +1,320 @@
|
|||
import copy
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.glow import LayerNorm
|
||||
|
||||
|
||||
class RelativePositionMultiHeadAttention(nn.Module):
|
||||
"""Implementation of Relative Position Encoding based on
|
||||
https://arxiv.org/pdf/1809.04281.pdf
|
||||
"""
|
||||
def __init__(self,
|
||||
channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_attn_window_size=None,
|
||||
heads_share=True,
|
||||
dropout_p=0.,
|
||||
input_length=None,
|
||||
proximal_bias=False,
|
||||
proximal_init=False):
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0, " [!] channels should be divisible by num_heads."
|
||||
# class attributes
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.num_heads = num_heads
|
||||
self.rel_attn_window_size = rel_attn_window_size
|
||||
self.heads_share = heads_share
|
||||
self.input_length = input_length
|
||||
self.proximal_bias = proximal_bias
|
||||
self.dropout_p = dropout_p
|
||||
self.attn = None
|
||||
# query, key, value layers
|
||||
self.k_channels = channels // num_heads
|
||||
self.conv_q = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = nn.Conv1d(channels, channels, 1)
|
||||
# output layers
|
||||
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
# relative positional encoding layers
|
||||
if rel_attn_window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else num_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1,
|
||||
self.k_channels) * rel_stddev)
|
||||
emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1,
|
||||
self.k_channels) * rel_stddev)
|
||||
self.register_parameter('emb_rel_k', emb_rel_k)
|
||||
self.register_parameter('emb_rel_v', emb_rel_v)
|
||||
|
||||
# init layers
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
# proximal bias
|
||||
if proximal_init:
|
||||
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
||||
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
||||
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.num_heads, self.k_channels,
|
||||
t_t).transpose(2, 3)
|
||||
key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.num_heads, self.k_channels,
|
||||
t_s).transpose(2, 3)
|
||||
# compute raw attention scores
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
|
||||
self.k_channels)
|
||||
# relative positional encoding
|
||||
if self.rel_attn_window_size is not None:
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
# get relative key embeddings
|
||||
key_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(
|
||||
query, key_relative_embeddings)
|
||||
rel_logits = self._relative_position_to_absolute_position(
|
||||
rel_logits)
|
||||
scores_local = rel_logits / math.sqrt(self.k_channels)
|
||||
scores = scores + scores_local
|
||||
# proximan bias
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attn_proximity_bias(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype)
|
||||
# attention score masking
|
||||
if mask is not None:
|
||||
# add small value to prevent oor error.
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.input_length is not None:
|
||||
block_mask = torch.ones_like(scores).triu(
|
||||
-self.input_length).tril(self.input_length)
|
||||
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
||||
# attention score normalization
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
# apply dropout to attention weights
|
||||
p_attn = self.dropout(p_attn)
|
||||
# compute output
|
||||
output = torch.matmul(p_attn, value)
|
||||
# relative positional encoding for values
|
||||
if self.rel_attn_window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(
|
||||
p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(
|
||||
relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(
|
||||
b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, p_attn, re):
|
||||
"""
|
||||
Args:
|
||||
p_attn (Tensor): attention weights.
|
||||
re (Tensor): relative value embedding vector. (a_(i,j)^V)
|
||||
|
||||
Shapes:
|
||||
p_attn: [B, H, T, V]
|
||||
re: [H or 1, V, D]
|
||||
logits: [B, H, T, D]
|
||||
"""
|
||||
logits = torch.matmul(p_attn, re.unsqueeze(0))
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _matmul_with_relative_keys(query, re):
|
||||
"""
|
||||
Args:
|
||||
query (Tensor): batch of query vectors. (x*W^Q)
|
||||
re (Tensor): relative key embedding vector. (a_(i,j)^K)
|
||||
|
||||
Shapes:
|
||||
query: [B, H, T, D]
|
||||
re: [H or 1, V, D]
|
||||
logits: [B, H, T, V]
|
||||
"""
|
||||
# logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)])
|
||||
logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1))
|
||||
return logits
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
"""Convert embedding vestors to a tensor of embeddings
|
||||
"""
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length = max(length - (self.rel_attn_window_size + 1), 0)
|
||||
slice_start_position = max((self.rel_attn_window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[:,
|
||||
slice_start_position:
|
||||
slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_to_absolute_position(x):
|
||||
"""Converts tensor from relative to absolute indexing for local attention.
|
||||
Args:
|
||||
x: [B, D, length, 2 * length - 1]
|
||||
Returns:
|
||||
A Tensor of shape [B, D, length, length]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Pad to shift from relative to absolute indexing.
|
||||
x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0])
|
||||
# Pad extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0])
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1,
|
||||
2 * length - 1])[:, :, :length, length - 1:]
|
||||
return x_final
|
||||
|
||||
@staticmethod
|
||||
def _absolute_position_to_relative_position(x):
|
||||
"""
|
||||
x: [B, H, T, T]
|
||||
ret: [B, H, T, 2*T-1]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
return x_final
|
||||
|
||||
@staticmethod
|
||||
def _attn_proximity_bias(length):
|
||||
"""Produce an attention mask that discourages distant
|
||||
attention values.
|
||||
Args:
|
||||
length (int): an integer scalar.
|
||||
Returns:
|
||||
a Tensor with shape [1, 1, length, length]
|
||||
"""
|
||||
# L
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
# L x L
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
# scale mask values
|
||||
diff = -torch.log1p(torch.abs(diff))
|
||||
# 1 x 1 x L x L
|
||||
return diff.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
dropout_p=0.,
|
||||
activation=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_p = dropout_p
|
||||
self.activation = activation
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.conv_2 = nn.Conv1d(filter_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
if self.activation == "gelu":
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size=1,
|
||||
dropout_p=0.,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_p = dropout_p
|
||||
self.rel_attn_window_size = rel_attn_window_size
|
||||
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for _ in range(self.num_layers):
|
||||
self.attn_layers.append(
|
||||
RelativePositionMultiHeadAttention(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
num_heads,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
dropout_p=dropout_p,
|
||||
input_length=input_length))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(hidden_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
dropout_p=dropout_p))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
for i in range(self.num_layers):
|
||||
x = x * x_mask
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.dropout(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.dropout(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -150,7 +151,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _make_ga_mask(ilen, olen, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=ilen.device))
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
|
||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2)))
|
||||
|
||||
|
@ -243,3 +244,27 @@ class TacotronLoss(torch.nn.Module):
|
|||
|
||||
return_dict['loss'] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class GlowTTSLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(GlowTTSLoss, self).__init__()
|
||||
self.constant_factor = 0.5 * math.log(2 * math.pi)
|
||||
|
||||
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log,
|
||||
o_attn_dur, x_lengths):
|
||||
return_dict = {}
|
||||
# flow loss - neg log likelihood
|
||||
pz = torch.sum(scales) + 0.5 * torch.sum(
|
||||
torch.exp(-2 * scales) * (z - means)**2)
|
||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (
|
||||
torch.sum(y_lengths // 2) * 2 * z.shape[1])
|
||||
# duration loss - MSE
|
||||
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
||||
# duration loss - huber loss
|
||||
loss_dur = torch.nn.functional.smooth_l1_loss(
|
||||
o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths)
|
||||
return_dict['loss'] = log_mle + loss_dur
|
||||
return_dict['log_mle'] = log_mle
|
||||
return_dict['loss_dur'] = loss_dur
|
||||
return return_dict
|
|
@ -0,0 +1,185 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
|
||||
|
||||
|
||||
class GlowTts(nn.Module):
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
filter_channels_dp,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=1,
|
||||
sigmoid_scale=False,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer"):
|
||||
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.filter_channels_dp = filter_channels_dp
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.num_heads = num_heads
|
||||
self.num_layers_enc = num_layers_enc
|
||||
self.dropout_p = dropout_p
|
||||
self.num_flow_blocks_dec = num_flow_blocks_dec
|
||||
self.kernel_size_dec = kernel_size_dec
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_block_layers = num_block_layers
|
||||
self.dropout_p_dec = dropout_p_dec
|
||||
self.num_speakers = num_speakers
|
||||
self.c_in_channels = c_in_channels
|
||||
self.num_splits = num_splits
|
||||
self.num_sqz = num_sqz
|
||||
self.sigmoid_scale = sigmoid_scale
|
||||
self.rel_attn_window_size = rel_attn_window_size
|
||||
self.input_length = input_length
|
||||
self.mean_only = mean_only
|
||||
self.hidden_channels_enc = hidden_channels_enc
|
||||
self.hidden_channels_dec = hidden_channels_dec
|
||||
self.use_encoder_prenet = use_encoder_prenet
|
||||
self.noise_scale=0.66
|
||||
self.length_scale=1.
|
||||
|
||||
self.encoder = Encoder(num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
filter_channels=filter_channels,
|
||||
filter_channels_dp=filter_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers_enc,
|
||||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
c_in_channels=c_in_channels)
|
||||
|
||||
self.decoder = Decoder(out_channels,
|
||||
hidden_channels_dec or hidden_channels,
|
||||
kernel_size_dec,
|
||||
dilation_rate,
|
||||
num_flow_blocks_dec,
|
||||
num_block_layers,
|
||||
dropout_p=dropout_p_dec,
|
||||
num_splits=num_splits,
|
||||
num_sqz=num_sqz,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
c_in_channels=c_in_channels)
|
||||
|
||||
if num_speakers > 1:
|
||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
def compute_outputs(self, attn, o_mean, o_log_scale, x_mask):
|
||||
# compute final values with the computed alignment
|
||||
y_mean = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
y_log_scale = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
|
||||
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
# compute total duration with adjustment
|
||||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
||||
def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: B x T
|
||||
x_lenghts: B
|
||||
y: B x C x T
|
||||
y_lengths: B
|
||||
"""
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# format feature vectors and feature vector lenghts
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||
# create masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
# find the alignment path
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||
attn = maximum_path(logp,
|
||||
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, g=None):
|
||||
if g is not None:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# compute output durations
|
||||
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = None
|
||||
# compute masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# compute attention mask
|
||||
attn = generate_path(w_ceil.squeeze(1),
|
||||
attn_mask.squeeze(1)).unsqueeze(1)
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
|
||||
self.noise_scale) * y_mask
|
||||
# decoder pass
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
|
||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||
if y_max_length is not None:
|
||||
y_max_length = (y_max_length // self.num_sqz) * self.num_sqz
|
||||
y = y[:, :, :y_max_length]
|
||||
if attn is not None:
|
||||
attn = attn[:, :, :, :y_max_length]
|
||||
y_lengths = (y_lengths // self.num_sqz) * self.num_sqz
|
||||
return y, y_lengths, y_max_length, attn
|
||||
|
||||
def store_inverse(self):
|
||||
self.decoder.store_inverse()
|
|
@ -1,3 +1,4 @@
|
|||
import re
|
||||
import torch
|
||||
import importlib
|
||||
import numpy as np
|
||||
|
@ -31,21 +32,22 @@ def split_dataset(items):
|
|||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
seq_range = torch.arange(max_len,
|
||||
dtype=sequence_length.dtype,
|
||||
device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
MyModel = getattr(MyModel, to_camel(c.model))
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
|
@ -97,6 +99,31 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
elif c.model.lower() == "glow_tts":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
hidden_channels=192,
|
||||
filter_channels=768,
|
||||
filter_channels_dp=256,
|
||||
out_channels=80,
|
||||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
encoder_type=c.encoder_type,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=1,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=num_speakers,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=2,
|
||||
sigmoid_scale=False,
|
||||
mean_only=True,
|
||||
hidden_channels_enc=192,
|
||||
hidden_channels_dec=192,
|
||||
use_encoder_prenet=True)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
|||
if use_cuda:
|
||||
model.cuda()
|
||||
# set model stepsize
|
||||
if 'r' in state.keys():
|
||||
if hasattr(model.decoder, 'r'):
|
||||
model.decoder.set_r(state['r'])
|
||||
return model, state
|
||||
|
||||
|
|
|
@ -46,16 +46,24 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
|||
|
||||
|
||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||
if CONFIG.use_gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
else:
|
||||
if 'tacotron' in CONFIG.model.lower():
|
||||
if CONFIG.use_gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
elif 'glow' in CONFIG.model.lower():
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
stop_tokens = None
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
|
@ -99,9 +107,9 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me
|
|||
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||
decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
stop_tokens = stop_tokens[0].cpu().numpy()
|
||||
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
|
|
|
@ -6,14 +6,20 @@ import matplotlib.pyplot as plt
|
|||
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
|
||||
|
||||
|
||||
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False):
|
||||
def plot_alignment(alignment,
|
||||
info=None,
|
||||
fig_size=(16, 10),
|
||||
title=None,
|
||||
output_fig=False):
|
||||
if isinstance(alignment, torch.Tensor):
|
||||
alignment_ = alignment.detach().cpu().numpy().squeeze()
|
||||
else:
|
||||
alignment_ = alignment
|
||||
fig, ax = plt.subplots(figsize=fig_size)
|
||||
im = ax.imshow(
|
||||
alignment_.T, aspect='auto', origin='lower', interpolation='none')
|
||||
im = ax.imshow(alignment_.T,
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if info is not None:
|
||||
|
@ -29,7 +35,10 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_f
|
|||
return fig
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
||||
def plot_spectrogram(spectrogram,
|
||||
ap=None,
|
||||
fig_size=(16, 10),
|
||||
output_fig=False):
|
||||
if isinstance(spectrogram, torch.Tensor):
|
||||
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
|
||||
else:
|
||||
|
@ -45,7 +54,17 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
|||
return fig
|
||||
|
||||
|
||||
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False):
|
||||
def visualize(alignment,
|
||||
postnet_output,
|
||||
text,
|
||||
hop_length,
|
||||
CONFIG,
|
||||
stop_tokens=None,
|
||||
decoder_output=None,
|
||||
output_path=None,
|
||||
figsize=(8, 24),
|
||||
output_fig=False):
|
||||
|
||||
if decoder_output is not None:
|
||||
num_plot = 4
|
||||
else:
|
||||
|
@ -60,18 +79,30 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
|
|||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||
# compute phoneme representation and back
|
||||
if CONFIG.use_phonemes:
|
||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
seq = phoneme_to_sequence(
|
||||
text, [CONFIG.text_cleaner],
|
||||
CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
text = sequence_to_phoneme(
|
||||
seq,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
print(text)
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
# plot stopnet predictions
|
||||
plt.subplot(num_plot, 1, 2)
|
||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
|
||||
if stop_tokens is not None:
|
||||
# plot stopnet predictions
|
||||
plt.subplot(num_plot, 1, 2)
|
||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
|
||||
# plot postnet spectrogram
|
||||
plt.subplot(num_plot, 1, 3)
|
||||
librosa.display.specshow(postnet_output.T, sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length, x_axis="time", y_axis="linear",
|
||||
librosa.display.specshow(postnet_output.T,
|
||||
sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio['mel_fmin'],
|
||||
fmax=CONFIG.audio['mel_fmax'])
|
||||
|
||||
|
@ -82,8 +113,11 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
|
|||
|
||||
if decoder_output is not None:
|
||||
plt.subplot(num_plot, 1, 4)
|
||||
librosa.display.specshow(decoder_output.T, sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length, x_axis="time", y_axis="linear",
|
||||
librosa.display.specshow(decoder_output.T,
|
||||
sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio['mel_fmin'],
|
||||
fmax=CONFIG.audio['mel_fmax'])
|
||||
plt.xlabel("Time", fontsize=label_fontsize)
|
||||
|
|
|
@ -174,8 +174,9 @@ class AudioProcessor(object):
|
|||
for key in stats_config.keys():
|
||||
if key in skip_parameters:
|
||||
continue
|
||||
assert stats_config[key] == self.__dict__[key],\
|
||||
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
if key != 'sample_rate':
|
||||
assert stats_config[key] == self.__dict__[key],\
|
||||
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
@ -322,6 +323,7 @@ class AudioProcessor(object):
|
|||
def load_wav(self, filename, sr=None):
|
||||
if sr is None:
|
||||
x, sr = sf.read(filename)
|
||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
||||
else:
|
||||
x, sr = librosa.load(filename, sr=sr)
|
||||
if self.do_trim_silence:
|
||||
|
@ -329,7 +331,6 @@ class AudioProcessor(object):
|
|||
x = self.trim_silence(x)
|
||||
except ValueError:
|
||||
print(f' [!] File cannot be trimmed for silence - {filename}')
|
||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
||||
if self.do_sound_norm:
|
||||
x = self.sound_norm(x)
|
||||
return x
|
||||
|
|
|
@ -5,8 +5,8 @@ import pickle as pickle_tts
|
|||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
def find_class(self, module, name):
|
||||
if 'mozilla_voice_tts' in module :
|
||||
module = module.replace('mozilla_voice_tts', 'TTS')
|
||||
if 'TTS' in module :
|
||||
module = module.replace('TTS', 'TTS')
|
||||
return super().find_class(module, name)
|
||||
|
||||
class AttrDict(dict):
|
||||
|
|
|
@ -22,7 +22,7 @@ class PQMF(torch.nn.Module):
|
|||
for k in range(N):
|
||||
constant_factor = (2 * k + 1) * (np.pi /
|
||||
(2 * N)) * (np.arange(taps + 1) -
|
||||
((taps - 1) / 2))
|
||||
((taps - 1) / 2)) # TODO: (taps - 1) -> taps
|
||||
phase = (-1)**k * np.pi / 4
|
||||
H[k] = 2 * QMF * np.cos(constant_factor + phase)
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class FullbandMelganGenerator(MelganGenerator):
|
||||
def __init__(self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=4):
|
||||
super(FullbandMelganGenerator,
|
||||
self).__init__(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features,
|
||||
(self.inference_padding, self.inference_padding),
|
||||
'replicate')
|
||||
return self.layers(cond_features)
|
|
@ -48,7 +48,6 @@ class ResidualStack(tf.keras.layers.Layer):
|
|||
]
|
||||
|
||||
def call(self, x):
|
||||
# breakpoint()
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
res = shortcut(x)
|
||||
for layer in block:
|
||||
|
|
|
@ -67,6 +67,15 @@ def setup_generator(c):
|
|||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'fullband_melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'parallel_wavegan_generator':
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
|
|
22
setup.py
22
setup.py
|
@ -5,11 +5,21 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import numpy
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import setuptools.command.develop
|
||||
import setuptools.command.build_py
|
||||
|
||||
# handle import if cython is not already installed.
|
||||
try:
|
||||
from Cython.Build import cythonize
|
||||
except ImportError:
|
||||
# create closure for deferred import
|
||||
def cythonize (*args, ** kwargs ):
|
||||
from Cython.Build import cythonize
|
||||
return cythonize(*args, ** kwargs)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
||||
parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file to embed in wheel.')
|
||||
|
@ -36,6 +46,16 @@ else:
|
|||
pass
|
||||
|
||||
|
||||
# Handle Cython code
|
||||
def find_pyx(path='.'):
|
||||
pyx_files = []
|
||||
for root, dirs, filenames in os.walk(path):
|
||||
for fname in filenames:
|
||||
if fname.endswith('.pyx'):
|
||||
pyx_files.append(os.path.join(root, fname))
|
||||
return pyx_files
|
||||
|
||||
|
||||
class build_py(setuptools.command.build_py.build_py): # pylint: disable=too-many-ancestors
|
||||
def run(self):
|
||||
self.create_version_file()
|
||||
|
@ -99,6 +119,8 @@ setup(
|
|||
'tts-server = TTS.server.server:main'
|
||||
]
|
||||
},
|
||||
include_dirs=[numpy.get_include()],
|
||||
ext_modules=cythonize(find_pyx(), language_level=3),
|
||||
packages=find_packages(include=['TTS*']),
|
||||
project_urls={
|
||||
'Documentation': 'https://github.com/mozilla/TTS/wiki',
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
import copy
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from tests import get_tests_input_path
|
||||
from torch import nn, optim
|
||||
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.models.glow_tts import GlowTts
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
#pylint: disable=unused-variable
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
c = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class GlowTTSTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
criterion = criterion = GlowTTSLoss()
|
||||
|
||||
# model to train
|
||||
model = GlowTts(
|
||||
num_chars=32,
|
||||
hidden_channels=128,
|
||||
filter_channels=32,
|
||||
filter_channels_dp=32,
|
||||
out_channels=80,
|
||||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=1,
|
||||
sigmoid_scale=False,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer"
|
||||
).to(device)
|
||||
|
||||
# reference model to compare model weights
|
||||
model_ref = GlowTts(
|
||||
num_chars=32,
|
||||
hidden_channels=128,
|
||||
filter_channels=32,
|
||||
filter_channels_dp=32,
|
||||
out_channels=80,
|
||||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=1,
|
||||
sigmoid_scale=False,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer"
|
||||
).to(device)
|
||||
|
||||
model.train()
|
||||
print(" > Num parameters for GlowTTS model:%s" %
|
||||
(count_parameters(model)))
|
||||
|
||||
# pass the state to ref model
|
||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||
optimizer.zero_grad()
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, input_lengths)
|
||||
loss = loss_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
assert (param != param_ref).any(
|
||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref)
|
||||
count += 1
|
Loading…
Reference in New Issue